@@ -411,7 +411,7 @@ def get_decoder_layers(self):
411411 case DecoderBlockType .GEMMA3 :
412412 return [gemma3 .Gemma3DecoderLayerToLinen ]
413413 case DecoderBlockType .GPT3 :
414- return [gpt3 .Gpt3DecoderLayer ]
414+ return [gpt3 .Gpt3DecoderLayerToLinen ]
415415 case DecoderBlockType .GPT_OSS :
416416 return [gpt_oss .GptOssScannableBlockToLinen ] if self .config .scan_layers else [gpt_oss .GptOssDecoderLayerToLinen ]
417417 case DecoderBlockType .QWEN3 :
@@ -590,7 +590,7 @@ def _apply_embedding(
590590 name = "position_embedder" ,
591591 config = cfg ,
592592 mesh = self .mesh ,
593- )(decoder_positions , model_mode = model_mode )
593+ )(decoder_positions . astype ( "int32" ) , model_mode = model_mode )
594594 return y
595595
596596 @nn .compact
@@ -843,9 +843,7 @@ def __call__(
843843 # Iterate over the two layer groups (dense and MoE) and apply layer transformation
844844 for layer , num_layers , layer_prefix in zip (layers , num_layers_list , layer_prefixes ):
845845 for index in range (num_layers ):
846- y = layer (
847- config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode
848- )(
846+ y = layer (config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode )(
849847 y ,
850848 decoder_segment_ids ,
851849 decoder_positions ,
0 commit comments