Skip to content

Commit 2e96ab0

Browse files
Nush395Torax team
authored andcommitted
Move usages of functools.partial(jax.jit, ...) to use decorator factory pattern.
PiperOrigin-RevId: 836233292
1 parent 8a206ab commit 2e96ab0

File tree

7 files changed

+7
-13
lines changed

7 files changed

+7
-13
lines changed

docs/jax_classes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Designing JAX-compatible classes
8787
import functools
8888
import jax
8989

90-
@functools.partial(jax.jit, static_argnums=0)
90+
@jax.jit(static_argnums=0)
9191
def f(x):
9292
return x.num_wheels
9393

torax/_src/core_profiles/updaters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
the next timestep and returns updates to State.
2929
"""
3030
import dataclasses
31-
import functools
3231

3332
import jax
3433
from jax import numpy as jnp
@@ -135,7 +134,7 @@ def get_prescribed_core_profile_values(
135134
}
136135

137136

138-
@functools.partial(jax.jit, static_argnames=['evolving_names'])
137+
@jax.jit(static_argnames='evolving_names')
139138
def update_core_profiles_during_step(
140139
x_new: tuple[cell_variable.CellVariable, ...],
141140
runtime_params: runtime_params_lib.RuntimeParams,

torax/_src/orchestration/jit_run_loop.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414

1515
"""JITted run_loop for iterating over the simulation step function."""
16-
import functools
17-
1816
import chex
1917
import jax
2018
import jax.numpy as jnp
@@ -27,7 +25,7 @@
2725
from torax._src.output_tools import post_processing
2826

2927

30-
@functools.partial(jax.jit, static_argnames=['max_steps'])
28+
@jax.jit(static_argnames='max_steps')
3129
def run_loop_jit(
3230
step_fn: step_function.SimulationStepFn,
3331
max_steps: int,

torax/_src/solver/tests/jax_root_finding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_root_newton_raphson_basic(self):
5252
x_init = np.array([0.0, 0.0], dtype=dtype)
5353
sol_np = optimize.root(f_closed, [0, 0], tol=tol)
5454

55-
@functools.partial(jax.jit, static_argnames=['tol', 'maxiter'])
55+
@jax.jit(static_argnames=['tol', 'maxiter'])
5656
def root_jax(x, tol, maxiter):
5757
return jax_root_finding.root_newton_raphson(
5858
f_closed, x, tol=tol, maxiter=maxiter

torax/_src/sources/ion_cyclotron_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def __eq__(self, other: typing_extensions.Self) -> bool:
266266
return isinstance(other, ToricNNWrapper)
267267

268268

269-
@functools.partial(jax.jit, static_argnames='toric_nn')
269+
@jax.jit(static_argnames='toric_nn')
270270
def _toric_nn_predict(
271271
toric_nn: ToricNNWrapper,
272272
inputs: ToricNNInputs,

torax/_src/tests/jax_utils_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
import functools
1614
import os
1715
from unittest import mock
1816

@@ -134,7 +132,7 @@ def f(x: jax.Array):
134132
@parameterized.parameters(['while_loop', 'pure_callback'])
135133
def test_non_inlined_function(self, implementation):
136134

137-
@functools.partial(jax.jit, static_argnames=['z'])
135+
@jax.jit(static_argnames='z')
138136
def f(x, z, y=2.0):
139137
if z == 'left':
140138
return x['temp1'] * y + jnp.sin(x['temp2'])

torax/_src/transport_model/transport_coefficients_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Code to build the combined transport coefficients for a simulation."""
1616
import dataclasses
17-
import functools
1817

1918
import jax
2019
from torax._src import state
@@ -25,7 +24,7 @@
2524
from torax._src.transport_model import transport_model as transport_model_lib
2625

2726

28-
@functools.partial(jax.jit, static_argnums=(0, 1, 2))
27+
@jax.jit(static_argnums=(0, 1, 2))
2928
def calculate_total_transport_coeffs(
3029
pedestal_model: pedestal_model_lib.PedestalModel,
3130
transport_model: transport_model_lib.TransportModel,

0 commit comments

Comments
 (0)