Skip to content

Commit e55614e

Browse files
committed
feat: add Emanuel V3 optimizations, UnytBackend benchmarks, and test stabilization
- Implement Numba optimizations for EmanuelConvectionPythonV3. - Add comprehensive benchmarks for Numba vs. Backend overhead (benchmark_numba_x_backend.py). - Add UnytBackend support to Emanuel V3 for 3.5x speedup over DataArrayBackend. - Stabilize test suite by adding constant reset to pytest conftest.py. - Add Emanuel V3 parity tests and record benchmark results.
1 parent 300a7d6 commit e55614e

14 files changed

Lines changed: 4891 additions & 156 deletions
Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,58 @@
11
import time
22
from datetime import timedelta
3+
34
import numpy as np
45
import sympl
56
from sympl._core.backend import DataArrayBackend
7+
68
import climt
7-
from climt import EmanuelConvection, EmanuelConvectionPython, get_default_state, get_grid
9+
from climt import (
10+
EmanuelConvection,
11+
EmanuelConvectionPythonV3,
12+
get_default_state,
13+
get_grid,
14+
)
15+
816

917
def run_benchmark():
1018
print("Benchmarking Emanuel Convection: Fortran vs Pure Python...")
11-
19+
1220
sympl.set_backend(DataArrayBackend())
1321
grid = get_grid(nx=32, ny=32, nz=30)
14-
22+
1523
# Initialize components
1624
conv_fortran = EmanuelConvection()
17-
conv_python = EmanuelConvectionPython()
18-
25+
conv_python = EmanuelConvectionPythonV3()
26+
1927
# Initial state
2028
state = get_default_state([conv_fortran], grid_state=grid)
2129
# Ensure specific humidity is reasonable for convection
22-
state['specific_humidity'].values[:] = 0.01
23-
state['air_temperature'].values[:] = 290.0
24-
30+
state["specific_humidity"].values[:] = 0.04
31+
state["air_temperature"].values[:] = 290.0
32+
2533
timestep = timedelta(minutes=20)
26-
34+
2735
# Run Fortran version
2836
start = time.perf_counter()
2937
t_fort, d_fort = conv_fortran(state, timestep)
3038
dur_fort = time.perf_counter() - start
3139
print(f" Fortran Duration: {dur_fort:.4f} s")
32-
40+
3341
# Run Python version
3442
start = time.perf_counter()
3543
t_py, d_py = conv_python(state, timestep)
3644
dur_py = time.perf_counter() - start
3745
print(f" Python Duration: {dur_py:.4f} s")
38-
46+
3947
# Compare outputs
4048
print("\nVerifying Outputs...")
4149
t_vars = ["air_temperature", "specific_humidity"]
4250
for var in t_vars:
4351
diff = np.abs(t_fort[var].values - t_py[var].values)
4452
print(f" {var} max diff: {np.max(diff):.2e}")
4553

46-
print(f"\nSpeedup (Fortran/Python): {dur_fort/dur_py:.2f}x")
54+
print(f"\nSpeedup (Fortran/Python): {dur_fort / dur_py:.2f}x")
55+
4756

4857
if __name__ == "__main__":
4958
run_benchmark()
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import time
2+
from datetime import timedelta
3+
4+
import numpy as np
5+
import sympl
6+
from sympl._core.backend import DataArrayBackend
7+
8+
import climt
9+
from climt import (
10+
EmanuelConvection,
11+
EmanuelConvectionPythonV3,
12+
RRTMGLongwave,
13+
RRTMGShortwave,
14+
SimplePhysics,
15+
SlabSurface,
16+
UnytBackend,
17+
get_default_state,
18+
)
19+
20+
21+
def run_benchmark(convection_cls, backend_obj, iterations=1000):
22+
"""
23+
Runs a radiative-convective equilibrium simulation for a given backend and convection scheme.
24+
Based on examples/gmd_radiative_convective_python_emanuelv3.py
25+
"""
26+
sympl.set_backend(backend_obj)
27+
28+
timestep = timedelta(minutes=5)
29+
if isinstance(backend_obj, UnytBackend):
30+
from climt import UnytTimeDelta
31+
32+
timestep = UnytTimeDelta(timestep)
33+
34+
# Initialize components
35+
convection = convection_cls()
36+
radiation_sw = RRTMGShortwave()
37+
radiation_lw = RRTMGLongwave()
38+
slab = SlabSurface()
39+
simple_physics = SimplePhysics()
40+
41+
# Get default state
42+
state = get_default_state(
43+
[simple_physics, convection, radiation_lw, radiation_sw, slab]
44+
)
45+
46+
# Helper to set values across backends
47+
def set_values(obj, val):
48+
if hasattr(obj, "values"):
49+
obj.values[:] = val
50+
elif hasattr(obj, "data"):
51+
obj.data[:] = val
52+
else:
53+
obj[:] = val
54+
55+
# Set initial conditions as in the example
56+
set_values(state["surface_albedo_for_direct_shortwave"], 0.5)
57+
set_values(state["surface_albedo_for_direct_near_infrared"], 0.5)
58+
set_values(state["surface_albedo_for_diffuse_shortwave"], 0.5)
59+
set_values(state["zenith_angle"], np.pi / 2.5)
60+
set_values(state["surface_temperature"], 300.0)
61+
set_values(state["ocean_mixed_layer_thickness"], 5)
62+
set_values(state["area_type"], "sea")
63+
64+
time_stepper = sympl.AdamsBashforth([convection, radiation_lw, radiation_sw, slab])
65+
66+
# --- Warmup Phase ---
67+
# We run a full iteration of the loop logic to trigger any JIT compilation (Numba/RRTMG)
68+
convection.current_time_step = timestep
69+
diagnostics, state = time_stepper(state, timestep)
70+
state.update(diagnostics)
71+
diagnostics, new_state = simple_physics(state, timestep)
72+
state.update(diagnostics)
73+
state.update(new_state)
74+
state["time"] += timestep
75+
set_values(state["eastward_wind"], 3.0)
76+
77+
# --- Timed Phase ---
78+
start_time = time.perf_counter()
79+
80+
for i in range(iterations):
81+
# Update convection timestep (required by Emanuel scheme in climt)
82+
convection.current_time_step = timestep
83+
84+
# Step the main components
85+
diagnostics, state = time_stepper(state, timestep)
86+
state.update(diagnostics)
87+
88+
# Step simple physics
89+
diagnostics, new_state = simple_physics(state, timestep)
90+
state.update(diagnostics)
91+
state.update(new_state)
92+
93+
# Update time and boundary conditions
94+
state["time"] += timestep
95+
set_values(state["eastward_wind"], 3.0)
96+
97+
end_time = time.perf_counter()
98+
return end_time - start_time
99+
100+
101+
if __name__ == "__main__":
102+
iterations = 1000
103+
print(f"Benchmarking Emanuel Convection backends ({iterations} iterations)...")
104+
print("=" * 80)
105+
print(f"{'Configuration':<45} | {'Duration (s)':<15} | {'ms/step':<10}")
106+
print("-" * 80)
107+
108+
results = {}
109+
110+
configs = [
111+
("Fortran + DataArray", EmanuelConvection, DataArrayBackend()),
112+
("Fortran + Unyt", EmanuelConvection, UnytBackend()),
113+
("V3 Python + DataArray", EmanuelConvectionPythonV3, DataArrayBackend()),
114+
("V3 Python + Unyt", EmanuelConvectionPythonV3, UnytBackend()),
115+
]
116+
117+
for label, conv_cls, backend in configs:
118+
try:
119+
duration = run_benchmark(conv_cls, backend, iterations=iterations)
120+
ms_per_step = (duration / iterations) * 1000
121+
print(f"{label:<45} | {duration:<15.4f} | {ms_per_step:<10.2f}")
122+
results[label] = duration
123+
except Exception as e:
124+
print(f"{label:<45} | {'FAILED':<15} | N/A")
125+
print(f"Error: {e}")
126+
127+
print("=" * 80)
128+
if "Fortran + DataArray" in results and "Fortran + Unyt" in results:
129+
f_da = results["Fortran + DataArray"]
130+
f_unyt = results["Fortran + Unyt"]
131+
print(f"Fortran Backend Speedup (Unyt/DA): {f_da / f_unyt:.2f}x")
132+
133+
if "V3 Python + DataArray" in results and "V3 Python + Unyt" in results:
134+
v3_da = results["V3 Python + DataArray"]
135+
v3_unyt = results["V3 Python + Unyt"]
136+
print(f"V3 Python Backend Speedup (Unyt/DA): {v3_da / v3_unyt:.2f}x")
137+
138+
if "Fortran + Unyt" in results and "V3 Python + Unyt" in results:
139+
fortran_unyt = results["Fortran + Unyt"]
140+
v3_unyt = results["V3 Python + Unyt"]
141+
print(f"V3 (Unyt) vs Fortran (Unyt): {fortran_unyt / v3_unyt:.2f}x")
142+
print("=" * 80)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
Benchmark: Held-Suarez GCM with GFS dynamical core.
3+
4+
Runs 300 time steps of the Held-Suarez test case using UnytBackend,
5+
with and without Numba JIT. Each config runs in a subprocess to
6+
ensure clean JIT state.
7+
8+
Usage:
9+
python benchmarks/benchmark_held_suarez_gcm.py
10+
python benchmarks/benchmark_held_suarez_gcm.py --steps 100
11+
"""
12+
import argparse
13+
import json
14+
import os
15+
import subprocess
16+
import sys
17+
import time
18+
19+
20+
WORKER_SCRIPT = '''
21+
import time, json, os
22+
import numpy as np
23+
import sympl
24+
from datetime import timedelta
25+
26+
import climt
27+
from climt import HeldSuarez, UnytBackend, get_default_state, get_grid
28+
from gfs_dynamical_core import GFSDynamicalCore
29+
30+
STEPS = {steps}
31+
32+
sympl.set_backend(UnytBackend())
33+
34+
model_time_step = timedelta(seconds=600)
35+
36+
held_suarez = HeldSuarez()
37+
dycore = GFSDynamicalCore([held_suarez])
38+
grid = get_grid(nx=128, ny=62)
39+
my_state = get_default_state([dycore], grid_state=grid)
40+
41+
my_state['eastward_wind'].values[:] = np.random.randn(
42+
*my_state['eastward_wind'].shape
43+
)
44+
45+
# Warmup (1 step, includes JIT compilation)
46+
t_warmup_start = time.perf_counter()
47+
diag, output = dycore(my_state, model_time_step)
48+
my_state.update(output)
49+
my_state['time'] += model_time_step
50+
t_warmup = time.perf_counter() - t_warmup_start
51+
52+
# Timed run
53+
t0 = time.perf_counter()
54+
for i in range(STEPS):
55+
diag, output = dycore(my_state, model_time_step)
56+
my_state.update(output)
57+
my_state['time'] += model_time_step
58+
elapsed = time.perf_counter() - t0
59+
60+
numba_mode = "OFF" if os.environ.get("NUMBA_DISABLE_JIT") == "1" else "ON"
61+
print(json.dumps({{
62+
"elapsed": elapsed,
63+
"warmup": t_warmup,
64+
"steps": STEPS,
65+
"per_step": elapsed / STEPS,
66+
"numba": numba_mode,
67+
}}))
68+
'''
69+
70+
71+
def run_config(steps, disable_jit):
72+
env = os.environ.copy()
73+
if disable_jit:
74+
env["NUMBA_DISABLE_JIT"] = "1"
75+
else:
76+
env.pop("NUMBA_DISABLE_JIT", None)
77+
78+
script = WORKER_SCRIPT.format(steps=steps)
79+
result = subprocess.run(
80+
[sys.executable, "-c", script],
81+
capture_output=True,
82+
text=True,
83+
env=env,
84+
timeout=1800,
85+
)
86+
if result.returncode != 0:
87+
print(f"FAILED (exit {result.returncode}):", file=sys.stderr)
88+
for line in result.stderr.strip().split("\n")[-5:]:
89+
print(f" {line}", file=sys.stderr)
90+
return None
91+
92+
# GFS core may print debug lines to stdout; find the JSON line
93+
for line in reversed(result.stdout.strip().split("\n")):
94+
line = line.strip()
95+
if line.startswith("{"):
96+
try:
97+
return json.loads(line)
98+
except json.JSONDecodeError:
99+
continue
100+
print(f"No JSON in output: {result.stdout[:300]}", file=sys.stderr)
101+
return None
102+
103+
104+
def main():
105+
parser = argparse.ArgumentParser(
106+
description="Held-Suarez GCM benchmark (UnytBackend, Numba on/off)"
107+
)
108+
parser.add_argument("--steps", type=int, default=300, help="Timesteps to run")
109+
args = parser.parse_args()
110+
111+
print(f"Held-Suarez GCM benchmark: {args.steps} steps, 128x62 grid, UnytBackend")
112+
print("=" * 70)
113+
114+
for label, disable_jit in [("Numba OFF", True), ("Numba ON", False)]:
115+
print(f"\nRunning {label}...", flush=True)
116+
t_wall_start = time.perf_counter()
117+
data = run_config(args.steps, disable_jit)
118+
t_wall = time.perf_counter() - t_wall_start
119+
120+
if data:
121+
print(f" Warmup (1 step): {data['warmup']:.2f}s")
122+
print(f" Timed ({data['steps']} steps): {data['elapsed']:.2f}s")
123+
print(f" Per step: {data['per_step']:.4f}s")
124+
print(f" Wall clock total: {t_wall:.1f}s")
125+
else:
126+
print(f" FAILED")
127+
128+
print()
129+
130+
131+
if __name__ == "__main__":
132+
main()

0 commit comments

Comments
 (0)