Skip to content

Conversation

@hsuan-lun-chiang
Copy link
Collaborator

Description

Migrate Gpt-OSS implementation from Linen to NNX.

Tests

Ran train command to train gpt-oss for 20 steps:

python3 -m MaxText.train src/MaxText/configs/base.yml     base_output_directory=gs://maxtext-test/gpt-oss-train/     run_name=megablox_pre_training     model_name=gpt-oss-20b     tokenizer_type=huggingface     tokenizer_path=openai/gpt-oss-20b     dataset_type=synthetic     enable_checkpointing=true     attention=flash     sparse_matmul=True     megablox=True     dtype=bfloat16     weight_dtype=bfloat16     per_device_batch_size=4     steps=30     max_target_length=1024     ici_fsdp_parallelism=8

Logs:

Logs - Before Migration
Logs - After Migration

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@RissyRan
Copy link
Collaborator

RissyRan commented Oct 2, 2025

I see you added gemini-review flag, and it didn't work. We have a rule to check if this branch is forked version here

@hsuan-lun-chiang
Copy link
Collaborator Author

I see you added gemini-review flag, and it didn't work. We have a rule to check if this branch is forked version here

I wasn’t aware of that rule, thank you for pointing it out!

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsuan-lun-chiang. Could you please run train (you already have this), decode, and then maxengine/jetstream (with profiles collected for maxengine/jetstream)? Similar to #2088

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Gpt-OSS-to-NNX branch from cddd713 to d200e51 Compare October 15, 2025 07:21
@ecnal-cienet
Copy link

ecnal-cienet commented Oct 15, 2025

Additional Tests

Environment

Machine Type: TPU v6e-8
How to execute commands: SSH to TPU, enter Python venv and directly execute commands.


Train

Executed Command:

python3 -m MaxText.train src/MaxText/configs/base.yml \
base_output_directory=gs://lance-maxtext/gpt-oss-train \
run_name=gpt_oss_train_after \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
dataset_type=synthetic \
attention=flash \
weight_dtype=bfloat16 \
per_device_batch_size=4 \
steps=30 \
max_target_length=1024 \
ici_fsdp_parallelism=8

Logs:
Before Migration
After Migration
WebDiff


Decode

Executed Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=gpt-oss-20b \
    tokenizer_path=openai/gpt-oss-20b \
    tokenizer_type=huggingface \
    scan_layers=false \
    per_device_batch_size=1 \
    max_target_length=1024 \
    ici_fsdp_parallelism=1 \
    ici_tensor_parallelism=8 \
    max_prefill_predict_length=128 \
    max_target_length=1024 \
    prompt="I love to" \
    attention=dot_product \
    mla_naive_kvcache=False
    load_parameters_path=gs://lance-maxtext/gpt-oss-20b-ckpt-unscanned/0/items

Logs:
Before Migration
After Migration
WebDiff


MaxEngine / JetStream

Step 1: Launch MaxEngine:

# On terminal 1
export LIBTPU_INIT_ARGS="--xla_jf_auto_cross_replica_sharding=false --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_tpu_enable_windowed_einsum_for_all_gather=false --xla_tpu_prefer_latch_optimized_rhs_layouts=true --xla_tpu_enable_experimental_fusion_cost_model=false --xla_tpu_dot_dot_fusion_duplicated=false --xla_tpu_dot_dot_fusion=true --xla_jf_conv_input_fusion=true --xla_jf_conv_output_fusion=true --xla_tpu_rwb_fusion=false --xla_tpu_copy_fusion_pad_unpad_ratio=0 --xla_tpu_licm_size_inflation_ratio=1 --xla_tpu_copy_elision_analysis_allowance=150000 --xla_tpu_copy_insertion_use_region_analysis_limit=10000 --xla_tpu_order_dot_after_layout=true --xla_jf_rematerialization_percent_shared_memory_limit=100 --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true --xla_tpu_enforce_prefetch_fifo_order=false --xla_tpu_prefetch_interval_picker_size_override=6000000 --xla_tpu_async_copy_bandwidth_scaling_factor=1 --xla_tpu_nd_short_transfer_max_chunks=-1 --xla_tpu_enable_aggressive_broadcast_priority_update=true --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT --xla_tpu_memory_bound_loop_optimizer_options=enabled:true --xla_tpu_enable_copy_fusion=true --xla_tpu_enable_cross_program_prefetch_freeing=false --xla_tpu_enable_dot_strength_reduction=true --xla_tpu_layout_use_dot_grouping=false --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false --xla_tpu_vector_load_fusion_window=1024 --xla_tpu_vector_store_fusion_window=256 --xla_jf_conv_reshape_fusion=false --xla_tpu_input_conv_multi_users=false --xla_tpu_enable_multi_level_input_dot_dot_fusion=false --xla_tpu_enable_multi_level_output_dot_dot_fusion=false --xla_tpu_dot_dot_fusion_separable_convs_only=false --xla_tpu_enable_multi_level_nested_loop_fusion=true --xla_tpu_nested_dot_fusion=true --xla_tpu_enable_multi_level_nested_dot_fusion=false --xla_jf_enable_multi_output_fusion=true --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false --xla_tpu_enable_flash_attention=true"

python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
scan_layers=false \
per_device_batch_size=1 \
max_target_length=1024 \
ici_fsdp_parallelism=1 \
ici_tensor_parallelism=8 \
max_prefill_predict_length=4096 \
max_target_length=8192 \
attention=dot_product \
mla_naive_kvcache=False \
load_parameters_path=gs://lance-maxtext/gpt-oss-20b-ckpt-unscanned/0/items \
enable_jax_profiler=True \
hf_access_token=$HF_TOKEN

Step 2: Launch JetStream

# On terminal 2
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py \
--tokenizer=openai/gpt-oss-20b \
--num-prompts 5000 \
--dataset mmlu \
--dataset-path mmlu/data/test/ \
--request-rate 0 \
--warmup-mode sampled \
--save-request-outputs \
--run-eval True \
--use-hf-tokenizer True

Step 3: Output Collection

# On terminal 2
# Defile data saving location
RUN=run-$(date +%Y-%m-%d-%H-%M-%S)
echo $RUN
log_dir=$HOME/test_memory/mistral/$RUN
echo $log_dir

# Collect data
python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link

Logs:
MaxEngine-Before
MaxEngine-After
WebDiff

JetStream-Before
JetStream-After
WebDiff

Xprof-Before
Xprof-After

@bvandermoon
Copy link
Collaborator

Additional Tests

Environment

Machine Type: TPU v6e-8 How to execute commands: SSH to TPU, enter Python venv and directly execute commands.

Train

Executed Command:

python3 -m MaxText.train src/MaxText/configs/base.yml \
base_output_directory=gs://lance-maxtext/gpt-oss-train \
run_name=gpt_oss_train_after \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
dataset_type=synthetic \
attention=flash \
weight_dtype=bfloat16 \
per_device_batch_size=4 \
steps=30 \
max_target_length=1024 \
ici_fsdp_parallelism=8

Logs: Before Migration After Migration WebDiff

Decode

Executed Command:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=gpt-oss-20b \
    tokenizer_path=openai/gpt-oss-20b \
    tokenizer_type=huggingface \
    scan_layers=false \
    per_device_batch_size=1 \
    max_target_length=1024 \
    ici_fsdp_parallelism=1 \
    ici_tensor_parallelism=8 \
    max_prefill_predict_length=128 \
    max_target_length=1024 \
    prompt="I love to" \
    attention=dot_product \
    mla_naive_kvcache=False
    load_parameters_path=gs://lance-maxtext/gpt-oss-20b-ckpt-unscanned/0/items

Logs: Before Migration After Migration WebDiff

MaxEngine / JetStream

Step 1: Launch MaxEngine:

# On terminal 1
export LIBTPU_INIT_ARGS="--xla_jf_auto_cross_replica_sharding=false --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_tpu_enable_windowed_einsum_for_all_gather=false --xla_tpu_prefer_latch_optimized_rhs_layouts=true --xla_tpu_enable_experimental_fusion_cost_model=false --xla_tpu_dot_dot_fusion_duplicated=false --xla_tpu_dot_dot_fusion=true --xla_jf_conv_input_fusion=true --xla_jf_conv_output_fusion=true --xla_tpu_rwb_fusion=false --xla_tpu_copy_fusion_pad_unpad_ratio=0 --xla_tpu_licm_size_inflation_ratio=1 --xla_tpu_copy_elision_analysis_allowance=150000 --xla_tpu_copy_insertion_use_region_analysis_limit=10000 --xla_tpu_order_dot_after_layout=true --xla_jf_rematerialization_percent_shared_memory_limit=100 --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true --xla_tpu_enforce_prefetch_fifo_order=false --xla_tpu_prefetch_interval_picker_size_override=6000000 --xla_tpu_async_copy_bandwidth_scaling_factor=1 --xla_tpu_nd_short_transfer_max_chunks=-1 --xla_tpu_enable_aggressive_broadcast_priority_update=true --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT --xla_tpu_memory_bound_loop_optimizer_options=enabled:true --xla_tpu_enable_copy_fusion=true --xla_tpu_enable_cross_program_prefetch_freeing=false --xla_tpu_enable_dot_strength_reduction=true --xla_tpu_layout_use_dot_grouping=false --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false --xla_tpu_vector_load_fusion_window=1024 --xla_tpu_vector_store_fusion_window=256 --xla_jf_conv_reshape_fusion=false --xla_tpu_input_conv_multi_users=false --xla_tpu_enable_multi_level_input_dot_dot_fusion=false --xla_tpu_enable_multi_level_output_dot_dot_fusion=false --xla_tpu_dot_dot_fusion_separable_convs_only=false --xla_tpu_enable_multi_level_nested_loop_fusion=true --xla_tpu_nested_dot_fusion=true --xla_tpu_enable_multi_level_nested_dot_fusion=false --xla_jf_enable_multi_output_fusion=true --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false --xla_tpu_enable_flash_attention=true"

python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
tokenizer_type=huggingface \
tokenizer_path=openai/gpt-oss-20b \
scan_layers=false \
per_device_batch_size=1 \
max_target_length=1024 \
ici_fsdp_parallelism=1 \
ici_tensor_parallelism=8 \
max_prefill_predict_length=4096 \
max_target_length=8192 \
attention=dot_product \
mla_naive_kvcache=False \
load_parameters_path=gs://lance-maxtext/gpt-oss-20b-ckpt-unscanned/0/items \
enable_jax_profiler=True \
hf_access_token=$HF_TOKEN

Step 2: Launch JetStream

# On terminal 2
JAX_PLATFORMS=tpu python benchmarks/benchmark_serving.py \
--tokenizer=openai/gpt-oss-20b \
--num-prompts 5000 \
--dataset mmlu \
--dataset-path mmlu/data/test/ \
--request-rate 0 \
--warmup-mode sampled \
--save-request-outputs \
--run-eval True \
--use-hf-tokenizer True

Step 3: Output Collection

# On terminal 2
# Defile data saving location
RUN=run-$(date +%Y-%m-%d-%H-%M-%S)
echo $RUN
log_dir=$HOME/test_memory/mistral/$RUN
echo $log_dir

# Collect data
python -m jax.collect_profile 9999 6000 --log_dir=$log_dir --no_perfetto_link

Logs: MaxEngine-Before MaxEngine-After WebDiff

JetStream-Before JetStream-After WebDiff

Xprof-Before Xprof-After

Thanks @ecnal-cienet for running these tests. The Maxengine/Jetstream profiles looks a bit off. Could you try starting the profiles after the Jetstream warmup requests have finished and the real requests have started? That is the portion we want to profile. Let me know if you have already done this and we can discuss further

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM. Left a comment on the profile collection

@ecnal-cienet
Copy link

Hi @bvandermoon, I re-tested JetStream profiling and already updated the xprof links of my previous comment. Thanks. If there is still issues, please let me know. Thanks.

@bvandermoon
Copy link
Collaborator

Hi @bvandermoon, I re-tested JetStream profiling and already updated the xprof links of my previous comment. Thanks. If there is still issues, please let me know. Thanks.

Awesome, thank you @ecnal-cienet. The profiles look good to me now

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @hsuan-lun-chiang and @ecnal-cienet. @cgarciae could you please take a look as well?

@copybara-service copybara-service bot merged commit 3c0d285 into AI-Hypercomputer:main Nov 6, 2025
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants