mx.fast.metal_kernel has no way to pass compiler flags. The kernel source is compiled with MLX's default math mode. Bare exp() calls in the kernel source resolve to metal::fast::exp, which may not handle exp(-INFINITY) = 0.0 correctly under relaxed IEEE assumptions.
Any kernel implementing softmax with masked attention (causal, sliding-window) relies on exp(-∞) = 0. In fast math mode this is not guaranteed. The only current workaround is to explicitly call metal::precise::exp(...) in the kernel source — which works but requires users to know this undocumented constraint.
Proposed API addition:
kernel = mx.fast.metal_kernel(
name="my_kernel",
input_names=["x"],
output_names=["y"],
source=src,
compile_options=["-fmetal-math-mode=relaxed"], # NEW
)
Alternatively, a math_mode enum parameter ("fast" / "relaxed" / "safe").
mx.fast.metal_kernelhas no way to pass compiler flags. The kernel source is compiled with MLX's default math mode. Bareexp()calls in the kernel source resolve tometal::fast::exp, which may not handleexp(-INFINITY) = 0.0correctly under relaxed IEEE assumptions.Any kernel implementing softmax with masked attention (causal, sliding-window) relies on exp(-∞) = 0. In fast math mode this is not guaranteed. The only current workaround is to explicitly call metal::precise::exp(...) in the kernel source — which works but requires users to know this undocumented constraint.
Proposed API addition:
Alternatively, a math_mode enum parameter ("fast" / "relaxed" / "safe").