Skip to content

Commit ed0fe5e

Browse files
committed
new unit test and all pass
1 parent 1c7ea6c commit ed0fe5e

File tree

6 files changed

+344
-83
lines changed

6 files changed

+344
-83
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ disable=
1414
max-line-length=100
1515

1616
[TYPECHECK]
17-
ignored-modules=xarray,xesmf,cmor,numpy,geocat.comp,yaml,cftime
17+
ignored-modules=xarray,xesmf,cmor,numpy,geocat.comp,yaml,cftime,pytest
1818
ignored-classes=cmor

cmip7_prep/cmor_writer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def _define_axes(self, ds: xr.Dataset, vdef: Any) -> list[int]:
124124
time = ds["time"]
125125
t_units = time.attrs.get("units", "days since 1850-01-01")
126126
cal = time.attrs.get("calendar", time.encoding.get("calendar", "noleap"))
127+
127128
tvals = _encode_time_to_num(time, t_units, cal)
128129

129130
# Optional bounds: try common names or CF 'bounds' attribute
@@ -138,7 +139,7 @@ def _define_axes(self, ds: xr.Dataset, vdef: Any) -> list[int]:
138139
t_bnds = _encode_time_to_num(tb, t_units, cal)
139140
except ValueError:
140141
t_bnds = None
141-
142+
print(f"ds = {ds} tb = {tb} tbnds is {t_bnds}")
142143
axes.append(
143144
cmor.axis(
144145
table_entry="time",

cmip7_prep/data/cmor_dataset.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"calendar": "noleap",
23
"outpath": "out",
34
"tracking_prefix": "hdl:21.14100",
45
"product": "output",

cmip7_prep/regrid.py

Lines changed: 174 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,45 @@ def _get_src_shape(m: xr.Dataset) -> Tuple[int, int]:
119119
raise ValueError("Cannot infer source grid size from weight file.")
120120

121121

122+
def _dst_latlon_1d_from_map(mapfile: Path) -> tuple[np.ndarray, np.ndarray]:
123+
"""Return 1D (lat_out, lon_out) arrays for the destination grid from the map file."""
124+
with _open_nc(mapfile) as m:
125+
# prefer 2-D centers, then reshape → 1-D
126+
lat2d = _read_array(m, "yc_b", "lat_b", "dst_grid_center_lat", "yc", "lat")
127+
lon2d = _read_array(m, "xc_b", "lon_b", "dst_grid_center_lon", "xc", "lon")
128+
if lat2d is not None and lon2d is not None:
129+
size = int(np.asarray(lat2d).size)
130+
# try dims if present, else infer 180x360 for 1° grid
131+
if "dst_grid_dims" in m:
132+
dims = np.asarray(m["dst_grid_dims"]).ravel().astype(int)
133+
if dims.size >= 2:
134+
ny, nx = int(dims[-2]), int(dims[-1])
135+
else:
136+
ny, nx = 180, 360
137+
else:
138+
if size == 180 * 360:
139+
ny, nx = 180, 360
140+
else:
141+
ny = int(round(np.sqrt(size)))
142+
nx = size // ny
143+
lat2d = np.asarray(lat2d).reshape(ny, nx)
144+
lon2d = np.asarray(lon2d).reshape(ny, nx)
145+
lat1d = lat2d[:, 0].astype("f8")
146+
lon1d = lon2d[0, :].astype("f8")
147+
return lat1d, lon1d
148+
149+
# fallback: 1-D already present
150+
lat1d = _read_array(m, "lat", "yc")
151+
lon1d = _read_array(m, "lon", "xc")
152+
if lat1d is not None and lon1d is not None:
153+
return np.asarray(lat1d, dtype="f8"), np.asarray(lon1d, dtype="f8")
154+
155+
# last resort: fabricate standard 1°
156+
lat = np.linspace(-89.5, 89.5, 180, dtype="f8")
157+
lon = np.arange(360, dtype="f8") + 0.5
158+
return lat, lon
159+
160+
122161
def _get_dst_latlon_1d(m: xr.Dataset) -> Tuple[np.ndarray, np.ndarray]:
123162
"""Return 1D dest lat, lon arrays from weight file.
124163
@@ -258,14 +297,14 @@ def _ensure_ncol_last(da: xr.DataArray) -> Tuple[xr.DataArray, Tuple[str, ...]]:
258297
return da.transpose(*non_spatial, "ncol"), non_spatial
259298

260299

261-
def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
262-
"""Normalize 2-D dims to ('lat','lon') if they came out as ('y','x')."""
263-
dim_map = {}
264-
if "y" in da.dims:
265-
dim_map["y"] = "lat"
266-
if "x" in da.dims:
267-
dim_map["x"] = "lon"
268-
return da.rename(dim_map) if dim_map else da
300+
# def _rename_xy_to_latlon(da: xr.DataArray) -> xr.DataArray:
301+
# """Normalize 2-D dims to ('lat','lon') if they came out as ('y','x')."""
302+
# dim_map = {}
303+
# if "y" in da.dims:
304+
# dim_map["y"] = "lat"
305+
# if "x" in da.dims:
306+
# dim_map["x"] = "lon"
307+
# return da.rename(dim_map) if dim_map else da
269308

270309

271310
# -------------------------
@@ -305,8 +344,8 @@ def regrid_to_1deg(
305344
if varname not in ds_in:
306345
raise KeyError(f"{varname!r} not in dataset.")
307346

308-
da = ds_in[varname]
309-
da2, non_spatial = _ensure_ncol_last(da)
347+
var_da = ds_in[varname] # always a DataArray
348+
da2, non_spatial = _ensure_ncol_last(var_da)
310349

311350
# cast to save memory
312351
if dtype is not None and str(da2.dtype) != dtype:
@@ -337,42 +376,135 @@ def regrid_to_1deg(
337376
if "time" in da2_2d.dims and output_time_chunk:
338377
kwargs["output_chunks"] = {"time": output_time_chunk}
339378

340-
out = regridder(da2_2d, **kwargs) # -> (*non_spatial, y/x or lat/lon)
341-
out = _rename_xy_to_latlon(out)
342-
343-
if keep_attrs:
344-
out.attrs.update(da.attrs)
345-
346-
for c in non_spatial:
347-
if c in ds_in.coords and c in out.dims:
348-
out = out.assign_coords({c: ds_in[c]})
349-
350-
if "lat" in out.coords:
351-
out["lat"].attrs.setdefault("units", "degrees_north")
352-
out["lat"].attrs.setdefault("standard_name", "latitude")
353-
if "lon" in out.coords:
354-
out["lon"].attrs.setdefault("units", "degrees_east")
355-
out["lon"].attrs.setdefault("standard_name", "longitude")
379+
out = regridder(da2_2d, **kwargs) # current call that returns (*non_spatial, ?, ?)
380+
381+
# --- NEW: robust lat/lon assignment based on destination grid lengths ---
382+
lat1d, lon1d = _dst_latlon_1d_from_map(spec.path)
383+
ny, nx = len(lat1d), len(lon1d)
384+
385+
# find the last two dims that came from xESMF
386+
spatial_dims = [d for d in out.dims if d not in non_spatial]
387+
if len(spatial_dims) < 2:
388+
raise ValueError(f"Unexpected output dims {out.dims}; need two spatial dims.")
389+
390+
da, db = spatial_dims[-2], spatial_dims[-1]
391+
na, nb = out.sizes[da], out.sizes[db]
392+
393+
# Decide mapping by comparing lengths to (ny, nx)
394+
if na == ny and nb == nx:
395+
out = out.rename({da: "lat", db: "lon"})
396+
elif na == nx and nb == ny:
397+
out = out.rename({da: "lon", db: "lat"})
398+
else:
399+
# Heuristic fallback: pick the dim whose size matches 180 as lat
400+
if {na, nb} == {ny, nx}:
401+
# covered above; should not reach here
402+
pass
403+
else:
404+
# choose the one closer to 180 as lat
405+
choose_lat = da if abs(na - 180) <= abs(nb - 180) else db
406+
choose_lon = db if choose_lat == da else da
407+
out = out.rename({choose_lat: "lat", choose_lon: "lon"})
408+
409+
# assign canonical 1-D coords
410+
out = out.assign_coords(lat=("lat", lat1d), lon=("lon", lon1d))
411+
412+
try:
413+
out = out.transpose(*non_spatial, "lat", "lon")
414+
except ValueError:
415+
# fallback if non_spatial is empty
416+
out = out.transpose("lat", "lon")
417+
if keep_attrs and hasattr(var_da, "attrs"):
418+
out.attrs.update(var_da.attrs)
356419

357420
return out
358421

359422

360-
def regrid_mask_or_area(
361-
da_in: xr.DataArray,
423+
def regrid_to_1deg_ds(
424+
ds_in: xr.Dataset,
425+
varname: str,
362426
*,
427+
time_from: xr.Dataset | None = None,
428+
method: Optional[str] = None,
363429
conservative_map: Optional[Path] = None,
364-
) -> xr.DataArray:
365-
"""Regrid a mask or cell-area field using conservative weights."""
366-
if "ncol" not in da_in.dims:
367-
raise ValueError("Expected 'ncol' in dims for mask/area regridding.")
368-
if "time" in da_in.dims:
369-
da_in = da_in.transpose("time", "ncol", ...)
370-
371-
spec = MapSpec(
372-
"conservative", Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP
373-
)
374-
regridder = _RegridderCache.get(spec.path, spec.method_label)
430+
bilinear_map: Optional[Path] = None,
431+
keep_attrs: bool = True,
432+
dtype: str | None = "float32",
433+
output_time_chunk: int | None = 12,
434+
) -> xr.Dataset:
435+
"""Regrid `varname` and return a Dataset containing the regridded variable.
375436
376-
out = regridder(da_in)
377-
out = _rename_xy_to_latlon(out)
378-
return out
437+
Parameters mirror regrid_to_1deg, but this function:
438+
- returns an xr.Dataset({varname: DataArray})
439+
- if `time_from` is provided, copies 'time' and its bounds into the dataset
440+
"""
441+
da = regrid_to_1deg(
442+
ds_in,
443+
varname,
444+
method=method,
445+
conservative_map=conservative_map,
446+
bilinear_map=bilinear_map,
447+
keep_attrs=keep_attrs,
448+
dtype=dtype,
449+
output_time_chunk=output_time_chunk,
450+
)
451+
ds_out = xr.Dataset({varname: da})
452+
if time_from is not None:
453+
ds_out = _attach_time_and_bounds(ds_out, time_from)
454+
return ds_out
455+
456+
457+
# def regrid_mask_or_area(
458+
# da_in: xr.DataArray,
459+
# *,
460+
# conservative_map: Optional[Path] = None,
461+
# ) -> xr.DataArray:
462+
# """Regrid a mask or cell-area field using conservative weights."""
463+
# if "ncol" not in da_in.dims:
464+
# raise ValueError("Expected 'ncol' in dims for mask/area regridding.")
465+
# if "time" in da_in.dims:
466+
# da_in = da_in.transpose("time", "ncol", ...)#
467+
# spec = MapSpec(
468+
# "conservative", Path(conservative_map) if conservative_map else DEFAULT_CONS_MAP
469+
# )
470+
# regridder = _RegridderCache.get(spec.path, spec.method_label)#
471+
#
472+
# out = regridder(da_in)
473+
# out = _rename_xy_to_latlon(out)
474+
# return out
475+
476+
# --- convenience: carry time + bounds from a source dataset into an output dataset ---
477+
478+
479+
def _attach_time_and_bounds(ds_out: xr.Dataset, time_from: xr.Dataset) -> xr.Dataset:
480+
"""Return ds_out with 'time' coord and existing time bounds copied from time_from.
481+
482+
- Looks for the bounds variable via time.attrs['bounds'], or 'time_bounds' / 'time_bnds'.
483+
- Aligns bounds along time and ensures dims are (time, nbnd) before attaching.
484+
- Adds bounds as a DATA VARIABLE (not a coord) so 'nbnd' can be created.
485+
"""
486+
if "time" in time_from:
487+
ds_out = ds_out.assign_coords(time=time_from["time"])
488+
489+
# locate bounds
490+
bname = time_from["time"].attrs.get("bounds")
491+
tb = None
492+
if isinstance(bname, str) and bname in time_from:
493+
tb = time_from[bname]
494+
elif "time_bounds" in time_from:
495+
bname, tb = "time_bounds", time_from["time_bounds"]
496+
elif "time_bnds" in time_from:
497+
bname, tb = "time_bnds", time_from["time_bnds"]
498+
499+
if tb is not None:
500+
# align length to ds_out
501+
if tb.sizes.get("time") != ds_out.sizes.get("time"):
502+
tb = tb.reindex(time=ds_out["time"])
503+
# ensure (time, nbnd) ordering
504+
if tb.dims[0] != "time":
505+
other = next(d for d in tb.dims if d != "time")
506+
tb = tb.transpose("time", other)
507+
ds_out[bname] = tb # data variable (nbnd is created if needed)
508+
ds_out["time"].attrs["bounds"] = bname
509+
510+
return ds_out

scripts/testone.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,46 @@
11
from pathlib import Path
22
import xarray as xr
33
from cmip7_prep.mapping_compat import Mapping
4-
from cmip7_prep.regrid import regrid_to_1deg
4+
from cmip7_prep.pipeline import realize_regrid_prepare
55
from cmip7_prep.cmor_writer import CmorSession
66

7-
mapping = Mapping.from_packaged_default() # uses cmip7_prep/mapping/cesm_to_cmip7.yaml
8-
var = "tas" # pick a mapped Amon var
9-
cfg = mapping.get_cfg(var)
10-
cesm_vars = []
11-
if cfg.get("source"): # direct 1:1 mapping
12-
cesm_vars.append(cfg["source"])
13-
if cfg.get("raw_variables"): # identity or formula inputs
14-
cesm_vars.extend(cfg["raw_variables"])
15-
cesm_vars = sorted(set(cesm_vars)) # unique, ordered
7+
# 0) Load mapping (uses packaged data/cesm_to_cmip7.yaml by default)
8+
mapping = Mapping.from_packaged_default()
9+
cmip_var = "tas"
10+
11+
# 1) Only open the files required for this CMIP var
12+
cfg = mapping.get_cfg(cmip_var)
13+
cesm_vars = sorted(
14+
{*(cfg.get("raw_variables") or []), *([cfg["source"]] if cfg.get("source") else [])}
15+
)
1616

17-
# 2) Build the file list by variable name match
1817
basedir = Path(
1918
"/glade/derecho/scratch/cmip7/archive/timeseries/b.e30_beta06.B1850C_LTso.ne30_t232_wgx3.192.wrkflw.1_32/atm/hist_monthly/"
2019
)
2120
files = sorted({str(p) for v in cesm_vars for p in basedir.glob(f"*{v}*.nc")})
22-
2321
if not files:
24-
raise FileNotFoundError(f"No files matched for {cesm_vars} in {basedir}")
25-
26-
ds = xr.open_mfdataset(files, combine="by_coords", use_cftime=True, parallel=True)
22+
raise FileNotFoundError(f"No files for {cesm_vars}")
2723

28-
ds = ds.isel(time=slice(0, 12))
29-
30-
da = mapping.realize(ds, var)
31-
# 3) Choose a time chunk that is a MULTIPLE of the stored chunk size
32-
t_stored = None
33-
chunksizes = da.encoding.get("chunksizes")
34-
if chunksizes and "time" in da.dims:
35-
t_axis = da.dims.index("time")
36-
try:
37-
t_stored = int(chunksizes[t_axis])
38-
except Exception:
39-
t_stored = None
40-
41-
# If the file is chunked 1-month along time (common), this picks 12 months per task.
42-
# If stored time-chunk is 24, this picks 24 or 48, etc.—always a multiple → no warning.
43-
if t_stored:
44-
time_chunk = t_stored * 12
45-
else:
46-
time_chunk = 12 # fallback; may warn on some datasets
47-
da = da.chunk({"time": time_chunk})
24+
ds_native = xr.open_mfdataset(
25+
files, combine="by_coords", use_cftime=True, parallel=True
26+
)
4827

49-
da_rg = regrid_to_1deg(
50-
xr.Dataset({var: da}), var, output_time_chunk=time_chunk, dtype="float32"
28+
# 2) One call: realize → chunk → regrid → carry time+bounds
29+
ds_cmor = realize_regrid_prepare(
30+
mapping,
31+
ds_native,
32+
cmip_var,
33+
time_chunk=12,
34+
regrid_kwargs={"output_time_chunk": 12, "dtype": "float32"},
5135
)
5236

37+
# 3) CMOR write
5338
with CmorSession(
5439
tables_path="/glade/work/cmip7/e3sm_to_cmip/cmip6-cmor-tables/Tables"
5540
) as cm:
5641
vdef = type(
57-
"VDef", (), {"name": var, "realm": "Amon", "units": cfg.get("units", "")}
42+
"VDef", (), {"name": cmip_var, "realm": "Amon", "units": cfg.get("units", "")}
5843
)()
59-
cm.write_variable(xr.Dataset({var: da_rg}), var, vdef, outdir=Path("out_smoke"))
44+
cm.write_variable(ds_cmor, cmip_var, vdef, outdir=Path("out_smoke"))
45+
6046
print("ok")

0 commit comments

Comments
 (0)