Skip to content

Commit 4644db3

Browse files
committed
Remove commented out code in example run script
1 parent 8280d60 commit 4644db3

File tree

2 files changed

+216
-30
lines changed

2 files changed

+216
-30
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# %%
2+
"""
3+
This script compares regridding results between two methods:
4+
1. `xesmf` (xarray-based regridding library).
5+
2. `regrid2` (cdms2-based regridding library).
6+
7+
Key Steps:
8+
1. Load datasets.
9+
2. Perform regridding using `xesmf` with unsorted and sorted latitude bounds.
10+
3. Perform regridding using `regrid2`.
11+
4. Compare statistical differences in results.
12+
13+
Findings:
14+
- Regridding results differ between `xesmf` and `regrid2` due to algorithmic differences.
15+
- `xesmf` depends on having coordinates and coordinate bounds aligned.
16+
- Statistical differences (e.g., min, max, mean, sum, std) highlight sensitivity to grid preparation and implementation.
17+
18+
conda create -n xcdat_cdat latest python xcdat=0.8.0 cdms2=3.1.5 ipykernel
19+
conda activate xcdat_cdat
20+
"""
21+
22+
# %%
23+
import cdms2
24+
import numpy as np
25+
import pandas as pd
26+
from regrid2 import Regridder
27+
from regrid2.horizontal import extractBounds
28+
29+
30+
def print_stats(*arrays, labels=None):
31+
"""Prints statistical comparison of multiple arrays using a pandas DataFrame."""
32+
if labels is None:
33+
labels = [f"Array {i + 1}" for i in range(len(arrays))]
34+
elif len(labels) != len(arrays):
35+
raise ValueError("Number of labels must match the number of arrays.")
36+
37+
stats = {
38+
"Min": [np.min(arr) for arr in arrays],
39+
"Max": [np.max(arr) for arr in arrays],
40+
"Mean": [np.mean(arr) for arr in arrays],
41+
"Sum": [np.sum(arr) for arr in arrays],
42+
"Std": [np.std(arr) for arr in arrays],
43+
}
44+
45+
# Create a DataFrame from the stats dictionary
46+
df = pd.DataFrame(stats, index=labels)
47+
48+
# Print the DataFrame
49+
print("\nStatistical Comparison:")
50+
print(df)
51+
52+
#%%
53+
def make_lat_descending(var):
54+
lat = var.getLatitude()
55+
lat_index = next(i for i, ax in enumerate(var.getAxisList()) if ax.id == lat.id)
56+
57+
# Reverse latitude values
58+
lat_vals = lat[:][::-1]
59+
lat_reversed = cdms2.createAxis(lat_vals)
60+
lat_reversed.id = lat.id
61+
lat_reversed.units = lat.units
62+
lat_reversed.designateLatitude()
63+
64+
# Reverse data along latitude axis
65+
slicer = [slice(None)] * var.ndim
66+
slicer[lat_index] = slice(None, None, -1)
67+
data_reversed = var[tuple(slicer)]
68+
69+
# Replace the latitude axis in the axis list
70+
new_axes = list(var.getAxisList())
71+
new_axes[lat_index] = lat_reversed
72+
73+
# Create new variable with updated latitude axis
74+
var_reversed = cdms2.createVariable(data_reversed, axes=new_axes, id=var.id)
75+
76+
return var_reversed
77+
78+
def drop_bounds(var, axis_ids=("latitude",)):
79+
"""
80+
Returns a copy of `var` with bounds removed from specified axes.
81+
"""
82+
axes = []
83+
for ax in var.getAxisList():
84+
ax_copy = cdms2.createAxis(ax[:])
85+
ax_copy.id = ax.id
86+
ax_copy.units = getattr(ax, "units", "")
87+
if ax.id.lower() in axis_ids or ax.isLatitude() or ax.isLongitude():
88+
ax_copy.setBounds(None)
89+
axes.append(ax_copy)
90+
91+
new_var = cdms2.createVariable(var[:], axes=axes, id=var.id)
92+
return new_var
93+
94+
# %%
95+
# 1. CDAT + Regrid2 (ascending latitude, descending latitude bounds) -- -- default values, automatically sorted
96+
# --------------------------------------------------------------------
97+
# Convert xarray datasets to cdms2 variables
98+
with (
99+
cdms2.open(
100+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
101+
) as f_a,
102+
cdms2.open(
103+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
104+
) as f_b,
105+
):
106+
var_a1 = f_a("PRECT")
107+
var_b1 = f_b("PRECT")
108+
109+
# Create regridder using regrid2
110+
misaligned1 = Regridder(var_b1.getGrid(), var_a1.getGrid())(var_b1)
111+
112+
#%%
113+
# 2. CDAT + Regrid2 (descending latitude, ascending latitude bounds)
114+
# --------------------------------------------------------------------
115+
# Convert xarray datasets to cdms2 variables
116+
with (
117+
cdms2.open(
118+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
119+
) as f_a,
120+
cdms2.open(
121+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
122+
) as f_b,
123+
):
124+
var_a2 = f_a("PRECT")
125+
var_b2 = f_b("PRECT")
126+
127+
var_a2 = make_lat_descending(var_a2)
128+
var_b2 = make_lat_descending(var_b2)
129+
130+
131+
# Create regridder using regrid2
132+
aligned = Regridder(var_b2.getGrid(), var_a2.getGrid())(var_b2)
133+
134+
135+
# %%
136+
# 3. CDAT + Regrid2 (ascending latitude, no latitude bounds)
137+
# --------------------------------------------------------------------
138+
# Convert xarray datasets to cdms2 variables
139+
with (
140+
cdms2.open(
141+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
142+
) as f_a,
143+
cdms2.open(
144+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
145+
) as f_b,
146+
):
147+
var_a3 = f_a("PRECT")
148+
var_b3 = f_b("PRECT")
149+
150+
var_a3 = drop_bounds(var_a3)
151+
var_b3 = drop_bounds(var_a3)
152+
153+
154+
# Create regridder using regrid2
155+
no_bnds1 = Regridder(var_b3.getGrid(), var_a3.getGrid())(var_b3)
156+
157+
158+
# %%
159+
# 4. CDAT + Regrid2 (ascending latitude, no latitude bounds)
160+
# --------------------------------------------------------------------
161+
# Convert xarray datasets to cdms2 variables
162+
with (
163+
cdms2.open(
164+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/test.nc"
165+
) as f_a,
166+
cdms2.open(
167+
"/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/ref.nc"
168+
) as f_b,
169+
):
170+
var_a4 = f_a("PRECT")
171+
var_b4 = f_b("PRECT")
172+
173+
var_a4 = make_lat_descending(var_a4)
174+
var_b4 = make_lat_descending(var_a4)
175+
176+
var_a4 = drop_bounds(var_a4)
177+
var_b4 = drop_bounds(var_a4)
178+
179+
180+
# Create regridder using regrid2
181+
no_bnds2 = Regridder(var_b4.getGrid(), var_a4.getGrid())(var_b4)
182+
183+
184+
# %%
185+
# Compare statistics
186+
# ----------------------------------------------------
187+
print_stats(
188+
misaligned1,
189+
aligned,
190+
no_bnds1,
191+
no_bnds2,
192+
labels=[
193+
"asc lat, desc lat_bnds",
194+
"desc lat, desc lat_bnds",
195+
"asc lat, no lat_bnds",
196+
"desc lat, no lat_bnds",
197+
],
198+
)
Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
11
import os
22
from e3sm_diags.parameter.core_parameter import CoreParameter
33
from e3sm_diags.parameter.diurnal_cycle_parameter import DiurnalCycleParameter
4-
#from e3sm_diags.parameter.enso_diags_parameter import EnsoDiagsParameter
5-
#from e3sm_diags.parameter.qbo_parameter import QboParameter
6-
#from e3sm_diags.parameter.streamflow_parameter import StreamflowParameter
7-
#from e3sm_diags.parameter.tc_analysis_parameter import TCAnalysisParameter
8-
from e3sm_diags.parameter.tropical_subseasonal_parameter import TropicalSubseasonalParameter
4+
from e3sm_diags.parameter.tropical_subseasonal_parameter import (
5+
TropicalSubseasonalParameter,
6+
)
97
from e3sm_diags.run import runner
108

119
param = CoreParameter()
1210

13-
param.reference_data_path = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology'
11+
param.reference_data_path = (
12+
"/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology"
13+
)
1414

15-
test_base_path = '/pscratch/sd/c/chengzhu/ne256pg2_ne256pg2.F20TR-SCREAMv1.rainfrac1.spanc1000.auto2700.acc150.n0128'
16-
ref_climo = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology/'
17-
ref_ts = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series'
18-
param.test_data_path = f'{test_base_path}/rgr/climo'
19-
param.test_name = '1ma_ne30pg2.AVERAGE.nmonths_x1'
20-
param.seasons = ["ANN","DJF","MAM","JJA","SON"]
21-
#param.save_netcdf = True
15+
test_base_path = "/pscratch/sd/c/chengzhu/ne256pg2_ne256pg2.F20TR-SCREAMv1.rainfrac1.spanc1000.auto2700.acc150.n0128"
16+
ref_climo = "/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology/"
17+
ref_ts = "/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series"
18+
param.test_data_path = f"{test_base_path}/rgr/climo"
19+
param.test_name = "1ma_ne30pg2.AVERAGE.nmonths_x1"
20+
param.seasons = ["ANN", "DJF", "MAM", "JJA", "SON"]
21+
# param.save_netcdf = True
2222

23-
prefix = '/global/cfs/cdirs/e3sm/www/zhang40/tests/eamxx'
24-
param.results_dir = os.path.join(prefix, 'eamxx_ne256_0520_trop')
23+
prefix = "/global/cfs/cdirs/e3sm/www/zhang40/tests/eamxx"
24+
param.results_dir = os.path.join(prefix, "eamxx_ne256_0520_trop")
2525
params = [param]
2626

2727
trop_param = TropicalSubseasonalParameter()
28-
trop_param.test_data_path = f'{test_base_path}/rgr/ts_daily'
28+
trop_param.test_data_path = f"{test_base_path}/rgr/ts_daily"
2929
trop_param.short_test_name = "Daily_3hi_ne30pg2"
3030
trop_param.test_start_yr = 1996
3131
trop_param.test_end_yr = 2001
@@ -38,7 +38,7 @@
3838
params.append(trop_param)
3939

4040
dc_param = DiurnalCycleParameter()
41-
dc_param.test_data_path = f'{test_base_path}/rgr/climo_diurnal_3hrly'
41+
dc_param.test_data_path = f"{test_base_path}/rgr/climo_diurnal_3hrly"
4242
dc_param.test_name = "3hi_ne30pg2.INSTANT.nhours_x3"
4343
dc_param.short_test_name = "3hi_ne30pg2"
4444
# Plotting diurnal cycle amplitude on different scales. Default is True
@@ -49,18 +49,6 @@
4949

5050
params.append(dc_param)
5151

52-
runner.sets_to_run = [
53-
#"lat_lon",
54-
#"zonal_mean_xy",
55-
#"zonal_mean_2d",
56-
#"zonal_mean_2d_stratosphere",
57-
#"polar",
58-
#"cosp_histogram",
59-
#"meridional_mean_2d",
60-
#"annual_cycle_zonal_mean",
61-
"tropical_subseasonal",
62-
#"diurnal_cycle",
63-
]
52+
runner.sets_to_run = ["tropical_subseasonal"]
6453

6554
runner.run_diags(params)
66-

0 commit comments

Comments
 (0)