20  Chapter 12: Advanced Data Structures - When Arrays and Trees Aren’t Enough

Answering “What’s the sum from index 47 to 891?” in Microseconds

“The difference between a good programmer and a great programmer is understanding data structures.” - Linus Torvalds

“The difference between a great programmer and a wizard is knowing the ADVANCED data structures.” - Every competitive programmer ever

20.1 12.1 Introduction: Beyond the Basics

You know arrays, linked lists, hash tables, and binary search trees. These are the bread and butter of programming. But what happens when you need to:

  • Find the sum of elements from index 1000 to 5000 in an array that changes frequently (try doing that in O(log n)!)
  • Support “undo” in your application without storing copies of everything (time travel, anyone?)
  • Represent a billion-element bit vector in just megabytes (not gigabytes!)
  • Process data faster than your CPU’s cache misses (algorithms that adapt to any cache size!)

These problems require advanced data structures—clever ways of organizing data that unlock operations you didn’t think were possible.

What makes a data structure “advanced”?

  1. Non-obvious structure: Not immediately intuitive, but brilliant once you understand it
  2. Extreme efficiency: Asymptotically or practically faster than standard approaches
  3. Novel capabilities: Enable operations that seem impossible
  4. Elegant design: Simple ideas that compose beautifully

In this chapter, we’ll explore data structures that competitive programmers swear by, that power database indices, that make version control systems possible, and that squeeze the last drop of performance from modern hardware.

Real-world impact: - Segment trees: Used in competitive programming, computational geometry, graphics - Fenwick trees: Database range queries, real-time analytics - Persistent structures: Git, functional programming languages, undo systems - Succinct structures: Bioinformatics (massive genomes), web search (graph compression) - Cache-oblivious algorithms: High-performance computing, databases

Let’s dive into the wonderful world of advanced data structures!

20.2 12.2 Segment Trees: Range Queries on Steroids

20.2.1 12.2.1 The Problem

You have an array of n elements and need to: 1. Query: Find the sum (or min, max, etc.) of elements from index L to R 2. Update: Change the value at a specific index

Naive solutions: - Query: O(n) by iterating through the range - Update: O(1) just set the value

OR

  • Use prefix sums for O(1) queries, but O(n) updates

Can we do better? Yes! O(log n) for both operations!

20.2.2 12.2.2 The Segment Tree Idea

A segment tree is a binary tree where: - Each leaf represents a single element - Each internal node represents a range (segment) of elements - Each node stores the aggregate (sum, min, max, etc.) of its segment

Array: [1, 3, 5, 7, 9, 11]

Segment Tree (storing sums):

                [0-5: 36]
               /           \
        [0-2: 9]            [3-5: 27]
        /      \            /        \
    [0-1: 4]  [2:5]    [3-4: 16]   [5:11]
    /    \              /     \
[0:1]   [1:3]       [3:7]    [4:9]

Key insight: Any range [L, R] can be broken into O(log n) nodes in the tree!

Example: Query sum from index 1 to 4: - Break into: [1:3] + [2:5] + [3-4:16] - Or even better: [1:3] + [2:5] + [3:7] + [4:9] - Total: 3 + 5 + 7 + 9 = 24

20.2.3 12.2.3 Building a Segment Tree

class SegmentTree:
    """
    Segment Tree for range queries and point updates.
    Supports any associative operation (sum, min, max, GCD, etc.).
    """
    
    def __init__(self, arr, operation='sum'):
        """
        Build segment tree from array.
        
        Args:
            arr: Input array
            operation: 'sum', 'min', 'max', 'gcd', etc.
        
        Time: O(n)
        Space: O(n)
        """
        self.n = len(arr)
        self.arr = arr.copy()
        
        # Tree needs 4*n space (worst case)
        self.tree = [0] * (4 * self.n)
        
        # Set operation and identity
        if operation == 'sum':
            self.op = lambda a, b: a + b
            self.identity = 0
        elif operation == 'min':
            self.op = min
            self.identity = float('inf')
        elif operation == 'max':
            self.op = max
            self.identity = float('-inf')
        elif operation == 'gcd':
            import math
            self.op = math.gcd
            self.identity = 0
        else:
            raise ValueError(f"Unknown operation: {operation}")
        
        # Build the tree
        self._build(0, 0, self.n - 1)
    
    def _build(self, node, start, end):
        """
        Build segment tree recursively.
        
        Args:
            node: Current node index in tree
            start, end: Range [start, end] this node represents
        """
        if start == end:
            # Leaf node
            self.tree[node] = self.arr[start]
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
            
            # Build left and right subtrees
            self._build(left_child, start, mid)
            self._build(right_child, mid + 1, end)
            
            # Internal node = combine children
            self.tree[node] = self.op(
                self.tree[left_child],
                self.tree[right_child]
            )
    
    def query(self, L, R):
        """
        Query range [L, R].
        
        Time: O(log n)
        """
        return self._query(0, 0, self.n - 1, L, R)
    
    def _query(self, node, start, end, L, R):
        """
        Recursive query helper.
        
        Cases:
        1. [start, end] completely outside [L, R] → return identity
        2. [start, end] completely inside [L, R] → return tree[node]
        3. [start, end] partially overlaps [L, R] → recurse on children
        """
        # No overlap
        if R < start or end < L:
            return self.identity
        
        # Complete overlap
        if L <= start and end <= R:
            return self.tree[node]
        
        # Partial overlap
        mid = (start + end) // 2
        left_child = 2 * node + 1
        right_child = 2 * node + 2
        
        left_result = self._query(left_child, start, mid, L, R)
        right_result = self._query(right_child, mid + 1, end, L, R)
        
        return self.op(left_result, right_result)
    
    def update(self, index, value):
        """
        Update element at index to value.
        
        Time: O(log n)
        """
        self.arr[index] = value
        self._update(0, 0, self.n - 1, index, value)
    
    def _update(self, node, start, end, index, value):
        """Recursive update helper."""
        if start == end:
            # Leaf node
            self.tree[node] = value
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
            
            if index <= mid:
                self._update(left_child, start, mid, index, value)
            else:
                self._update(right_child, mid + 1, end, index, value)
            
            # Update current node
            self.tree[node] = self.op(
                self.tree[left_child],
                self.tree[right_child]
            )
    
    def __str__(self):
        """String representation."""
        return f"SegmentTree({self.arr})"

20.2.4 12.2.4 Step-by-Step Example

Let’s build a segment tree for array [1, 3, 5, 7, 9, 11] and query sum from index 1 to 4:

def example_segment_tree_trace():
    """Trace segment tree operations step by step."""
    print("=== Segment Tree Example ===\n")
    
    arr = [1, 3, 5, 7, 9, 11]
    print(f"Array: {arr}")
    
    # Build tree
    st = SegmentTree(arr, operation='sum')
    print("\nSegment tree built!")
    
    # Visualize tree structure
    print("\nTree structure (node: [range] = value):")
    print("Level 0: [0-5] = 36")
    print("Level 1: [0-2] = 9,  [3-5] = 27")
    print("Level 2: [0-1] = 4,  [2] = 5,  [3-4] = 16,  [5] = 11")
    print("Level 3: [0] = 1,  [1] = 3,  [3] = 7,  [4] = 9")
    
    # Query
    L, R = 1, 4
    result = st.query(L, R)
    print(f"\nQuery sum({L}, {R}):")
    print(f"  Elements: {arr[L:R+1]}")
    print(f"  Sum: {sum(arr[L:R+1])}")
    print(f"  Segment tree result: {result}")
    print(f"  ✓ Correct!")
    
    # Update
    print(f"\nUpdate index 2 from {arr[2]} to 10")
    st.update(2, 10)
    
    # Query again
    result_after = st.query(L, R)
    print(f"\nQuery sum({L}, {R}) after update:")
    print(f"  New result: {result_after}")
    print(f"  Expected: {3 + 10 + 7 + 9} = 29")
    print(f"  ✓ Correct!")

if __name__ == "__main__":
    example_segment_tree_trace()

20.2.5 12.2.5 Lazy Propagation: Range Updates

What if we want to update an entire range [L, R] at once?

Naive approach: Update each element individually → O(n log n)

Lazy propagation: Defer updates until necessary → O(log n)!

Key idea: Mark nodes as “lazy” and propagate updates only when needed.

class LazySegmentTree:
    """
    Segment tree with lazy propagation for range updates.
    
    Supports:
    - Range query: O(log n)
    - Range update: O(log n)
    """
    
    def __init__(self, arr):
        """Initialize with array."""
        self.n = len(arr)
        self.arr = arr.copy()
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)  # Lazy propagation array
        self._build(0, 0, self.n - 1)
    
    def _build(self, node, start, end):
        """Build tree."""
        if start == end:
            self.tree[node] = self.arr[start]
        else:
            mid = (start + end) // 2
            left = 2 * node + 1
            right = 2 * node + 2
            self._build(left, start, mid)
            self._build(right, mid + 1, end)
            self.tree[node] = self.tree[left] + self.tree[right]
    
    def _push(self, node, start, end):
        """
        Push lazy value down to children.
        This is where the magic happens!
        """
        if self.lazy[node] != 0:
            # Apply lazy value to current node
            self.tree[node] += (end - start + 1) * self.lazy[node]
            
            # If not a leaf, propagate to children
            if start != end:
                left = 2 * node + 1
                right = 2 * node + 2
                self.lazy[left] += self.lazy[node]
                self.lazy[right] += self.lazy[node]
            
            # Clear lazy value
            self.lazy[node] = 0
    
    def update_range(self, L, R, value):
        """
        Add value to all elements in range [L, R].
        
        Time: O(log n)
        """
        self._update_range(0, 0, self.n - 1, L, R, value)
    
    def _update_range(self, node, start, end, L, R, value):
        """Recursive range update with lazy propagation."""
        # Push pending updates
        self._push(node, start, end)
        
        # No overlap
        if R < start or end < L:
            return
        
        # Complete overlap
        if L <= start and end <= R:
            # Mark as lazy and defer
            self.lazy[node] += value
            self._push(node, start, end)
            return
        
        # Partial overlap - recurse
        mid = (start + end) // 2
        left = 2 * node + 1
        right = 2 * node + 2
        
        self._update_range(left, start, mid, L, R, value)
        self._update_range(right, mid + 1, end, L, R, value)
        
        # Push children before reading
        self._push(left, start, mid)
        self._push(right, mid + 1, end)
        
        # Update current node
        self.tree[node] = self.tree[left] + self.tree[right]
    
    def query_range(self, L, R):
        """
        Query sum of range [L, R].
        
        Time: O(log n)
        """
        return self._query_range(0, 0, self.n - 1, L, R)
    
    def _query_range(self, node, start, end, L, R):
        """Recursive range query."""
        # Push pending updates
        self._push(node, start, end)
        
        # No overlap
        if R < start or end < L:
            return 0
        
        # Complete overlap
        if L <= start and end <= R:
            return self.tree[node]
        
        # Partial overlap
        mid = (start + end) // 2
        left = 2 * node + 1
        right = 2 * node + 2
        
        left_sum = self._query_range(left, start, mid, L, R)
        right_sum = self._query_range(right, mid + 1, end, L, R)
        
        return left_sum + right_sum

def example_lazy_propagation():
    """Demonstrate lazy propagation."""
    print("\n=== Lazy Propagation Example ===\n")
    
    arr = [1, 2, 3, 4, 5, 6, 7, 8]
    print(f"Array: {arr}")
    
    st = LazySegmentTree(arr)
    
    # Query initial sum
    L, R = 2, 5
    print(f"\nInitial sum({L}, {R}) = {st.query_range(L, R)}")
    print(f"  Expected: {sum(arr[L:R+1])} ✓")
    
    # Range update
    print(f"\nAdd 10 to range [1, 4]")
    st.update_range(1, 4, 10)
    
    # Query after update
    result = st.query_range(L, R)
    print(f"\nsum({L}, {R}) after update = {result}")
    
    # Manual calculation
    new_arr = arr.copy()
    for i in range(1, 5):
        new_arr[i] += 10
    expected = sum(new_arr[L:R+1])
    print(f"  Expected: {expected}")
    print(f"  ✓ Correct!" if result == expected else "  ✗ Wrong!")

if __name__ == "__main__":
    example_lazy_propagation()

20.2.6 12.2.6 Applications and Variants

Common applications: 1. Range sum/min/max queries with updates 2. Interval scheduling problems 3. Computational geometry (sweep line algorithms) 4. Graphics (collision detection)

Variants:

class SegmentTreeVariants:
    """Various segment tree applications."""
    
    @staticmethod
    def range_minimum_query(arr):
        """
        Build RMQ segment tree.
        Query min in O(log n), update in O(log n).
        """
        return SegmentTree(arr, operation='min')
    
    @staticmethod
    def range_gcd_query(arr):
        """
        Query GCD of range in O(log n).
        Useful for: finding common factors
        """
        return SegmentTree(arr, operation='gcd')
    
    @staticmethod
    def count_elements_less_than(arr, threshold):
        """
        Count elements < threshold in range [L, R].
        Uses segment tree with custom operation.
        """
        # Each node stores count of elements < threshold
        # Can be extended to support dynamic thresholds
        pass

def example_segment_tree_applications():
    """Demonstrate various segment tree applications."""
    print("\n=== Segment Tree Applications ===\n")
    
    arr = [12, 7, 5, 15, 3, 9, 11, 18]
    
    # Range Minimum Query
    print("1. Range Minimum Query:")
    rmq = SegmentTree(arr, operation='min')
    L, R = 2, 6
    result = rmq.query(L, R)
    print(f"   min({L}, {R}) = {result}")
    print(f"   Elements: {arr[L:R+1]}")
    print(f"   Expected: {min(arr[L:R+1])}\n")
    
    # Range Maximum Query
    print("2. Range Maximum Query:")
    rmaxq = SegmentTree(arr, operation='max')
    result = rmaxq.query(L, R)
    print(f"   max({L}, {R}) = {result}")
    print(f"   Expected: {max(arr[L:R+1])}\n")
    
    # Range GCD Query
    print("3. Range GCD Query:")
    import math
    gcd_tree = SegmentTree(arr, operation='gcd')
    result = gcd_tree.query(L, R)
    print(f"   gcd({L}, {R}) = {result}")
    expected_gcd = arr[L]
    for i in range(L+1, R+1):
        expected_gcd = math.gcd(expected_gcd, arr[i])
    print(f"   Expected: {expected_gcd} ✓")

if __name__ == "__main__":
    example_segment_tree_applications()

20.3 12.3 Fenwick Trees: Elegant Simplicity

20.3.1 12.3.1 The Inspiration

Segment trees are powerful but… they’re a bit heavy. 4n space, lots of pointer chasing, somewhat complex code.

Fenwick Trees (also called Binary Indexed Trees or BIT) solve the same problem with: - ✅ Much simpler code (~10 lines!) - ✅ Better cache performance - ✅ Only n space (vs 4n) - ✅ Same O(log n) operations

The catch? Less flexible than segment trees (mainly for cumulative operations like sum).

20.3.2 12.3.2 The Brilliant Idea

The key insight: represent cumulative sums using binary representation of indices!

Example: Array of size 8 (indices 1-8 in 1-indexed)

Index (binary):  001  010  011  100  101  110  111  1000
BIT stores:      [1]  [1-2] [3] [1-4] [5] [5-6] [7] [1-8]

BIT[1] = arr[1]
BIT[2] = arr[1] + arr[2]
BIT[3] = arr[3]
BIT[4] = arr[1] + arr[2] + arr[3] + arr[4]
...

Pattern: BIT[i] stores sum of 2^k elements ending at i, where k = position of rightmost set bit in i.

Why is this brilliant?

To compute prefix sum up to index i: 1. Add BIT[i] 2. Remove rightmost set bit from i 3. Add BIT[new i] 4. Repeat until i = 0

Example: Sum up to index 7 (binary: 111) - Add BIT[111 = 7] (covers [7]) - Remove rightmost bit: 110 (6) - Add BIT[110 = 6] (covers [5-6]) - Remove rightmost bit: 100 (4) - Add BIT[100 = 4] (covers [1-4]) - Done! Sum = BIT[7] + BIT[6] + BIT[4]

20.3.3 12.3.3 Implementation

class FenwickTree:
    """
    Fenwick Tree (Binary Indexed Tree).
    
    Supports:
    - Prefix sum query: O(log n)
    - Point update: O(log n)
    - Range query: O(log n)
    
    Space: O(n) (much better than segment tree!)
    """
    
    def __init__(self, n):
        """
        Initialize Fenwick tree.
        
        Args:
            n: Size of array (1-indexed internally)
        """
        self.n = n
        self.tree = [0] * (n + 1)  # 1-indexed
    
    @classmethod
    def from_array(cls, arr):
        """Build from array (0-indexed)."""
        ft = cls(len(arr))
        for i, val in enumerate(arr):
            ft.update(i, val)
        return ft
    
    def update(self, index, delta):
        """
        Add delta to element at index.
        
        Args:
            index: 0-indexed position
            delta: Value to add
        
        Time: O(log n)
        """
        index += 1  # Convert to 1-indexed
        
        while index <= self.n:
            self.tree[index] += delta
            # Move to next index that needs updating
            # Add rightmost set bit
            index += index & (-index)
    
    def prefix_sum(self, index):
        """
        Get sum of elements from 0 to index (inclusive).
        
        Args:
            index: 0-indexed position
        
        Returns:
            Sum of arr[0:index+1]
        
        Time: O(log n)
        """
        index += 1  # Convert to 1-indexed
        total = 0
        
        while index > 0:
            total += self.tree[index]
            # Remove rightmost set bit
            index -= index & (-index)
        
        return total
    
    def range_sum(self, left, right):
        """
        Get sum of elements from left to right (inclusive).
        
        Time: O(log n)
        """
        if left > 0:
            return self.prefix_sum(right) - self.prefix_sum(left - 1)
        else:
            return self.prefix_sum(right)
    
    def set(self, index, value):
        """
        Set element at index to value.
        
        Time: O(log n)
        """
        current = self.range_sum(index, index)
        delta = value - current
        self.update(index, delta)
    
    def __str__(self):
        """String representation."""
        return f"FenwickTree(n={self.n})"

def visualize_fenwick_tree():
    """Visualize how Fenwick tree works."""
    print("=== Fenwick Tree Visualization ===\n")
    
    arr = [1, 2, 3, 4, 5, 6, 7, 8]
    print(f"Array: {arr}\n")
    
    ft = FenwickTree.from_array(arr)
    
    print("Fenwick Tree structure:")
    print("Index (1-indexed) | Binary | Covers | Value")
    print("-" * 50)
    
    for i in range(1, len(arr) + 1):
        binary = format(i, '03b')
        # Find rightmost set bit
        rightmost = i & (-i)
        start = i - rightmost + 1
        print(f"{i:8d} | {binary:6s} | [{start:2d}-{i:2d}] | {ft.tree[i]:5.0f}")
    
    print("\n" + "=" * 50)
    print("Queries:")
    print("=" * 50)
    
    # Demonstrate prefix sum
    for i in [3, 5, 7]:
        result = ft.prefix_sum(i)
        expected = sum(arr[:i+1])
        print(f"\nPrefix sum to index {i}:")
        print(f"  Result: {result}")
        print(f"  Expected: {expected}")
        
        # Show which nodes were accessed
        index = i + 1
        nodes = []
        while index > 0:
            nodes.append(index)
            index -= index & (-index)
        print(f"  Nodes accessed: {nodes}")
    
    # Demonstrate update
    print("\n" + "=" * 50)
    print("Update index 3 by +10")
    print("=" * 50)
    
    index = 3
    delta = 10
    
    # Show which nodes get updated
    idx = index + 1
    nodes = []
    temp_idx = idx
    while temp_idx <= ft.n:
        nodes.append(temp_idx)
        temp_idx += temp_idx & (-temp_idx)
    
    print(f"Nodes to update: {nodes}")
    
    ft.update(index, delta)
    
    result = ft.prefix_sum(7)
    expected = sum(arr[:8]) + delta
    print(f"\nAfter update, prefix_sum(7) = {result}")
    print(f"Expected: {expected}")

if __name__ == "__main__":
    visualize_fenwick_tree()

20.3.4 12.3.4 The Magic of Bit Manipulation

The core operations use a beautiful bit trick:

def explain_bit_tricks():
    """Explain the bit manipulation tricks in Fenwick trees."""
    print("=== Fenwick Tree Bit Tricks ===\n")
    
    print("Key operation: index & (-index)")
    print("This isolates the rightmost set bit!\n")
    
    examples = [1, 2, 3, 4, 5, 6, 7, 8, 12]
    
    print(f"{'Index':>6} {'Binary':>8} {'-Index':>8} {'& Result':>10} {'Range Size':>12}")
    print("-" * 56)
    
    for i in examples:
        binary = format(i, '08b')
        neg_binary = format(-i & 0xFF, '08b')
        result = i & (-i)
        
        print(f"{i:6d} {binary:>8s} {neg_binary:>8s} {result:10d} {result:12d}")
    
    print("\nExplanation:")
    print("  -index in two's complement flips all bits and adds 1")
    print("  ANDing with original gives rightmost set bit")
    print("  This tells us how many elements this node covers!")
    
    print("\n" + "=" * 56)
    print("Update operation: index += index & (-index)")
    print("Moves to next index that needs updating\n")
    
    index = 5
    print(f"Start at index {index} ({format(index, '08b')})")
    for step in range(4):
        if index > 16:
            break
        print(f"  Step {step + 1}: index = {index:2d} ({format(index, '08b')})")
        index += index & (-index)
    
    print("\n" + "=" * 56)
    print("Query operation: index -= index & (-index)")
    print("Moves to previous relevant node\n")
    
    index = 7
    print(f"Start at index {index} ({format(index, '08b')})")
    step = 0
    while index > 0:
        print(f"  Step {step + 1}: index = {index:2d} ({format(index, '08b')})")
        index -= index & (-index)
        step += 1

if __name__ == "__main__":
    explain_bit_tricks()

20.3.5 12.3.5 2D Fenwick Tree

Fenwick trees extend beautifully to 2D!

class FenwickTree2D:
    """
    2D Fenwick Tree for rectangle sum queries.
    
    Useful for: image processing, computational geometry
    """
    
    def __init__(self, rows, cols):
        """Initialize 2D Fenwick tree."""
        self.rows = rows
        self.cols = cols
        self.tree = [[0] * (cols + 1) for _ in range(rows + 1)]
    
    def update(self, row, col, delta):
        """
        Add delta to cell (row, col).
        
        Time: O(log n × log m)
        """
        row += 1  # Convert to 1-indexed
        col += 1
        
        r = row
        while r <= self.rows:
            c = col
            while c <= self.cols:
                self.tree[r][c] += delta
                c += c & (-c)
            r += r & (-r)
    
    def prefix_sum(self, row, col):
        """
        Sum of rectangle from (0,0) to (row, col).
        
        Time: O(log n × log m)
        """
        row += 1  # Convert to 1-indexed
        col += 1
        
        total = 0
        r = row
        while r > 0:
            c = col
            while c > 0:
                total += self.tree[r][c]
                c -= c & (-c)
            r -= r & (-r)
        
        return total
    
    def rectangle_sum(self, r1, c1, r2, c2):
        """
        Sum of rectangle from (r1,c1) to (r2,c2).
        
        Uses inclusion-exclusion:
        sum(r1,c1,r2,c2) = sum(0,0,r2,c2) - sum(0,0,r1-1,c2)
                          - sum(0,0,r2,c1-1) + sum(0,0,r1-1,c1-1)
        
        Time: O(log n × log m)
        """
        total = self.prefix_sum(r2, c2)
        
        if r1 > 0:
            total -= self.prefix_sum(r1 - 1, c2)
        if c1 > 0:
            total -= self.prefix_sum(r2, c1 - 1)
        if r1 > 0 and c1 > 0:
            total += self.prefix_sum(r1 - 1, c1 - 1)
        
        return total

def example_2d_fenwick():
    """Demonstrate 2D Fenwick tree."""
    print("\n=== 2D Fenwick Tree ===\n")
    
    # Create 4x4 matrix
    matrix = [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ]
    
    print("Matrix:")
    for row in matrix:
        print("  ", row)
    
    # Build 2D Fenwick tree
    ft2d = FenwickTree2D(4, 4)
    for i in range(4):
        for j in range(4):
            ft2d.update(i, j, matrix[i][j])
    
    # Query rectangle sum
    r1, c1, r2, c2 = 1, 1, 2, 2
    result = ft2d.rectangle_sum(r1, c1, r2, c2)
    
    print(f"\nRectangle sum from ({r1},{c1}) to ({r2},{c2}):")
    print(f"  Elements: {[matrix[i][c1:c2+1] for i in range(r1, r2+1)]}")
    
    expected = sum(matrix[i][j] for i in range(r1, r2+1) for j in range(c1, c2+1))
    print(f"  Expected: {expected}")
    print(f"  Result: {result}")
    print(f"  ✓ Correct!" if result == expected else "  ✗ Wrong!")

if __name__ == "__main__":
    example_2d_fenwick()

20.3.6 12.3.6 Fenwick Tree vs Segment Tree

def compare_fenwick_vs_segment():
    """Compare Fenwick tree and Segment tree."""
    import time
    import random
    
    print("\n=== Fenwick vs Segment Tree ===\n")
    
    n = 100000
    arr = [random.randint(1, 100) for _ in range(n)]
    
    print(f"Array size: {n:,}")
    print(f"Number of operations: {n // 10:,}\n")
    
    # Build both structures
    print("Building data structures...")
    
    start = time.time()
    fenwick = FenwickTree.from_array(arr)
    fenwick_build_time = time.time() - start
    
    start = time.time()
    segment = SegmentTree(arr, operation='sum')
    segment_build_time = time.time() - start
    
    print(f"  Fenwick build: {fenwick_build_time:.4f}s")
    print(f"  Segment build: {segment_build_time:.4f}s")
    
    # Benchmark queries
    num_queries = n // 10
    queries = [(random.randint(0, n-100), random.randint(0, n-1)) 
               for _ in range(num_queries)]
    
    print(f"\nPerforming {num_queries:,} range queries...")
    
    start = time.time()
    for l, r in queries:
        if l > r:
            l, r = r, l
        _ = fenwick.range_sum(l, r)
    fenwick_query_time = time.time() - start
    
    start = time.time()
    for l, r in queries:
        if l > r:
            l, r = r, l
        _ = segment.query(l, r)
    segment_query_time = time.time() - start
    
    print(f"  Fenwick queries: {fenwick_query_time:.4f}s")
    print(f"  Segment queries: {segment_query_time:.4f}s")
    print(f"  Speedup: {segment_query_time/fenwick_query_time:.2f}x")
    
    # Benchmark updates
    num_updates = n // 10
    updates = [(random.randint(0, n-1), random.randint(-10, 10)) 
               for _ in range(num_updates)]
    
    print(f"\nPerforming {num_updates:,} updates...")
    
    start = time.time()
    for idx, delta in updates:
        fenwick.update(idx, delta)
    fenwick_update_time = time.time() - start
    
    start = time.time()
    for idx, val in updates:
        segment.update(idx, segment.arr[idx] + val)
    segment_update_time = time.time() - start
    
    print(f"  Fenwick updates: {fenwick_update_time:.4f}s")
    print(f"  Segment updates: {segment_update_time:.4f}s")
    print(f"  Speedup: {segment_update_time/fenwick_update_time:.2f}x")
    
    print("\n" + "=" * 60)
    print("Summary:")
    print("  Fenwick Tree: Faster, simpler, less memory")
    print("  Segment Tree: More flexible, supports more operations")

if __name__ == "__main__":
    compare_fenwick_vs_segment()

20.4 12.4 Persistent Data Structures: Time Travel!

20.4.1 12.4.1 The Problem

Imagine you’re building a text editor with unlimited undo/redo. How do you store every version efficiently?

Naive approach: Copy entire data structure for each version → O(n) space per version!

Persistent data structures: Share common parts between versions → O(1) or O(log n) space per version!

Key idea: When you “modify” a persistent structure, you create a new version that shares most data with the old version.

20.4.2 12.4.2 Persistent Array

class PersistentArray:
    """
    Persistent array using path copying.
    
    Each update creates a new version in O(log n) time and space.
    All versions remain accessible!
    """
    
    class Node:
        """Tree node representing array segment."""
        def __init__(self, value=None, left=None, right=None):
            self.value = value
            self.left = left
            self.right = right
    
    def __init__(self, arr=None, size=None):
        """
        Initialize persistent array.
        
        Args:
            arr: Initial array (optional)
            size: Size if creating empty array
        """
        if arr is not None:
            self.size = len(arr)
            self.root = self._build(arr, 0, len(arr) - 1)
        else:
            self.size = size or 0
            self.root = self._build_empty(0, size - 1) if size else None
    
    def _build(self, arr, left, right):
        """Build tree from array."""
        if left == right:
            return self.Node(value=arr[left])
        
        mid = (left + right) // 2
        return self.Node(
            left=self._build(arr, left, mid),
            right=self._build(arr, mid + 1, right)
        )
    
    def _build_empty(self, left, right):
        """Build tree with zeros."""
        if left == right:
            return self.Node(value=0)
        
        mid = (left + right) // 2
        return self.Node(
            left=self._build_empty(left, mid),
            right=self._build_empty(mid + 1, right)
        )
    
    def get(self, index):
        """
        Get value at index.
        
        Time: O(log n)
        """
        return self._get(self.root, 0, self.size - 1, index)
    
    def _get(self, node, left, right, index):
        """Recursive get helper."""
        if left == right:
            return node.value
        
        mid = (left + right) // 2
        if index <= mid:
            return self._get(node.left, left, mid, index)
        else:
            return self._get(node.right, mid + 1, right, index)
    
    def set(self, index, value):
        """
        Create new version with updated value.
        
        Returns: New PersistentArray (old one unchanged!)
        Time: O(log n)
        Space: O(log n) new nodes
        """
        new_arr = PersistentArray(size=self.size)
        new_arr.root = self._set(self.root, 0, self.size - 1, index, value)
        return new_arr
    
    def _set(self, node, left, right, index, value):
        """
        Recursive set helper - creates new nodes on path.
        
        This is PATH COPYING: we only copy nodes on the path
        from root to the updated leaf!
        """
        if left == right:
            return self.Node(value=value)
        
        mid = (left + right) // 2
        
        if index <= mid:
            # Update left side, copy right side
            return self.Node(
                left=self._set(node.left, left, mid, index, value),
                right=node.right  # SHARE this subtree!
            )
        else:
            # Update right side, copy left side
            return self.Node(
                left=node.left,  # SHARE this subtree!
                right=self._set(node.right, mid + 1, right, index, value)
            )
    
    def to_list(self):
        """Convert to regular list (for debugging)."""
        result = []
        self._to_list(self.root, result)
        return result
    
    def _to_list(self, node, result):
        """Recursive conversion to list."""
        if node is None:
            return
        if node.left is None and node.right is None:
            result.append(node.value)
        else:
            self._to_list(node.left, result)
            self._to_list(node.right, result)

def example_persistent_array():
    """Demonstrate persistent array."""
    print("=== Persistent Array ===\n")
    
    # Create initial version
    v0 = PersistentArray([1, 2, 3, 4, 5])
    print(f"Version 0: {v0.to_list()}")
    
    # Create version 1: change index 2
    v1 = v0.set(2, 10)
    print(f"Version 1: {v1.to_list()}")
    print(f"Version 0: {v0.to_list()} (unchanged!)")
    
    # Create version 2: change index 4
    v2 = v1.set(4, 20)
    print(f"Version 2: {v2.to_list()}")
    
    # Create version 3: change index 0
    v3 = v2.set(0, 30)
    print(f"Version 3: {v3.to_list()}")
    
    print("\n" + "=" * 50)
    print("All versions still accessible:")
    print(f"  v0: {v0.to_list()}")
    print(f"  v1: {v1.to_list()}")
    print(f"  v2: {v2.to_list()}")
    print(f"  v3: {v3.to_list()}")
    
    print("\n" + "=" * 50)
    print("Space efficiency:")
    print(f"  Array size: {v0.size}")
    print(f"  Number of versions: 4")
    print(f"  Naive space: {v0.size * 4} elements")
    print(f"  Actual space: ~{v0.size + 3 * int(np.log2(v0.size))} elements")
    print(f"  Savings: Shared {v0.size * 3} elements!")

if __name__ == "__main__":
    example_persistent_array()

20.4.3 12.4.3 Persistent Segment Tree

class PersistentSegmentTree:
    """
    Persistent segment tree for range queries with history.
    
    Each update creates new version in O(log n) time/space.
    Perfect for: time-travel queries, version control
    """
    
    class Node:
        """Segment tree node."""
        def __init__(self, value=0, left=None, right=None):
            self.value = value
            self.left = left
            self.right = right
    
    def __init__(self, arr=None, n=None):
        """Initialize from array or size."""
        if arr is not None:
            self.n = len(arr)
            self.root = self._build(arr, 0, self.n - 1)
        else:
            self.n = n
            self.root = self._build_empty(0, n - 1)
    
    def _build(self, arr, left, right):
        """Build initial tree."""
        if left == right:
            return self.Node(value=arr[left])
        
        mid = (left + right) // 2
        left_child = self._build(arr, left, mid)
        right_child = self._build(arr, mid + 1, right)
        
        return self.Node(
            value=left_child.value + right_child.value,
            left=left_child,
            right=right_child
        )
    
    def _build_empty(self, left, right):
        """Build empty tree."""
        if left == right:
            return self.Node(value=0)
        
        mid = (left + right) // 2
        return self.Node(
            left=self._build_empty(left, mid),
            right=self._build_empty(mid + 1, right)
        )
    
    def query(self, L, R, root=None):
        """
        Query sum in range [L, R].
        
        Args:
            L, R: Range to query
            root: Specific version's root (default: current)
        
        Time: O(log n)
        """
        if root is None:
            root = self.root
        return self._query(root, 0, self.n - 1, L, R)
    
    def _query(self, node, left, right, L, R):
        """Recursive query."""
        if R < left or right < L:
            return 0
        if L <= left and right <= R:
            return node.value
        
        mid = (left + right) // 2
        return (self._query(node.left, left, mid, L, R) +
                self._query(node.right, mid + 1, right, L, R))
    
    def update(self, index, value):
        """
        Create new version with updated value.
        
        Returns: New PersistentSegmentTree
        Time: O(log n)
        Space: O(log n) new nodes
        """
        new_tree = PersistentSegmentTree(n=self.n)
        new_tree.root = self._update(self.root, 0, self.n - 1, index, value)
        return new_tree
    
    def _update(self, node, left, right, index, value):
        """Recursive update with path copying."""
        if left == right:
            return self.Node(value=value)
        
        mid = (left + right) // 2
        
        if index <= mid:
            new_left = self._update(node.left, left, mid, index, value)
            new_node = self.Node(
                value=new_left.value + node.right.value,
                left=new_left,
                right=node.right  # Share!
            )
        else:
            new_right = self._update(node.right, mid + 1, right, index, value)
            new_node = self.Node(
                value=node.left.value + new_right.value,
                left=node.left,  # Share!
                right=new_right
            )
        
        return new_node

def example_version_control():
    """Simulate version control system using persistent structures."""
    print("\n=== Version Control with Persistent Segment Tree ===\n")
    
    # Initial code: character frequencies
    code = [5, 3, 7, 2, 9, 1, 4, 8]
    print(f"Initial code: {code}")
    
    versions = {}
    versions[0] = PersistentSegmentTree(code)
    
    print("\nVersion history:")
    print(f"  v0: Initial commit")
    
    # Version 1: Change index 2
    versions[1] = versions[0].update(2, 15)
    print(f"  v1: Update index 2: 7 → 15")
    
    # Version 2: Change index 5
    versions[2] = versions[1].update(5, 10)
    print(f"  v2: Update index 5: 1 → 10")
    
    # Version 3: Branch from v1!
    versions[3] = versions[1].update(4, 20)
    print(f"  v3: Branch from v1, update index 4: 9 → 20")
    
    # Query different versions
    print("\n" + "=" * 50)
    print("Time-travel queries:")
    print("=" * 50)
    
    L, R = 1, 5
    for v in [0, 1, 2, 3]:
        result = versions[v].query(L, R)
        print(f"  v{v}: sum({L}, {R}) = {result}")
    
    print("\n" + "=" * 50)
    print("Space efficiency:")
    print(f"  Array size: {len(code)}")
    print(f"  Versions: 4")
    print(f"  Naive space: {len(code) * 4} elements")
    print(f"  Persistent space: ~{len(code) + 3 * int(np.log2(len(code))) * 2} nodes")

if __name__ == "__main__":
    example_version_control()

20.4.4 12.4.4 Applications of Persistence

class ApplicationExamples:
    """Real-world applications of persistent data structures."""
    
    @staticmethod
    def text_editor_with_history():
        """
        Text editor with unlimited undo/redo.
        Each edit creates new version in O(log n).
        """
        class TextEditor:
            def __init__(self, text=""):
                self.versions = [PersistentArray(list(text))]
                self.current_version = 0
            
            def insert(self, pos, char):
                """Insert character at position."""
                # For simplicity, implementing as set
                # Real version would handle size changes
                new_version = self.versions[self.current_version].set(pos, char)
                self.versions.append(new_version)
                self.current_version += 1
            
            def undo(self):
                """Undo last edit."""
                if self.current_version > 0:
                    self.current_version -= 1
            
            def redo(self):
                """Redo edit."""
                if self.current_version < len(self.versions) - 1:
                    self.current_version += 1
            
            def get_text(self):
                """Get current text."""
                return ''.join(map(str, self.versions[self.current_version].to_list()))
        
        return TextEditor
    
    @staticmethod
    def database_with_snapshots():
        """
        Database that supports querying past states.
        """
        class Database:
            def __init__(self, size):
                self.versions = {}
                self.versions[0] = PersistentSegmentTree(n=size)
                self.current_time = 0
            
            def set(self, index, value):
                """Update value at current time."""
                self.current_time += 1
                self.versions[self.current_time] = \
                    self.versions[self.current_time - 1].update(index, value)
            
            def query(self, L, R, time=None):
                """Query range at specific time."""
                if time is None:
                    time = self.current_time
                return self.versions[time].query(L, R)
            
            def snapshot(self):
                """Create snapshot (free with persistence!)."""
                return self.current_time
        
        return Database

def example_text_editor():
    """Demonstrate text editor with unlimited undo."""
    print("\n=== Text Editor with History ===\n")
    
    TextEditor = ApplicationExamples.text_editor_with_history()
    editor = TextEditor("hello")
    
    print(f"Initial: '{editor.get_text()}'")
    
    editor.insert(0, 'H')
    print(f"After insert: '{editor.get_text()}'")
    
    editor.undo()
    print(f"After undo: '{editor.get_text()}'")
    
    editor.redo()
    print(f"After redo: '{editor.get_text()}'")

if __name__ == "__main__":
    example_text_editor()

20.5 12.5 Succinct Data Structures: Data Compression on Steroids

20.5.1 12.5.1 The Problem

You have a billion-element bit vector (125 MB). You need to support: - access(i): Get bit at position i - rank(i): Count 1s up to position i - select(k): Find position of k-th 1

Naive: Store array → 125 MB, but rank/select are slow (O(n))

Succinct: Store in just n + o(n) bits (barely more than the data itself!) with O(1) operations!

20.5.2 12.5.2 Succinct Bit Vector

class SuccinctBitVector:
    """
    Succinct bit vector with rank/select support.
    
    Space: n + O(n/log n) bits
    Operations: O(1) time
    
    Uses two-level indexing:
    - Superblocks: every log²(n) bits
    - Blocks: every log(n) bits
    """
    
    def __init__(self, bits):
        """
        Build succinct bit vector.
        
        Args:
            bits: List or string of 0s and 1s
        """
        if isinstance(bits, str):
            self.bits = [int(b) for b in bits]
        else:
            self.bits = bits
        
        self.n = len(self.bits)
        
        # Choose block sizes
        import math
        self.log_n = max(1, int(math.log2(self.n + 1)))
        self.block_size = max(1, self.log_n)
        self.superblock_size = max(1, self.log_n * self.log_n)
        
        # Build rank structures
        self._build_rank_structures()
    
    def _build_rank_structures(self):
        """Build auxiliary structures for fast rank queries."""
        # Superblock ranks: cumulative count at each superblock
        self.superblock_ranks = []
        
        # Block ranks: count within superblock
        self.block_ranks = []
        
        cumulative = 0
        
        for i in range(0, self.n, self.superblock_size):
            self.superblock_ranks.append(cumulative)
            
            # Process blocks within this superblock
            superblock_count = 0
            for j in range(i, min(i + self.superblock_size, self.n), self.block_size):
                self.block_ranks.append(superblock_count)
                
                # Count bits in this block
                block_count = sum(self.bits[j:min(j + self.block_size, self.n)])
                superblock_count += block_count
            
            cumulative += superblock_count
    
    def access(self, i):
        """
        Get bit at position i.
        
        Time: O(1)
        """
        return self.bits[i]
    
    def rank(self, i):
        """
        Count number of 1s in bits[0:i+1].
        
        Time: O(1) (with precomputed tables)
        For simplicity, this is O(log n)
        """
        if i < 0:
            return 0
        if i >= self.n:
            i = self.n - 1
        
        # Find superblock
        superblock_idx = i // self.superblock_size
        superblock_rank = self.superblock_ranks[superblock_idx] if superblock_idx < len(self.superblock_ranks) else 0
        
        # Find block within superblock
        block_idx = i // self.block_size
        block_rank = self.block_ranks[block_idx] if block_idx < len(self.block_ranks) else 0
        
        # Count remaining bits
        block_start = (i // self.block_size) * self.block_size
        remaining = sum(self.bits[block_start:i+1])
        
        return superblock_rank + block_rank + remaining
    
    def select(self, k):
        """
        Find position of k-th 1 (1-indexed).
        
        Time: O(log n) with binary search
        Can be made O(1) with more space
        """
        if k <= 0 or k > sum(self.bits):
            return -1
        
        # Binary search
        left, right = 0, self.n - 1
        
        while left < right:
            mid = (left + right) // 2
            rank_mid = self.rank(mid)
            
            if rank_mid < k:
                left = mid + 1
            else:
                right = mid
        
        return left if self.rank(left) >= k else -1
    
    def __str__(self):
        """String representation."""
        return ''.join(map(str, self.bits))

def example_succinct_bit_vector():
    """Demonstrate succinct bit vector."""
    print("=== Succinct Bit Vector ===\n")
    
    # Create bit vector
    bits = "11010110101100001110"
    bv = SuccinctBitVector(bits)
    
    print(f"Bit vector: {bits}")
    print(f"Length: {len(bits)} bits\n")
    
    # Test access
    print("Access operations:")
    for i in [0, 5, 10, 15]:
        print(f"  bit[{i}] = {bv.access(i)}")
    
    # Test rank
    print("\nRank operations (count 1s up to position):")
    for i in [3, 7, 11, 15, 19]:
        rank = bv.rank(i)
        expected = sum(int(b) for b in bits[:i+1])
        print(f"  rank({i}) = {rank} (expected {expected}) {'✓' if rank == expected else '✗'}")
    
    # Test select
    print("\nSelect operations (find k-th 1):")
    for k in [1, 3, 5, 8]:
        pos = bv.select(k)
        print(f"  select({k}) = position {pos}")
        if pos >= 0:
            print(f"    Verification: bit[{pos}] = {bv.access(pos)}, rank({pos}) = {bv.rank(pos)}")
    
    # Space analysis
    print("\n" + "=" * 50)
    print("Space analysis:")
    print(f"  Data: {len(bits)} bits")
    print(f"  Superblocks: {len(bv.superblock_ranks)} integers")
    print(f"  Blocks: {len(bv.block_ranks)} integers")
    print(f"  Overhead: ~{(len(bv.superblock_ranks) + len(bv.block_ranks)) * 32 / len(bits):.1f}%")

if __name__ == "__main__":
    example_succinct_bit_vector()

20.5.3 12.5.3 Wavelet Tree

A wavelet tree is a succinct structure for storing sequences that supports: - access(i): Get element at position i - rank(c, i): Count occurrences of c up to position i - select(c, k): Find position of k-th occurrence of c

All in O(log σ) time, where σ = alphabet size!

class WaveletTree:
    """
    Wavelet tree for sequence queries.
    
    Supports rank/select on general sequences (not just bits).
    Space: n log σ bits
    Time: O(log σ) per operation
    """
    
    class Node:
        """Wavelet tree node."""
        def __init__(self, alphabet, sequence=None):
            self.alphabet = sorted(set(alphabet))
            self.bitmap = None
            self.left = None
            self.right = None
            
            if sequence and len(self.alphabet) > 1:
                self._build(sequence)
        
        def _build(self, sequence):
            """Build node recursively."""
            if len(self.alphabet) == 1:
                return
            
            # Split alphabet
            mid = len(self.alphabet) // 2
            left_alphabet = self.alphabet[:mid]
            right_alphabet = self.alphabet[mid:]
            
            # Create bitmap: 0 if element in left half, 1 if in right half
            self.bitmap = SuccinctBitVector([
                0 if elem in left_alphabet else 1
                for elem in sequence
            ])
            
            # Partition sequence
            left_seq = [e for e in sequence if e in left_alphabet]
            right_seq = [e for e in sequence if e in right_alphabet]
            
            # Build children
            if left_seq:
                self.left = WaveletTree.Node(left_alphabet, left_seq)
            if right_seq:
                self.right = WaveletTree.Node(right_alphabet, right_seq)
    
    def __init__(self, sequence):
        """Build wavelet tree from sequence."""
        self.sequence = list(sequence)
        self.n = len(sequence)
        alphabet = sorted(set(sequence))
        self.root = self.Node(alphabet, sequence)
    
    def access(self, i):
        """
        Get element at position i.
        
        Time: O(log σ)
        """
        node = self.root
        
        while node and len(node.alphabet) > 1:
            bit = node.bitmap.access(i)
            if bit == 0:
                # Go left
                i = node.bitmap.rank(i) - 1 if i > 0 else 0
                node = node.left
            else:
                # Go right
                i = i - node.bitmap.rank(i)
                node = node.right
        
        return node.alphabet[0] if node else None
    
    def rank(self, char, i):
        """
        Count occurrences of char in sequence[0:i+1].
        
        Time: O(log σ)
        """
        node = self.root
        left_bound = 0
        right_bound = self.n - 1
        
        while node and len(node.alphabet) > 1:
            mid = len(node.alphabet) // 2
            
            if char in node.alphabet[:mid]:
                # Go left
                i = node.bitmap.rank(i) - 1 if i >= 0 else -1
                node = node.left
            else:
                # Go right
                zeros_before = node.bitmap.rank(i) if i >= 0 else 0
                i = i - zeros_before
                node = node.right
        
        return i + 1 if i >= 0 else 0
    
    def __str__(self):
        """String representation."""
        return f"WaveletTree({self.sequence})"

def example_wavelet_tree():
    """Demonstrate wavelet tree."""
    print("\n=== Wavelet Tree ===\n")
    
    sequence = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3]
    print(f"Sequence: {sequence}\n")
    
    wt = WaveletTree(sequence)
    
    # Test access
    print("Access operations:")
    for i in [0, 3, 7]:
        result = wt.access(i)
        expected = sequence[i]
        print(f"  access({i}) = {result} (expected {expected}) {'✓' if result == expected else '✗'}")
    
    # Test rank
    print("\nRank operations:")
    for char in [1, 3, 5]:
        for i in [4, 7, 9]:
            result = wt.rank(char, i)
            expected = sequence[:i+1].count(char)
            print(f"  rank({char}, {i}) = {result} (expected {expected}) {'✓' if result == expected else '✗'}")

if __name__ == "__main__":
    example_wavelet_tree()

20.5.4 12.5.4 Applications of Succinct Structures

class SuccinctApplications:
    """Real-world applications of succinct data structures."""
    
    @staticmethod
    def dna_sequence_index():
        """
        Index DNA sequences succinctly.
        DNA has 4-letter alphabet {A, C, G, T}.
        Using wavelet tree: 2 bits per base + overhead.
        """
        class DNAIndex:
            def __init__(self, sequence):
                self.wt = WaveletTree(list(sequence))
                self.sequence = sequence
            
            def count_base(self, base, start, end):
                """Count occurrences of base in range."""
                if start > 0:
                    return self.wt.rank(base, end) - self.wt.rank(base, start - 1)
                return self.wt.rank(base, end)
            
            def gc_content(self, start, end):
                """Calculate GC content in range."""
                g_count = self.count_base('G', start, end)
                c_count = self.count_base('C', start, end)
                total = end - start + 1
                return (g_count + c_count) / total * 100
        
        return DNAIndex
    
    @staticmethod
    def compressed_graph():
        """
        Represent graphs succinctly.
        For web graphs: ~2 bits per edge!
        """
        class SuccinctGraph:
            def __init__(self, edges, n_vertices):
                """
                Store graph as sequence of edge destinations.
                Uses wavelet tree for neighbor queries.
                """
                self.n = n_vertices
                
                # Sort edges by source
                edges_sorted = sorted(edges)
                
                # Store destinations
                self.destinations = [dst for src, dst in edges_sorted]
                self.wt = WaveletTree(self.destinations)
                
                # Store offsets (where each vertex's edges start)
                self.offsets = [0]
                current_src = 0
                for i, (src, dst) in enumerate(edges_sorted):
                    while current_src < src:
                        self.offsets.append(i)
                        current_src += 1
                while current_src < n_vertices:
                    self.offsets.append(len(edges_sorted))
                    current_src += 1
            
            def neighbors(self, v):
                """Get neighbors of vertex v."""
                start = self.offsets[v]
                end = self.offsets[v + 1] - 1
                return [self.destinations[i] for i in range(start, end + 1)]
            
            def has_edge(self, u, v):
                """Check if edge (u, v) exists."""
                return v in self.neighbors(u)
        
        return SuccinctGraph

def example_dna_index():
    """Demonstrate DNA sequence indexing."""
    print("\n=== DNA Sequence Index ===\n")
    
    DNAIndex = SuccinctApplications.dna_sequence_index()
    
    # Sample DNA sequence
    dna = "ACGTACGTTAGCTAGCTAGCTAGCTACGTACGTT"
    print(f"DNA sequence: {dna}")
    print(f"Length: {len(dna)} bases\n")
    
    index = DNAIndex(dna)
    
    # Query base counts
    print("Base counts in range [10, 20]:")
    for base in ['A', 'C', 'G', 'T']:
        count = index.count_base(base, 10, 20)
        print(f"  {base}: {count}")
    
    # Calculate GC content
    gc = index.gc_content(0, len(dna) - 1)
    print(f"\nOverall GC content: {gc:.1f}%")
    
    # Space analysis
    print("\n" + "=" * 50)
    print("Space analysis:")
    print(f"  Original: {len(dna)} characters = {len(dna) * 8} bits")
    print(f"  Succinct: ~{len(dna) * 2} bits (2 bits per base)")
    print(f"  Compression: {len(dna) * 8 / (len(dna) * 2):.1f}x")

if __name__ == "__main__":
    example_dna_index()

20.6 12.6 Cache-Oblivious Algorithms: Automatically Efficient

20.6.1 12.6.1 The Cache Problem

Modern CPUs have multiple cache levels (L1, L2, L3). Accessing L1 is ~100x faster than RAM!

Traditional approach: Tune algorithms for specific cache size → doesn’t work across different machines!

Cache-oblivious approach: Design algorithms that are efficient for ALL cache sizes automatically!

20.6.2 12.6.2 Cache-Oblivious Matrix Transpose

import numpy as np

class CacheObliviousAlgorithms:
    """
    Cache-oblivious algorithms that adapt to any cache size.
    """
    
    @staticmethod
    def matrix_transpose_naive(A):
        """
        Naive matrix transpose.
        Poor cache performance: jumps around in memory.
        
        Time: O(n²)
        Cache misses: O(n² / B) where B = cache line size
        """
        n, m = A.shape
        B = np.zeros((m, n))
        
        for i in range(n):
            for j in range(m):
                B[j, i] = A[i, j]  # Bad: B accessed column-wise!
        
        return B
    
    @staticmethod
    def matrix_transpose_cache_oblivious(A, B=None, i0=0, j0=0, n=None, m=None):
        """
        Cache-oblivious matrix transpose using divide-and-conquer.
        
        Key idea: Recursively divide until submatrix fits in cache.
        Works for ANY cache size!
        
        Time: O(n²)
        Cache misses: O(n² / B + n² / √M) where M = cache size
        """
        if B is None:
            n, m = A.shape
            B = np.zeros((m, n))
            return CacheObliviousAlgorithms.matrix_transpose_cache_oblivious(
                A, B, 0, 0, n, m
            )
        
        if n is None:
            n, m = A.shape
        
        # Base case: small enough, do directly
        if n * m <= 64:  # Tune this threshold
            for i in range(n):
                for j in range(m):
                    B[j0 + j, i0 + i] = A[i0 + i, j0 + j]
            return B
        
        # Divide: split along larger dimension
        if n >= m:
            # Split rows
            mid = n // 2
            CacheObliviousAlgorithms.matrix_transpose_cache_oblivious(
                A, B, i0, j0, mid, m
            )
            CacheObliviousAlgorithms.matrix_transpose_cache_oblivious(
                A, B, i0 + mid, j0, n - mid, m
            )
        else:
            # Split columns
            mid = m // 2
            CacheObliviousAlgorithms.matrix_transpose_cache_oblivious(
                A, B, i0, j0, n, mid
            )
            CacheObliviousAlgorithms.matrix_transpose_cache_oblivious(
                A, B, i0, j0 + mid, n, m - mid
            )
        
        return B
    
    @staticmethod
    def matrix_multiply_cache_oblivious(A, B, C=None, i0=0, j0=0, k0=0, n=None, m=None, p=None):
        """
        Cache-oblivious matrix multiplication.
        
        Multiplies A (n×m) by B (m×p) → C (n×p)
        
        Time: O(nmp)
        Cache misses: O(n³ / B√M) - optimal!
        """
        if C is None:
            n, m = A.shape
            m2, p = B.shape
            assert m == m2, "Incompatible dimensions"
            C = np.zeros((n, p))
            return CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, 0, 0, 0, n, m, p
            )
        
        if n is None:
            n, m = A.shape
            m2, p = B.shape
        
        # Base case
        if n * m * p <= 64:
            for i in range(n):
                for j in range(p):
                    for k in range(m):
                        C[i0 + i, j0 + j] += A[i0 + i, k0 + k] * B[k0 + k, j0 + j]
            return C
        
        # Divide along largest dimension
        if n >= m and n >= p:
            mid = n // 2
            CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, i0, j0, k0, mid, m, p
            )
            CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, i0 + mid, j0, k0, n - mid, m, p
            )
        elif m >= n and m >= p:
            mid = m // 2
            CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, i0, j0, k0, n, mid, p
            )
            CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, i0, j0, k0 + mid, n, m - mid, p
            )
        else:
            mid = p // 2
            CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, i0, j0, k0, n, m, mid
            )
            CacheObliviousAlgorithms.matrix_multiply_cache_oblivious(
                A, B, C, i0, j0 + mid, k0, n, m, p - mid
            )
        
        return C

def benchmark_cache_oblivious():
    """Benchmark cache-oblivious algorithms."""
    import time
    
    print("\n=== Cache-Oblivious Algorithms ===\n")
    
    sizes = [128, 256, 512, 1024]
    
    print("Matrix Transpose Benchmark:")
    print(f"{'Size':>6} {'Naive (ms)':>12} {'Cache-Oblivious (ms)':>22} {'Speedup':>10}")
    print("-" * 62)
    
    for n in sizes:
        A = np.random.randn(n, n)
        
        # Naive
        start = time.time()
        B1 = CacheObliviousAlgorithms.matrix_transpose_naive(A)
        naive_time = (time.time() - start) * 1000
        
        # Cache-oblivious
        start = time.time()
        B2 = CacheObliviousAlgorithms.matrix_transpose_cache_oblivious(A)
        co_time = (time.time() - start) * 1000
        
        # Verify correctness
        assert np.allclose(B1, B2), "Results don't match!"
        
        speedup = naive_time / co_time
        print(f"{n:6d} {naive_time:12.2f} {co_time:22.2f} {speedup:10.2f}x")
    
    print("\n" + "=" * 62)
    print("Key insight: Cache-oblivious version adapts to cache size!")
    print("Performance remains good regardless of cache configuration.")

if __name__ == "__main__":
    benchmark_cache_oblivious()

20.6.3 12.6.3 Van Emde Boas Layout

class VanEmdeBoasLayout:
    """
    Van Emde Boas layout for cache-oblivious binary trees.
    
    Instead of storing tree level-by-level or in-order,
    recursively divide into subtrees.
    """
    
    @staticmethod
    def build_layout(n):
        """
        Build VEB layout for complete binary tree with n nodes.
        
        Returns: Array where array[i] = node stored at position i
        """
        layout = [0] * n
        VanEmdeBoasLayout._build_recursive(layout, 0, n, 0)
        return layout
    
    @staticmethod
    def _build_recursive(layout, start, size, node_id):
        """Recursive VEB layout construction."""
        if size == 0:
            return
        
        if size == 1:
            layout[start] = node_id
            return
        
        # Find height of tree
        import math
        h = int(math.log2(size + 1))
        h_top = h // 2
        
        # Size of top tree
        top_size = (1 << h_top) - 1
        
        # Store top tree
        VanEmdeBoasLayout._build_recursive(layout, start, top_size, node_id)
        
        # Store bottom trees
        bottom_start = start + top_size
        bottom_size = (size - top_size) // (1 << h_top)
        
        for i in range(1 << h_top):
            bottom_node_start = node_id + top_size + i * (bottom_size + 1)
            VanEmdeBoasLayout._build_recursive(
                layout,
                bottom_start + i * bottom_size,
                bottom_size,
                bottom_node_start
            )
    
    @staticmethod
    def visualize_layout(n):
        """Visualize VEB layout."""
        layout = VanEmdeBoasLayout.build_layout(n)
        
        print(f"\nVan Emde Boas Layout for {n} nodes:")
        print(f"Storage order: {layout}")
        
        # Compare with level-order
        level_order = list(range(n))
        print(f"Level order:   {level_order}")
        
        print("\nAdvantage: Subtrees stored contiguously → better cache performance!")

if __name__ == "__main__":
    VanEmdeBoasLayout.visualize_layout(15)

20.7 12.7 Chapter Project: Advanced Data Structure Library

Let’s build a comprehensive library!

20.7.1 12.7.1 Project Structure

AdvancedDataStructures/
├── advds/
│   ├── __init__.py
│   ├── trees/
│   │   ├── segment_tree.py
│   │   ├── fenwick_tree.py
│   │   └── persistent_tree.py
│   ├── succinct/
│   │   ├── bit_vector.py
│   │   ├── wavelet_tree.py
│   │   └── rank_select.py
│   ├── cache_oblivious/
│   │   ├── matrix_ops.py
│   │   └── layouts.py
│   ├── applications/
│   │   ├── range_queries.py
│   │   ├── version_control.py
│   │   └── dna_index.py
│   └── benchmarks/
│       └── performance.py
├── tests/
├── examples/
├── docs/
└── setup.py

20.7.2 12.7.2 Unified Interface

# advds/__init__.py
"""
Advanced Data Structures Library

Provides:
- Segment Trees: Range queries with updates
- Fenwick Trees: Efficient prefix sums
- Persistent Structures: Time-travel data structures
- Succinct Structures: Space-efficient representations
- Cache-Oblivious: Automatically cache-efficient algorithms
"""

__version__ = "1.0.0"

from .trees import SegmentTree, LazySegmentTree, FenwickTree, FenwickTree2D
from .trees import PersistentArray, PersistentSegmentTree
from .succinct import SuccinctBitVector, WaveletTree
from .cache_oblivious import CacheObliviousOps
from .applications import RangeQuerySolver, VersionControl, DNAIndex

__all__ = [
    # Range query structures
    'SegmentTree',
    'LazySegmentTree',
    'FenwickTree',
    'FenwickTree2D',
    
    # Persistent structures
    'PersistentArray',
    'PersistentSegmentTree',
    
    # Succinct structures
    'SuccinctBitVector',
    'WaveletTree',
    
    # Cache-oblivious
    'CacheObliviousOps',
    
    # Applications
    'RangeQuerySolver',
    'VersionControl',
    'DNAIndex',
]

20.7.3 12.7.3 Range Query Solver Application

# advds/applications/range_queries.py
"""
Unified interface for range query problems.
"""

from ..trees import SegmentTree, FenwickTree, LazySegmentTree

class RangeQuerySolver:
    """
    High-level interface for range query problems.
    Automatically chooses best data structure.
    """
    
    def __init__(self, arr, query_type='sum', update_type='point'):
        """
        Initialize range query solver.
        
        Args:
            arr: Initial array
            query_type: 'sum', 'min', 'max', 'gcd'
            update_type: 'point' or 'range'
        """
        self.arr = arr
        self.n = len(arr)
        self.query_type = query_type
        self.update_type = update_type
        
        # Choose data structure
        if update_type == 'point' and query_type == 'sum':
            # Fenwick tree is optimal
            self.ds = FenwickTree.from_array(arr)
            self.backend = 'fenwick'
        elif update_type == 'range':
            # Need lazy segment tree
            self.ds = LazySegmentTree(arr)
            self.backend = 'lazy_segment'
        else:
            # General segment tree
            self.ds = SegmentTree(arr, operation=query_type)
            self.backend = 'segment'
    
    def query(self, left, right):
        """Query range [left, right]."""
        if self.backend == 'fenwick':
            return self.ds.range_sum(left, right)
        else:
            if self.backend == 'lazy_segment':
                return self.ds.query_range(left, right)
            return self.ds.query(left, right)
    
    def update(self, index, value=None, left=None, right=None):
        """
        Update element or range.
        
        Point update: update(index, value)
        Range update: update(left=L, right=R, value=delta)
        """
        if left is not None and right is not None:
            # Range update
            if self.backend != 'lazy_segment':
                raise ValueError("Range updates require LazySegmentTree")
            self.ds.update_range(left, right, value)
        else:
            # Point update
            if self.backend == 'fenwick':
                current = self.ds.range_sum(index, index)
                delta = value - current
                self.ds.update(index, delta)
            else:
                self.ds.update(index, value)
    
    def get_structure_info(self):
        """Get information about chosen data structure."""
        info = {
            'backend': self.backend,
            'size': self.n,
            'query_type': self.query_type,
            'update_type': self.update_type
        }
        
        if self.backend == 'fenwick':
            info['space'] = f"O(n) = {self.n} elements"
            info['query_time'] = "O(log n)"
            info['update_time'] = "O(log n)"
        elif self.backend == 'segment':
            info['space'] = f"O(4n) = {4 * self.n} elements"
            info['query_time'] = "O(log n)"
            info['update_time'] = "O(log n)"
        else:  # lazy_segment
            info['space'] = f"O(4n) = {4 * self.n} elements"
            info['query_time'] = "O(log n)"
            info['update_time'] = "O(log n) per range"
        
        return info

def example_range_query_solver():
    """Demonstrate automatic data structure selection."""
    print("=== Range Query Solver ===\n")
    
    arr = [1, 3, 5, 7, 9, 11, 13, 15]
    
    # Scenario 1: Point updates, sum queries
    print("Scenario 1: Point updates + sum queries")
    solver1 = RangeQuerySolver(arr, query_type='sum', update_type='point')
    info1 = solver1.get_structure_info()
    print(f"  Chosen: {info1['backend']}")
    print(f"  Query time: {info1['query_time']}")
    print(f"  Example: sum(2, 5) = {solver1.query(2, 5)}\n")
    
    # Scenario 2: Point updates, min queries
    print("Scenario 2: Point updates + min queries")
    solver2 = RangeQuerySolver(arr, query_type='min', update_type='point')
    info2 = solver2.get_structure_info()
    print(f"  Chosen: {info2['backend']}")
    print(f"  Query time: {info2['query_time']}")
    print(f"  Example: min(2, 5) = {solver2.query(2, 5)}\n")
    
    # Scenario 3: Range updates, sum queries
    print("Scenario 3: Range updates + sum queries")
    solver3 = RangeQuerySolver(arr, query_type='sum', update_type='range')
    info3 = solver3.get_structure_info()
    print(f"  Chosen: {info3['backend']}")
    print(f"  Update time: {info3['update_time']}")
    solver3.update(left=2, right=5, value=10)
    print(f"  After adding 10 to [2,5]: sum(2,5) = {solver3.query(2, 5)}")

if __name__ == "__main__":
    example_range_query_solver()

20.7.4 12.7.4 Comprehensive Benchmarking Suite

# advds/benchmarks/performance.py
"""
Comprehensive benchmarking of all data structures.
"""

import time
import numpy as np
import matplotlib.pyplot as plt
from ..trees import SegmentTree, FenwickTree, LazySegmentTree
from ..succinct import SuccinctBitVector, WaveletTree

class Benchmarks:
    """Performance benchmarking suite."""
    
    @staticmethod
    def benchmark_range_queries():
        """Benchmark range query structures."""
        print("=== Range Query Benchmarks ===\n")
        
        sizes = [1000, 5000, 10000, 50000, 100000]
        
        results = {
            'segment_tree': {'build': [], 'query': [], 'update': []},
            'fenwick_tree': {'build': [], 'query': [], 'update': []}
        }
        
        for n in sizes:
            arr = np.random.randint(1, 100, n).tolist()
            num_ops = min(1000, n // 10)
            
            print(f"Size: {n:,}")
            
            # Segment Tree
            start = time.time()
            st = SegmentTree(arr, operation='sum')
            build_time = time.time() - start
            results['segment_tree']['build'].append(build_time * 1000)
            
            # Query benchmark
            queries = [(np.random.randint(0, n-100), np.random.randint(0, n)) 
                      for _ in range(num_ops)]
            start = time.time()
            for l, r in queries:
                if l > r:
                    l, r = r, l
                _ = st.query(l, r)
            query_time = time.time() - start
            results['segment_tree']['query'].append(query_time * 1000)
            
            # Update benchmark
            updates = [(np.random.randint(0, n), np.random.randint(1, 100)) 
                      for _ in range(num_ops)]
            start = time.time()
            for idx, val in updates:
                st.update(idx, val)
            update_time = time.time() - start
            results['segment_tree']['update'].append(update_time * 1000)
            
            print(f"  Segment Tree: build={build_time*1000:.2f}ms, "
                  f"query={query_time*1000:.2f}ms, update={update_time*1000:.2f}ms")
            
            # Fenwick Tree
            start = time.time()
            ft = FenwickTree.from_array(arr)
            build_time = time.time() - start
            results['fenwick_tree']['build'].append(build_time * 1000)
            
            start = time.time()
            for l, r in queries:
                if l > r:
                    l, r = r, l
                _ = ft.range_sum(l, r)
            query_time = time.time() - start
            results['fenwick_tree']['query'].append(query_time * 1000)
            
            start = time.time()
            for idx, val in updates:
                current = ft.range_sum(idx, idx)
                ft.update(idx, val - current)
            update_time = time.time() - start
            results['fenwick_tree']['update'].append(update_time * 1000)
            
            print(f"  Fenwick Tree:  build={build_time*1000:.2f}ms, "
                  f"query={query_time*1000:.2f}ms, update={update_time*1000:.2f}ms\n")
        
        return results, sizes
    
    @staticmethod
    def plot_results(results, sizes):
        """Plot benchmark results."""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        operations = ['build', 'query', 'update']
        titles = ['Build Time', 'Query Time (1000 ops)', 'Update Time (1000 ops)']
        
        for idx, (op, title) in enumerate(zip(operations, titles)):
            ax = axes[idx]
            
            for ds_name, ds_results in results.items():
                ax.plot(sizes, ds_results[op], marker='o', label=ds_name)
            
            ax.set_xlabel('Array Size')
            ax.set_ylabel('Time (ms)')
            ax.set_title(title)
            ax.legend()
            ax.grid(True)
        
        plt.tight_layout()
        plt.savefig('range_query_benchmark.png', dpi=150)
        plt.close()
        
        print("✓ Benchmark plot saved to 'range_query_benchmark.png'")
    
    @staticmethod
    def run_all():
        """Run all benchmarks."""
        results, sizes = Benchmarks.benchmark_range_queries()
        Benchmarks.plot_results(results, sizes)

if __name__ == "__main__":
    Benchmarks.run_all()

20.8 12.8 Summary and Key Takeaways

Core Data Structures: 1. Segment Trees: O(log n) range queries with updates, very flexible 2. Fenwick Trees: Simpler, faster, less memory, but less flexible 3. Persistent Structures: Time-travel with O(log n) space per version 4. Succinct Structures: n + o(n) space with O(1) or O(log σ) operations 5. Cache-Oblivious: Automatically efficient for all cache sizes

When to Use What: - Simple range sums: Fenwick tree (simplest, fastest) - Complex range queries (min, max, GCD): Segment tree - Range updates: Lazy segment tree - Need history/undo: Persistent structures - Massive data, tight memory: Succinct structures - Unknown cache sizes: Cache-oblivious algorithms

Key Insights: - Bit manipulation enables elegant solutions (Fenwick tree) - Path copying makes persistence cheap - Divide-and-conquer adapts to cache automatically - Auxiliary structures trade space for query speed

Real-World Impact: - Databases: Range queries, time-travel queries - Genomics: Succinct indices for huge genomes - Version control: Git uses persistent structures - High-performance computing: Cache-oblivious algorithms

20.9 12.9 Exercises

20.9.1 Understanding

  1. Segment Tree: Prove that a segment tree uses O(n) space (not O(4n) worst case) for most arrays.

  2. Fenwick Magic: Explain why index & (-index) extracts the rightmost set bit.

  3. Persistence: Calculate the space usage of a persistent array after k updates.

20.9.2 Implementation

  1. 2D Segment Tree: Implement a 2D segment tree for rectangle queries.

  2. Persistent Stack: Implement a persistent stack with O(1) push/pop and access to all versions.

  3. Succinct Tree: Implement a succinct representation of a binary tree using 2n + o(n) bits.

20.9.3 Applications

  1. Skyline Problem: Use segment tree to solve the skyline problem in O(n log n).

  2. Version Control: Build a simple version control system using persistent data structures.

  3. Genome Assembly: Use succinct structures to index and query a large genome.

20.9.4 Advanced

  1. Dynamic Connectivity: Implement dynamic connectivity queries using persistent Union-Find.

  2. Fractional Cascading: Speed up segment tree queries to O(log n + k) where k = output size.

  3. Cache-Oblivious B-Tree: Implement a cache-oblivious search tree.

20.10 12.10 Further Reading

Classic Papers: - Bentley (1980): “Multidimensional Divide-and-Conquer” - Fenwick (1994): “A New Data Structure for Cumulative Frequency Tables” - Driscoll et al. (1989): “Making Data Structures Persistent” - Jacobson (1989): “Space-efficient Static Trees and Graphs”

Books: - Okasaki: “Purely Functional Data Structures” - Navarro: “Compact Data Structures” - Demaine: “Cache-Oblivious Algorithms and Data Structures”

Online Resources: - CP-Algorithms: Comprehensive tutorials - Topcoder tutorials: Practical competitive programming - SDSL Library: Succinct Data Structure Library (C++)


You’ve now mastered the advanced data structures that separate good programmers from great ones! These structures power databases, enable version control systems, compress massive datasets, and squeeze every last drop of performance from modern hardware.

Next: We’ll explore how these algorithms are applied in real research and industry!