Skip to content

Commit 4ee1a2d

Browse files
authored
feat: support reading of custom-length RNTuple floats and suppressed columns (#1347)
* Started implementing reading of quantized and truncated floats * Added support for suppressed columns * Added tests * Cleaner reading of floats with 1, 2, or 3 bytes * Only support little-endian systems * Fixed bug with Numpy 1 * Improved tests * Fixed tests for Numpy 2
1 parent 2ba58f2 commit 4ee1a2d

File tree

3 files changed

+225
-3
lines changed

3 files changed

+225
-3
lines changed

src/uproot/const.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@
242242
rntuple_col_type_to_num_dict["splitindex32"],
243243
rntuple_col_type_to_num_dict["splitindex64"],
244244
)
245+
rntuple_custom_float_types = (
246+
rntuple_col_type_to_num_dict["real32trunc"],
247+
rntuple_col_type_to_num_dict["real32quant"],
248+
)
245249

246250

247251
class RNTupleLocatorType(IntEnum):

src/uproot/models/RNTuple.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import struct
9+
import sys
910
from collections import defaultdict
1011
from itertools import accumulate
1112

@@ -241,6 +242,11 @@ def read_members(self, chunk, cursor, context, file):
241242
f"""memberwise serialization of {type(self).__name__}
242243
in file {self.file.file_path}"""
243244
)
245+
# Probably no one will encounter this, but just in case something doesn't work correctly
246+
if sys.byteorder != "little":
247+
raise NotImplementedError(
248+
"RNTuple reading is only supported on little-endian systems"
249+
)
244250

245251
(
246252
self._members["fVersionEpoch"],
@@ -524,6 +530,8 @@ def base_col_form(self, cr, col_id, parameters=None, cardinality=False):
524530
dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
525531
if dt_str == "bit":
526532
dt_str = "bool"
533+
elif dtype_byte in uproot.const.rntuple_custom_float_types:
534+
dt_str = "float32"
527535
return ak.forms.NumpyForm(
528536
dt_str,
529537
form_key=form_key,
@@ -546,6 +554,8 @@ def col_form(self, field_id):
546554
)
547555

548556
rel_crs = self._column_records_dict[cfid]
557+
# for this part we can use the default (zeroth) representation
558+
rel_crs = [c for c in rel_crs if c.repr_idx == 0]
549559

550560
if len(rel_crs) == 1: # base case
551561
cardinality = "RNTupleCardinality" in self.field_records[field_id].type_name
@@ -673,9 +683,14 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split):
673683
context = {}
674684
# bool in RNTuple is always stored as bits
675685
isbit = dtype_str == "bit"
676-
len_divider = 8 if isbit else 1
677686
num_elements = len(destination)
678-
num_elements_toread = int(numpy.ceil(num_elements / len_divider))
687+
if isbit:
688+
num_elements_toread = int(numpy.ceil(num_elements / 8))
689+
elif dtype_str in ("real32trunc", "real32quant"):
690+
num_elements_toread = int(numpy.ceil((num_elements * 4 * nbits) / 32))
691+
dtype = numpy.dtype("uint8")
692+
else:
693+
num_elements_toread = num_elements
679694
uncomp_size = num_elements_toread * dtype.itemsize
680695
decomp_chunk, cursor = self.read_locator(loc, uncomp_size, context)
681696
content = cursor.array(
@@ -722,6 +737,23 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split):
722737
.reshape(-1, 8)[:, ::-1]
723738
.reshape(-1)
724739
)
740+
elif dtype_str in ("real32trunc", "real32quant"):
741+
if nbits == 32:
742+
content = content.view(numpy.uint32)
743+
elif nbits % 8 == 0:
744+
new_content = numpy.zeros((num_elements, 4), numpy.uint8)
745+
nbytes = nbits // 8
746+
new_content[:, :nbytes] = content.reshape(-1, nbytes)
747+
content = new_content.view(numpy.uint32).reshape(-1)
748+
else:
749+
ak = uproot.extras.awkward()
750+
vm = ak.forth.ForthMachine32(
751+
f"""input x output y uint32 {num_elements} x #{nbits}bit-> y"""
752+
)
753+
vm.run({"x": content})
754+
content = vm["y"]
755+
if dtype_str == "real32trunc":
756+
content <<= 32 - nbits
725757

726758
# needed to chop off extra bits incase we used `unpackbits`
727759
destination[:] = content[:num_elements]
@@ -754,6 +786,10 @@ def read_col_pages(
754786

755787
def read_col_page(self, ncol, cluster_i):
756788
linklist = self.page_list_envelopes.pagelinklist[cluster_i]
789+
# Check if the column is suppressed and pick the non-suppressed one if so
790+
if ncol < len(linklist) and linklist[ncol].suppressed:
791+
rel_crs = self._column_records_dict[self.column_records[ncol].field_id]
792+
ncol = next(cr.idx for cr in rel_crs if not linklist[cr.idx].suppressed)
757793
pagelist = linklist[ncol].pages if ncol < len(linklist) else []
758794
dtype_byte = self.column_records[ncol].type
759795
dtype_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
@@ -762,14 +798,20 @@ def read_col_page(self, ncol, cluster_i):
762798
dtype = numpy.dtype([("index", "int64"), ("tag", "int32")])
763799
elif dtype_str == "bit":
764800
dtype = numpy.dtype("bool")
801+
elif dtype_byte in uproot.const.rntuple_custom_float_types:
802+
dtype = numpy.dtype("uint32") # for easier bit manipulation
765803
else:
766804
dtype = numpy.dtype(dtype_str)
767805
res = numpy.empty(total_len, dtype)
768806
split = dtype_byte in uproot.const.rntuple_split_types
769807
zigzag = dtype_byte in uproot.const.rntuple_zigzag_types
770808
delta = dtype_byte in uproot.const.rntuple_delta_types
771809
index = dtype_byte in uproot.const.rntuple_index_types
772-
nbits = uproot.const.rntuple_col_num_to_size_dict[dtype_byte]
810+
nbits = (
811+
self.column_records[ncol].nbits
812+
if ncol < len(self.column_records)
813+
else uproot.const.rntuple_col_num_to_size_dict[dtype_byte]
814+
)
773815
tracker = 0
774816
cumsum = 0
775817
for page_desc in pagelist:
@@ -789,6 +831,15 @@ def read_col_page(self, ncol, cluster_i):
789831
res = _from_zigzag(res)
790832
elif delta:
791833
res = numpy.cumsum(res)
834+
elif dtype_str == "real32trunc":
835+
res = res.view(numpy.float32)
836+
elif dtype_str == "real32quant" and ncol < len(self.column_records):
837+
min_value = self.column_records[ncol].min_value
838+
max_value = self.column_records[ncol].max_value
839+
res = min_value + res.astype(numpy.float32) * (max_value - min_value) / (
840+
(1 << nbits) - 1
841+
)
842+
res = res.astype(numpy.float32)
792843
return res
793844

794845
def arrays(
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE
2+
3+
import skhep_testdata
4+
import numpy as np
5+
6+
import uproot
7+
8+
9+
def truncate_float(value, bits):
10+
a = np.float32(value).view(np.uint32)
11+
a &= np.uint32(0xFFFFFFFF) << (32 - bits)
12+
return a.astype(np.uint32).view(np.float32)
13+
14+
15+
def quantize_float(value, bits, min, max):
16+
min = np.float32(min)
17+
max = np.float32(max)
18+
if value < min or value > max:
19+
raise ValueError(f"Value {value} is out of range [{min}, {max}]")
20+
scaled_value = (value - min) * (2**bits - 1) / (max - min)
21+
int_value = np.round(scaled_value)
22+
quantized_float = min + int_value * (max - min) / ((1 << bits) - 1)
23+
return quantized_float.astype(np.float32)
24+
25+
26+
def test_custom_floats():
27+
filename = skhep_testdata.data_path("test_float_types_rntuple_v1-0-0-0.root")
28+
with uproot.open(filename) as f:
29+
obj = f["ntuple"]
30+
31+
arrays = obj.arrays()
32+
33+
min_value = -2.0
34+
max_value = 3.0
35+
36+
entry = arrays[0]
37+
true_value = 1.23456789
38+
assert entry.trunc10 == truncate_float(true_value, 10)
39+
assert entry.trunc16 == truncate_float(true_value, 16)
40+
assert entry.trunc24 == truncate_float(true_value, 24)
41+
assert entry.trunc31 == truncate_float(true_value, 31)
42+
assert np.isclose(
43+
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
44+
)
45+
assert np.isclose(
46+
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
47+
)
48+
assert np.isclose(
49+
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
50+
)
51+
assert np.isclose(
52+
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
53+
)
54+
assert np.isclose(
55+
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
56+
)
57+
assert np.isclose(
58+
entry.quant25, quantize_float(true_value, 25, min_value, max_value)
59+
)
60+
assert np.isclose(
61+
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
62+
)
63+
64+
entry = arrays[1]
65+
true_value = 1.4660155e13
66+
assert entry.trunc10 == truncate_float(true_value, 10)
67+
assert entry.trunc16 == truncate_float(true_value, 16)
68+
assert entry.trunc24 == truncate_float(true_value, 24)
69+
assert entry.trunc31 == truncate_float(true_value, 31)
70+
true_value = 1.6666666
71+
assert np.isclose(
72+
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
73+
)
74+
assert np.isclose(
75+
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
76+
)
77+
assert np.isclose(
78+
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
79+
)
80+
assert np.isclose(
81+
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
82+
)
83+
assert np.isclose(
84+
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
85+
)
86+
assert np.isclose(
87+
entry.quant25, quantize_float(true_value, 25, min_value, max_value)
88+
)
89+
assert np.isclose(
90+
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
91+
)
92+
93+
entry = arrays[2]
94+
true_value = -6.2875986e-22
95+
assert entry.trunc10 == truncate_float(true_value, 10)
96+
assert entry.trunc16 == truncate_float(true_value, 16)
97+
assert entry.trunc24 == truncate_float(true_value, 24)
98+
assert entry.trunc31 == truncate_float(true_value, 31)
99+
assert np.isclose(
100+
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
101+
)
102+
assert np.isclose(
103+
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
104+
)
105+
assert np.isclose(
106+
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
107+
)
108+
assert np.isclose(
109+
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
110+
)
111+
assert np.isclose(
112+
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
113+
)
114+
assert np.isclose(
115+
entry.quant25,
116+
quantize_float(true_value, 25, min_value, max_value),
117+
atol=2e-07,
118+
)
119+
assert np.isclose(
120+
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
121+
)
122+
123+
entry = arrays[3]
124+
true_value = -1.9060668
125+
assert entry.trunc10 == truncate_float(true_value, 10)
126+
assert entry.trunc16 == truncate_float(true_value, 16)
127+
assert entry.trunc24 == truncate_float(true_value, 24)
128+
assert entry.trunc31 == truncate_float(true_value, 31)
129+
assert np.isclose(
130+
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
131+
)
132+
assert np.isclose(
133+
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
134+
)
135+
assert np.isclose(
136+
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
137+
)
138+
assert np.isclose(
139+
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
140+
)
141+
assert np.isclose(
142+
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
143+
)
144+
assert np.isclose(
145+
entry.quant25, quantize_float(true_value, 25, min_value, max_value)
146+
)
147+
assert np.isclose(
148+
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
149+
)
150+
151+
152+
def test_multiple_representations():
153+
filename = skhep_testdata.data_path(
154+
"test_multiple_representations_rntuple_v1-0-0-0.root"
155+
)
156+
with uproot.open(filename) as f:
157+
obj = f["ntuple"]
158+
159+
assert len(obj.page_list_envelopes.pagelinklist) == 3
160+
# The zeroth representation is active in clusters 0 and 2, but not in cluster 1
161+
assert not obj.page_list_envelopes.pagelinklist[0][0].suppressed
162+
assert obj.page_list_envelopes.pagelinklist[1][0].suppressed
163+
assert not obj.page_list_envelopes.pagelinklist[2][0].suppressed
164+
165+
arrays = obj.arrays()
166+
167+
assert np.allclose(arrays.real, [1, 2, 3])

0 commit comments

Comments
 (0)