Skip to content

Commit 4ded101

Browse files
committed
Implemented iterate method
1 parent 61fc792 commit 4ded101

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

src/uproot/models/RNTuple.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,55 @@ def _arrays(
133133
)[entry_start:entry_stop]
134134

135135

136+
def _num_entries_for(in_ntuple, target_num_bytes, filter_name):
137+
# TODO: part of this is also done in _arrays, so we should refactor this
138+
# TODO: there might be a better way to estimate the number of entries
139+
entry_stop = in_ntuple.ntuple.num_entries
140+
141+
clusters = in_ntuple.ntuple.cluster_summaries
142+
cluster_starts = numpy.array([c.num_first_entry for c in clusters])
143+
144+
start_cluster_idx = numpy.searchsorted(cluster_starts, 0, side="right") - 1
145+
stop_cluster_idx = numpy.searchsorted(cluster_starts, entry_stop, side="right")
146+
147+
form = in_ntuple.to_akform().select_columns(
148+
filter_name, prune_unions_and_records=False
149+
)
150+
target_cols = []
151+
_recursive_find(form, target_cols)
152+
153+
total_bytes = 0
154+
for key in target_cols:
155+
if "column" in key and "union" not in key:
156+
key_nr = int(key.split("-")[1])
157+
for cluster in range(start_cluster_idx, stop_cluster_idx):
158+
pages = in_ntuple.ntuple.page_list_envelopes.pagelinklist[cluster][
159+
key_nr
160+
]
161+
total_bytes += sum(page.locator.num_bytes for page in pages)
162+
163+
total_entries = entry_stop
164+
if total_bytes == 0:
165+
num_entries = 0
166+
else:
167+
num_entries = int(round(target_num_bytes * total_entries / total_bytes))
168+
if num_entries <= 0:
169+
return 1
170+
else:
171+
return num_entries
172+
173+
174+
def _regularize_step_size(in_ntuple, step_size, filter_name):
175+
if uproot._util.isint(step_size):
176+
return step_size
177+
target_num_bytes = uproot._util.memory_size(
178+
step_size,
179+
"number of entries or memory size string with units "
180+
f"(such as '100 MB') required, not {step_size!r}",
181+
)
182+
return _num_entries_for(in_ntuple, target_num_bytes, filter_name)
183+
184+
136185
class Model_ROOT_3a3a_Experimental_3a3a_RNTuple(uproot.model.Model):
137186
"""
138187
A versionless :doc:`uproot.model.Model` for ``ROOT::Experimental::RNTuple``.
@@ -742,6 +791,13 @@ def arrays(
742791
array_cache=array_cache,
743792
)
744793

794+
def iterate(self, filter_name="*", *args, step_size="100 MB", **kwargs):
795+
step_size = _regularize_step_size(self, step_size, filter_name)
796+
for start in range(0, self.num_entries, step_size):
797+
yield self.arrays(
798+
*args, entry_start=start, entry_stop=start + step_size, **kwargs
799+
)
800+
745801

746802
# Supporting function and classes
747803
def _split_switch_bits(content):
@@ -1215,6 +1271,13 @@ def __array__(self, *args, **kwargs):
12151271
else:
12161272
return numpy.array(out, *args, **kwargs)
12171273

1274+
def iterate(self, filter_name="*", *args, step_size="100 MB", **kwargs):
1275+
step_size = _regularize_step_size(self, step_size, filter_name)
1276+
for start in range(0, self.ntuple.num_entries, step_size):
1277+
yield self.array(
1278+
*args, entry_start=start, entry_stop=start + step_size, **kwargs
1279+
)
1280+
12181281

12191282
uproot.classes["ROOT::Experimental::RNTuple"] = (
12201283
Model_ROOT_3a3a_Experimental_3a3a_RNTuple

tests/test_1250_rntuple_improvements.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,49 @@ def test_array_methods():
3737
obj = f["RNT:CollectionTree"]
3838
jets = obj["AntiKt4TruthDressedWZJetsAux:"].arrays()
3939
assert len(jets[0].pt) == 5
40+
41+
42+
def test_iterate():
43+
filename = skhep_testdata.data_path(
44+
"Run2012BC_DoubleMuParked_Muons_rntuple_1000evts.root"
45+
)
46+
with uproot.open(filename) as f:
47+
obj = f["Events"]
48+
for i, arrays in enumerate(obj.iterate(step_size=100)):
49+
assert len(arrays) == 100
50+
if i == 0:
51+
expected_pt = [10.763696670532227, 15.736522674560547]
52+
expected_charge = [-1, -1]
53+
assert arrays["Muon_pt"][0].tolist() == expected_pt
54+
assert arrays["Muon_charge"][0].tolist() == expected_charge
55+
56+
for i, arrays in enumerate(obj.iterate(step_size="10 kB")):
57+
if i == 0:
58+
assert len(arrays) == 363
59+
expected_pt = [10.763696670532227, 15.736522674560547]
60+
expected_charge = [-1, -1]
61+
assert arrays["Muon_pt"][0].tolist() == expected_pt
62+
assert arrays["Muon_charge"][0].tolist() == expected_charge
63+
elif i == 1:
64+
assert len(arrays) == 363
65+
elif i == 2:
66+
assert len(arrays) == 274
67+
else:
68+
assert False
69+
70+
Muon_pt = obj["Muon_pt"]
71+
for i, arrays in enumerate(Muon_pt.iterate(step_size=100)):
72+
assert len(arrays) == 100
73+
if i == 0:
74+
expected_pt = [10.763696670532227, 15.736522674560547]
75+
assert arrays[0].tolist() == expected_pt
76+
77+
for i, arrays in enumerate(Muon_pt.iterate(step_size="5 kB")):
78+
if i == 0:
79+
assert len(arrays) == 611
80+
expected_pt = [10.763696670532227, 15.736522674560547]
81+
assert arrays[0].tolist() == expected_pt
82+
elif i == 1:
83+
assert len(arrays) == 389
84+
else:
85+
assert False

0 commit comments

Comments
 (0)