Skip to content

Commit 0f879f0

Browse files
committed
Merge branch 'main' into so3
2 parents b8d411a + c47ace2 commit 0f879f0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

e3nn_jax/_src/utils/jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def jit_code(f, *args, **kwargs):
77
import jaxlib.xla_extension as xla_ext
88

99
f_jax = jax.jit(f)
10-
jax_comp = f_jax.lower(*args, **kwargs).compiler_ir(dialect="mhlo")
10+
jax_comp = f_jax.lower(*args, **kwargs).compiler_ir(dialect="stablehlo")
1111
jax_hlo = str(jax_comp)
1212
backend = xla_bridge.get_backend()
1313
jax_optimized_hlo = backend.compile(jax_hlo)

0 commit comments

Comments
 (0)