10
10
from __future__ import annotations
11
11
12
12
import enum
13
+ import types
13
14
from abc import ABC , abstractmethod
14
15
from contextlib import contextmanager
15
16
from typing import Any , Callable , Iterator , Tuple , Union
16
- import types
17
17
18
18
"""
19
19
Implementation of the module level helper functions for the UHI
@@ -155,6 +155,9 @@ def _process_index_for_axis(self, index, axis):
155
155
return _get_axis_len (self , axis ) if index is len else index (self , axis )
156
156
157
157
if isinstance (index , int ):
158
+ # -1 index returns the last valid bin
159
+ if index == - 1 :
160
+ return _overflow (self , axis ) - 1
158
161
# Shift the indices by 1 to align with the UHI convention,
159
162
# where 0 corresponds to the first bin, unlike ROOT where 0 represents underflow and 1 is the first bin.
160
163
index = index + 1
@@ -166,7 +169,7 @@ def _process_index_for_axis(self, index, axis):
166
169
raise index
167
170
168
171
169
- def _compute_uhi_index (self , index , axis ):
172
+ def _compute_uhi_index (self , index , axis , include_flow_bins = True ):
170
173
"""Convert tag functors to valid bin indices."""
171
174
if isinstance (index , _rebin ) or index is _sum :
172
175
index = slice (None , None , index )
@@ -175,13 +178,13 @@ def _compute_uhi_index(self, index, axis):
175
178
return _process_index_for_axis (self , index , axis )
176
179
177
180
if isinstance (index , slice ):
178
- start , stop = _resolve_slice_indices (self , index , axis )
181
+ start , stop = _resolve_slice_indices (self , index , axis , include_flow_bins )
179
182
return slice (start , stop , index .step )
180
183
181
184
raise TypeError (f"Unsupported index type: { type (index ).__name__ } " )
182
185
183
186
184
- def _compute_common_index (self , index ):
187
+ def _compute_common_index (self , index , include_flow_bins = True ):
185
188
"""Normalize and expand the index to match the histogram dimension."""
186
189
dim = self .GetDimension ()
187
190
if isinstance (index , dict ):
@@ -209,19 +212,27 @@ def _compute_common_index(self, index):
209
212
if len (index ) != dim :
210
213
raise IndexError (f"Expected { dim } indices, got { len (index )} " )
211
214
212
- return [_compute_uhi_index (self , idx , axis ) for axis , idx in enumerate (index )]
215
+ return [_compute_uhi_index (self , idx , axis , include_flow_bins ) for axis , idx in enumerate (index )]
213
216
214
217
215
218
def _setbin (self , index , value ):
216
219
"""Set the bin content for a specific bin index"""
217
220
self .SetBinContent (index , value )
218
221
219
222
220
- def _resolve_slice_indices (self , index , axis ):
223
+ def _resolve_slice_indices (self , index , axis , include_flow_bins = True ):
221
224
"""Resolve slice start and stop indices for a given axis"""
222
225
start , stop = index .start , index .stop
223
- start = _process_index_for_axis (self , start , axis ) if start is not None else _underflow (self , axis )
224
- stop = _process_index_for_axis (self , stop , axis ) if stop is not None else _overflow (self , axis ) + 1
226
+ start = (
227
+ _process_index_for_axis (self , start , axis )
228
+ if start is not None
229
+ else _underflow (self , axis ) + (0 if include_flow_bins else 1 )
230
+ )
231
+ stop = (
232
+ _process_index_for_axis (self , stop , axis )
233
+ if stop is not None
234
+ else _overflow (self , axis ) + (1 if include_flow_bins else 0 )
235
+ )
225
236
if start < _underflow (self , axis ) or stop > (_overflow (self , axis ) + 1 ) or start > stop :
226
237
raise IndexError (f"Slice indices { start , stop } out of range for axis { axis } " )
227
238
return start , stop
@@ -251,15 +262,15 @@ def _get_processed_slices(self, index):
251
262
if len (index ) != self .GetDimension ():
252
263
raise IndexError (f"Expected { self .GetDimension ()} indices, got { len (index )} " )
253
264
processed_slices , out_of_range_indices , actions = [], [], [None ] * self .GetDimension ()
254
- for i , idx in enumerate (index ):
255
- axis_bins = range (_get_axis (self , i ). GetNbins () + 2 )
265
+ for axis , idx in enumerate (index ):
266
+ axis_bins = range (_overflow (self , axis ) + 1 )
256
267
if isinstance (idx , slice ):
257
268
slice_range = range (idx .start , idx .stop )
258
269
processed_slices .append (slice_range )
259
270
uflow = [b for b in axis_bins if b < idx .start ]
260
271
oflow = [b for b in axis_bins if b >= idx .stop ]
261
272
out_of_range_indices .append ((uflow , oflow ))
262
- actions [i ] = idx .step
273
+ actions [axis ] = idx .step
263
274
else :
264
275
processed_slices .append ([idx ])
265
276
@@ -288,7 +299,8 @@ def _get_slice_indices(slices):
288
299
"""
289
300
import numpy as np
290
301
291
- return np .array (np .meshgrid (* slices )).T .reshape (- 1 , len (slices ))
302
+ grids = np .meshgrid (* slices , indexing = "ij" )
303
+ return np .array (grids ).reshape (len (slices ), - 1 ).T
292
304
293
305
294
306
def _set_flow_bins (self , target_hist , out_of_range_indices ):
@@ -347,14 +359,22 @@ def _slice_get(self, index):
347
359
return _apply_actions (target_hist , actions )
348
360
349
361
350
- def _slice_set (self , index , value ):
362
+ def _slice_set (self , index , unprocessed_index , value ):
351
363
"""
352
364
This method modifies the histogram by updating the bin contents for the
353
365
specified slice. It supports assigning a scalar value to all bins or
354
366
assigning an array of values, provided the array's shape matches the slice.
355
367
"""
356
368
import numpy as np
357
369
370
+ # Depending on the shape of the array provided, we can set or not the flow bins
371
+ # Setting with a scalar does not set the flow bins
372
+ include_flow_bins = not (
373
+ (isinstance (value , np .ndarray ) and value .shape == _shape (self , include_flow_bins = False )) or np .isscalar (value )
374
+ )
375
+ if not include_flow_bins :
376
+ index = _compute_common_index (self , unprocessed_index , include_flow_bins = False )
377
+
358
378
processed_slices , _ , actions = _get_processed_slices (self , index )
359
379
slice_indices = _get_slice_indices (processed_slices )
360
380
if isinstance (value , np .ndarray ):
@@ -377,25 +397,36 @@ def _slice_set(self, index, value):
377
397
378
398
379
399
def _getitem (self , index ):
380
- index = _compute_common_index (self , index )
381
- if all (isinstance (i , int ) for i in index ):
382
- return self .GetBinContent (* index )
400
+ uhi_index = _compute_common_index (self , index )
401
+ if all (isinstance (i , int ) for i in uhi_index ):
402
+ return self .GetBinContent (* uhi_index )
383
403
384
- if any (isinstance (i , slice ) for i in index ):
385
- return _slice_get (self , index )
404
+ if any (isinstance (i , slice ) for i in uhi_index ):
405
+ return _slice_get (self , uhi_index )
386
406
387
407
388
408
def _setitem (self , index , value ):
389
- index = _compute_common_index (self , index )
390
- if all (isinstance (i , int ) for i in index ):
391
- _setbin (self , self .GetBin (* index ), value )
392
- elif any (isinstance (i , slice ) for i in index ):
393
- _slice_set (self , index , value )
409
+ uhi_index = _compute_common_index (self , index )
410
+ if all (isinstance (i , int ) for i in uhi_index ):
411
+ _setbin (self , self .GetBin (* uhi_index ), value )
412
+ elif any (isinstance (i , slice ) for i in uhi_index ):
413
+ _slice_set (self , uhi_index , index , value )
414
+
415
+
416
+ def _eq (self , other ):
417
+ import numpy as np
418
+
419
+ return (
420
+ isinstance (other , type (self ))
421
+ and _shape (self ) == _shape (other )
422
+ and np .array_equal (_values_default (self ), _values_default (other ))
423
+ )
394
424
395
425
396
426
def _add_indexing_features (klass : Any ) -> None :
397
427
klass .__getitem__ = _getitem
398
428
klass .__setitem__ = _setitem
429
+ klass .__eq__ = _eq
399
430
400
431
401
432
"""
0 commit comments