@@ -91,15 +91,33 @@ class UnytStateContainer:
9191 dims (tuple of str): The names of the dimensions.
9292 """
9393
94- def __init__ (self , data , dims ):
95- if not isinstance (data , (unyt .unyt_array , np .ndarray )):
94+ def __init__ (self , data , dims = None , attrs = None ):
95+ if not isinstance (
96+ data , (unyt .unyt_array , np .ndarray , float , int , np .floating , np .integer )
97+ ):
9698 # Ideally strict, but helpful to be flexible if easy.
9799 # For now, strict to match design.
98100 raise TypeError (
99101 f"Data must be a unyt.unyt_array or numpy.ndarray, got { type (data )} "
100102 )
103+
104+ if (
105+ attrs is not None
106+ and "units" in attrs
107+ and not isinstance (data , unyt .unyt_array )
108+ ):
109+ try :
110+ sanitized_units = attrs ["units" ].replace ("^" , "**" ).replace (" " , "*" )
111+ data = unyt .unyt_array (data , sanitized_units )
112+ except Exception :
113+ pass
114+
101115 self .data = data
102- self .dims = tuple (dims )
116+ self .dims = tuple (dims ) if dims is not None else ()
117+
118+ def rename (self , name_dict ):
119+ new_dims = [name_dict .get (d , d ) for d in self .dims ]
120+ return UnytStateContainer (self .data , new_dims )
103121
104122 def __repr__ (self ):
105123 return f"UnytStateContainer(data={ self .data } , dims={ self .dims } )"
@@ -123,6 +141,68 @@ def attrs(self):
123141 return {"units" : str (self .data .units )}
124142 return {}
125143
144+ class _LocIndexer :
145+ def __init__ (self , container ):
146+ self .container = container
147+
148+ def __getitem__ (self , key ):
149+ if isinstance (key , dict ):
150+ slices = []
151+ for dim in self .container .dims :
152+ if dim in key :
153+ slices .append (key [dim ])
154+ else :
155+ slices .append (slice (None ))
156+ return self .container [tuple (slices )]
157+ return self .container [key ]
158+
159+ def __setitem__ (self , key , value ):
160+ if isinstance (key , dict ):
161+ slices = []
162+ for dim in self .container .dims :
163+ if dim in key :
164+ slices .append (key [dim ])
165+ else :
166+ slices .append (slice (None ))
167+ self .container [tuple (slices )] = value
168+ else :
169+ self .container [key ] = value
170+
171+ @property
172+ def loc (self ):
173+ return self ._LocIndexer (self )
174+
175+ def transpose (self , * dims ):
176+ if len (dims ) == 1 and isinstance (dims [0 ], (tuple , list )):
177+ dims = dims [0 ]
178+ # map dim names to axis indices
179+ perm = [self .dims .index (dim ) for dim in dims ]
180+ return UnytStateContainer (self .data .transpose (perm ), dims )
181+
182+ def __getitem__ (self , key ):
183+ sliced_data = self .data [key ]
184+ if isinstance (sliced_data , (unyt .unyt_array , np .ndarray )):
185+ if sliced_data .ndim == len (self .dims ):
186+ return UnytStateContainer (sliced_data , self .dims )
187+ else :
188+ # Provide dummy dimensions or just truncate if we cannot infer dropped dimension easily
189+ # Truncate to match ndim for now
190+ return UnytStateContainer (sliced_data , self .dims [: sliced_data .ndim ])
191+ return sliced_data
192+
193+ def __eq__ (self , other ):
194+ if isinstance (other , UnytStateContainer ):
195+ return self .data == other .data
196+ return self .data == other
197+
198+ def __setitem__ (self , key , value ):
199+ if isinstance (value , UnytStateContainer ):
200+ self .data [key ] = value .data
201+ elif isinstance (value , unyt .unyt_array ) and hasattr (self .data , "units" ):
202+ self .data [key ] = value .to (self .data .units )
203+ else :
204+ self .data [key ] = value
205+
126206 def to_units (self , units ):
127207 if not isinstance (self .data , unyt .unyt_array ):
128208 return self
@@ -398,3 +478,6 @@ def get_shape(self, state_value):
398478 if not isinstance (state_value , UnytStateContainer ):
399479 raise TypeError (f"Expected UnytStateContainer, got { type (state_value )} " )
400480 return state_value .data .shape
481+
482+ def get_container_type (self ):
483+ return UnytStateContainer
0 commit comments