Skip to content

Conversation

@hsuan-lun-chiang
Copy link
Collaborator

@hsuan-lun-chiang hsuan-lun-chiang commented Aug 1, 2025

Description

This PR

  1. Migrate Gpt3 implementation from Linen to NNX.

Including the following classes:

  • Gpt3MultiHeadAttention
  • Gpt3DecoderLayer
  1. Fix the decode function of Gpt3, by
  • Cast decoder_positions to int32 - The trainable position embedding layer (Embed layer) requires integer indices for its lookup, but decoder_positions was passed as a float. This casts it to int32 to prevent the ValueError.

Tests

Ran train command to train gpt3-6b for 10 steps:

python3 -m MaxText.train  MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/1/ model_name=gpt3-6b dataset_type=synthetic steps=10

Logs:
Linen, before migration
NNX, after migration

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome to see this @hsuan-lun-chiang. Could you please add before/after logs for the training command shown in your test section? It would be great to see that perf is the same before/after here

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch 7 times, most recently from 1df6836 to b1f7bd8 Compare August 5, 2025 06:14
@hsuan-lun-chiang
Copy link
Collaborator Author

Awesome to see this @hsuan-lun-chiang. Could you please add before/after logs for the training command shown in your test section? It would be great to see that perf is the same before/after here

Update the description with the before/after logs, thank you.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae can you please take a look at this also?

@mesakhcienet mesakhcienet force-pushed the feat/Migrate-Gpt3-to-NNX branch from b1f7bd8 to f779923 Compare August 5, 2025 07:25
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch 5 times, most recently from cacd42b to 4d57ee8 Compare August 6, 2025 11:16
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsuan-lun-chiang. Could you please run train (you already have this), decode, and then maxengine/jetstream (with profiles collected for maxengine/jetstream)? Similar to #2088. I can help with profile collection offline if you want, just let me know

@ecnal-cienet
Copy link

ecnal-cienet commented Oct 9, 2025

Results for Train and Jetstream

Test Environment

Machine Type: TPU V6e-8

Train

Executed Command:

python3 -m MaxText.train \
MaxText/configs/base.yml \
run_name=gpt3-train-run \
base_output_directory=gs://lance-maxtext/gpt3-6b-train-before/ \
model_name=gpt3-6b \
dataset_type=synthetic \
steps=10

Results:


Maxengine / Jetstream

Step 1: Launch Maxengine

# On terminal 1
export LIBTPU_INIT_ARGS="--xla_jf_auto_cross_replica_sharding=false --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_tpu_enable_windowed_einsum_for_all_gather=false --xla_tpu_prefer_latch_optimized_rhs_layouts=true --xla_tpu_enable_experimental_fusion_cost_model=false --xla_tpu_dot_dot_fusion_duplicated=false --xla_tpu_dot_dot_fusion=true --xla_jf_conv_input_fusion=true --xla_jf_conv_output_fusion=true --xla_tpu_rwb_fusion=false --xla_tpu_copy_fusion_pad_unpad_ratio=0 --xla_tpu_licm_size_inflation_ratio=1 --xla_tpu_copy_elision_analysis_allowance=150000 --xla_tpu_copy_insertion_use_region_analysis_limit=10000 --xla_tpu_order_dot_after_layout=true --xla_jf_rematerialization_percent_shared_memory_limit=100 --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true --xla_tpu_enforce_prefetch_fifo_order=false --xla_tpu_prefetch_interval_picker_size_override=6000000 --xla_tpu_async_copy_bandwidth_scaling_factor=1 --xla_tpu_nd_short_transfer_max_chunks=-1 --xla_tpu_enable_aggressive_broadcast_priority_update=true --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT --xla_tpu_memory_bound_loop_optimizer_options=enabled:true --xla_tpu_enable_copy_fusion=true --xla_tpu_enable_cross_program_prefetch_freeing=false --xla_tpu_enable_dot_strength_reduction=true --xla_tpu_layout_use_dot_grouping=false --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false --xla_tpu_vector_load_fusion_window=1024 --xla_tpu_vector_store_fusion_window=256 --xla_jf_conv_reshape_fusion=false --xla_tpu_input_conv_multi_users=false --xla_tpu_enable_multi_level_input_dot_dot_fusion=false --xla_tpu_enable_multi_level_output_dot_dot_fusion=false --xla_tpu_dot_dot_fusion_separable_convs_only=false --xla_tpu_enable_multi_level_nested_loop_fusion=true --xla_tpu_nested_dot_fusion=true --xla_tpu_enable_multi_level_nested_dot_fusion=false --xla_jf_enable_multi_output_fusion=true --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false --xla_tpu_enable_flash_attention=true"

python3 -m MaxText.maxengine_server \
MaxText/configs/base.yml \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
per_device_batch_size=1 \
max_target_length=1024 \
ici_fsdp_parallelism=1 \
ici_tensor_parallelism=8 \
attention=dot_product \
load_parameters_path=gs://lance-maxtext/gpt-oss-train-after/test_pre_gpt_oss_20b/checkpoints/9/items \
max_prefill_predict_length=128 \
prompt="I love to" \
mla_naive_kvcache=False \
enable_jax_profiler=True hf_access_token=$HF_TOKEN

Step 2: Execute Jetstream

# On terminal 2
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py \
--tokenizer=openai/gpt-oss-20b \
--num-prompts 5000 \
--dataset mmlu \
--dataset-path mmlu/data/test/ \
--request-rate 0 \
--warmup-mode sampled \
--save-request-outputs \
--run-eval True \
--use-hf-tokenizer True

Step 3: Output Collection

# On terminal 2
# Defile data saving location
RUN=run-$(date +%Y-%m-%d-%H-%M-%S)
echo $RUN
log_dir=$HOME/test_memory/mistral/$RUN
echo $log_dir

# Collect data
python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link

Results:


Decode

Executed Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
model_name=gpt3-6b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
per_device_batch_size=1 \
ici_fsdp_parallelism=2 \
ici_autoregressive_parallelism=4 \
max_prefill_predict_length=128 \
prefill_chunk_size=0 \
prompt="I love to" \
attention=dot_product \
weight_dtype=bfloat16 \
load_parameters_path=gs://lance-maxtext/gpt3-6b-train-before/gpt3-train-run/checkpoints/0/items

For both before and after migration, we received the same error:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/wanglance_google_com/maxtext/src/MaxText/decode.py", line 211, in <module>
    app.run(main)
  File "/home/wanglance_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/wanglance_google_com/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/decode.py", line 97, in main
    params = engine.load_params(rng_load_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/maxengine.py", line 252, in load_params
    self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations(
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/maxtext_utils.py", line 1083, in get_prefill_kv_cache_annotations
    abstract_state = jax.eval_shape(init_kv_cache_partial)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/maxtext_utils.py", line 1070, in init_kv_cache
    model_vars = model.init(
                 ^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/models.py", line 64, in init
    return nn.Module.init(module, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/models.py", line 154, in __call__
    logits, hidden_state = self.decoder(
                           ^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/decoders.py", line 652, in __call__
    y = self._apply_embedding(
        ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/decoders.py", line 563, in _apply_embedding
    y += embed_as_linen(
         ^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/nnx_wrappers.py", line 437, in __call__
    out = method_fn(module, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wanglance_google_com/maxtext/src/MaxText/layers/embeddings.py", line 143, in __call__
    raise ValueError("Input type must be an integer or unsigned integer.")
ValueError: Input type must be an integer or unsigned integer.

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch 4 times, most recently from 691cf2f to edfe0a9 Compare October 21, 2025 03:53
Comment on lines 263 to 285
self.kv_cache_layer = kvcache.KVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=feature_dim[0],
key_seq_len=feature_dim[1],
value_seq_len=feature_dim[1],
key_heads=self.num_heads,
value_heads=self.num_heads,
key_head_size=self.head_dim,
value_head_size=self.head_dim,
dtype=self.dtype,
kv_quant=self.kv_quant,
prefill_cache_axis_order=prefill_cache_axis_order,
ar_cache_axis_order=ar_cache_axis_order,
model_mode=model_mode,
rngs=self.rngs,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bvandermoon,

During testing, we discovered that the decode function in Gpt3 failed with an AssertionError - assert prefill_kv_cache on the main branch, as indicated by the logs shared by @ecnal-cienet earlier. To address this, I've patched the code by adding the KVCahe to Gpt3MultiHeadAttention. I've also updated the PR description to reflect these changes.

However, since we currently lack a reference model, we're unable to verify the results with certainty.

It would be great if you could review these changes and provide any feedback. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsuan-lun-chiang. Does the error show up on main as well? I am wondering if decode has not been supported. If so, we don't need to add the KVCache here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the error also show up in the main branch.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsuan-lun-chiang. Can we remove the KVCache portion from this PR? That way we can just focus on ensuring the before/after match for the migration. It would be good to add it back as a follow-up though

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

@hsuan-lun-chiang
Copy link
Collaborator Author

hsuan-lun-chiang commented Nov 4, 2025

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

Hi @bvandermoon ,
Here are profiles collected with xplane:
Before
After

command:
python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10

@bvandermoon
Copy link
Collaborator

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

Hi @bvandermoon , Here are profiles collected with xplane: Before After

command: python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10

Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch from edfe0a9 to efaaa20 Compare November 5, 2025 10:21
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch 2 times, most recently from 3db6d5d to 51449d6 Compare November 5, 2025 10:45
@hsuan-lun-chiang
Copy link
Collaborator Author

@hsuan-lun-chiang do you have train profiles for this run? I see some issues with the Jetstream profiles, maybe they were started before the actual requests started happening?

Hi @bvandermoon , Here are profiles collected with xplane: Before After
command: python3 -m src.MaxText.train MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/33/ model_name=gpt3-6b dataset_type=synthetic steps=10

Thanks @hsuan-lun-chiang. The profiles look good. I will take one more pass on the PR tomorrow

Thank you! I also rebased the code to the latest version.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsuan-lun-chiang. Looking good, just have a few comments

Comment on lines 263 to 285
self.kv_cache_layer = kvcache.KVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=feature_dim[0],
key_seq_len=feature_dim[1],
value_seq_len=feature_dim[1],
key_heads=self.num_heads,
value_heads=self.num_heads,
key_head_size=self.head_dim,
value_head_size=self.head_dim,
dtype=self.dtype,
kv_quant=self.kv_quant,
prefill_cache_axis_order=prefill_cache_axis_order,
ar_cache_axis_order=ar_cache_axis_order,
model_mode=model_mode,
rngs=self.rngs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsuan-lun-chiang. Can we remove the KVCache portion from this PR? That way we can just focus on ensuring the before/after match for the migration. It would be good to add it back as a follow-up though

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch 3 times, most recently from d12fdf9 to 350b7ec Compare November 7, 2025 08:47
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt3-to-NNX branch from 350b7ec to 966ef0a Compare November 7, 2025 10:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants