1414from .settings import config
1515
1616
17- mxClassID = dict (
18- (
19- # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
20- ("mxUNKNOWN_CLASS" , None ),
21- ("mxCELL_CLASS" , None ),
22- ("mxSTRUCT_CLASS" , None ),
23- ("mxLOGICAL_CLASS" , np .dtype ("bool" )),
24- ("mxCHAR_CLASS" , np .dtype ("c" )),
25- ("mxVOID_CLASS" , np .dtype ("O" )),
26- ("mxDOUBLE_CLASS" , np .dtype ("float64" )),
27- ("mxSINGLE_CLASS" , np .dtype ("float32" )),
28- ("mxINT8_CLASS" , np .dtype ("int8" )),
29- ("mxUINT8_CLASS" , np .dtype ("uint8" )),
30- ("mxINT16_CLASS" , np .dtype ("int16" )),
31- ("mxUINT16_CLASS" , np .dtype ("uint16" )),
32- ("mxINT32_CLASS" , np .dtype ("int32" )),
33- ("mxUINT32_CLASS" , np .dtype ("uint32" )),
34- ("mxINT64_CLASS" , np .dtype ("int64" )),
35- ("mxUINT64_CLASS" , np .dtype ("uint64" )),
36- ("mxFUNCTION_CLASS" , None ),
37- )
38- )
39-
40- rev_class_id = {dtype : i for i , dtype in enumerate (mxClassID .values ())}
41- dtype_list = list (mxClassID .values ())
42- type_names = list (mxClassID )
17+ deserialize_lookup = {
18+ 0 : {"dtype" : None , "scalar_type" : "UNKNOWN" },
19+ 1 : {"dtype" : None , "scalar_type" : "CELL" },
20+ 2 : {"dtype" : None , "scalar_type" : "STRUCT" },
21+ 3 : {"dtype" : np .dtype ("bool" ), "scalar_type" : "LOGICAL" },
22+ 4 : {"dtype" : np .dtype ("c" ), "scalar_type" : "CHAR" },
23+ 5 : {"dtype" : np .dtype ("O" ), "scalar_type" : "VOID" },
24+ 6 : {"dtype" : np .dtype ("float64" ), "scalar_type" : "DOUBLE" },
25+ 7 : {"dtype" : np .dtype ("float32" ), "scalar_type" : "SINGLE" },
26+ 8 : {"dtype" : np .dtype ("int8" ), "scalar_type" : "INT8" },
27+ 9 : {"dtype" : np .dtype ("uint8" ), "scalar_type" : "UINT8" },
28+ 10 : {"dtype" : np .dtype ("int16" ), "scalar_type" : "INT16" },
29+ 11 : {"dtype" : np .dtype ("uint16" ), "scalar_type" : "UINT16" },
30+ 12 : {"dtype" : np .dtype ("int32" ), "scalar_type" : "INT32" },
31+ 13 : {"dtype" : np .dtype ("uint32" ), "scalar_type" : "UINT32" },
32+ 14 : {"dtype" : np .dtype ("int64" ), "scalar_type" : "INT64" },
33+ 15 : {"dtype" : np .dtype ("uint64" ), "scalar_type" : "UINT64" },
34+ 16 : {"dtype" : None , "scalar_type" : "FUNCTION" },
35+ 65_536 : {"dtype" : np .dtype ("datetime64[Y]" ), "scalar_type" : "DATETIME64[Y]" },
36+ 65_537 : {"dtype" : np .dtype ("datetime64[M]" ), "scalar_type" : "DATETIME64[M]" },
37+ 65_538 : {"dtype" : np .dtype ("datetime64[W]" ), "scalar_type" : "DATETIME64[W]" },
38+ 65_539 : {"dtype" : np .dtype ("datetime64[D]" ), "scalar_type" : "DATETIME64[D]" },
39+ 65_540 : {"dtype" : np .dtype ("datetime64[h]" ), "scalar_type" : "DATETIME64[h]" },
40+ 65_541 : {"dtype" : np .dtype ("datetime64[m]" ), "scalar_type" : "DATETIME64[m]" },
41+ 65_542 : {"dtype" : np .dtype ("datetime64[s]" ), "scalar_type" : "DATETIME64[s]" },
42+ 65_543 : {"dtype" : np .dtype ("datetime64[ms]" ), "scalar_type" : "DATETIME64[ms]" },
43+ 65_544 : {"dtype" : np .dtype ("datetime64[us]" ), "scalar_type" : "DATETIME64[us]" },
44+ 65_545 : {"dtype" : np .dtype ("datetime64[ns]" ), "scalar_type" : "DATETIME64[ns]" },
45+ 65_546 : {"dtype" : np .dtype ("datetime64[ps]" ), "scalar_type" : "DATETIME64[ps]" },
46+ 65_547 : {"dtype" : np .dtype ("datetime64[fs]" ), "scalar_type" : "DATETIME64[fs]" },
47+ 65_548 : {"dtype" : np .dtype ("datetime64[as]" ), "scalar_type" : "DATETIME64[as]" },
48+ }
49+ serialize_lookup = {
50+ v ["dtype" ]: {"type_id" : k , "scalar_type" : v ["scalar_type" ]}
51+ for k , v in deserialize_lookup .items ()
52+ if v ["dtype" ] is not None
53+ }
54+
4355
4456compression = {b"ZL123\0 " : zlib .decompress }
4557
@@ -176,7 +188,7 @@ def pack_blob(self, obj):
176188 return self .pack_float (obj )
177189 if isinstance (obj , np .ndarray ) and obj .dtype .fields :
178190 return self .pack_recarray (np .array (obj ))
179- if isinstance (obj , np .number ):
191+ if isinstance (obj , ( np .number , np . datetime64 ) ):
180192 return self .pack_array (np .array (obj ))
181193 if isinstance (obj , (bool , np .bool_ )):
182194 return self .pack_array (np .array (obj ))
@@ -211,14 +223,18 @@ def read_array(self):
211223 shape = self .read_value (count = n_dims )
212224 n_elem = np .prod (shape , dtype = int )
213225 dtype_id , is_complex = self .read_value ("uint32" , 2 )
214- dtype = dtype_list [dtype_id ]
215226
216- if type_names [dtype_id ] == "mxVOID_CLASS" :
227+ # Get dtype from type id
228+ dtype = deserialize_lookup [dtype_id ]["dtype" ]
229+
230+ # Check if name is void
231+ if deserialize_lookup [dtype_id ]["scalar_type" ] == "VOID" :
217232 data = np .array (
218233 list (self .read_blob (self .read_value ()) for _ in range (n_elem )),
219234 dtype = np .dtype ("O" ),
220235 )
221- elif type_names [dtype_id ] == "mxCHAR_CLASS" :
236+ # Check if name is char
237+ elif deserialize_lookup [dtype_id ]["scalar_type" ] == "CHAR" :
222238 # compensate for MATLAB packing of char arrays
223239 data = self .read_value (dtype , count = 2 * n_elem )
224240 data = data [::2 ].astype ("U1" )
@@ -240,6 +256,8 @@ def pack_array(self, array):
240256 """
241257 Serialize an np.ndarray into bytes. Scalars are encoded with ndim=0.
242258 """
259+ if "datetime64" in array .dtype .name :
260+ self .set_dj0 ()
243261 blob = (
244262 b"A"
245263 + np .uint64 (array .ndim ).tobytes ()
@@ -248,22 +266,26 @@ def pack_array(self, array):
248266 is_complex = np .iscomplexobj (array )
249267 if is_complex :
250268 array , imaginary = np .real (array ), np .imag (array )
251- type_id = (
252- rev_class_id [array .dtype ]
253- if array .dtype .char != "U"
254- else rev_class_id [np .dtype ("O" )]
255- )
256- if dtype_list [type_id ] is None :
257- raise DataJointError ("Type %s is ambiguous or unknown" % array .dtype )
269+ try :
270+ type_id = serialize_lookup [array .dtype ]["type_id" ]
271+ except KeyError :
272+ # U is for unicode string
273+ if array .dtype .char == "U" :
274+ type_id = serialize_lookup [np .dtype ("O" )]["type_id" ]
275+ else :
276+ raise DataJointError (f"Type { array .dtype } is ambiguous or unknown" )
258277
259278 blob += np .array ([type_id , is_complex ], dtype = np .uint32 ).tobytes ()
260- if type_names [type_id ] == "mxVOID_CLASS" : # array of dtype('O')
279+ if (
280+ array .dtype .char == "U"
281+ or serialize_lookup [array .dtype ]["scalar_type" ] == "VOID"
282+ ):
261283 blob += b"" .join (
262284 len_u64 (it ) + it
263285 for it in (self .pack_blob (e ) for e in array .flatten (order = "F" ))
264286 )
265287 self .set_dj0 () # not supported by original mym
266- elif type_names [ type_id ] == "mxCHAR_CLASS" : # array of dtype('c')
288+ elif serialize_lookup [ array . dtype ][ "scalar_type" ] == "CHAR" :
267289 blob += (
268290 array .view (np .uint8 ).astype (np .uint16 ).tobytes ()
269291 ) # convert to 16-bit chars for MATLAB
0 commit comments