Skip to content

Commit 96292fe

Browse files
siliataiderdpiparo
authored andcommitted
[python] UHI Refactoring: Slicing Logic, Equality Operator, and Negative Indexing
1 parent d486882 commit 96292fe

File tree

1 file changed

+54
-23
lines changed
  • bindings/pyroot/pythonizations/python/ROOT/_pythonization

1 file changed

+54
-23
lines changed

Diff for: bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py

+54-23
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from __future__ import annotations
1111

1212
import enum
13+
import types
1314
from abc import ABC, abstractmethod
1415
from contextlib import contextmanager
1516
from typing import Any, Callable, Iterator, Tuple, Union
16-
import types
1717

1818
"""
1919
Implementation of the module level helper functions for the UHI
@@ -155,6 +155,9 @@ def _process_index_for_axis(self, index, axis):
155155
return _get_axis_len(self, axis) if index is len else index(self, axis)
156156

157157
if isinstance(index, int):
158+
# -1 index returns the last valid bin
159+
if index == -1:
160+
return _overflow(self, axis) - 1
158161
# Shift the indices by 1 to align with the UHI convention,
159162
# where 0 corresponds to the first bin, unlike ROOT where 0 represents underflow and 1 is the first bin.
160163
index = index + 1
@@ -166,7 +169,7 @@ def _process_index_for_axis(self, index, axis):
166169
raise index
167170

168171

169-
def _compute_uhi_index(self, index, axis):
172+
def _compute_uhi_index(self, index, axis, include_flow_bins=True):
170173
"""Convert tag functors to valid bin indices."""
171174
if isinstance(index, _rebin) or index is _sum:
172175
index = slice(None, None, index)
@@ -175,13 +178,13 @@ def _compute_uhi_index(self, index, axis):
175178
return _process_index_for_axis(self, index, axis)
176179

177180
if isinstance(index, slice):
178-
start, stop = _resolve_slice_indices(self, index, axis)
181+
start, stop = _resolve_slice_indices(self, index, axis, include_flow_bins)
179182
return slice(start, stop, index.step)
180183

181184
raise TypeError(f"Unsupported index type: {type(index).__name__}")
182185

183186

184-
def _compute_common_index(self, index):
187+
def _compute_common_index(self, index, include_flow_bins=True):
185188
"""Normalize and expand the index to match the histogram dimension."""
186189
dim = self.GetDimension()
187190
if isinstance(index, dict):
@@ -209,19 +212,27 @@ def _compute_common_index(self, index):
209212
if len(index) != dim:
210213
raise IndexError(f"Expected {dim} indices, got {len(index)}")
211214

212-
return [_compute_uhi_index(self, idx, axis) for axis, idx in enumerate(index)]
215+
return [_compute_uhi_index(self, idx, axis, include_flow_bins) for axis, idx in enumerate(index)]
213216

214217

215218
def _setbin(self, index, value):
216219
"""Set the bin content for a specific bin index"""
217220
self.SetBinContent(index, value)
218221

219222

220-
def _resolve_slice_indices(self, index, axis):
223+
def _resolve_slice_indices(self, index, axis, include_flow_bins=True):
221224
"""Resolve slice start and stop indices for a given axis"""
222225
start, stop = index.start, index.stop
223-
start = _process_index_for_axis(self, start, axis) if start is not None else _underflow(self, axis)
224-
stop = _process_index_for_axis(self, stop, axis) if stop is not None else _overflow(self, axis) + 1
226+
start = (
227+
_process_index_for_axis(self, start, axis)
228+
if start is not None
229+
else _underflow(self, axis) + (0 if include_flow_bins else 1)
230+
)
231+
stop = (
232+
_process_index_for_axis(self, stop, axis)
233+
if stop is not None
234+
else _overflow(self, axis) + (1 if include_flow_bins else 0)
235+
)
225236
if start < _underflow(self, axis) or stop > (_overflow(self, axis) + 1) or start > stop:
226237
raise IndexError(f"Slice indices {start, stop} out of range for axis {axis}")
227238
return start, stop
@@ -251,15 +262,15 @@ def _get_processed_slices(self, index):
251262
if len(index) != self.GetDimension():
252263
raise IndexError(f"Expected {self.GetDimension()} indices, got {len(index)}")
253264
processed_slices, out_of_range_indices, actions = [], [], [None] * self.GetDimension()
254-
for i, idx in enumerate(index):
255-
axis_bins = range(_get_axis(self, i).GetNbins() + 2)
265+
for axis, idx in enumerate(index):
266+
axis_bins = range(_overflow(self, axis) + 1)
256267
if isinstance(idx, slice):
257268
slice_range = range(idx.start, idx.stop)
258269
processed_slices.append(slice_range)
259270
uflow = [b for b in axis_bins if b < idx.start]
260271
oflow = [b for b in axis_bins if b >= idx.stop]
261272
out_of_range_indices.append((uflow, oflow))
262-
actions[i] = idx.step
273+
actions[axis] = idx.step
263274
else:
264275
processed_slices.append([idx])
265276

@@ -288,7 +299,8 @@ def _get_slice_indices(slices):
288299
"""
289300
import numpy as np
290301

291-
return np.array(np.meshgrid(*slices)).T.reshape(-1, len(slices))
302+
grids = np.meshgrid(*slices, indexing="ij")
303+
return np.array(grids).reshape(len(slices), -1).T
292304

293305

294306
def _set_flow_bins(self, target_hist, out_of_range_indices):
@@ -347,14 +359,22 @@ def _slice_get(self, index):
347359
return _apply_actions(target_hist, actions)
348360

349361

350-
def _slice_set(self, index, value):
362+
def _slice_set(self, index, unprocessed_index, value):
351363
"""
352364
This method modifies the histogram by updating the bin contents for the
353365
specified slice. It supports assigning a scalar value to all bins or
354366
assigning an array of values, provided the array's shape matches the slice.
355367
"""
356368
import numpy as np
357369

370+
# Depending on the shape of the array provided, we can set or not the flow bins
371+
# Setting with a scalar does not set the flow bins
372+
include_flow_bins = not (
373+
(isinstance(value, np.ndarray) and value.shape == _shape(self, include_flow_bins=False)) or np.isscalar(value)
374+
)
375+
if not include_flow_bins:
376+
index = _compute_common_index(self, unprocessed_index, include_flow_bins=False)
377+
358378
processed_slices, _, actions = _get_processed_slices(self, index)
359379
slice_indices = _get_slice_indices(processed_slices)
360380
if isinstance(value, np.ndarray):
@@ -377,25 +397,36 @@ def _slice_set(self, index, value):
377397

378398

379399
def _getitem(self, index):
380-
index = _compute_common_index(self, index)
381-
if all(isinstance(i, int) for i in index):
382-
return self.GetBinContent(*index)
400+
uhi_index = _compute_common_index(self, index)
401+
if all(isinstance(i, int) for i in uhi_index):
402+
return self.GetBinContent(*uhi_index)
383403

384-
if any(isinstance(i, slice) for i in index):
385-
return _slice_get(self, index)
404+
if any(isinstance(i, slice) for i in uhi_index):
405+
return _slice_get(self, uhi_index)
386406

387407

388408
def _setitem(self, index, value):
389-
index = _compute_common_index(self, index)
390-
if all(isinstance(i, int) for i in index):
391-
_setbin(self, self.GetBin(*index), value)
392-
elif any(isinstance(i, slice) for i in index):
393-
_slice_set(self, index, value)
409+
uhi_index = _compute_common_index(self, index)
410+
if all(isinstance(i, int) for i in uhi_index):
411+
_setbin(self, self.GetBin(*uhi_index), value)
412+
elif any(isinstance(i, slice) for i in uhi_index):
413+
_slice_set(self, uhi_index, index, value)
414+
415+
416+
def _eq(self, other):
417+
import numpy as np
418+
419+
return (
420+
isinstance(other, type(self))
421+
and _shape(self) == _shape(other)
422+
and np.array_equal(_values_default(self), _values_default(other))
423+
)
394424

395425

396426
def _add_indexing_features(klass: Any) -> None:
397427
klass.__getitem__ = _getitem
398428
klass.__setitem__ = _setitem
429+
klass.__eq__ = _eq
399430

400431

401432
"""

0 commit comments

Comments
 (0)