File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -21,9 +21,16 @@ public class Gemma4RMSNorm: Module, UnaryLayer {
2121 }
2222
2323 public func callAsFunction( _ x: MLXArray ) -> MLXArray {
24- // vLLM: standard RMSNorm — weight * x / sqrt(E[x^2] + eps)
25- // Gemma 4 norm weights are trained scale factors (~8-10), NOT zero-init offsets.
26- return MLXFast . rmsNorm ( x, weight: self . weight, eps: self . eps)
24+ // Evaluate in Float32 to prevent Float16 infinity overflow on M1/M2 architectures.
25+ // Vision models like SigLIP/Qwen inject imageFeatures at magnitude ~300.
26+ // 300^2 = 90,000, which overflows MLX .float16 max of 65504 causing NaNs if not isolated.
27+ let originalType = x. dtype
28+ let xF32 = x. asType ( . float32)
29+ let variance = MLX . mean ( MLX . square ( xF32) , axis: - 1 , keepDims: true )
30+ let rsqrtVar = MLX . rsqrt ( variance + eps)
31+
32+ // Weight is applied after to the original type to avoid type mismatch
33+ return ( self . weight * ( xF32 * rsqrtVar) . asType ( originalType) )
2734 }
2835}
2936
You can’t perform that action at this time.
0 commit comments