@@ -71,6 +71,7 @@ def __init__(
7171 shape = None ,
7272 global_shape = None ,
7373 compressed = None ,
74+ dtype = None ,
7475 stream_starts = None ,
7576 stream_nbytes = None ,
7677 stream_offsets = None ,
@@ -84,6 +85,7 @@ def __init__(
8485 self ._shape = copy .deepcopy (other ._shape )
8586 self ._global_shape = copy .deepcopy (other ._global_shape )
8687 self ._compressed = copy .deepcopy (other ._compressed )
88+ self ._dtype = np .dtype (other ._dtype )
8789 self ._stream_starts = copy .deepcopy (other ._stream_starts )
8890 self ._stream_nbytes = copy .deepcopy (other ._stream_nbytes )
8991 self ._stream_offsets = copy .deepcopy (other ._stream_offsets )
@@ -97,6 +99,7 @@ def __init__(
9799 self ._shape = shape
98100 self ._global_shape = global_shape
99101 self ._compressed = compressed
102+ self ._dtype = np .dtype (dtype )
100103 self ._stream_starts = stream_starts
101104 self ._stream_nbytes = stream_nbytes
102105 self ._stream_offsets = stream_offsets
@@ -120,19 +123,23 @@ def _init_params(self):
120123 else :
121124 self ._global_leading_shape = self ._global_shape [:- 1 ]
122125 self ._global_nstreams = np .prod (self ._global_leading_shape )
123- # For reference, record the type of the original data.
124- if self ._stream_offsets is not None :
125- if self ._stream_gains is not None :
126- # This is floating point data
127- if self ._stream_gains .dtype == np .dtype (np .float64 ):
128- self ._typestr = "float64"
129- else :
130- self ._typestr = "float32"
131- else :
132- # This is int64 data
133- self ._typestr = "int64"
126+ # For reference, record the type string of the original data.
127+ self ._typestr = self ._dtype_str (self ._dtype )
128+
129+ @staticmethod
130+ def _dtype_str (dt ):
131+ if dt == np .dtype (np .float64 ):
132+ return "float64"
133+ elif dt == np .dtype (np .float32 ):
134+ return "float32"
135+ elif dt == np .dtype (np .int64 ):
136+ return "int64"
137+ elif dt == np .dtype (np .int32 ):
138+ return "int32"
134139 else :
135- self ._typestr = "int32"
140+ msg = f"Unsupported dtype '{ dt } '"
141+ raise RuntimeError (msg )
142+ return None
136143
137144 # Shapes of decompressed array
138145
@@ -233,63 +240,109 @@ def mpi_dist(self):
233240 """The range of the leading dimension assigned to each MPI process."""
234241 return self ._mpi_dist
235242
243+ @property
244+ def dtype (self ):
245+ """The dtype of the uncompressed array."""
246+ return self ._dtype
247+
236248 def _keep_view (self , key ):
237249 if len (key ) != len (self ._leading_shape ):
238250 raise ValueError ("view size does not match leading dimensions" )
239251 view = np .zeros (self ._leading_shape , dtype = bool )
240252 view [key ] = True
241253 return view
242254
243- def __getitem__ (self , key ):
255+ def _slice_nelem (self , slc , dim ):
256+ start , stop , step = slc .indices (dim )
257+ nslc = (stop - start ) // step
258+ if nslc < 0 :
259+ nslc = 0
260+ return nslc
261+
262+ def __getitem__ (self , raw_key ):
244263 """Decompress a slice of data on the fly."""
245264 first = None
246265 last = None
247266 keep = None
248- if isinstance (key , tuple ):
249- # We are slicing on multiple dimensions
250- if len (key ) == len (self ._shape ):
251- # Slicing on the sample dimension too
252- keep = self ._keep_view (key [:- 1 ])
253- samp_key = key [- 1 ]
254- if isinstance (samp_key , slice ):
267+ ndim = len (self ._shape )
268+ output_shape = list ()
269+ sample_shape = (self ._shape [- 1 ],)
270+ if isinstance (raw_key , tuple ):
271+ key = raw_key
272+ else :
273+ key = (raw_key ,)
274+ keep_slice = list ()
275+ for axis , axkey in enumerate (key ):
276+ if axis < ndim - 1 :
277+ # One of the leading dimensions
278+ keep_slice .append (axkey )
279+ if not isinstance (axkey , (int , np .integer )):
280+ # Some kind of slice, do not compress this dimension. Compute
281+ # the number of elements in the output shape.
282+ nslc = self ._slice_nelem (axkey , self ._shape [axis ])
283+ output_shape .append (nslc )
284+ else :
285+ # This is the sample axis. Special handling to ensure that the
286+ # selected samples are contiguous.
287+ if isinstance (axkey , slice ):
255288 # A slice
256- if samp_key .step is not None and samp_key .step != 1 :
257- raise ValueError ("Only stride==1 supported on stream slices" )
258- first = samp_key .start
259- last = samp_key .stop
260- elif isinstance (samp_key , (int , np .integer )):
289+ if (axkey .step is not None and axkey .step != 1 ):
290+ msg = "Only stride==1 supported on stream slices"
291+ raise ValueError (msg )
292+ if (
293+ axkey .start is not None
294+ and axkey .stop is not None
295+ and axkey .stop < axkey .start
296+ ):
297+ msg = "Only increasing slices supported on streams"
298+ raise ValueError (msg )
299+ first = axkey .start
300+ last = axkey .stop
301+ if first is None or first < 0 :
302+ first = 0
303+ if first > self ._shape [- 1 ] - 1 :
304+ first = self ._shape [- 1 ] - 1
305+ if last is None or last > self ._shape [- 1 ]:
306+ last = self ._shape [- 1 ]
307+ if last < 1 :
308+ last = 1
309+ sample_shape = (last - first ,)
310+ elif isinstance (axkey , (int , np .integer )):
261311 # Just a scalar
262- first = samp_key
263- last = samp_key + 1
312+ first = axkey
313+ last = axkey + 1
314+ sample_shape = ()
264315 else :
265316 raise ValueError (
266317 "Only contiguous slices supported on the stream dimension"
267318 )
268- else :
269- # Only slicing the leading dimensions
270- vw = list (key )
271- vw .extend (
272- [slice (None ) for x in range (len (self ._leading_shape ) - len (key ))]
273- )
274- keep = self ._keep_view (tuple (vw ))
275- else :
276- # We are slicing / indexing only the leading dimension
277- vw = [slice (None ) for x in range (len (self ._leading_shape ))]
278- vw [0 ] = key
279- keep = self ._keep_view (tuple (vw ))
280-
281- arr , _ = array_decompress_slice (
282- self ._compressed ,
283- self ._stream_size ,
284- self ._stream_starts ,
285- self ._stream_nbytes ,
286- stream_offsets = self ._stream_offsets ,
287- stream_gains = self ._stream_gains ,
288- keep = keep ,
289- first_stream_sample = first ,
290- last_stream_sample = last ,
319+ keep_slice .extend (
320+ [slice (None ) for x in range (len (self ._leading_shape ) - len (key ))]
291321 )
292- return arr
322+ output_shape .extend (
323+ [x for x in self ._leading_shape [len (key ):]]
324+ )
325+ keep = self ._keep_view (tuple (keep_slice ))
326+ output_shape = tuple (output_shape )
327+ full_shape = output_shape + sample_shape
328+ n_total = np .prod (full_shape )
329+ if n_total == 0 :
330+ # At least one dimension was zero, return empty array
331+ return np .zeros (full_shape , dtype = self ._dtype )
332+ else :
333+ # We have some samples
334+ arr , strm_indices = array_decompress_slice (
335+ self ._compressed ,
336+ self ._stream_size ,
337+ self ._stream_starts ,
338+ self ._stream_nbytes ,
339+ stream_offsets = self ._stream_offsets ,
340+ stream_gains = self ._stream_gains ,
341+ keep = keep ,
342+ first_stream_sample = first ,
343+ last_stream_sample = last ,
344+ )
345+ return arr .reshape (full_shape )
293346
294347 def __delitem__ (self , key ):
295348 raise RuntimeError ("Cannot delete individual streams" )
0 commit comments