Skip to content

Commit ff8bb82

Browse files
fix pre-commit and tests from #262 (#294)
* fix pre-commit errors * fix tests
1 parent eba53fe commit ff8bb82

File tree

3 files changed

+122
-37
lines changed

3 files changed

+122
-37
lines changed

e3sm_to_cmip/cmor_handlers/_formulas.py

Lines changed: 114 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,18 @@ def mmrbc(ds: xr.Dataset) -> xr.DataArray:
134134
mmrbc = Mass_bc or bc_a1 + bc_a3 + bc_a4 + bc_c1 + bc_c3 + bc_c4
135135
"""
136136

137-
if all(key in ds.data_vars for key in ["bc_a1", "bc_a3", "bc_a4",
138-
"bc_c1", "bc_c3", "bc_c4"]):
139-
result = ( ds["bc_a1"] + ds["bc_a3"] + ds["bc_a4"] +
140-
ds["bc_c1"] + ds["bc_c3"] + ds["bc_c4"] )
137+
if all(
138+
key in ds.data_vars
139+
for key in ["bc_a1", "bc_a3", "bc_a4", "bc_c1", "bc_c3", "bc_c4"]
140+
):
141+
result = (
142+
ds["bc_a1"]
143+
+ ds["bc_a3"]
144+
+ ds["bc_a4"]
145+
+ ds["bc_c1"]
146+
+ ds["bc_c3"]
147+
+ ds["bc_c4"]
148+
)
141149
elif "Mass_bc" in ds:
142150
result = ds["Mass_bc"]
143151
else:
@@ -155,10 +163,8 @@ def mmrdust(ds: xr.Dataset) -> xr.DataArray:
155163
mmrdust = Mass_dst or dst_a1 + dst_a3 + dst_c1 + dst_c3
156164
"""
157165

158-
if all(key in ds.data_vars for key in ["dst_a1", "dst_a3",
159-
"dst_c1", "dst_c3"]):
160-
result = ( ds["dst_a1"] + ds["dst_a3"] +
161-
ds["dst_c1"] + ds["dst_c3"] )
166+
if all(key in ds.data_vars for key in ["dst_a1", "dst_a3", "dst_c1", "dst_c3"]):
167+
result = ds["dst_a1"] + ds["dst_a3"] + ds["dst_c1"] + ds["dst_c3"]
162168
elif "Mass_dst" in ds:
163169
result = ds["Mass_dst"]
164170
else:
@@ -177,21 +183,44 @@ def mmroa(ds: xr.Dataset) -> xr.DataArray:
177183
soa_a1 + soa_a2 + soa_a3 + soa_c1 + soa_c2 + soa_c3
178184
"""
179185

180-
if all(key in ds.data_vars for key in ["pom_a1", "pom_a3", "pom_a4",
181-
"pom_c1", "pom_c3", "pom_c4",
182-
"soa_a1", "soa_a2", "soa_a3",
183-
"soa_c1", "soa_c2", "soa_c3"]):
184-
result = ( ds["pom_a1"] + ds["pom_a3"] + ds["pom_a4"] +
185-
ds["pom_c1"] + ds["pom_c3"] + ds["pom_c4"] +
186-
ds["soa_a1"] + ds["soa_a2"] + ds["soa_a3"] +
187-
ds["soa_c1"] + ds["soa_c2"] + ds["soa_c3"] )
186+
if all(
187+
key in ds.data_vars
188+
for key in [
189+
"pom_a1",
190+
"pom_a3",
191+
"pom_a4",
192+
"pom_c1",
193+
"pom_c3",
194+
"pom_c4",
195+
"soa_a1",
196+
"soa_a2",
197+
"soa_a3",
198+
"soa_c1",
199+
"soa_c2",
200+
"soa_c3",
201+
]
202+
):
203+
result = (
204+
ds["pom_a1"]
205+
+ ds["pom_a3"]
206+
+ ds["pom_a4"]
207+
+ ds["pom_c1"]
208+
+ ds["pom_c3"]
209+
+ ds["pom_c4"]
210+
+ ds["soa_a1"]
211+
+ ds["soa_a2"]
212+
+ ds["soa_a3"]
213+
+ ds["soa_c1"]
214+
+ ds["soa_c2"]
215+
+ ds["soa_c3"]
216+
)
188217
elif all(key in ds.data_vars for key in ["Mass_pom", "Mass_soa"]):
189218
result = ds["Mass_pom"] + ds["Mass_soa"]
190219
else:
191220
raise KeyError(
192221
"No formula could be applied for 'mmroa'. Check the handler entry for 'mmroa' "
193222
"and input file(s) contain either 'pom_a1', 'pom_a3', 'pom_a4', "
194-
"'pom_c1', 'pom_c3', 'pom_c4', 'soa_a1', 'soa_a2', 'soa_a3', "
223+
"'pom_c1', 'pom_c3', 'pom_c4', 'soa_a1', 'soa_a2', 'soa_a3', "
195224
"'soa_c1', 'soa_c2', 'soa_c3', or 'Mass_pom' and 'Mass_soa'."
196225
)
197226

@@ -203,10 +232,18 @@ def mmrsoa(ds: xr.Dataset) -> xr.DataArray:
203232
mmrsoa = Mass_soa or soa_a1 + soa_a2 + soa_a3 + soa_c1 + soa_c2 + soa_c3
204233
"""
205234

206-
if all(key in ds.data_vars for key in ["soa_a1", "soa_a2", "soa_a3",
207-
"soa_c1", "soa_c2", "soa_c3"]):
208-
result = ( ds["soa_a1"] + ds["soa_a2"] + ds["soa_a3"] +
209-
ds["soa_c1"] + ds["soa_c2"] + ds["soa_c3"] )
235+
if all(
236+
key in ds.data_vars
237+
for key in ["soa_a1", "soa_a2", "soa_a3", "soa_c1", "soa_c2", "soa_c3"]
238+
):
239+
result = (
240+
ds["soa_a1"]
241+
+ ds["soa_a2"]
242+
+ ds["soa_a3"]
243+
+ ds["soa_c1"]
244+
+ ds["soa_c2"]
245+
+ ds["soa_c3"]
246+
)
210247
elif "Mass_soa" in ds:
211248
result = ds["Mass_soa"]
212249
else:
@@ -224,10 +261,18 @@ def mmrss(ds: xr.Dataset) -> xr.DataArray:
224261
mmrss = Mass_ncl or ncl_a1 + ncl_a2 + ncl_a3 + ncl_c1 + ncl_c2 + ncl_c3
225262
"""
226263

227-
if all(key in ds.data_vars for key in ["ncl_a1", "ncl_a2", "ncl_a3",
228-
"ncl_c1", "ncl_c2", "ncl_c3"]):
229-
result = ( ds["ncl_a1"] + ds["ncl_a2"] + ds["ncl_a3"] +
230-
ds["ncl_c1"] + ds["ncl_c2"] + ds["ncl_c3"] )
264+
if all(
265+
key in ds.data_vars
266+
for key in ["ncl_a1", "ncl_a2", "ncl_a3", "ncl_c1", "ncl_c2", "ncl_c3"]
267+
):
268+
result = (
269+
ds["ncl_a1"]
270+
+ ds["ncl_a2"]
271+
+ ds["ncl_a3"]
272+
+ ds["ncl_c1"]
273+
+ ds["ncl_c2"]
274+
+ ds["ncl_c3"]
275+
)
231276
elif "Mass_ncl" in ds:
232277
result = ds["Mass_ncl"]
233278
else:
@@ -245,15 +290,50 @@ def mmrso4(ds: xr.Dataset) -> xr.DataArray:
245290
mmrso4 = Mass_so4 or so4_a1+so4_c1+so4_a2+so4_c2+so4_a3+so4_c3 for MAM5
246291
mmrso4 = Mass_so4 or so4_a1+so4_c1+so4_a2+so4_c2+so4_a3+so4_c3 for MAM4
247292
"""
248-
249-
if all(key in ds.data_vars for key in ["so4_a1", "so4_a2", "so4_a3", "so4_a5",
250-
"so4_c1", "so4_c2", "so4_c3", "so4_c5"]):
251-
result = ( ds["so4_a1"] + ds["so4_a2"] + ds["so4_a3"] + ds["so4_a5"] +
252-
ds["so4_c1"] + ds["so4_c2"] + ds["so4_c3"] + ds["so4_c5"] ) * 96.0636 / 115.10734
253-
elif all(key in ds.data_vars for key in ["so4_a1", "so4_a2", "so4_a3",
254-
"so4_c1", "so4_c2", "so4_c3"]):
255-
result = ( ds["so4_a1"] + ds["so4_a2"] + ds["so4_a3"] +
256-
ds["so4_c1"] + ds["so4_c2"] + ds["so4_c3"] ) * 96.0636 / 115.10734
293+
294+
if all(
295+
key in ds.data_vars
296+
for key in [
297+
"so4_a1",
298+
"so4_a2",
299+
"so4_a3",
300+
"so4_a5",
301+
"so4_c1",
302+
"so4_c2",
303+
"so4_c3",
304+
"so4_c5",
305+
]
306+
):
307+
result = (
308+
(
309+
ds["so4_a1"]
310+
+ ds["so4_a2"]
311+
+ ds["so4_a3"]
312+
+ ds["so4_a5"]
313+
+ ds["so4_c1"]
314+
+ ds["so4_c2"]
315+
+ ds["so4_c3"]
316+
+ ds["so4_c5"]
317+
)
318+
* 96.0636
319+
/ 115.10734
320+
)
321+
elif all(
322+
key in ds.data_vars
323+
for key in ["so4_a1", "so4_a2", "so4_a3", "so4_c1", "so4_c2", "so4_c3"]
324+
):
325+
result = (
326+
(
327+
ds["so4_a1"]
328+
+ ds["so4_a2"]
329+
+ ds["so4_a3"]
330+
+ ds["so4_c1"]
331+
+ ds["so4_c2"]
332+
+ ds["so4_c3"]
333+
)
334+
* 96.0636
335+
/ 115.10734
336+
)
257337
elif "Mass_so4" in ds:
258338
result = ds["Mass_so4"] * 96.0636 / 115.10734
259339
else:

e3sm_to_cmip/mpas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def _compute_moc_time_series(
728728
lat_bnds = np.zeros((len(lat) - 1, 2))
729729
lat_bnds[:, 0] = lat[0:-1]
730730
lat_bnds[:, 1] = lat[1:]
731-
lat = 0.5 * (lat_bnds[:, 0] + lat_bnds[:, 1])
731+
lat = 0.5 * (lat_bnds[:, 0] + lat_bnds[:, 1]) # type: ignore
732732

733733
lat_bnds = xarray.DataArray(lat_bnds, dims=("lat", "nbnd")) # type: ignore
734734
lat = xarray.DataArray(lat, dims=("lat",)) # type: ignore

tests/cmor_handlers/test__formulas.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,17 @@ def test_mmrbc():
142142
ds = xr.Dataset(
143143
data_vars={
144144
"bc_a1": _dummy_dataarray(),
145+
"bc_a3": _dummy_dataarray(),
145146
"bc_a4": _dummy_dataarray(),
146147
"bc_c1": _dummy_dataarray(),
148+
"bc_c3": _dummy_dataarray(),
147149
"bc_c4": _dummy_dataarray(),
148150
}
149151
)
150152

151153
result = mmrbc(ds)
152154
expected = xr.DataArray(
153-
dims=["lat", "lon"], data=np.array([[0, 4, 8], [0, 4, 8], [0, 4, 8]])
155+
dims=["lat", "lon"], data=np.array([[0, 6, 12], [0, 6, 12], [0, 6, 12]])
154156
)
155157
xr.testing.assert_allclose(result, expected)
156158

@@ -173,7 +175,10 @@ def test_mmrso4():
173175

174176
result = mmrso4(ds)
175177
expected = xr.DataArray(
176-
dims=["lat", "lon"], data=np.array([[0, 6, 12], [0, 6, 12], [0, 6, 12]])
178+
dims=["lat", "lon"],
179+
data=np.array(
180+
[[0, 5.00734, 10.01468], [0, 5.00734, 10.01468], [0, 5.00734, 10.01468]]
181+
),
177182
)
178183
xr.testing.assert_allclose(result, expected)
179184

0 commit comments

Comments
 (0)