@@ -8590,12 +8590,12 @@ def patched_qwen3_5_moe_sparse_moe_block(self, hidden_states: torch.Tensor) -> t
85908590 hidden_expanded = hidden_states_reshaped .unsqueeze (0 ).expand (num_experts , - 1 , - 1 )
85918591
85928592 # Vectorized expert computation using pre-transposed weights
8593- gate_up = torch .bmm (hidden_expanded , self ._gate_up_projs_t )
8593+ gate_up = torch .bmm (hidden_expanded , self ._gate_up_projs_t . to ( hidden_expanded . dtype ) )
85948594 intermediate_size = self .experts .intermediate_dim
85958595 gate = gate_up [:, :, :intermediate_size ]
85968596 up = gate_up [:, :, intermediate_size :]
85978597 activated = self .experts .act_fn (gate ) * up
8598- next_states = torch .bmm (activated , self ._down_projs_t )
8598+ next_states = torch .bmm (activated , self ._down_projs_t . to ( activated . dtype ) )
85998599
86008600 # Weight by routing and sum over experts
86018601 next_states = next_states * new_routing_weights .T .unsqueeze (- 1 )
@@ -8914,6 +8914,26 @@ def __enter__(self):
89148914 patched_qwen3_5_moe_sparse_moe_block , sparse_moe_block
89158915 )
89168916
8917+ def post_make_16bit_traceable (self ):
8918+ """Free duplicated expert weights after __make_16bit_traceable.
8919+
8920+ __make_16bit_traceable calls module.float() on Qwen3_5MoeExperts modules,
8921+ creating fp32 copies of gate_up_proj and down_proj parameters. Since patcher
8922+ already captured bf16 views (_gate_up_projs_t, _down_projs_t) for the patched
8923+ forward, the fp32 copies are unused waste. Free them to avoid OOM.
8924+ """
8925+ import gc
8926+
8927+ import torch
8928+ from transformers .models .qwen3_5_moe .modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
8929+
8930+ for decoder_layer in self ._model .model .layers :
8931+ if isinstance (decoder_layer .mlp , Qwen3_5MoeSparseMoeBlock ):
8932+ experts = decoder_layer .mlp .experts
8933+ experts .gate_up_proj .data = torch .empty (0 )
8934+ experts .down_proj .data = torch .empty (0 )
8935+ gc .collect ()
8936+
89178937 def __exit__ (self , exc_type , exc_value , traceback ):
89188938 from transformers .models .qwen3_5_moe .modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
89198939
@@ -9025,6 +9045,8 @@ def has_previous_state(self):
90259045 layer_idx = self .linear_attn_mapping [self .last_linear_layer ]
90269046 return self .conv_states [layer_idx ] is not None
90279047
9048+ _lm_head_weight = model .lm_head .weight
9049+
90289050 def patched_forward (
90299051 inputs_embeds ,
90309052 attention_mask = None ,
@@ -9063,7 +9085,7 @@ def patched_forward(
90639085 use_cache = use_cache ,
90649086 )
90659087 hidden_states = outputs [0 ]
9066- logits = model . lm_head (hidden_states )
9088+ logits = torch . nn . functional . linear (hidden_states , _lm_head_weight . to ( hidden_states . dtype ) )
90679089
90689090 result = {"logits" : logits }
90699091
@@ -9178,6 +9200,8 @@ def has_previous_state(self):
91789200 layer_idx = self .linear_attn_mapping [self .last_linear_layer ]
91799201 return self .conv_states [layer_idx ] is not None
91809202
9203+ _lm_head_weight = model .lm_head .weight
9204+
91819205 def patched_forward (
91829206 inputs_embeds ,
91839207 attention_mask = None ,
@@ -9216,7 +9240,7 @@ def patched_forward(
92169240 use_cache = use_cache ,
92179241 )
92189242 hidden_states = outputs [0 ]
9219- logits = model . lm_head (hidden_states )
9243+ logits = torch . nn . functional . linear (hidden_states , _lm_head_weight . to ( hidden_states . dtype ) )
92209244
92219245 result = {"logits" : logits }
92229246
@@ -9271,6 +9295,26 @@ def __enter__(self):
92719295 patched_qwen3_5_moe_sparse_moe_block , sparse_moe_block
92729296 )
92739297
9298+ def post_make_16bit_traceable (self ):
9299+ """Free duplicated expert weights after __make_16bit_traceable.
9300+
9301+ __make_16bit_traceable calls module.float() on Qwen3_5MoeExperts modules,
9302+ creating fp32 copies of gate_up_proj and down_proj parameters. Since patcher
9303+ already captured bf16 views (_gate_up_projs_t, _down_projs_t) for the patched
9304+ forward, the fp32 copies are unused waste. Free them to avoid OOM.
9305+ """
9306+ import gc
9307+
9308+ import torch
9309+ from transformers .models .qwen3_5_moe .modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
9310+
9311+ for decoder_layer in self ._model .model .language_model .layers :
9312+ if isinstance (decoder_layer .mlp , Qwen3_5MoeSparseMoeBlock ):
9313+ experts = decoder_layer .mlp .experts
9314+ experts .gate_up_proj .data = torch .empty (0 )
9315+ experts .down_proj .data = torch .empty (0 )
9316+ gc .collect ()
9317+
92749318 def __exit__ (self , exc_type , exc_value , traceback ):
92759319 from transformers .models .qwen3_5_moe .modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
92769320
0 commit comments