We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents b8d411a + c47ace2 commit 0f879f0Copy full SHA for 0f879f0
e3nn_jax/_src/utils/jit.py
@@ -7,7 +7,7 @@ def jit_code(f, *args, **kwargs):
7
import jaxlib.xla_extension as xla_ext
8
9
f_jax = jax.jit(f)
10
- jax_comp = f_jax.lower(*args, **kwargs).compiler_ir(dialect="mhlo")
+ jax_comp = f_jax.lower(*args, **kwargs).compiler_ir(dialect="stablehlo")
11
jax_hlo = str(jax_comp)
12
backend = xla_bridge.get_backend()
13
jax_optimized_hlo = backend.compile(jax_hlo)
0 commit comments