Skip to content

Commit ec6e822

Browse files
committed
Implement approx_tanh for ROCm using OCML tanh function
AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction like NVIDIA's PTX tanh.approx. This implements approx_tanh for ROCm using: - For f32 (and f16/bf16 via casting): Triton's __triton_hip_fast_tanhf which uses a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) - For f64: OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) Changes: - Add f64 support to approx_tanh function - Add ROCm platform detection in _elementwise_inline_asm_lowering - Add _approx_tanh_rocm_lowering function for ROCm-specific lowering - Add test_approx_tanh test with f16/bf16/f32/f64 support See: triton-lang/triton#7780
1 parent 09e023e commit ec6e822

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

jax/_src/pallas/triton/primitives.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def approx_tanh(x: jax.Array) -> jax.Array:
4848
elif x.dtype == jnp.float32:
4949
asm = "tanh.approx.f32 $0, $1;"
5050
constraint = "f"
51+
elif x.dtype == jnp.float64:
52+
# f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64)
53+
# CUDA does not have a PTX instruction for f64 approximate tanh
54+
asm = "tanh.approx.f64 $0, $1;"
55+
constraint = "d"
5156
else:
5257
raise TypeError(f"approx_tanh does not accept {x.dtype} arrays")
5358

@@ -119,6 +124,13 @@ def _elementwise_inline_asm_lowering(
119124
result_shape_dtypes,
120125
):
121126
del result_shape_dtypes # Unused.
127+
128+
# For ROCm, PTX inline assembly is not supported. For tanh.approx, we use
129+
# Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and
130+
# OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780
131+
if ctx.context.platform == "rocm" and "tanh.approx" in asm:
132+
return _approx_tanh_rocm_lowering(ctx, *args)
133+
122134
return tt_dialect.ElementwiseInlineAsmOp(
123135
[*map(mlir.aval_to_ir_type, ctx.avals_out)],
124136
asm,
@@ -129,6 +141,86 @@ def _elementwise_inline_asm_lowering(
129141
).result
130142

131143

144+
def _approx_tanh_rocm_lowering(
145+
ctx: lowering.LoweringRuleContext,
146+
*args,
147+
):
148+
"""Lower approx_tanh for ROCm.
149+
150+
AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction.
151+
152+
For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf
153+
which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
154+
See: https://github.com/triton-lang/triton/pull/7780
155+
156+
For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library)
157+
since fast_tanhf only supports f32.
158+
"""
159+
from jax._src.lib.mlir import ir
160+
from jax._src.lib.mlir.dialects import arith as arith_dialect
161+
162+
[arg] = args
163+
[out_aval] = ctx.avals_out
164+
in_dtype = ctx.avals_in[0].dtype
165+
166+
# Helper to get IR type for a dtype
167+
def dtype_to_ir_type(dtype):
168+
dtype = jnp.dtype(dtype)
169+
return mlir.dtype_to_ir_type(dtype)
170+
171+
# f64: use __ocml_tanh_f64 (fast_tanhf only supports f32)
172+
if in_dtype == jnp.float64:
173+
result_type = mlir.aval_to_ir_type(out_aval)
174+
result = tt_dialect.extern_elementwise(
175+
result_type,
176+
list(args),
177+
libname="",
178+
libpath="",
179+
symbol="__ocml_tanh_f64",
180+
pure=True,
181+
)
182+
return [result]
183+
184+
# fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back.
185+
needs_cast = in_dtype in (jnp.float16, jnp.bfloat16)
186+
187+
if needs_cast:
188+
# Cast input to f32 (extend)
189+
f32_type = dtype_to_ir_type(jnp.float32)
190+
if out_aval.shape:
191+
f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type)
192+
else:
193+
f32_result_type = f32_type
194+
arg_f32 = arith_dialect.extf(f32_result_type, arg)
195+
196+
# Call __triton_hip_fast_tanhf (fast exp-based implementation)
197+
tanh_result = tt_dialect.extern_elementwise(
198+
f32_result_type,
199+
[arg_f32],
200+
libname="libdevice",
201+
libpath="",
202+
symbol="__triton_hip_fast_tanhf",
203+
pure=True,
204+
)
205+
206+
# Cast result back to original dtype (truncate)
207+
out_type = mlir.aval_to_ir_type(out_aval)
208+
result = arith_dialect.truncf(out_type, tanh_result)
209+
else:
210+
# f32: call __triton_hip_fast_tanhf directly
211+
result_type = mlir.aval_to_ir_type(out_aval)
212+
result = tt_dialect.extern_elementwise(
213+
result_type,
214+
list(args),
215+
libname="libdevice",
216+
libpath="",
217+
symbol="__triton_hip_fast_tanhf",
218+
pure=True,
219+
)
220+
221+
return [result]
222+
223+
132224
def debug_barrier() -> None:
133225
"""Synchronizes all kernel executions in the grid."""
134226
return debug_barrier_p.bind()

tests/pallas/ops_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,47 @@ def kernel(o_ref):
18701870

18711871
np.testing.assert_allclose(f(), kernel())
18721872

1873+
@parameterized.parameters("float16", "bfloat16", "float32", "float64")
1874+
def test_approx_tanh(self, dtype):
1875+
self.skip_if_mosaic_gpu()
1876+
1877+
if jtu.test_device_matches(["tpu"]):
1878+
self.skipTest("Not implemented on TPU")
1879+
1880+
if self.INTERPRET:
1881+
self.skipTest("approx_tanh is not supported in interpret mode")
1882+
1883+
if (dtype == "bfloat16" and
1884+
jtu.test_device_matches(["cuda"]) and
1885+
not jtu.is_cuda_compute_capability_at_least("9.0")):
1886+
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")
1887+
1888+
if dtype == "float64":
1889+
if jtu.test_device_matches(["cuda"]):
1890+
self.skipTest("f64 approx_tanh is only supported on ROCm")
1891+
1892+
# Enable x64 for f64 test if not already enabled, restore after test
1893+
original_x64 = jax.config.x64_enabled
1894+
if dtype == "float64" and not original_x64:
1895+
jax.config.update("jax_enable_x64", True)
1896+
self.addCleanup(lambda: jax.config.update("jax_enable_x64", False))
1897+
1898+
@functools.partial(
1899+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype),
1900+
)
1901+
def kernel(x_ref, o_ref):
1902+
o_ref[...] = plgpu_triton.approx_tanh(x_ref[...])
1903+
1904+
x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype)
1905+
# We upcast to float32 because NumPy <2.0 does not handle custom dtypes
1906+
# properly. See https://github.com/jax-ml/jax/issues/11014.
1907+
np.testing.assert_allclose(
1908+
kernel(x).astype(jnp.float32),
1909+
jnp.tanh(x).astype(jnp.float32),
1910+
atol=5e-3,
1911+
rtol=5e-3,
1912+
)
1913+
18731914
@parameterized.parameters(
18741915
((2, 4), (8,)),
18751916
((2, 4), (8, 1)),

0 commit comments

Comments
 (0)