Skip to content

Commit 9af3545

Browse files
committed
Update scripts
1 parent 895e3dd commit 9af3545

File tree

2 files changed

+199
-1
lines changed

2 files changed

+199
-1
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# %%
2+
"""
3+
This script compares regridding results between two methods:
4+
1. `xesmf` (xarray-based regridding library).
5+
2. `regrid2` (cdms2-based regridding library).
6+
7+
Key Steps:
8+
1. Load datasets.
9+
2. Perform regridding using `xesmf` with unsorted and sorted latitude bounds.
10+
3. Perform regridding using `regrid2`.
11+
4. Compare statistical differences in results.
12+
13+
Findings:
14+
- Regridding results differ between `xesmf` and `regrid2` due to algorithmic differences.
15+
- `xesmf` depends on having coordinates and coordinate bounds aligned.
16+
- Statistical differences (e.g., min, max, mean, sum, std) highlight sensitivity to grid preparation and implementation.
17+
18+
conda create -n xcdat_cdat latest python xcdat=0.8.0 cdms2=3.1.5 ipykernel
19+
conda activate xcdat_cdat
20+
"""
21+
22+
# %%
23+
import cdms2
24+
import numpy as np
25+
import pandas as pd
26+
from regrid2 import Regridder
27+
from regrid2.horizontal import extractBounds
28+
29+
30+
def print_stats(*arrays, labels=None):
31+
"""Prints statistical comparison of multiple arrays using a pandas DataFrame."""
32+
if labels is None:
33+
labels = [f"Array {i + 1}" for i in range(len(arrays))]
34+
elif len(labels) != len(arrays):
35+
raise ValueError("Number of labels must match the number of arrays.")
36+
37+
stats = {
38+
"Min": [np.min(arr) for arr in arrays],
39+
"Max": [np.max(arr) for arr in arrays],
40+
"Mean": [np.mean(arr) for arr in arrays],
41+
"Sum": [np.sum(arr) for arr in arrays],
42+
"Std": [np.std(arr) for arr in arrays],
43+
}
44+
45+
# Create a DataFrame from the stats dictionary
46+
df = pd.DataFrame(stats, index=labels)
47+
48+
# Print the DataFrame
49+
print("\nStatistical Comparison:")
50+
print(df)
51+
52+
#%%
53+
def make_lat_descending(var):
54+
lat = var.getLatitude()
55+
lat_index = next(i for i, ax in enumerate(var.getAxisList()) if ax.id == lat.id)
56+
57+
# Reverse latitude values
58+
lat_vals = lat[:][::-1]
59+
lat_reversed = cdms2.createAxis(lat_vals)
60+
lat_reversed.id = lat.id
61+
lat_reversed.units = lat.units
62+
lat_reversed.designateLatitude()
63+
64+
# Reverse data along latitude axis
65+
slicer = [slice(None)] * var.ndim
66+
slicer[lat_index] = slice(None, None, -1)
67+
data_reversed = var[tuple(slicer)]
68+
69+
# Replace the latitude axis in the axis list
70+
new_axes = list(var.getAxisList())
71+
new_axes[lat_index] = lat_reversed
72+
73+
# Create new variable with updated latitude axis
74+
var_reversed = cdms2.createVariable(data_reversed, axes=new_axes, id=var.id)
75+
76+
return var_reversed
77+
78+
def drop_bounds(var, axis_ids=("latitude",)):
79+
"""
80+
Returns a copy of `var` with bounds removed from specified axes.
81+
"""
82+
axes = []
83+
for ax in var.getAxisList():
84+
ax_copy = cdms2.createAxis(ax[:])
85+
ax_copy.id = ax.id
86+
ax_copy.units = getattr(ax, "units", "")
87+
if ax.id.lower() in axis_ids or ax.isLatitude() or ax.isLongitude():
88+
ax_copy.setBounds(None)
89+
axes.append(ax_copy)
90+
91+
new_var = cdms2.createVariable(var[:], axes=axes, id=var.id)
92+
return new_var
93+
94+
# %%
95+
# 1. CDAT + Regrid2 (ascending latitude, descending latitude bounds) -- -- default values, automatically sorted
96+
# --------------------------------------------------------------------
97+
# Convert xarray datasets to cdms2 variables
98+
with (
99+
cdms2.open(
100+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
101+
) as f_a,
102+
cdms2.open(
103+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
104+
) as f_b,
105+
):
106+
var_a1 = f_a("PRECT")
107+
var_b1 = f_b("PRECT")
108+
109+
# Create regridder using regrid2
110+
misaligned1 = Regridder(var_b1.getGrid(), var_a1.getGrid())(var_b1)
111+
112+
#%%
113+
# 2. CDAT + Regrid2 (descending latitude, ascending latitude bounds)
114+
# --------------------------------------------------------------------
115+
# Convert xarray datasets to cdms2 variables
116+
with (
117+
cdms2.open(
118+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
119+
) as f_a,
120+
cdms2.open(
121+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
122+
) as f_b,
123+
):
124+
var_a2 = f_a("PRECT")
125+
var_b2 = f_b("PRECT")
126+
127+
var_a2 = make_lat_descending(var_a2)
128+
var_b2 = make_lat_descending(var_b2)
129+
130+
131+
# Create regridder using regrid2
132+
aligned = Regridder(var_b2.getGrid(), var_a2.getGrid())(var_b2)
133+
134+
135+
# %%
136+
# 3. CDAT + Regrid2 (ascending latitude, no latitude bounds)
137+
# --------------------------------------------------------------------
138+
# Convert xarray datasets to cdms2 variables
139+
with (
140+
cdms2.open(
141+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
142+
) as f_a,
143+
cdms2.open(
144+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
145+
) as f_b,
146+
):
147+
var_a3 = f_a("PRECT")
148+
var_b3 = f_b("PRECT")
149+
150+
var_a3 = drop_bounds(var_a3)
151+
var_b3 = drop_bounds(var_a3)
152+
153+
154+
# Create regridder using regrid2
155+
no_bnds1 = Regridder(var_b3.getGrid(), var_a3.getGrid())(var_b3)
156+
157+
158+
# %%
159+
# 4. CDAT + Regrid2 (ascending latitude, no latitude bounds)
160+
# --------------------------------------------------------------------
161+
# Convert xarray datasets to cdms2 variables
162+
with (
163+
cdms2.open(
164+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
165+
) as f_a,
166+
cdms2.open(
167+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
168+
) as f_b,
169+
):
170+
var_a4 = f_a("PRECT")
171+
var_b4 = f_b("PRECT")
172+
173+
var_a4 = make_lat_descending(var_a4)
174+
var_b4 = make_lat_descending(var_a4)
175+
176+
var_a4 = drop_bounds(var_a4)
177+
var_b4 = drop_bounds(var_a4)
178+
179+
180+
# Create regridder using regrid2
181+
no_bnds2 = Regridder(var_b4.getGrid(), var_a4.getGrid())(var_b4)
182+
183+
184+
# %%
185+
# Compare statistics
186+
# ----------------------------------------------------
187+
print_stats(
188+
misaligned1,
189+
aligned,
190+
no_bnds1,
191+
no_bnds2,
192+
labels=[
193+
"asc lat, desc lat_bnds",
194+
"desc lat, desc lat_bnds",
195+
"asc lat, no lat_bnds",
196+
"desc lat, no lat_bnds",
197+
],
198+
)

auxiliary_tools/debug/945-xesmf-diffs/compare_xesmf_xcdat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def print_stats(*arrays, labels=None):
7474
)
7575

7676
# 6. Asc lat, no lat_bnds
77-
ds_b6 = ds_b.copy(deep=True).sortby("lat", ascending=True)
77+
ds_b6 = ds_b.copy(deep=True).sortby("lat", descending=True)
7878
ds_b6 = ds_b6.drop("lat_bnds")
7979
nobounds2 = ds_b6.regridder.horizontal(
8080
"PRECT", output_grid_xesmf, tool='xesmf', method='conservative_normed'

0 commit comments

Comments
 (0)