Skip to content

Commit e88d870

Browse files
authored
Merge pull request #202 from CliMT/develop
Make sure tests are passing
2 parents 492159b + df4dd90 commit e88d870

9 files changed

Lines changed: 493 additions & 247 deletions

File tree

climt/_components/slab_surface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from sympl import TendencyComponent, initialize_numpy_arrays_with_properties
21
import numpy as np
2+
from sympl import TendencyComponent, initialize_numpy_arrays_with_properties
33

44

55
class SlabSurface(TendencyComponent):

climt/_core/initialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from datetime import datetime
23

34
import numpy as np
@@ -9,7 +10,6 @@
910
get_constant,
1011
set_constant,
1112
)
12-
import sys
1313

1414
if sys.version_info < (3, 9):
1515
import importlib_resources

climt/_core/unyt_backend.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,33 @@ class UnytStateContainer:
9191
dims (tuple of str): The names of the dimensions.
9292
"""
9393

94-
def __init__(self, data, dims):
95-
if not isinstance(data, (unyt.unyt_array, np.ndarray)):
94+
def __init__(self, data, dims=None, attrs=None):
95+
if not isinstance(
96+
data, (unyt.unyt_array, np.ndarray, float, int, np.floating, np.integer)
97+
):
9698
# Ideally strict, but helpful to be flexible if easy.
9799
# For now, strict to match design.
98100
raise TypeError(
99101
f"Data must be a unyt.unyt_array or numpy.ndarray, got {type(data)}"
100102
)
103+
104+
if (
105+
attrs is not None
106+
and "units" in attrs
107+
and not isinstance(data, unyt.unyt_array)
108+
):
109+
try:
110+
sanitized_units = attrs["units"].replace("^", "**").replace(" ", "*")
111+
data = unyt.unyt_array(data, sanitized_units)
112+
except Exception:
113+
pass
114+
101115
self.data = data
102-
self.dims = tuple(dims)
116+
self.dims = tuple(dims) if dims is not None else ()
117+
118+
def rename(self, name_dict):
119+
new_dims = [name_dict.get(d, d) for d in self.dims]
120+
return UnytStateContainer(self.data, new_dims)
103121

104122
def __repr__(self):
105123
return f"UnytStateContainer(data={self.data}, dims={self.dims})"
@@ -123,6 +141,68 @@ def attrs(self):
123141
return {"units": str(self.data.units)}
124142
return {}
125143

144+
class _LocIndexer:
145+
def __init__(self, container):
146+
self.container = container
147+
148+
def __getitem__(self, key):
149+
if isinstance(key, dict):
150+
slices = []
151+
for dim in self.container.dims:
152+
if dim in key:
153+
slices.append(key[dim])
154+
else:
155+
slices.append(slice(None))
156+
return self.container[tuple(slices)]
157+
return self.container[key]
158+
159+
def __setitem__(self, key, value):
160+
if isinstance(key, dict):
161+
slices = []
162+
for dim in self.container.dims:
163+
if dim in key:
164+
slices.append(key[dim])
165+
else:
166+
slices.append(slice(None))
167+
self.container[tuple(slices)] = value
168+
else:
169+
self.container[key] = value
170+
171+
@property
172+
def loc(self):
173+
return self._LocIndexer(self)
174+
175+
def transpose(self, *dims):
176+
if len(dims) == 1 and isinstance(dims[0], (tuple, list)):
177+
dims = dims[0]
178+
# map dim names to axis indices
179+
perm = [self.dims.index(dim) for dim in dims]
180+
return UnytStateContainer(self.data.transpose(perm), dims)
181+
182+
def __getitem__(self, key):
183+
sliced_data = self.data[key]
184+
if isinstance(sliced_data, (unyt.unyt_array, np.ndarray)):
185+
if sliced_data.ndim == len(self.dims):
186+
return UnytStateContainer(sliced_data, self.dims)
187+
else:
188+
# Provide dummy dimensions or just truncate if we cannot infer dropped dimension easily
189+
# Truncate to match ndim for now
190+
return UnytStateContainer(sliced_data, self.dims[: sliced_data.ndim])
191+
return sliced_data
192+
193+
def __eq__(self, other):
194+
if isinstance(other, UnytStateContainer):
195+
return self.data == other.data
196+
return self.data == other
197+
198+
def __setitem__(self, key, value):
199+
if isinstance(value, UnytStateContainer):
200+
self.data[key] = value.data
201+
elif isinstance(value, unyt.unyt_array) and hasattr(self.data, "units"):
202+
self.data[key] = value.to(self.data.units)
203+
else:
204+
self.data[key] = value
205+
126206
def to_units(self, units):
127207
if not isinstance(self.data, unyt.unyt_array):
128208
return self
@@ -398,3 +478,6 @@ def get_shape(self, state_value):
398478
if not isinstance(state_value, UnytStateContainer):
399479
raise TypeError(f"Expected UnytStateContainer, got {type(state_value)}")
400480
return state_value.data.shape
481+
482+
def get_container_type(self):
483+
return UnytStateContainer

climt/_core/util.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from sympl import jit, DataArray
2-
import numpy as np
31
import functools
42

3+
import numpy as np
4+
from sympl import DataArray, jit
5+
56

67
def ensure_contiguous_state(func):
78
@functools.wraps(func)
@@ -173,14 +174,14 @@ def calculate_q_sat(surf_temp, surf_press, Rd, Rv):
173174
return eps * sat_vap_press / (surf_press - (1 - eps) * sat_vap_press)
174175

175176

176-
@jit(nopython=True)
177+
# @jit(nopython=True)
177178
def bolton_q_sat(T, p, Rd, Rh2O):
178179
es = 611.2 * np.exp(17.67 * (T - 273.15) / (T - 29.65))
179180
epsilon = Rd / Rh2O
180181
return epsilon * es / (p - (1 - epsilon) * es)
181182

182183

183-
@jit(nopython=True)
184+
# @jit(nopython=True)
184185
def bolton_dqsat_dT(T, Lv, Rh2O, q_sat):
185186
"""Uses the assumptions of equation 12 in Reed and Jablonowski, 2012. In
186187
particular, assumes d(qsat)/dT is approximately epsilon/p*d(es)/dT"""

rad_conv_eq_unyt.nc

-122 KB
Binary file not shown.
184 Bytes
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)