Skip to content

mx.fast.metal_kernel: Add support for compiler options (-fmetal-math-mode, integer template parameters, Metal 4 Tensor types #3592

@rajveer43

Description

@rajveer43

mx.fast.metal_kernel has no way to pass compiler flags. The kernel source is compiled with MLX's default math mode. Bare exp() calls in the kernel source resolve to metal::fast::exp, which may not handle exp(-INFINITY) = 0.0 correctly under relaxed IEEE assumptions.

Any kernel implementing softmax with masked attention (causal, sliding-window) relies on exp(-∞) = 0. In fast math mode this is not guaranteed. The only current workaround is to explicitly call metal::precise::exp(...) in the kernel source — which works but requires users to know this undocumented constraint.

Proposed API addition:

kernel = mx.fast.metal_kernel(
    name="my_kernel",
    input_names=["x"],
    output_names=["y"],
    source=src,
    compile_options=["-fmetal-math-mode=relaxed"],   # NEW
)

Alternatively, a math_mode enum parameter ("fast" / "relaxed" / "safe").

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions