Skip to content

Commit b67e16d

Browse files
authored
Merge pull request #8 from hpc4cmb/slicing
Correctly reshape decompressed array data from [] notation.
2 parents 9d098a5 + 0cf1900 commit b67e16d

File tree

4 files changed

+161
-57
lines changed

4 files changed

+161
-57
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ jobs:
8585
&& conda create --yes -n test python==${{ matrix.python }} \
8686
&& conda activate test \
8787
&& conda install --yes --file packaging/conda_build_requirements.txt
88+
if test ${{ matrix.python }} = "3.9"; then conda install libxcrypt; fi
8889
8990
- name: Install
9091
run: |

docs/docs/install.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ create an environment called "flacarray". First create the env with all
4141
dependencies and activate it (FIXME, add a requirements file for dev):
4242

4343
conda create -n flacarray \
44-
c_compiler numpy libflac cython meson-python pkgconfig
44+
c-compiler numpy libflac cython meson-python pkgconfig
4545

4646
conda activate flacarray
4747

src/flacarray/array.py

Lines changed: 104 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
shape=None,
7272
global_shape=None,
7373
compressed=None,
74+
dtype=None,
7475
stream_starts=None,
7576
stream_nbytes=None,
7677
stream_offsets=None,
@@ -84,6 +85,7 @@ def __init__(
8485
self._shape = copy.deepcopy(other._shape)
8586
self._global_shape = copy.deepcopy(other._global_shape)
8687
self._compressed = copy.deepcopy(other._compressed)
88+
self._dtype = np.dtype(other._dtype)
8789
self._stream_starts = copy.deepcopy(other._stream_starts)
8890
self._stream_nbytes = copy.deepcopy(other._stream_nbytes)
8991
self._stream_offsets = copy.deepcopy(other._stream_offsets)
@@ -97,6 +99,7 @@ def __init__(
9799
self._shape = shape
98100
self._global_shape = global_shape
99101
self._compressed = compressed
102+
self._dtype = np.dtype(dtype)
100103
self._stream_starts = stream_starts
101104
self._stream_nbytes = stream_nbytes
102105
self._stream_offsets = stream_offsets
@@ -120,19 +123,23 @@ def _init_params(self):
120123
else:
121124
self._global_leading_shape = self._global_shape[:-1]
122125
self._global_nstreams = np.prod(self._global_leading_shape)
123-
# For reference, record the type of the original data.
124-
if self._stream_offsets is not None:
125-
if self._stream_gains is not None:
126-
# This is floating point data
127-
if self._stream_gains.dtype == np.dtype(np.float64):
128-
self._typestr = "float64"
129-
else:
130-
self._typestr = "float32"
131-
else:
132-
# This is int64 data
133-
self._typestr = "int64"
126+
# For reference, record the type string of the original data.
127+
self._typestr = self._dtype_str(self._dtype)
128+
129+
@staticmethod
130+
def _dtype_str(dt):
131+
if dt == np.dtype(np.float64):
132+
return "float64"
133+
elif dt == np.dtype(np.float32):
134+
return "float32"
135+
elif dt == np.dtype(np.int64):
136+
return "int64"
137+
elif dt == np.dtype(np.int32):
138+
return "int32"
134139
else:
135-
self._typestr = "int32"
140+
msg = f"Unsupported dtype '{dt}'"
141+
raise RuntimeError(msg)
142+
return None
136143

137144
# Shapes of decompressed array
138145

@@ -233,63 +240,109 @@ def mpi_dist(self):
233240
"""The range of the leading dimension assigned to each MPI process."""
234241
return self._mpi_dist
235242

243+
@property
244+
def dtype(self):
245+
"""The dtype of the uncompressed array."""
246+
return self._dtype
247+
236248
def _keep_view(self, key):
237249
if len(key) != len(self._leading_shape):
238250
raise ValueError("view size does not match leading dimensions")
239251
view = np.zeros(self._leading_shape, dtype=bool)
240252
view[key] = True
241253
return view
242254

243-
def __getitem__(self, key):
255+
def _slice_nelem(self, slc, dim):
256+
start, stop, step = slc.indices(dim)
257+
nslc = (stop - start) // step
258+
if nslc < 0:
259+
nslc = 0
260+
return nslc
261+
262+
def __getitem__(self, raw_key):
244263
"""Decompress a slice of data on the fly."""
245264
first = None
246265
last = None
247266
keep = None
248-
if isinstance(key, tuple):
249-
# We are slicing on multiple dimensions
250-
if len(key) == len(self._shape):
251-
# Slicing on the sample dimension too
252-
keep = self._keep_view(key[:-1])
253-
samp_key = key[-1]
254-
if isinstance(samp_key, slice):
267+
ndim = len(self._shape)
268+
output_shape = list()
269+
sample_shape = (self._shape[-1],)
270+
if isinstance(raw_key, tuple):
271+
key = raw_key
272+
else:
273+
key = (raw_key,)
274+
keep_slice = list()
275+
for axis, axkey in enumerate(key):
276+
if axis < ndim - 1:
277+
# One of the leading dimensions
278+
keep_slice.append(axkey)
279+
if not isinstance(axkey, (int, np.integer)):
280+
# Some kind of slice, do not compress this dimension. Compute
281+
# the number of elements in the output shape.
282+
nslc = self._slice_nelem(axkey, self._shape[axis])
283+
output_shape.append(nslc)
284+
else:
285+
# This is the sample axis. Special handling to ensure that the
286+
# selected samples are contiguous.
287+
if isinstance(axkey, slice):
255288
# A slice
256-
if samp_key.step is not None and samp_key.step != 1:
257-
raise ValueError("Only stride==1 supported on stream slices")
258-
first = samp_key.start
259-
last = samp_key.stop
260-
elif isinstance(samp_key, (int, np.integer)):
289+
if (axkey.step is not None and axkey.step != 1):
290+
msg = "Only stride==1 supported on stream slices"
291+
raise ValueError(msg)
292+
if (
293+
axkey.start is not None
294+
and axkey.stop is not None
295+
and axkey.stop < axkey.start
296+
):
297+
msg = "Only increasing slices supported on streams"
298+
raise ValueError(msg)
299+
first = axkey.start
300+
last = axkey.stop
301+
if first is None or first < 0:
302+
first = 0
303+
if first > self._shape[-1] - 1:
304+
first = self._shape[-1] - 1
305+
if last is None or last > self._shape[-1]:
306+
last = self._shape[-1]
307+
if last < 1:
308+
last = 1
309+
sample_shape = (last - first,)
310+
elif isinstance(axkey, (int, np.integer)):
261311
# Just a scalar
262-
first = samp_key
263-
last = samp_key + 1
312+
first = axkey
313+
last = axkey + 1
314+
sample_shape = ()
264315
else:
265316
raise ValueError(
266317
"Only contiguous slices supported on the stream dimension"
267318
)
268-
else:
269-
# Only slicing the leading dimensions
270-
vw = list(key)
271-
vw.extend(
272-
[slice(None) for x in range(len(self._leading_shape) - len(key))]
273-
)
274-
keep = self._keep_view(tuple(vw))
275-
else:
276-
# We are slicing / indexing only the leading dimension
277-
vw = [slice(None) for x in range(len(self._leading_shape))]
278-
vw[0] = key
279-
keep = self._keep_view(tuple(vw))
280-
281-
arr, _ = array_decompress_slice(
282-
self._compressed,
283-
self._stream_size,
284-
self._stream_starts,
285-
self._stream_nbytes,
286-
stream_offsets=self._stream_offsets,
287-
stream_gains=self._stream_gains,
288-
keep=keep,
289-
first_stream_sample=first,
290-
last_stream_sample=last,
319+
keep_slice.extend(
320+
[slice(None) for x in range(len(self._leading_shape) - len(key))]
291321
)
292-
return arr
322+
output_shape.extend(
323+
[x for x in self._leading_shape[len(key):]]
324+
)
325+
keep = self._keep_view(tuple(keep_slice))
326+
output_shape = tuple(output_shape)
327+
full_shape = output_shape + sample_shape
328+
n_total = np.prod(full_shape)
329+
if n_total == 0:
330+
# At least one dimension was zero, return empty array
331+
return np.zeros(full_shape, dtype=self._dtype)
332+
else:
333+
# We have some samples
334+
arr, strm_indices = array_decompress_slice(
335+
self._compressed,
336+
self._stream_size,
337+
self._stream_starts,
338+
self._stream_nbytes,
339+
stream_offsets=self._stream_offsets,
340+
stream_gains=self._stream_gains,
341+
keep=keep,
342+
first_stream_sample=first,
343+
last_stream_sample=last,
344+
)
345+
return arr.reshape(full_shape)
293346

294347
def __delitem__(self, key):
295348
raise RuntimeError("Cannot delete individual streams")

src/flacarray/tests/array.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def test_helpers(self):
3535
.astype(np.int32)
3636
)
3737

38-
comp_i32, starts_i32, nbytes_i32, off_i32, gain_i32 = array_compress(data_i32, level=5)
38+
comp_i32, starts_i32, nbytes_i32, off_i32, gain_i32 = array_compress(
39+
data_i32, level=5
40+
)
3941
self.assertTrue(off_i32 is None)
4042
self.assertTrue(gain_i32 is None)
4143

@@ -59,7 +61,9 @@ def test_helpers(self):
5961
low=-(2**30), high=2**29, size=flatsize, dtype=np.int64
6062
).reshape(data_shape)
6163

62-
comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress(data_i64, level=5)
64+
comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress(
65+
data_i64, level=5
66+
)
6367
self.assertTrue(gain_i64 is None)
6468

6569
check_i64 = array_decompress(
@@ -85,7 +89,9 @@ def test_helpers(self):
8589
).reshape(data_shape)
8690

8791
try:
88-
comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress(data_i64, level=5)
92+
comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress(
93+
data_i64, level=5
94+
)
8995
print("Failed to catch truncation of int64 data")
9096
self.assertTrue(False)
9197
except RuntimeError:
@@ -94,7 +100,9 @@ def test_helpers(self):
94100
# float32 data
95101

96102
data_f32 = create_fake_data(data_shape, 1.0).astype(np.float32)
97-
comp_f32, starts_f32, nbytes_f32, off_f32, gain_f32 = array_compress(data_f32, level=5)
103+
comp_f32, starts_f32, nbytes_f32, off_f32, gain_f32 = array_compress(
104+
data_f32, level=5
105+
)
98106
check_f32 = array_decompress(
99107
comp_f32,
100108
data_shape[-1],
@@ -123,7 +131,9 @@ def test_helpers(self):
123131

124132
data_f64 = create_fake_data(data_shape, 1.0)
125133

126-
comp_f64, starts_f64, nbytes_f64, off_f64, gain_f64 = array_compress(data_f64, level=5)
134+
comp_f64, starts_f64, nbytes_f64, off_f64, gain_f64 = array_compress(
135+
data_f64, level=5
136+
)
127137
check_f64 = array_decompress(
128138
comp_f64,
129139
data_shape[-1],
@@ -163,3 +173,43 @@ def test_array_memory(self):
163173
self.assertTrue(
164174
np.allclose(check_slc_f64, data_f64[:, :, first:last], rtol=1e-5, atol=1e-5)
165175
)
176+
177+
def test_slicing_shape(self):
178+
data_shape = (4, 3, 10, 100)
179+
flatsize = np.prod(data_shape)
180+
rng = np.random.default_rng()
181+
data_i32 = (
182+
rng.integers(low=-(2**27), high=2**30, size=flatsize, dtype=np.int32)
183+
.reshape(data_shape)
184+
.astype(np.int32)
185+
)
186+
187+
farray = FlacArray.from_array(data_i32)
188+
189+
# Try some slices and verify expected result shape.
190+
for dslc in [
191+
(slice(0)),
192+
(slice(1, 3)),
193+
(slice(3, 1)),
194+
(slice(3, 1, -1)),
195+
(1, 2, 5, 50),
196+
(1, 2, 5),
197+
(2, slice(0, 1, 1), slice(0, 1, 1), slice(None)),
198+
(1, slice(1, 3, 1), slice(6, 8, 1), 50),
199+
(slice(1, 3, 1), 2, slice(6, 8, 1), slice(60, 80, 1)),
200+
(2, 1, slice(2, 8, 2), slice(80, 120, 1)),
201+
(2, 1, slice(2, 8, 2), slice(80, None)),
202+
(2, 1, slice(2, 8, 2), slice(None, 10)),
203+
]:
204+
# Slice of the original numpy array
205+
check = data_i32[dslc]
206+
# Slice of the FlacArray
207+
fcheck = farray[dslc]
208+
209+
# Compare the shapes
210+
if fcheck.shape != check.shape:
211+
print(
212+
f"Array[{dslc}] shape: {fcheck.shape} != {check.shape}",
213+
flush=True,
214+
)
215+
raise RuntimeError("Failed slice shape check")

0 commit comments

Comments
 (0)