Skip to content

Commit 8ece70c

Browse files
committed
add new file data_structures/binary_tree/segment_tree_node.py
1 parent 0a3a965 commit 8ece70c

File tree

1 file changed

+204
-0
lines changed

1 file changed

+204
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
class Node():
2+
def __init__(self, start, end):
3+
# Initializes a segment tree node with start and end indices
4+
self.start = start
5+
self.end = end
6+
self.value = None
7+
self.left = None
8+
self.right = None
9+
10+
11+
class SegmentTree():
12+
def __init__(self, nums, mode='max'):
13+
"""
14+
Initializes the Segment Tree.
15+
:param nums: List of integers to build the tree from.
16+
:param mode: Operation mode of the tree ('max' or 'sum').
17+
"""
18+
self.siz = len(nums)
19+
self.mode = mode
20+
if mode not in {'max', 'sum'}:
21+
self.mode = 'max' # Default to max if invalid mode is given
22+
23+
# Build the tree from the input list
24+
self.root = self.build(0, self.siz - 1, nums)
25+
26+
def build(self, start, end, nums):
27+
"""
28+
Recursively builds the segment tree.
29+
:param start: Start index of the segment.
30+
:param end: End index of the segment.
31+
:param nums: Original input array.
32+
:return: Root node of the constructed subtree.
33+
"""
34+
if start > end:
35+
return None
36+
37+
if start == end:
38+
# Leaf node
39+
n = Node(start, end)
40+
n.value = nums[start]
41+
return n
42+
43+
mid = (start + end) // 2
44+
root = Node(start, end)
45+
root.left = self.build(start, mid, nums)
46+
root.right = self.build(mid + 1, end, nums)
47+
48+
# Set the value according to the mode
49+
if self.mode == 'max':
50+
root.value = max(root.left.value, root.right.value)
51+
else:
52+
root.value = root.left.value + root.right.value
53+
54+
return root
55+
56+
def max_in_range(self, start_index, end_index):
57+
"""
58+
Queries the maximum value in a given range.
59+
Only works in 'max' mode.
60+
"""
61+
if self.mode == 'sum':
62+
raise Exception('Current Segment Tree doesn\'t support finding max')
63+
64+
if start_index > end_index or start_index < 0 or end_index >= self.siz:
65+
raise Exception('Invalid index')
66+
67+
return self.query(self.root, start_index, end_index, 0, self.siz - 1)
68+
69+
def sum_in_range(self, start_index, end_index):
70+
"""
71+
Queries the sum of values in a given range.
72+
Only works in 'sum' mode.
73+
"""
74+
if self.mode == 'max':
75+
raise Exception('Current Segment Tree doesn\'t support summing')
76+
77+
if start_index > end_index or start_index < 0 or end_index >= self.siz:
78+
raise Exception('Invalid index')
79+
80+
return self.query(self.root, start_index, end_index, 0, self.siz - 1)
81+
82+
def query(self, node, start_index, end_index, start, end):
83+
"""
84+
Recursively queries a value (max or sum) in a given range.
85+
:param node: Current node in the tree.
86+
:param start_index: Query start index.
87+
:param end_index: Query end index.
88+
:param start: Node's segment start.
89+
:param end: Node's segment end.
90+
:return: Result of query in the range.
91+
"""
92+
# Complete overlap
93+
if start_index <= start and end <= end_index:
94+
return node.value
95+
96+
mid = (start + end) // 2
97+
98+
if end_index <= mid:
99+
# Entire range is in the left child
100+
return self.query(node.left, start_index, end_index, start, mid)
101+
elif start_index > mid:
102+
# Entire range is in the right child
103+
return self.query(node.right, start_index, end_index, mid + 1, end)
104+
else:
105+
# Range spans both children
106+
if self.mode == 'max':
107+
return max(
108+
self.query(node.left, start_index, end_index, start, mid),
109+
self.query(node.right, start_index, end_index, mid + 1, end)
110+
)
111+
else:
112+
return (
113+
self.query(node.left, start_index, end_index, start, mid) +
114+
self.query(node.right, start_index, end_index, mid + 1, end)
115+
)
116+
117+
def update(self, index, new_value):
118+
"""
119+
Updates a value at a specific index in the segment tree.
120+
:param index: Index to update.
121+
:param new_value: New value to set.
122+
"""
123+
if index < 0 or index >= self.siz:
124+
raise Exception('Invalid index')
125+
126+
self.modify(self.root, index, new_value, 0, self.siz - 1)
127+
128+
def modify(self, node, index, new_value, start, end):
129+
"""
130+
Recursively updates the tree to reflect a change at a specific index.
131+
:param node: Current node being processed.
132+
:param index: Index to update.
133+
:param new_value: New value to assign.
134+
:param start: Start index of node's segment.
135+
:param end: End index of node's segment.
136+
"""
137+
if start == end:
138+
node.value = new_value
139+
return
140+
141+
mid = (start + end) // 2
142+
143+
if index <= mid:
144+
self.modify(node.left, index, new_value, start, mid)
145+
else:
146+
self.modify(node.right, index, new_value, mid + 1, end)
147+
148+
# Recompute current node's value after update
149+
if self.mode == 'max':
150+
node.value = max(node.left.value, node.right.value)
151+
else:
152+
node.value = node.left.value + node.right.value
153+
154+
155+
"""
156+
nums = [1, 3, 5, 7, 9, 11]
157+
158+
st_max = SegmentTree(nums, mode='max')
159+
print(st_max.max_in_range(1, 3)) # Expected: 7 (max of [3,5,7])
160+
print(st_max.max_in_range(0, 5)) # Expected: 11
161+
st_max.update(3, 10) # nums[3] = 10
162+
print(st_max.max_in_range(1, 4)) # Expected: 10
163+
164+
try:
165+
st_max.sum_in_range(0, 2) # Should raise exception
166+
except Exception as e:
167+
print(e) # Expected: Current Segment Tree doesn't support summing
168+
169+
try:
170+
st_max.max_in_range(3, 2) # Should raise exception
171+
except Exception as e:
172+
print(e) # Expected: Invalid index
173+
174+
try:
175+
st_max.max_in_range(1, 200) # Should raise exception
176+
except Exception as e:
177+
print(e) # Expected: Invalid index
178+
179+
180+
st_sum = SegmentTree(nums, mode='sum')
181+
print(st_sum.sum_in_range(1, 3)) # Expected: 15 (3+5+7)
182+
print(st_sum.sum_in_range(0, 5)) # Expected: 36 (sum of all elements)
183+
print(st_sum.sum_in_range(1, 3)) # Expected: 15 (3+5+7)
184+
print(st_sum.sum_in_range(0, 5)) # Expected: 36 (sum of all elements)
185+
186+
try:
187+
st_sum.max_in_range(0, 2) # Should raise exception
188+
except Exception as e:
189+
print(e) # Expected: Current Segment Tree doesn't support finding max
190+
191+
try:
192+
st_sum.sum_in_range(3, 2) # Should raise exception
193+
except Exception as e:
194+
print(e) # Expected: Invalid index
195+
196+
try:
197+
st_sum.sum_in_range(1, 200) # Should raise exception
198+
except Exception as e:
199+
print(e) # Expected: Invalid index
200+
201+
202+
st_invalid = SegmentTree(nums, mode='unknown') # Should default to 'max'
203+
print(st_invalid.mode) # Expected: 'max'
204+
"""

0 commit comments

Comments
 (0)