@@ -93,15 +93,33 @@ class UnytStateContainer:
9393 dims (tuple of str): The names of the dimensions.
9494 """
9595
96- def __init__ (self , data , dims ):
97- if not isinstance (data , (unyt .unyt_array , np .ndarray )):
96+ def __init__ (self , data , dims = None , attrs = None ):
97+ if not isinstance (
98+ data , (unyt .unyt_array , np .ndarray , float , int , np .floating , np .integer )
99+ ):
98100 # Ideally strict, but helpful to be flexible if easy.
99101 # For now, strict to match design.
100102 raise TypeError (
101103 f"Data must be a unyt.unyt_array or numpy.ndarray, got { type (data )} "
102104 )
105+
106+ if (
107+ attrs is not None
108+ and "units" in attrs
109+ and not isinstance (data , unyt .unyt_array )
110+ ):
111+ try :
112+ sanitized_units = attrs ["units" ].replace ("^" , "**" ).replace (" " , "*" )
113+ data = unyt .unyt_array (data , sanitized_units )
114+ except Exception :
115+ pass
116+
103117 self .data = data
104- self .dims = tuple (dims )
118+ self .dims = tuple (dims ) if dims is not None else ()
119+
120+ def rename (self , name_dict ):
121+ new_dims = [name_dict .get (d , d ) for d in self .dims ]
122+ return UnytStateContainer (self .data , new_dims )
105123
106124 def __repr__ (self ):
107125 return f"UnytStateContainer(data={ self .data } , dims={ self .dims } )"
@@ -131,6 +149,68 @@ def attrs(self):
131149 return {"units" : str (self .data .units )}
132150 return {}
133151
152+ class _LocIndexer :
153+ def __init__ (self , container ):
154+ self .container = container
155+
156+ def __getitem__ (self , key ):
157+ if isinstance (key , dict ):
158+ slices = []
159+ for dim in self .container .dims :
160+ if dim in key :
161+ slices .append (key [dim ])
162+ else :
163+ slices .append (slice (None ))
164+ return self .container [tuple (slices )]
165+ return self .container [key ]
166+
167+ def __setitem__ (self , key , value ):
168+ if isinstance (key , dict ):
169+ slices = []
170+ for dim in self .container .dims :
171+ if dim in key :
172+ slices .append (key [dim ])
173+ else :
174+ slices .append (slice (None ))
175+ self .container [tuple (slices )] = value
176+ else :
177+ self .container [key ] = value
178+
179+ @property
180+ def loc (self ):
181+ return self ._LocIndexer (self )
182+
183+ def transpose (self , * dims ):
184+ if len (dims ) == 1 and isinstance (dims [0 ], (tuple , list )):
185+ dims = dims [0 ]
186+ # map dim names to axis indices
187+ perm = [self .dims .index (dim ) for dim in dims ]
188+ return UnytStateContainer (self .data .transpose (perm ), dims )
189+
190+ def __getitem__ (self , key ):
191+ sliced_data = self .data [key ]
192+ if isinstance (sliced_data , (unyt .unyt_array , np .ndarray )):
193+ if sliced_data .ndim == len (self .dims ):
194+ return UnytStateContainer (sliced_data , self .dims )
195+ else :
196+ # Provide dummy dimensions or just truncate if we cannot infer dropped dimension easily
197+ # Truncate to match ndim for now
198+ return UnytStateContainer (sliced_data , self .dims [: sliced_data .ndim ])
199+ return sliced_data
200+
201+ def __eq__ (self , other ):
202+ if isinstance (other , UnytStateContainer ):
203+ return self .data == other .data
204+ return self .data == other
205+
206+ def __setitem__ (self , key , value ):
207+ if isinstance (value , UnytStateContainer ):
208+ self .data [key ] = value .data
209+ elif isinstance (value , unyt .unyt_array ) and hasattr (self .data , "units" ):
210+ self .data [key ] = value .to (self .data .units )
211+ else :
212+ self .data [key ] = value
213+
134214 def to_units (self , units ):
135215 if not isinstance (self .data , unyt .unyt_array ):
136216 return self
0 commit comments