Skip to content

Commit d0d05b2

Browse files
authored
Also improve scalar readers by caching (#261)
Signed-off-by: Christian Vetter <christian.vetter@here.com>
1 parent 70b9050 commit d0d05b2

3 files changed

Lines changed: 144 additions & 61 deletions

File tree

flatdata-py/flatdata/lib/data_access.py

Lines changed: 106 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,61 +9,71 @@
99
_SIGN_BITS = [0] + [(1 << (bits - 1)) for bits in range(1, 65)]
1010

1111

12-
def read_value(data, offset_bits, num_bits, is_signed):
13-
offset_bytes, offset_extra_bits = divmod(offset_bits, 8)
14-
total_bytes = (num_bits + 7) // 8
15-
16-
if num_bits == 1:
17-
return int((data[offset_bytes] & (1 << offset_extra_bits)) != 0)
18-
19-
result = int.from_bytes(data[offset_bytes: offset_bytes + total_bytes], byteorder="little")
20-
result >>= offset_extra_bits
21-
if (total_bytes * 8 - offset_extra_bits) < num_bits:
22-
remainder = data[offset_bytes + total_bytes]
23-
result |= remainder << (total_bytes * 8 - offset_extra_bits)
12+
def make_field_reader(offset_bits, num_bits, is_signed):
13+
"""Build a specialized closure for reading a single field from a structure.
2414
25-
if num_bits < 64 or offset_extra_bits > 0:
26-
result = result & ((1 << num_bits) - 1)
27-
28-
if not is_signed:
29-
return result
30-
31-
return (result & (_SIGN_BITS[num_bits] - 1)) - (result & _SIGN_BITS[num_bits])
32-
33-
34-
def write_value(data, offset_bits, num_bits, is_signed, value):
35-
assert num_bits <= 64, f'Number of bits to write is greater than 64'
36-
37-
offset_bytes, offset_extra_bits = divmod(offset_bits, 8)
15+
Returns a function reader(data, pos_bytes) that reads the field value
16+
from ``data`` at byte position ``pos_bytes``. All constants (byte offset,
17+
bit shift, mask, sign handling) are pre-computed and captured by the
18+
closure so the hot path does minimal work.
19+
"""
20+
offset_bytes, offset_extra = divmod(offset_bits, 8)
3821
total_bytes = (num_bits + 7) // 8
39-
40-
if num_bits == 1:
41-
if value == 1:
42-
data[offset_bytes] |= 1 << offset_extra_bits
43-
else:
44-
data[offset_bytes] &= ~(1 << offset_extra_bits)
45-
return
46-
22+
end_byte = offset_bytes + total_bytes
4723
mask = (1 << num_bits) - 1
48-
value <<= offset_extra_bits
49-
value &= mask << offset_extra_bits
50-
value_in_little_endian = value.to_bytes(total_bytes + 1, byteorder="little", signed=is_signed)
51-
surrounding_bits = data[offset_bytes] & ((1 << offset_bits) - 1)
24+
needs_extra = (total_bytes * 8 - offset_extra) < num_bits
25+
extra_shift = total_bytes * 8 - offset_extra
5226

53-
byte_idx = 0
54-
data[offset_bytes] = value_in_little_endian[byte_idx]
55-
data[offset_bytes] |= surrounding_bits
56-
57-
byte_idx += 1
58-
while byte_idx < total_bytes:
59-
data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx]
60-
byte_idx += 1
27+
if num_bits == 1:
28+
bit_mask = 1 << offset_extra
29+
def reader(data, pos):
30+
return int((data[pos + offset_bytes] & bit_mask) != 0)
31+
return reader
6132

62-
bits_written = total_bytes * 8 - offset_extra_bits
63-
if bits_written < num_bits:
64-
surrounding_bits = data[offset_bytes + byte_idx] & ~((1 << offset_bits) - 1)
65-
data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx] & ((1 << (8 - (bits_written % 8))) - 1)
66-
data[offset_bytes + byte_idx] |= surrounding_bits
33+
if is_signed:
34+
sign_bit = _SIGN_BITS[num_bits]
35+
sign_mask = sign_bit - 1
36+
if needs_extra:
37+
def reader(data, pos):
38+
result = int.from_bytes(
39+
data[pos + offset_bytes: pos + end_byte], byteorder="little")
40+
result >>= offset_extra
41+
result |= data[pos + end_byte] << extra_shift
42+
result &= mask
43+
return (result & sign_mask) - (result & sign_bit)
44+
elif offset_extra:
45+
def reader(data, pos):
46+
result = (int.from_bytes(
47+
data[pos + offset_bytes: pos + end_byte],
48+
byteorder="little") >> offset_extra) & mask
49+
return (result & sign_mask) - (result & sign_bit)
50+
else:
51+
def reader(data, pos):
52+
result = int.from_bytes(
53+
data[pos + offset_bytes: pos + end_byte],
54+
byteorder="little") & mask
55+
return (result & sign_mask) - (result & sign_bit)
56+
return reader
57+
58+
# Unsigned paths
59+
if needs_extra:
60+
def reader(data, pos):
61+
result = int.from_bytes(
62+
data[pos + offset_bytes: pos + end_byte], byteorder="little")
63+
result >>= offset_extra
64+
result |= data[pos + end_byte] << extra_shift
65+
return result & mask
66+
elif offset_extra:
67+
def reader(data, pos):
68+
return (int.from_bytes(
69+
data[pos + offset_bytes: pos + end_byte],
70+
byteorder="little") >> offset_extra) & mask
71+
else:
72+
def reader(data, pos):
73+
return int.from_bytes(
74+
data[pos + offset_bytes: pos + end_byte],
75+
byteorder="little") & mask
76+
return reader
6777

6878

6979
def read_field_vectorized(raw_bytes_2d, field_offset_bits, field_width_bits, is_signed):
@@ -110,3 +120,49 @@ def read_field_vectorized(raw_bytes_2d, field_offset_bits, field_width_bits, is_
110120
result = np.where(result & sign_bit, signed, result.astype(np.int64))
111121

112122
return result
123+
124+
125+
def read_value(data, offset_bits, num_bits, is_signed):
126+
"""Read a bit-packed value from data at the given bit offset.
127+
128+
This is a convenience wrapper around :func:`make_field_reader` for one-off
129+
reads. For repeated reads of the same field, prefer building a reader once
130+
with ``make_field_reader`` and reusing it.
131+
"""
132+
reader = make_field_reader(offset_bits, num_bits, is_signed)
133+
return reader(data, 0)
134+
135+
136+
def write_value(data, offset_bits, num_bits, is_signed, value):
137+
assert num_bits <= 64, f'Number of bits to write is greater than 64'
138+
139+
offset_bytes, offset_extra_bits = divmod(offset_bits, 8)
140+
total_bytes = (num_bits + 7) // 8
141+
142+
if num_bits == 1:
143+
if value == 1:
144+
data[offset_bytes] |= 1 << offset_extra_bits
145+
else:
146+
data[offset_bytes] &= ~(1 << offset_extra_bits)
147+
return
148+
149+
mask = (1 << num_bits) - 1
150+
value <<= offset_extra_bits
151+
value &= mask << offset_extra_bits
152+
value_in_little_endian = value.to_bytes(total_bytes + 1, byteorder="little", signed=is_signed)
153+
surrounding_bits = data[offset_bytes] & ((1 << offset_bits) - 1)
154+
155+
byte_idx = 0
156+
data[offset_bytes] = value_in_little_endian[byte_idx]
157+
data[offset_bytes] |= surrounding_bits
158+
159+
byte_idx += 1
160+
while byte_idx < total_bytes:
161+
data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx]
162+
byte_idx += 1
163+
164+
bits_written = total_bytes * 8 - offset_extra_bits
165+
if bits_written < num_bits:
166+
surrounding_bits = data[offset_bytes + byte_idx] & ~((1 << offset_bits) - 1)
167+
data[offset_bytes + byte_idx] = value_in_little_endian[byte_idx] & ((1 << (8 - (bits_written % 8))) - 1)
168+
data[offset_bytes + byte_idx] |= surrounding_bits

flatdata-py/flatdata/lib/structure.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,33 @@
22
import json
33
import numpy as np
44

5-
from .data_access import read_value
5+
from .data_access import make_field_reader
66

77
FieldSignature = namedtuple(
88
"FieldSignature", ["offset", "width", "is_signed", "dtype"])
99

1010

1111
class Structure:
1212
__slots__ = ('_mem', '_pos')
13+
_READERS = {}
14+
15+
def __init_subclass__(cls, **kwargs):
16+
super().__init_subclass__(**kwargs)
17+
fields = cls.__dict__.get('_FIELDS')
18+
if fields is not None:
19+
cls._READERS = {name: make_field_reader(f.offset, f.width, f.is_signed)
20+
for name, f in fields.items()}
1321

1422
def __init__(self, mem, pos):
1523
self._mem = mem
1624
self._pos = pos
1725

1826
def __getattr__(self, name):
1927
try:
20-
field = self._FIELDS[name]
28+
reader = self._READERS[name]
2129
except KeyError:
2230
raise AttributeError("Field %s not found in structure" % name)
23-
return self._get_value(field)
24-
25-
def _get_value(self, field):
26-
return read_value(self._mem, self._pos * 8 + field.offset, field.width, field.is_signed)
31+
return reader(self._mem, self._pos)
2732

2833
def __dir__(self):
2934
return self._FIELD_KEYS
@@ -33,20 +38,24 @@ def __iter__(self):
3338
yield getattr(self, name)
3439

3540
def as_dict(self):
36-
return {name: self._get_value(field) for name, field in self._FIELDS.items()}
41+
mem, pos = self._mem, self._pos
42+
return {name: reader(mem, pos) for name, reader in self._READERS.items()}
3743

3844
def as_list(self):
39-
return [self._get_value(field) for field in self._FIELDS.values()]
45+
mem, pos = self._mem, self._pos
46+
return [reader(mem, pos) for reader in self._READERS.values()]
4047

4148
def as_tuple(self):
42-
return tuple(self._get_value(field) for field in self._FIELDS.values())
49+
mem, pos = self._mem, self._pos
50+
return tuple(reader(mem, pos) for reader in self._READERS.values())
4351

4452
@classmethod
4553
def dtype(cls):
4654
return [(name, np.dtype(field.dtype)) for name, field in cls._FIELDS.items()]
4755

4856
def as_nparray(self):
49-
return np.array([tuple(self._get_value(field) for name, field in self._FIELDS.items())],
57+
mem, pos = self._mem, self._pos
58+
return np.array([tuple(reader(mem, pos) for reader in self._READERS.values())],
5059
dtype=self.dtype())
5160

5261
def schema(self):

flatdata-py/tests/test_data_access.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from flatdata.lib.data_access import read_value, write_value
2+
from flatdata.lib.data_access import read_value, write_value, make_field_reader
33

44

55
def test_reader():
@@ -2264,3 +2264,21 @@ def _test_writer(buffer, offset, num_bits, is_signed, expected):
22642264
_test_writer(b"\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 3, 16, True, 8192)
22652265
_test_writer(b"\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 3, 16, True, 16384)
22662266
_test_writer(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 3, 2, True, 0)
2267+
2268+
2269+
def test_make_field_reader_with_nonzero_pos():
2270+
"""Reader closures must produce correct results at arbitrary byte positions."""
2271+
data = bytearray(20)
2272+
struct_bytes = b'\xab\xcd\xef\x12\x98\x76\x54\x32\x10'
2273+
data[0:9] = struct_bytes
2274+
data[10:19] = struct_bytes
2275+
2276+
for offset_bits in [0, 3, 8, 13]:
2277+
for num_bits in [1, 5, 8, 16, 32, 64]:
2278+
for is_signed in [False, True]:
2279+
if offset_bits + num_bits > len(struct_bytes) * 8:
2280+
continue
2281+
reader = make_field_reader(offset_bits, num_bits, is_signed)
2282+
assert reader(data, 0) == reader(data, 10), (
2283+
f"offset={offset_bits}, width={num_bits}, signed={is_signed}: "
2284+
f"pos=0 got {reader(data, 0)}, pos=10 got {reader(data, 10)}")

0 commit comments

Comments
 (0)