1- from sympl import get_constant , Stepper
2- import numpy as np
31from typing import NamedTuple
2+
3+ import numpy as np
4+ from sympl import Stepper , get_constant
5+
46from ..._core .backend import jit_compile , prange
57
6- try :
7- from numba import njit
8- HAS_NUMBA = True
9- except ImportError :
10- HAS_NUMBA = False
11- njit = lambda x , ** kwargs : x
128
139class DryAdjParams (NamedTuple ):
14- Cpd : float ; Cvap : float ; Rdair : float ; Pref : float ; Rv : float
10+ Cpd : float
11+ Cvap : float
12+ Rdair : float
13+ Pref : float
14+ Rv : float
15+
1516
1617class DryConvectiveAdjustment (Stepper ):
1718 """
@@ -51,8 +52,10 @@ class DryConvectiveAdjustment(Stepper):
5152 diagnostic_properties = {}
5253
5354 def array_call (self , state , time_step ):
54- t = state ["air_temperature" ]; q = state ["specific_humidity" ]
55- p = state ["air_pressure" ]; p_int = state ["P_int" ]
55+ t = state ["air_temperature" ]
56+ q = state ["specific_humidity" ]
57+ p = state ["air_pressure" ]
58+ p_int = state ["P_int" ]
5659
5760 orig_shape = t .shape
5861 t_flat = np .reshape (t , (t .shape [0 ], - 1 ))
@@ -61,21 +64,24 @@ def array_call(self, state, time_step):
6164 p_int_flat = np .reshape (p_int , (p_int .shape [0 ], - 1 ))
6265
6366 params = DryAdjParams (
64- Cpd = get_constant ("heat_capacity_of_dry_air_at_constant_pressure" , "J/kg/degK" ),
67+ Cpd = get_constant (
68+ "heat_capacity_of_dry_air_at_constant_pressure" , "J/kg/degK"
69+ ),
6570 Cvap = get_constant ("heat_capacity_of_vapor_phase" , "J/kg/K" ),
6671 Rdair = get_constant ("gas_constant_of_dry_air" , "J/kg/degK" ),
6772 Pref = get_constant ("reference_air_pressure" , "Pa" ),
68- Rv = get_constant ("gas_constant_of_vapor_phase" , "J/kg/K" )
73+ Rv = get_constant ("gas_constant_of_vapor_phase" , "J/kg/K" ),
6974 )
7075
7176 t_new , q_new = _dry_adj_kernel_np (t_flat , q_flat , p_flat , p_int_flat , params )
7277
7378 return {}, {
7479 "air_temperature" : np .reshape (t_new , orig_shape ),
75- "specific_humidity" : np .reshape (q_new , orig_shape )
80+ "specific_humidity" : np .reshape (q_new , orig_shape ),
7681 }
7782
78- @njit
83+
84+ @jit_compile
7985def _dry_adj_kernel_np (T , q , p , p_int , params ):
8086 nlev , ncol = T .shape
8187 T_new = T .copy ()
@@ -88,18 +94,21 @@ def _dry_adj_kernel_np(T, q, p, p_int, params):
8894 p_int_col = p_int [:, i ]
8995 pdiff = np .zeros (nlev )
9096 for m in range (nlev ):
91- pdiff [m ] = p_int_col [m ] - p_int_col [m + 1 ]
97+ pdiff [m ] = p_int_col [m ] - p_int_col [m + 1 ]
9298
9399 # TOA to Surface
94100 for k in range (nlev - 1 , - 1 , - 1 ):
95-
96101 rd_cp = np .zeros (nlev )
97102 theta_q = np .zeros (nlev )
98103 for m in range (nlev ):
99- rd_cp [m ] = (params .Rdair * (1.0 - q_new [m , i ]) + params .Rv * q_new [m , i ]) / \
100- (params .Cpd * (1.0 - q_new [m , i ]) + params .Cvap * q_new [m , i ])
101- theta_q [m ] = T_new [m , i ] * (params .Pref / p_col [m ]) ** rd_cp [m ] * \
102- (1.0 + q_new [m , i ] * eps - q_new [m , i ])
104+ rd_cp [m ] = (
105+ params .Rdair * (1.0 - q_new [m , i ]) + params .Rv * q_new [m , i ]
106+ ) / (params .Cpd * (1.0 - q_new [m , i ]) + params .Cvap * q_new [m , i ])
107+ theta_q [m ] = (
108+ T_new [m , i ]
109+ * (params .Pref / p_col [m ]) ** rd_cp [m ]
110+ * (1.0 + q_new [m , i ] * eps - q_new [m , i ])
111+ )
103112
104113 current_theta_sum = 0.0
105114 max_unstable_idx = - 1
0 commit comments