Skip to content

Commit 92df5a6

Browse files
Saransh-cpppre-commit-ci[bot]henryiii
authored
feat: support variable rebinning (#913)
* feat: support full UHI for rebinning * One type of axis at time * style: pre-commit fixes * Revert accidental deletion * Better code quality * fix: partial fix Signed-off-by: Henry Schreiner <[email protected]> * Refactor and move the rebinning logic in the loop * Fix repr test * Update src/boost_histogram/tag.py * Fix rebinning logic * Add logic for updating bin contents * make it work for nd hists * fix: support callable, add validation Signed-off-by: Henry Schreiner <[email protected]> * fix: the result of group_mapping() should be checked for None Signed-off-by: Henry Schreiner <[email protected]> --------- Signed-off-by: Henry Schreiner <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Henry Schreiner <[email protected]>
1 parent 5d04fa0 commit 92df5a6

File tree

5 files changed

+153
-24
lines changed

5 files changed

+153
-24
lines changed

.gitmodules

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
[submodule "pybind11"]
2-
path = extern/pybind11
3-
url = ../../pybind/pybind11.git
4-
[submodule "extern/boosthistogram"]
1+
[submodule "extern/histogram"]
52
path = extern/histogram
63
url = ../../boostorg/histogram.git
74
[submodule "extern/core"]

src/boost_histogram/_internal/hist.py

+53-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from boost_histogram import _core
2929

3030
from .axestuple import AxesTuple
31-
from .axis import Axis
31+
from .axis import Axis, Variable
3232
from .enum import Kind
3333
from .storage import Double, Storage
3434
from .typing import Accumulator, ArrayLike, CppHistogram
@@ -827,6 +827,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
827827
slices: list[_core.algorithm.reduce_command] = []
828828
pick_each: dict[int, int] = {}
829829
pick_set: dict[int, list[int]] = {}
830+
reduced: CppHistogram | None = None
830831

831832
# Compute needed slices and projections
832833
for i, ind in enumerate(indexes):
@@ -855,38 +856,82 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
855856
# This ensures that callable start/stop are handled
856857
start, stop = self.axes[i]._process_loc(ind.start, ind.stop)
857858

859+
groups = []
858860
if ind != slice(None):
859861
merge = 1
860862
if ind.step is not None:
861-
if hasattr(ind.step, "factor"):
863+
if getattr(ind.step, "factor", None) is not None:
862864
merge = ind.step.factor
865+
elif (
866+
hasattr(ind.step, "group_mapping")
867+
and (tmp_groups := ind.step.group_mapping(self.axes[i]))
868+
is not None
869+
):
870+
groups = tmp_groups
863871
elif callable(ind.step):
864872
if ind.step is sum:
865873
integrations.add(i)
866874
else:
867-
raise RuntimeError("Full UHI not supported yet")
875+
raise NotImplementedError
868876

869877
if ind.start is not None or ind.stop is not None:
870878
slices.append(
871879
_core.algorithm.slice(
872880
i, start, stop, _core.algorithm.slice_mode.crop
873881
)
874882
)
875-
continue
883+
if len(groups) == 0:
884+
continue
876885
else:
877886
raise IndexError(
878887
"The third argument to a slice must be rebin or projection"
879888
)
880889

881890
assert isinstance(start, int)
882891
assert isinstance(stop, int)
883-
slices.append(_core.algorithm.slice_and_rebin(i, start, stop, merge))
892+
# rebinning with factor
893+
if len(groups) == 0:
894+
slices.append(
895+
_core.algorithm.slice_and_rebin(i, start, stop, merge)
896+
)
897+
# rebinning with groups
898+
elif len(groups) != 0:
899+
if not reduced:
900+
reduced = self._hist
901+
axes = [reduced.axis(x) for x in range(reduced.rank())]
902+
reduced_view = reduced.view(flow=True)
903+
new_axes_indices = [axes[i].edges[0]]
904+
905+
j = 0
906+
for group in groups:
907+
new_axes_indices += [axes[i].edges[j + group]]
908+
j += group
909+
910+
variable_axis = Variable(
911+
new_axes_indices, metadata=axes[i].metadata
912+
)
913+
axes[i] = variable_axis._ax
914+
915+
logger.debug("Axes: %s", axes)
916+
917+
new_reduced = reduced.__class__(axes)
918+
new_view = new_reduced.view(flow=True)
919+
920+
j = 1
921+
for new_j, group in enumerate(groups):
922+
for _ in range(group):
923+
pos = [slice(None)] * (i)
924+
new_view[(*pos, new_j + 1, ...)] += reduced_view[ # type: ignore[arg-type]
925+
(*pos, j, ...) # type: ignore[arg-type]
926+
]
927+
j += 1
928+
929+
reduced = new_reduced
884930

885931
# Will be updated below
886-
if slices or pick_set or pick_each or integrations:
932+
if (slices or pick_set or pick_each or integrations) and not reduced:
887933
reduced = self._hist
888-
else:
889-
logger.debug("Reduce actions are all empty, just making a copy")
934+
elif not reduced:
890935
reduced = copy.copy(self._hist)
891936

892937
if pick_each:

src/boost_histogram/tag.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
import copy
66
from builtins import sum
7-
from typing import TypeVar
7+
from typing import TYPE_CHECKING, Sequence, TypeVar
8+
9+
if TYPE_CHECKING:
10+
from uhi.typing.plottable import PlottableAxis
811

912
from ._internal.typing import AxisLike
1013

@@ -108,12 +111,40 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002
108111

109112

110113
class rebin:
111-
__slots__ = ("factor",)
112-
113-
def __init__(self, value: int) -> None:
114-
self.factor = value
114+
__slots__ = (
115+
"factor",
116+
"groups",
117+
)
118+
119+
def __init__(
120+
self,
121+
factor: int | None = None,
122+
*,
123+
groups: Sequence[int] | None = None,
124+
) -> None:
125+
if not sum(i is None for i in [factor, groups]) == 1:
126+
raise ValueError("Exactly one, a factor or groups should be provided")
127+
self.factor = factor
128+
self.groups = groups
115129

116130
def __repr__(self) -> str:
117-
return f"{self.__class__.__name__}({self.factor})"
118-
119-
# TODO: Add __call__ to support UHI
131+
repr_str = f"{self.__class__.__name__}"
132+
args: dict[str, int | Sequence[int] | None] = {
133+
"factor": self.factor,
134+
"groups": self.groups,
135+
}
136+
for k, v in args.items():
137+
if v is not None:
138+
return_str = f"{repr_str}({k}={v})"
139+
break
140+
return return_str
141+
142+
def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
143+
if self.groups is not None:
144+
if sum(self.groups) != len(axis):
145+
msg = f"The sum of the groups ({sum(self.groups)}) must be equal to the number of bins in the axis ({len(axis)})"
146+
raise ValueError(msg)
147+
return self.groups
148+
if self.factor is not None:
149+
return [self.factor] * len(axis)
150+
raise ValueError("No rebinning factor or groups provided")

tests/test_histogram.py

+59-3
Original file line numberDiff line numberDiff line change
@@ -632,13 +632,17 @@ def test_shrink_1d():
632632

633633
def test_rebin_1d():
634634
h = bh.Histogram(bh.axis.Regular(20, 1, 5))
635-
h.fill(1.1)
635+
h.fill([1.1, 2.2, 3.3, 4.4])
636636

637637
hs = h[{0: slice(None, None, bh.rebin(4))}]
638-
assert_array_equal(hs.view(), [1, 0, 0, 0, 0])
638+
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])
639639

640640
hs = h[{0: bh.rebin(4)}]
641-
assert_array_equal(hs.view(), [1, 0, 0, 0, 0])
641+
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])
642+
643+
hs = h[{0: bh.rebin(groups=[1, 2, 3, 14])}]
644+
assert_array_equal(hs.view(), [1, 0, 0, 3])
645+
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
642646

643647

644648
def test_shrink_rebin_1d():
@@ -659,8 +663,60 @@ def test_rebin_nd():
659663
assert h[{1: s[:: bh.rebin(2)]}].axes.size == (20, 15, 40)
660664
assert h[{2: s[:: bh.rebin(2)]}].axes.size == (20, 30, 20)
661665

666+
assert h[{0: s[:: bh.rebin(groups=[1, 2, 17])]}].axes.size == (3, 30, 40)
667+
assert h[{1: s[:: bh.rebin(groups=[1, 2, 27])]}].axes.size == (20, 3, 40)
668+
assert h[{2: s[:: bh.rebin(groups=[1, 2, 37])]}].axes.size == (20, 30, 3)
669+
assert np.all(
670+
np.isclose(
671+
h[{0: s[:: bh.rebin(groups=[1, 2, 17])]}].axes[0].edges,
672+
[1.0, 1.1, 1.3, 3.0],
673+
)
674+
)
675+
assert np.all(
676+
np.isclose(
677+
h[{1: s[:: bh.rebin(groups=[1, 2, 27])]}].axes[1].edges,
678+
[1.0, 1.06666667, 1.2, 3.0],
679+
)
680+
)
681+
assert np.all(
682+
np.isclose(
683+
h[{2: s[:: bh.rebin(groups=[1, 2, 37])]}].axes[2].edges,
684+
[1.0, 1.05, 1.15, 3.0],
685+
)
686+
)
687+
662688
assert h[{0: s[:: bh.rebin(2)], 2: s[:: bh.rebin(2)]}].axes.size == (10, 30, 20)
663689

690+
assert h[
691+
{0: s[:: bh.rebin(groups=[1, 2, 17])], 2: s[:: bh.rebin(groups=[1, 2, 37])]}
692+
].axes.size == (3, 30, 3)
693+
assert np.all(
694+
np.isclose(
695+
h[
696+
{
697+
0: s[:: bh.rebin(groups=[1, 2, 17])],
698+
2: s[:: bh.rebin(groups=[1, 2, 37])],
699+
}
700+
]
701+
.axes[0]
702+
.edges,
703+
[1.0, 1.1, 1.3, 3],
704+
)
705+
)
706+
assert np.all(
707+
np.isclose(
708+
h[
709+
{
710+
0: s[:: bh.rebin(groups=[1, 2, 17])],
711+
2: s[:: bh.rebin(groups=[1, 2, 37])],
712+
}
713+
]
714+
.axes[2]
715+
.edges,
716+
[1.0, 1.05, 1.15, 3.0],
717+
)
718+
)
719+
664720
assert h[{1: s[:: bh.sum]}].axes.size == (20, 40)
665721
assert h[{1: bh.sum}].axes.size == (20, 40)
666722

tests/test_histogram_indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def test_repr():
240240
assert repr(bh.overflow + 1) == "overflow + 1"
241241
assert repr(bh.overflow - 1) == "overflow - 1"
242242

243-
assert repr(bh.rebin(2)) == "rebin(2)"
243+
assert repr(bh.rebin(2)) == "rebin(factor=2)"
244244

245245

246246
# Was broken in 0.6.1

0 commit comments

Comments
 (0)