@@ -405,7 +405,7 @@ def get_decoder_layers(self):
405405 case DecoderBlockType .GEMMA3 :
406406 return [gemma3 .Gemma3DecoderLayerToLinen ]
407407 case DecoderBlockType .GPT3 :
408- return [gpt3 .Gpt3DecoderLayer ]
408+ return [gpt3 .Gpt3DecoderLayerToLinen ]
409409 case DecoderBlockType .GPT_OSS :
410410 return [gpt_oss .GptOssScannableBlockToLinen ] if self .config .scan_layers else [gpt_oss .GptOssDecoderLayerToLinen ]
411411 case DecoderBlockType .QWEN3 :
@@ -584,7 +584,7 @@ def _apply_embedding(
584584 name = "position_embedder" ,
585585 config = cfg ,
586586 mesh = self .mesh ,
587- )(decoder_positions , model_mode = model_mode )
587+ )(decoder_positions . astype ( "int32" ) , model_mode = model_mode )
588588 return y
589589
590590 @nn .compact
@@ -837,9 +837,7 @@ def __call__(
837837 # Iterate over the two layer groups (dense and MoE) and apply layer transformation
838838 for layer , num_layers , layer_prefix in zip (layers , num_layers_list , layer_prefixes ):
839839 for index in range (num_layers ):
840- y = layer (
841- config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode
842- )(
840+ y = layer (config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode )(
843841 y ,
844842 decoder_segment_ids ,
845843 decoder_positions ,
0 commit comments