Skip to content

Commit d12fdf9

Browse files
Migrate Gpt3 to NNX.
1 parent 9d0c860 commit d12fdf9

File tree

2 files changed

+203
-166
lines changed

2 files changed

+203
-166
lines changed

src/MaxText/layers/decoders.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def get_decoder_layers(self):
405405
case DecoderBlockType.GEMMA3:
406406
return [gemma3.Gemma3DecoderLayerToLinen]
407407
case DecoderBlockType.GPT3:
408-
return [gpt3.Gpt3DecoderLayer]
408+
return [gpt3.Gpt3DecoderLayerToLinen]
409409
case DecoderBlockType.GPT_OSS:
410410
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
411411
case DecoderBlockType.QWEN3:
@@ -584,7 +584,7 @@ def _apply_embedding(
584584
name="position_embedder",
585585
config=cfg,
586586
mesh=self.mesh,
587-
)(decoder_positions, model_mode=model_mode)
587+
)(decoder_positions.astype("int32"), model_mode=model_mode)
588588
return y
589589

590590
@nn.compact
@@ -837,9 +837,7 @@ def __call__(
837837
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
838838
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
839839
for index in range(num_layers):
840-
y = layer(
841-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
842-
)(
840+
y = layer(config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode)(
843841
y,
844842
decoder_segment_ids,
845843
decoder_positions,

0 commit comments

Comments
 (0)