Skip to content

Commit 300a7d6

Browse files
JoyMonteiroclaude
andcommitted
Add benchmark script measuring Numba speedups for all optimized components
Reports Python vs Numba timing for 8192 columns × 30 levels using ._pyfunc (py_func) on each @njit kernel. Results: 32x–124x speedups across HeldSuarez, GrayLW, Frierson tau, GSC, DryConv, Berger, SlabSurface, and Instellation components. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7700802 commit 300a7d6

1 file changed

Lines changed: 273 additions & 0 deletions

File tree

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
Benchmark numba speedup for each optimized component.
3+
4+
Timings are wall-clock seconds for N iterations over a 8192-column,
5+
30-level grid. The "Python" path calls the raw Python function via
6+
.py_func on the @njit-decorated kernel (same code, no JIT).
7+
"""
8+
import time
9+
import numpy as np
10+
import sympl
11+
import climt
12+
from climt import get_grid, get_default_state
13+
from datetime import timedelta
14+
15+
NCOL = 8192
16+
NLEV = 30
17+
TIMESTEP = timedelta(minutes=10)
18+
19+
20+
def wall(fn, iters):
21+
fn() # warm-up
22+
t0 = time.perf_counter()
23+
for _ in range(iters):
24+
fn()
25+
return time.perf_counter() - t0
26+
27+
28+
def report(name, t_py, t_nb, iters, ncol):
29+
speedup = t_py / t_nb
30+
print(f" Python : {t_py:.3f}s ({t_py/iters*1000:.2f} ms/call)")
31+
print(f" Numba : {t_nb:.3f}s ({t_nb/iters*1000:.2f} ms/call)")
32+
print(f" Speedup: {speedup:.1f}x")
33+
34+
35+
# ---------------------------------------------------------------------------
36+
def bench_held_suarez():
37+
from climt import HeldSuarez
38+
from climt._components.held_suarez import _held_suarez_kernel_np
39+
40+
ITERS = 50
41+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
42+
sympl.set_backend(sympl.DataArrayBackend())
43+
comp = HeldSuarez()
44+
state = get_default_state([comp], grid_state=grid)
45+
state['eastward_wind'].values[:] = 10.0
46+
state['air_temperature'].values[:] = 300.0
47+
comp(state) # trigger JIT compile
48+
49+
print(f"\nHeldSuarez ({NCOL} cols, {ITERS} iters)")
50+
t_nb = wall(lambda: comp(state), ITERS)
51+
t_py = wall(lambda: comp(state) if _held_suarez_kernel_np.py_func else None, ITERS)
52+
53+
# Time pure Python kernel directly
54+
t = state['air_temperature'].values.reshape(NLEV, -1)
55+
u = state['eastward_wind'].values.reshape(NLEV, -1)
56+
v = state['northward_wind'].values.reshape(NLEV, -1)
57+
p = state['air_pressure'].values.reshape(NLEV, -1)
58+
ps = state['surface_air_pressure'].values.reshape(-1)
59+
lat = state['latitude'].values.reshape(-1)
60+
params = comp._params
61+
62+
t_py = wall(lambda: _held_suarez_kernel_np.py_func(u, v, t, p, ps, lat, params), ITERS)
63+
t_nb2 = wall(lambda: _held_suarez_kernel_np(u, v, t, p, ps, lat, params), ITERS)
64+
report("HeldSuarez", t_py, t_nb2, ITERS, NCOL)
65+
66+
67+
def bench_gray_radiation():
68+
from climt import GrayLongwaveRadiation, Frierson06LongwaveOpticalDepth
69+
from climt._components.radiation import _gray_lw_kernel_np, _frierson_tau_kernel_np
70+
71+
ITERS = 50
72+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
73+
sympl.set_backend(sympl.DataArrayBackend())
74+
tau_comp = Frierson06LongwaveOpticalDepth()
75+
lw_comp = GrayLongwaveRadiation()
76+
state = get_default_state([tau_comp, lw_comp], grid_state=grid)
77+
state.update(tau_comp(state))
78+
lw_comp(state) # trigger JIT
79+
80+
from sympl import get_constant
81+
sigma = float(get_constant("stefan_boltzmann_constant", "W/m^2/K^4"))
82+
g = float(get_constant("gravitational_acceleration", "m/s^2"))
83+
cpd = float(get_constant("heat_capacity_of_dry_air_at_constant_pressure", "J/kg/K"))
84+
85+
t_flat = np.ascontiguousarray(state['air_temperature'].values.reshape(NLEV, -1))
86+
tau_flat = np.ascontiguousarray(state['longwave_optical_depth_on_interface_levels'].values.reshape(NLEV+1, -1))
87+
ts_flat = np.ascontiguousarray(state['surface_temperature'].values.reshape(-1))
88+
pint_flat = np.ascontiguousarray(state['air_pressure_on_interface_levels'].values.reshape(NLEV+1, -1))
89+
90+
print(f"\nGrayLongwaveRadiation ({NCOL} cols, {ITERS} iters)")
91+
t_py = wall(lambda: _gray_lw_kernel_np.py_func(t_flat, pint_flat, ts_flat, tau_flat, sigma), ITERS)
92+
t_nb = wall(lambda: _gray_lw_kernel_np(t_flat, pint_flat, ts_flat, tau_flat, sigma), ITERS)
93+
report("GrayLW", t_py, t_nb, ITERS, NCOL)
94+
95+
lat_flat = np.ascontiguousarray(state['latitude'].values.reshape(-1))
96+
ps_flat = np.ascontiguousarray(state['surface_air_pressure'].values.reshape(-1))
97+
tau0e, tau0p, fl = tau_comp._tau0e, tau_comp._tau0p, tau_comp._fl
98+
99+
print(f"\nFrierson06LongwaveOpticalDepth ({NCOL} cols, {ITERS} iters)")
100+
t_py = wall(lambda: _frierson_tau_kernel_np.py_func(lat_flat, pint_flat, ps_flat, tau0e, tau0p, fl), ITERS)
101+
t_nb = wall(lambda: _frierson_tau_kernel_np(lat_flat, pint_flat, ps_flat, tau0e, tau0p, fl), ITERS)
102+
report("Frierson tau", t_py, t_nb, ITERS, NCOL)
103+
104+
105+
def bench_gsc():
106+
from climt import GridScaleCondensation
107+
from climt._components.grid_scale_condensation import _gsc_kernel_np
108+
109+
ITERS = 20
110+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
111+
sympl.set_backend(sympl.DataArrayBackend())
112+
comp = GridScaleCondensation()
113+
state = get_default_state([comp], grid_state=grid)
114+
state['specific_humidity'].values[:] = 0.02
115+
comp(state, TIMESTEP) # trigger JIT
116+
117+
t_flat = np.ascontiguousarray(state['air_temperature'].values.reshape(NLEV, -1))
118+
q_flat = np.ascontiguousarray(state['specific_humidity'].values.reshape(NLEV, -1))
119+
p_flat = np.ascontiguousarray(state['air_pressure'].values.reshape(NLEV, -1))
120+
pint_flat = np.ascontiguousarray(state['air_pressure_on_interface_levels'].values.reshape(NLEV+1, -1))
121+
params = comp._params
122+
123+
print(f"\nGridScaleCondensation ({NCOL} cols, {ITERS} iters)")
124+
t_py = wall(lambda: _gsc_kernel_np.py_func(t_flat, q_flat, p_flat, pint_flat, params), ITERS)
125+
t_nb = wall(lambda: _gsc_kernel_np(t_flat, q_flat, p_flat, pint_flat, params), ITERS)
126+
report("GSC", t_py, t_nb, ITERS, NCOL)
127+
128+
129+
def bench_dry_convection():
130+
from climt import DryConvectiveAdjustment
131+
from climt._components.dry_convection.component import _dry_adj_kernel_np, DryAdjParams
132+
from sympl import get_constant
133+
134+
ITERS = 10
135+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
136+
sympl.set_backend(sympl.DataArrayBackend())
137+
comp = DryConvectiveAdjustment()
138+
state = get_default_state([comp], grid_state=grid)
139+
state['air_temperature'].values[:5] += 20.0
140+
comp(state, TIMESTEP) # trigger JIT
141+
142+
t_flat = np.ascontiguousarray(state['air_temperature'].values.reshape(NLEV, -1))
143+
q_flat = np.ascontiguousarray(state['specific_humidity'].values.reshape(NLEV, -1))
144+
p_flat = np.ascontiguousarray(state['air_pressure'].values.reshape(NLEV, -1))
145+
pint_flat = np.ascontiguousarray(state['air_pressure_on_interface_levels'].values.reshape(NLEV+1, -1))
146+
params = DryAdjParams(
147+
Cpd=float(get_constant("heat_capacity_of_dry_air_at_constant_pressure", "J/kg/degK")),
148+
Cvap=float(get_constant("heat_capacity_of_vapor_phase", "J/kg/K")),
149+
Rdair=float(get_constant("gas_constant_of_dry_air", "J/kg/degK")),
150+
Pref=float(get_constant("reference_air_pressure", "Pa")),
151+
Rv=float(get_constant("gas_constant_of_vapor_phase", "J/kg/K"))
152+
)
153+
154+
print(f"\nDryConvectiveAdjustment ({NCOL} cols, {ITERS} iters)")
155+
t_py = wall(lambda: _dry_adj_kernel_np.py_func(t_flat, q_flat, p_flat, pint_flat, params), ITERS)
156+
t_nb = wall(lambda: _dry_adj_kernel_np(t_flat, q_flat, p_flat, pint_flat, params), ITERS)
157+
report("DryConv", t_py, t_nb, ITERS, NCOL)
158+
159+
160+
def bench_berger():
161+
from climt import BergerSolarInsolation
162+
from climt._components.berger_solar_insolation import _get_solar_parameters_np
163+
from datetime import datetime
164+
165+
ITERS = 100
166+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
167+
sympl.set_backend(sympl.DataArrayBackend())
168+
comp = BergerSolarInsolation()
169+
state = get_default_state([comp], grid_state=grid)
170+
state['time'] = datetime(2000, 6, 21)
171+
comp(state) # trigger JIT
172+
173+
year = state['time'].year
174+
lambda_m0, ecc, omega, obl = comp._orbital_parameters[year]
175+
from climt._components.berger_solar_insolation import years_since_vernal_equinox
176+
yve = years_since_vernal_equinox(state['time'])
177+
frac_day = (state['time'].hour + state['time'].minute / 60.) / 24.
178+
lat_flat = np.ascontiguousarray(state['latitude'].values.reshape(-1).astype(np.float64))
179+
lon_flat = np.ascontiguousarray(state['longitude'].values.reshape(-1).astype(np.float64))
180+
solar_const = float(climt.get_default_state([comp], grid_state=grid)['solar_constant'].values.flat[0]) if 'solar_constant' in state else 1367.0
181+
182+
from sympl import get_constant
183+
solar_const = float(get_constant('stellar_irradiance', 'W/m^2'))
184+
185+
print(f"\nBergerSolarInsolation ({NCOL} cols, {ITERS} iters)")
186+
t_py = wall(lambda: _get_solar_parameters_np.py_func(lambda_m0, ecc, omega, obl, yve, frac_day, lat_flat, lon_flat, solar_const), ITERS)
187+
t_nb = wall(lambda: _get_solar_parameters_np(lambda_m0, ecc, omega, obl, yve, frac_day, lat_flat, lon_flat, solar_const), ITERS)
188+
report("Berger", t_py, t_nb, ITERS, NCOL)
189+
190+
191+
def bench_slab_surface():
192+
from climt import SlabSurface
193+
from climt._components.slab_surface import _slab_surface_kernel_np
194+
195+
ITERS = 50
196+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
197+
sympl.set_backend(sympl.DataArrayBackend())
198+
comp = SlabSurface()
199+
state = get_default_state([comp], grid_state=grid)
200+
comp(state) # trigger JIT
201+
202+
def _flat(key):
203+
return np.ascontiguousarray(state[key].values.reshape(-1))
204+
205+
def _surf(key):
206+
v = state[key].values
207+
return np.ascontiguousarray((v[0] if v.ndim > 2 else v).reshape(-1))
208+
209+
area_type_raw = state['area_type'].values.reshape(-1)
210+
area_map = {'land': 0, 'land_ice': 1, 'sea': 2, 'sea_ice': 3}
211+
area_code = np.array([area_map.get(str(a), 2) for a in area_type_raw], dtype=np.float64)
212+
213+
sw_down = _surf('downwelling_shortwave_flux_in_air')
214+
lw_down = _surf('downwelling_longwave_flux_in_air')
215+
sw_up = _surf('upwelling_shortwave_flux_in_air')
216+
lw_up = _surf('upwelling_longwave_flux_in_air')
217+
lh = _flat('surface_upward_latent_heat_flux')
218+
sh = _flat('surface_upward_sensible_heat_flux')
219+
up_soil = _flat('upward_heat_flux_at_ground_level_in_soil')
220+
hf_ice = _flat('heat_flux_into_sea_water_due_to_sea_ice')
221+
sw_dens = _flat('sea_water_density')
222+
sf_dens = _flat('surface_material_density')
223+
hc_soil = _flat('heat_capacity_of_soil')
224+
stc = _flat('surface_thermal_capacity')
225+
omt = _flat('ocean_mixed_layer_thickness')
226+
slt = _flat('soil_layer_thickness')
227+
228+
print(f"\nSlabSurface ({NCOL} cols, {ITERS} iters)")
229+
t_py = wall(lambda: _slab_surface_kernel_np.py_func(sw_down, lw_down, sw_up, lw_up, lh, sh, area_code, up_soil, hf_ice, sw_dens, sf_dens, hc_soil, stc, omt, slt), ITERS)
230+
t_nb = wall(lambda: _slab_surface_kernel_np(sw_down, lw_down, sw_up, lw_up, lh, sh, area_code, up_soil, hf_ice, sw_dens, sf_dens, hc_soil, stc, omt, slt), ITERS)
231+
report("SlabSurface", t_py, t_nb, ITERS, NCOL)
232+
233+
234+
def bench_instellation():
235+
from climt import Instellation
236+
from climt._components.instellation.component import (
237+
_instellation_kernel_np, days_from_2000, fractional_day
238+
)
239+
from datetime import datetime
240+
241+
ITERS = 100
242+
grid = get_grid(nx=NCOL, ny=1, nz=NLEV)
243+
sympl.set_backend(sympl.DataArrayBackend())
244+
comp = Instellation()
245+
state = get_default_state([comp], grid_state=grid)
246+
state['time'] = datetime(2000, 6, 21, 12)
247+
comp(state) # trigger JIT
248+
249+
lat_flat = np.ascontiguousarray(state['latitude'].values.reshape(-1).astype(np.float64))
250+
lon_flat = np.ascontiguousarray(state['longitude'].values.reshape(-1).astype(np.float64))
251+
252+
t = state['time']
253+
julian_centuries = days_from_2000(t) / 36525.0
254+
frac_day = fractional_day(t)
255+
256+
print(f"\nInstellation ({NCOL} cols, {ITERS} iters)")
257+
args = (lat_flat, lon_flat, julian_centuries, frac_day)
258+
t_py = wall(lambda: _instellation_kernel_np.py_func(*args), ITERS)
259+
t_nb = wall(lambda: _instellation_kernel_np(*args), ITERS)
260+
report("Instellation", t_py, t_nb, ITERS, NCOL)
261+
262+
263+
if __name__ == "__main__":
264+
print(f"Numba speedup benchmark — {NCOL} columns, {NLEV} levels")
265+
print("=" * 55)
266+
bench_held_suarez()
267+
bench_gray_radiation()
268+
bench_gsc()
269+
bench_dry_convection()
270+
bench_berger()
271+
bench_slab_surface()
272+
bench_instellation()
273+
print()

0 commit comments

Comments
 (0)