# segment tree

## intro
| Aspect              | Prefix Sum                               | Difference Array                               | Segment Tree                                                  |
| ------------------- | ---------------------------------------- | ---------------------------------------------- | ------------------------------------------------------------- |
| **Primary Purpose** | range sum queries                        | range updates                                  | range queries, and range updates                              |
| **Operation Time**  | Sum Query: `O(1)`                        | Update: `O(1)`                                 | Query: `O(logC)` or `O(logn)`; Update: `O(logC)` or `O(logn)` |
| **Reconstruction**  | get original array from diff of elements | get original array from prefix sum of elements | N/A                                                           |
| **Use Case**        | static arrays                            | cumulative updates                             | interval-based manipulations                                  |
- segment tree is a tree where each node is an interval
- tree based is more easy to understand
- build tree `O(n)`
    - or use dynamic build, only build node when update() and query(), cost `O(logC)` (C is max val of num we assign)
- point modify `O(logC)` or `O(logn)`
- range query `O(logC)` or `O(logn)`
    - sum
    - count
    - max
    - other related aggregations
- range modify `O(logC)` if using lazy propagation
    - can have multiple lazy variables depends on how many operations need
- push_down(), push_up(), update(), query() will be implemented differently depends on the diff type of range query and range modify

```python
# sum type
# non dynamic build
# non range modification
# non lazy propagation
class Node:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.mid = (start + end) // 2
        self.left = None
        self.right = None
        self.val = 0

class SegmentTree:
    def __init__(self, nums):
        def build(start, end):
            if start == end:
                node = Node(start, end)
                node.val = nums[start]
                return node
            node = Node(start, end)
            node.left = build(start, node.mid)
            node.right = build(node.mid + 1, end)
            node.val = node.left.val + node.right.val
            return node
        
        self.root = build(0, len(nums) - 1)
    
    def update(self, index, val):
        def helper(node, index, val):
            if node.start == node.end:
                node.val = val
                return 
            if index <= node.mid:
                helper(node.left, index, val)
            else:
                helper(node.right, index, val)
            node.val = node.left.val + node.right.val
        
        helper(self.root, index, val)

    def query(self, left, right):
        def helper(node, start, end):
            if start <= node.start and node.end <= end:
                return node.val
            res = 0
            if start <= node.mid:
                res += helper(node.left, start, end)
            if end >= node.mid + 1:
                res += helper(node.right, start, end)
            return res
            
        return helper(self.root, left, right)

# time O(n) for initialize and O(logn) for update() and query()
# space O(n), due to segment tree
# using segment tree and sum type segment tree
```

```python
# sum type
# with dynamic build
# with range modification
# with lazy propagation
class Node:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.mid = (start + end) // 2
        self.left = None
        self.right = None
        self.val = 0
        self.lazy = 0

class SegmentTree:
    def __init__(self, start, end):
        self.root = Node(start, end)

    def push_down(self, node):
        if not node.left:
            node.left = Node(node.start, node.mid)
        if not node.right:
            node.right = Node(node.mid + 1, node.end)
        if node.lazy == 0:
            return
        node.left.val += node.lazy * (node.left.end - node.left.start + 1)
        node.right.val += node.lazy * (node.right.end - node.right.start + 1)
        node.left.lazy += node.lazy
        node.right.lazy += node.lazy
        node.lazy = 0

    def push_up(self, node):
        node.val = node.left.val + node.right.val

    def update(self, node, start, end, add):
        if start <= node.start and node.end <= end:
            node.val += add * (node.end - node.start + 1)
            node.lazy += add
            return
        self.push_down(node)
        if start <= node.mid:
            self.update(node.left, start, end, add)
        if end >= node.mid + 1:
            self.update(node.right, start, end, add)
        self.push_up(node)

    def query(self, node, start, end):
        if start <= node.start and node.end <= end:
            return node.val
        self.push_down(node)
        res = 0
        if start <= node.mid:
            res += self.query(node.left, start, end)
        if end >= node.mid + 1:
            res += self.query(node.right, start, end)
        return res

# time O(1) for initialize and O(logC) for others
# space O(min(n, C)), due to segment tree
# using segment tree and sum type segment tree
```

## pattern
- use sum type segment tree
- use count type segment tree
- use max type segment tree