Skip to content

Commit 966ef0a

Browse files
Migrate Gpt3 to NNX.
1 parent 69ed0c5 commit 966ef0a

File tree

4 files changed

+208
-170
lines changed

4 files changed

+208
-170
lines changed

src/MaxText/layers/decoders.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def get_decoder_layers(self):
411411
case DecoderBlockType.GEMMA3:
412412
return [gemma3.Gemma3DecoderLayerToLinen]
413413
case DecoderBlockType.GPT3:
414-
return [gpt3.Gpt3DecoderLayer]
414+
return [gpt3.Gpt3DecoderLayerToLinen]
415415
case DecoderBlockType.GPT_OSS:
416416
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
417417
case DecoderBlockType.QWEN3:
@@ -590,7 +590,7 @@ def _apply_embedding(
590590
name="position_embedder",
591591
config=cfg,
592592
mesh=self.mesh,
593-
)(decoder_positions, model_mode=model_mode)
593+
)(decoder_positions.astype("int32"), model_mode=model_mode)
594594
return y
595595

596596
@nn.compact
@@ -843,9 +843,7 @@ def __call__(
843843
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
844844
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
845845
for index in range(num_layers):
846-
y = layer(
847-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
848-
)(
846+
y = layer(config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode)(
849847
y,
850848
decoder_segment_ids,
851849
decoder_positions,

0 commit comments

Comments
 (0)