diff --git a/docs/docs/install.md b/docs/docs/install.md index 6d02f68..24a6dca 100644 --- a/docs/docs/install.md +++ b/docs/docs/install.md @@ -41,7 +41,7 @@ create an environment called "flacarray". First create the env with all dependencies and activate it (FIXME, add a requirements file for dev): conda create -n flacarray \ - c_compiler numpy libflac cython meson-python pkgconfig + c-compiler numpy libflac cython meson-python pkgconfig conda activate flacarray diff --git a/src/flacarray/array.py b/src/flacarray/array.py index 01bf672..c5a1d7f 100644 --- a/src/flacarray/array.py +++ b/src/flacarray/array.py @@ -240,45 +240,85 @@ def _keep_view(self, key): view[key] = True return view + def _slice_nelem(self, slc, dim): + islc = slc.indices(dim) + nslc = (islc[1] - islc[0]) // islc[2] + if islc[1] < islc[0]: + nslc += 1 + return nslc + def __getitem__(self, key): """Decompress a slice of data on the fly.""" first = None last = None keep = None + ndim = len(self._shape) + output_shape = list() + sample_shape = (self._shape[-1],) if isinstance(key, tuple): # We are slicing on multiple dimensions - if len(key) == len(self._shape): - # Slicing on the sample dimension too - keep = self._keep_view(key[:-1]) - samp_key = key[-1] - if isinstance(samp_key, slice): - # A slice - if samp_key.step is not None and samp_key.step != 1: - raise ValueError("Only stride==1 supported on stream slices") - first = samp_key.start - last = samp_key.stop - elif isinstance(samp_key, (int, np.integer)): - # Just a scalar - first = samp_key - last = samp_key + 1 + keep_slice = list() + for axis, axkey in enumerate(key): + if axis < ndim - 1: + # One of the leading dimensions + keep_slice.append(axkey) + if not isinstance(axkey, (int, np.integer)): + # Some kind of slice, do not compress this dimension. Compute + # the number of elements in the output shape. + nslc = self._slice_nelem(axkey, self._shape[axis]) + output_shape.append(nslc) else: - raise ValueError( - "Only contiguous slices supported on the stream dimension" - ) - else: - # Only slicing the leading dimensions - vw = list(key) - vw.extend( - [slice(None) for x in range(len(self._leading_shape) - len(key))] - ) - keep = self._keep_view(tuple(vw)) + # This is the sample axis. Special handling to ensure that the + # selected samples are contiguous. + if isinstance(axkey, slice): + # A slice + if (axkey.step is not None and axkey.step != 1): + msg = "Only stride==1 supported on stream slices" + raise ValueError(msg) + if ( + axkey.start is not None + and axkey.stop is not None + and axkey.stop < axkey.start + ): + msg = "Only increasing slices supported on streams" + raise ValueError(msg) + first = axkey.start + last = axkey.stop + if first is None or first < 0: + first = 0 + if first > self._shape[-1] - 1: + first = self._shape[-1] - 1 + if last is None or last > self._shape[-1]: + last = self._shape[-1] + if last < 1: + last = 1 + sample_shape = (last - first,) + elif isinstance(axkey, (int, np.integer)): + # Just a scalar + first = axkey + last = axkey + 1 + sample_shape = () + else: + raise ValueError( + "Only contiguous slices supported on the stream dimension" + ) + keep_slice.extend( + [slice(None) for x in range(len(self._leading_shape) - len(key))] + ) else: # We are slicing / indexing only the leading dimension - vw = [slice(None) for x in range(len(self._leading_shape))] - vw[0] = key - keep = self._keep_view(tuple(vw)) - - arr, _ = array_decompress_slice( + keep_slice = [slice(None) for x in range(len(self._leading_shape))] + keep_slice[0] = key + if not isinstance(key, (int, np.integer)): + # Some kind of slice, do not compress this dimension. Compute + # the number of elements in the output shape. + nslc = self._slice_nelem(key, self._shape[0]) + output_shape.append(nslc) + + keep = self._keep_view(tuple(keep_slice)) + output_shape = tuple(output_shape) + + arr, strm_indices = array_decompress_slice( self._compressed, self._stream_size, self._stream_starts, @@ -289,7 +329,8 @@ def __getitem__(self, key): first_stream_sample=first, last_stream_sample=last, ) - return arr + full_shape = output_shape + sample_shape + return arr.reshape(full_shape) def __delitem__(self, key): raise RuntimeError("Cannot delete individual streams") diff --git a/src/flacarray/tests/array.py b/src/flacarray/tests/array.py index c9ba363..a8999d3 100644 --- a/src/flacarray/tests/array.py +++ b/src/flacarray/tests/array.py @@ -35,7 +35,9 @@ def test_helpers(self): .astype(np.int32) ) - comp_i32, starts_i32, nbytes_i32, off_i32, gain_i32 = array_compress(data_i32, level=5) + comp_i32, starts_i32, nbytes_i32, off_i32, gain_i32 = array_compress( + data_i32, level=5 + ) self.assertTrue(off_i32 is None) self.assertTrue(gain_i32 is None) @@ -59,7 +61,9 @@ def test_helpers(self): low=-(2**30), high=2**29, size=flatsize, dtype=np.int64 ).reshape(data_shape) - comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress(data_i64, level=5) + comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress( + data_i64, level=5 + ) self.assertTrue(gain_i64 is None) check_i64 = array_decompress( @@ -85,7 +89,9 @@ def test_helpers(self): ).reshape(data_shape) try: - comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress(data_i64, level=5) + comp_i64, starts_i64, nbytes_i64, off_i64, gain_i64 = array_compress( + data_i64, level=5 + ) print("Failed to catch truncation of int64 data") self.assertTrue(False) except RuntimeError: @@ -94,7 +100,9 @@ def test_helpers(self): # float32 data data_f32 = create_fake_data(data_shape, 1.0).astype(np.float32) - comp_f32, starts_f32, nbytes_f32, off_f32, gain_f32 = array_compress(data_f32, level=5) + comp_f32, starts_f32, nbytes_f32, off_f32, gain_f32 = array_compress( + data_f32, level=5 + ) check_f32 = array_decompress( comp_f32, data_shape[-1], @@ -123,7 +131,9 @@ def test_helpers(self): data_f64 = create_fake_data(data_shape, 1.0) - comp_f64, starts_f64, nbytes_f64, off_f64, gain_f64 = array_compress(data_f64, level=5) + comp_f64, starts_f64, nbytes_f64, off_f64, gain_f64 = array_compress( + data_f64, level=5 + ) check_f64 = array_decompress( comp_f64, data_shape[-1], @@ -163,3 +173,39 @@ def test_array_memory(self): self.assertTrue( np.allclose(check_slc_f64, data_f64[:, :, first:last], rtol=1e-5, atol=1e-5) ) + + def test_slicing_shape(self): + data_shape = (4, 3, 10, 100) + flatsize = np.prod(data_shape) + rng = np.random.default_rng() + data_i32 = ( + rng.integers(low=-(2**27), high=2**30, size=flatsize, dtype=np.int32) + .reshape(data_shape) + .astype(np.int32) + ) + + farray = FlacArray.from_array(data_i32) + + # Try some slices and verify expected result shape. + for dslc in [ + (1, 2, 5, 50), + (1, 2, 5), + (2, slice(0, 1, 1), slice(0, 1, 1), slice(None)), + (1, slice(1, 3, 1), slice(6, 8, 1), 50), + (slice(1, 3, 1), 2, slice(6, 8, 1), slice(60, 80, 1)), + (2, 1, slice(2, 8, 2), slice(80, 120, 1)), + (2, 1, slice(2, 8, 2), slice(80, None)), + (2, 1, slice(2, 8, 2), slice(None, 10)), + ]: + # Slice of the original numpy array + check = data_i32[dslc] + # Slice of the FlacArray + fcheck = farray[dslc] + + # Compare the shapes + if fcheck.shape != check.shape: + print( + f"Array[{dslc}] shape: {fcheck.shape} != {check.shape}", + flush=True, + ) + raise RuntimeError("Failed slice shape check")