@@ -229,21 +229,31 @@ def _with_real_sm_count(self):
229229 else :
230230 yield
231231
232- @torch .compile (dynamic = True ) if not _is_hip else lambda f : f
233- def _project_and_scale_head_gates (self , x : torch .Tensor ):
232+ def _weights_proj_bf16_in_fp32_out (self , x : torch .Tensor ) -> torch .Tensor :
233+ if _is_cuda and hasattr (deep_gemm , "bf16_gemm_nt" ):
234+ weight = self .weights_proj .weight
235+ out = torch .empty (
236+ (x .shape [0 ], weight .shape [0 ]),
237+ dtype = torch .float32 ,
238+ device = x .device ,
239+ )
240+ deep_gemm .bf16_gemm_nt (x , weight , out )
241+ return out
242+
234243 if _is_hip :
235244 x = x .to (self .weights_proj .weight .dtype )
236245 weights , _ = self .weights_proj (x )
237- weights = weights .float ()
246+ return weights .float ()
247+
248+ @torch .compile (dynamic = True ) if not _is_hip else lambda f : f
249+ def _project_and_scale_head_gates (self , x : torch .Tensor ):
250+ weights = self ._weights_proj_bf16_in_fp32_out (x )
238251 weights = weights * self .n_heads ** - 0.5
239252 return weights
240253
241254 @torch .compile (dynamic = True ) if not _is_hip else lambda f : f
242255 def _get_logits_head_gate (self , x : torch .Tensor , q_scale : torch .Tensor ):
243- if _is_hip :
244- x = x .to (self .weights_proj .weight .dtype )
245- weights , _ = self .weights_proj (x )
246- weights = weights .float ()
256+ weights = self ._weights_proj_bf16_in_fp32_out (x )
247257 weights = weights * self .n_heads ** - 0.5
248258 weights = weights .unsqueeze (- 1 ) * q_scale * self .softmax_scale
249259 return weights
0 commit comments