Skip to content

Segment Tree

# segment tree / range sum query problem
import math

"""
Given an integer array nums, handle multiple queries of the following types:

    Update the value of an element in nums.
    Calculate the sum of the elements of nums between indices left and right inclusive where left <= right.
Input
["NumArray", "sumRange", "update", "sumRange"]
[[[1, 3, 5]], [0, 2], [1, 2], [0, 2]]
Output
[null, 9, null, 8]

 0 1 2 3 4
[1,2,3,4,5]


Brute force : Range Sum : O(NQ), Update : O(1)
Segment Tree : Range Sum : O(Q), Update : O(N)
Instead of calculating sum for N numbers Q number of times it'll take O(N * Q) : not efficient -> keep precomputed sum
[1,3,6,10,15] now for sum(1,3) we can find sum(left-1) - sum(right) = sum(0) - sum(3) = abs(1 - 20) = 19 -> O(Q)
When we are updating lets say arr[1] = 4 we need to update array and precomuted sum array : new arr = [1,4,3,4,5], precomp = [1,5,8,12,17] Each update takes O(N) time

Segment Tree using array (Using FBT) -:
    if curr_node = i: left_child = 2*i + 1, right_child = 2*i + 2, parent = floor((i-1)/2) # same as FBT
    Once you reach leaf node -> return sum, for internal node -> sum = left + right

    [1,2,5,6,7,9]
     /        \
    [1,2,5]    [6,7,9]

    Root node will store sum of all
    Sum of [6,7,9] will be stored in 2*0 + 2 = 2nd index
    Sum of [1,2,5] will be stored in 2*0 + 1 = 1st index
    Similarily, sum of [1,2] will be stored in 2*1 + 1 = 3rd index

Final array looks like this-:
 0   1  2   3 
[30, 8, 22, 3...]

Number of leaf nodes in FBT : n, number of internal nodes : n-1, ht(h) = ceil(log(n+1)), size of segment tree array : n = 2^h - 1 

Range Sum Query after making Segment Tree of array :: 3 cases -:
Total overlap, no overlap, partial overlap (of ranges)
When no overlap, return 0 value
Complete overlap, return the value
Partial overlap, make call to left and right side and return left+right

Update : update original array (O(1)) and update segment tree (same cases) if out of range : return; takes log(N) time

Segment Trees are useful only when updates are frequent. 

"""


# utility function to get middle index
def getMid(start, end):
    return start + (end - start) // 2


# recursive function to get sum of values in given range of array
"""
st = pointer to segment tree
si = index of current node in segment tree (0 is passed as it's the root)
ss and se = start and end index of segment represented by current node
qs and qe = start and end index of query range
"""


def getSum_helper(st, ss, se, qs, qe, si):
    if qs <= ss and qe >= se:
        return st[si]  # segment of node is part of given range
    if se < qs or ss > qe:
        return 0  # segment of node is outside given range
    # if part of segment overlaps with given range -> split again
    mid = getMid(ss, se)
    return getSum_helper(st, ss, mid, qs, qe, 2 * si + 1) + getSum_helper(
        st, mid + 1, se, qs, qe, 2 * si + 2
    )


# update
"""
i = index of element to be updated (in input array)
diff = value to be added to all nodes in seg tree which have i in range 
"""


def update_helper(st, ss, se, i, diff, si):
    if i < ss or i > se:
        return  # input index lies outside of range of this segment
    # in range
    st[si] = st[si] + diff
    # if not leaf node, update all internal nodes by splitting
    if se != ss:
        mid = getMid(ss, se)
        update_helper(st, ss, mid, i, diff, 2 * si + 1)
        update_helper(st, mid + 1, se, i, diff, 2 * si + 2)


# now update in array
def update(arr, st, n, i, new_val):
    if i < 0 or i > n - 1:
        return
    diff = new_val - arr[i]  # difference between new val and old val
    arr[i] = new_val
    update_helper(st, 0, n - 1, i, diff, 0)  # index to be updated is 0 (ROOT)


def getSum(st, n, qs, qe):
    if qs < 0 or qe > n - 1 or qs > qe:
        return -1
    return getSum_helper(st, 0, n - 1, qs, qe, 0)


def constructSegTree_helper(
    arr, ss, se, st, si
):  # si = index of current node in segment tree st
    # if only 1 element in array : store in current node of seg tree and return
    if ss == se:
        st[si] = arr[ss]
        return arr[ss]
    # if more -> split
    mid = getMid(ss, se)
    st[si] = constructSegTree_helper(
        arr, ss, mid, st, si * 2 + 1
    ) + constructSegTree_helper(arr, mid + 1, se, st, si * 2 + 2)
    return st[si]


def constructSegTree(arr, n):
    # ht
    x = math.ceil(math.log2(n))
    max_size = 2 * pow(2, x) - 1
    st = [0] * max_size
    constructSegTree_helper(arr, 0, n - 1, st, 0)
    return st


def main():
    arr = [1, 3, 5, 7, 9, 11]
    n = len(arr)
    st = constructSegTree(arr, n)
    print("Sum of values in given range = ", getSum(st, n, 1, 3))  # 15
    update(arr, st, n, 1, 10)  # 3 updated to 10
    print("Updated sum of values in range range = ", getSum(st, n, 1, 3))  # 22


main()