1
+ class Node ():
2
+ def __init__ (self , start , end ):
3
+ # Initializes a segment tree node with start and end indices
4
+ self .start = start
5
+ self .end = end
6
+ self .value = None
7
+ self .left = None
8
+ self .right = None
9
+
10
+
11
+ class SegmentTree ():
12
+ def __init__ (self , nums , mode = 'max' ):
13
+ """
14
+ Initializes the Segment Tree.
15
+ :param nums: List of integers to build the tree from.
16
+ :param mode: Operation mode of the tree ('max' or 'sum').
17
+ """
18
+ self .siz = len (nums )
19
+ self .mode = mode
20
+ if mode not in {'max' , 'sum' }:
21
+ self .mode = 'max' # Default to max if invalid mode is given
22
+
23
+ # Build the tree from the input list
24
+ self .root = self .build (0 , self .siz - 1 , nums )
25
+
26
+ def build (self , start , end , nums ):
27
+ """
28
+ Recursively builds the segment tree.
29
+ :param start: Start index of the segment.
30
+ :param end: End index of the segment.
31
+ :param nums: Original input array.
32
+ :return: Root node of the constructed subtree.
33
+ """
34
+ if start > end :
35
+ return None
36
+
37
+ if start == end :
38
+ # Leaf node
39
+ n = Node (start , end )
40
+ n .value = nums [start ]
41
+ return n
42
+
43
+ mid = (start + end ) // 2
44
+ root = Node (start , end )
45
+ root .left = self .build (start , mid , nums )
46
+ root .right = self .build (mid + 1 , end , nums )
47
+
48
+ # Set the value according to the mode
49
+ if self .mode == 'max' :
50
+ root .value = max (root .left .value , root .right .value )
51
+ else :
52
+ root .value = root .left .value + root .right .value
53
+
54
+ return root
55
+
56
+ def max_in_range (self , start_index , end_index ):
57
+ """
58
+ Queries the maximum value in a given range.
59
+ Only works in 'max' mode.
60
+ """
61
+ if self .mode == 'sum' :
62
+ raise Exception ('Current Segment Tree doesn\' t support finding max' )
63
+
64
+ if start_index > end_index or start_index < 0 or end_index >= self .siz :
65
+ raise Exception ('Invalid index' )
66
+
67
+ return self .query (self .root , start_index , end_index , 0 , self .siz - 1 )
68
+
69
+ def sum_in_range (self , start_index , end_index ):
70
+ """
71
+ Queries the sum of values in a given range.
72
+ Only works in 'sum' mode.
73
+ """
74
+ if self .mode == 'max' :
75
+ raise Exception ('Current Segment Tree doesn\' t support summing' )
76
+
77
+ if start_index > end_index or start_index < 0 or end_index >= self .siz :
78
+ raise Exception ('Invalid index' )
79
+
80
+ return self .query (self .root , start_index , end_index , 0 , self .siz - 1 )
81
+
82
+ def query (self , node , start_index , end_index , start , end ):
83
+ """
84
+ Recursively queries a value (max or sum) in a given range.
85
+ :param node: Current node in the tree.
86
+ :param start_index: Query start index.
87
+ :param end_index: Query end index.
88
+ :param start: Node's segment start.
89
+ :param end: Node's segment end.
90
+ :return: Result of query in the range.
91
+ """
92
+ # Complete overlap
93
+ if start_index <= start and end <= end_index :
94
+ return node .value
95
+
96
+ mid = (start + end ) // 2
97
+
98
+ if end_index <= mid :
99
+ # Entire range is in the left child
100
+ return self .query (node .left , start_index , end_index , start , mid )
101
+ elif start_index > mid :
102
+ # Entire range is in the right child
103
+ return self .query (node .right , start_index , end_index , mid + 1 , end )
104
+ else :
105
+ # Range spans both children
106
+ if self .mode == 'max' :
107
+ return max (
108
+ self .query (node .left , start_index , end_index , start , mid ),
109
+ self .query (node .right , start_index , end_index , mid + 1 , end )
110
+ )
111
+ else :
112
+ return (
113
+ self .query (node .left , start_index , end_index , start , mid ) +
114
+ self .query (node .right , start_index , end_index , mid + 1 , end )
115
+ )
116
+
117
+ def update (self , index , new_value ):
118
+ """
119
+ Updates a value at a specific index in the segment tree.
120
+ :param index: Index to update.
121
+ :param new_value: New value to set.
122
+ """
123
+ if index < 0 or index >= self .siz :
124
+ raise Exception ('Invalid index' )
125
+
126
+ self .modify (self .root , index , new_value , 0 , self .siz - 1 )
127
+
128
+ def modify (self , node , index , new_value , start , end ):
129
+ """
130
+ Recursively updates the tree to reflect a change at a specific index.
131
+ :param node: Current node being processed.
132
+ :param index: Index to update.
133
+ :param new_value: New value to assign.
134
+ :param start: Start index of node's segment.
135
+ :param end: End index of node's segment.
136
+ """
137
+ if start == end :
138
+ node .value = new_value
139
+ return
140
+
141
+ mid = (start + end ) // 2
142
+
143
+ if index <= mid :
144
+ self .modify (node .left , index , new_value , start , mid )
145
+ else :
146
+ self .modify (node .right , index , new_value , mid + 1 , end )
147
+
148
+ # Recompute current node's value after update
149
+ if self .mode == 'max' :
150
+ node .value = max (node .left .value , node .right .value )
151
+ else :
152
+ node .value = node .left .value + node .right .value
153
+
154
+
155
+ """
156
+ nums = [1, 3, 5, 7, 9, 11]
157
+
158
+ st_max = SegmentTree(nums, mode='max')
159
+ print(st_max.max_in_range(1, 3)) # Expected: 7 (max of [3,5,7])
160
+ print(st_max.max_in_range(0, 5)) # Expected: 11
161
+ st_max.update(3, 10) # nums[3] = 10
162
+ print(st_max.max_in_range(1, 4)) # Expected: 10
163
+
164
+ try:
165
+ st_max.sum_in_range(0, 2) # Should raise exception
166
+ except Exception as e:
167
+ print(e) # Expected: Current Segment Tree doesn't support summing
168
+
169
+ try:
170
+ st_max.max_in_range(3, 2) # Should raise exception
171
+ except Exception as e:
172
+ print(e) # Expected: Invalid index
173
+
174
+ try:
175
+ st_max.max_in_range(1, 200) # Should raise exception
176
+ except Exception as e:
177
+ print(e) # Expected: Invalid index
178
+
179
+
180
+ st_sum = SegmentTree(nums, mode='sum')
181
+ print(st_sum.sum_in_range(1, 3)) # Expected: 15 (3+5+7)
182
+ print(st_sum.sum_in_range(0, 5)) # Expected: 36 (sum of all elements)
183
+ print(st_sum.sum_in_range(1, 3)) # Expected: 15 (3+5+7)
184
+ print(st_sum.sum_in_range(0, 5)) # Expected: 36 (sum of all elements)
185
+
186
+ try:
187
+ st_sum.max_in_range(0, 2) # Should raise exception
188
+ except Exception as e:
189
+ print(e) # Expected: Current Segment Tree doesn't support finding max
190
+
191
+ try:
192
+ st_sum.sum_in_range(3, 2) # Should raise exception
193
+ except Exception as e:
194
+ print(e) # Expected: Invalid index
195
+
196
+ try:
197
+ st_sum.sum_in_range(1, 200) # Should raise exception
198
+ except Exception as e:
199
+ print(e) # Expected: Invalid index
200
+
201
+
202
+ st_invalid = SegmentTree(nums, mode='unknown') # Should default to 'max'
203
+ print(st_invalid.mode) # Expected: 'max'
204
+ """
0 commit comments