Skip to content

Commit d16b6c0

Browse files
author
Aegis AI Assistant
committed
fix(gemma4): implement float32 safe RMSNorm wrapper, resolving NaN overflow shattering on Apple Silicon without compromising base magnitudes
1 parent 112c45e commit d16b6c0

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

Libraries/MLXLLM/Models/Gemma4.swift

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@ public class Gemma4RMSNorm: Module, UnaryLayer {
2121
}
2222

2323
public func callAsFunction(_ x: MLXArray) -> MLXArray {
24-
// vLLM: standard RMSNorm — weight * x / sqrt(E[x^2] + eps)
25-
// Gemma 4 norm weights are trained scale factors (~8-10), NOT zero-init offsets.
26-
return MLXFast.rmsNorm(x, weight: self.weight, eps: self.eps)
24+
// Evaluate in Float32 to prevent Float16 infinity overflow on M1/M2 architectures.
25+
// Vision models like SigLIP/Qwen inject imageFeatures at magnitude ~300.
26+
// 300^2 = 90,000, which overflows MLX .float16 max of 65504 causing NaNs if not isolated.
27+
let originalType = x.dtype
28+
let xF32 = x.asType(.float32)
29+
let variance = MLX.mean(MLX.square(xF32), axis: -1, keepDims: true)
30+
let rsqrtVar = MLX.rsqrt(variance + eps)
31+
32+
// Weight is applied after to the original type to avoid type mismatch
33+
return (self.weight * (xF32 * rsqrtVar).asType(originalType))
2734
}
2835
}
2936

0 commit comments

Comments
 (0)