Skip to content

Conversation

@kocchop
Copy link
Collaborator

@kocchop kocchop commented Nov 6, 2025

Description

Soften the EP check for GPUs. Currently the check is there for some TPU constraints. However, with GPUs, there are no functional limitations on using EP across both ICI and DCN. Hence, added a hardware check to the logic.

Tests

Sample end-2-end test script for GPUs

export XLA_FLAGS= "--xla_gpu_enable_latency_hiding_scheduler=true
                --xla_gpu_enable_latency_hiding_scheduler=true
                --xla_gpu_all_reduce_combine_threshold_bytes=134217728
                --xla_gpu_all_gather_combine_threshold_bytes=352321536
                --xla_gpu_reduce_scatter_combine_threshold_bytes=44040192
                --xla_gpu_enable_pipelined_all_gather=true
                --xla_gpu_enable_pipelined_reduce_scatter=true
                --xla_gpu_enable_pipelined_all_reduce=true
                --xla_gpu_enable_while_loop_double_buffering=true
                --xla_gpu_enable_all_gather_combine_by_dim=false
                --xla_gpu_enable_reduce_scatter_combine_by_dim=false
                --xla_disable_hlo_passes=rematerialization"
                
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.90
export NVTE_FUSED_ATTN=1

python3 -m MaxText.train /opt/maxtext/src/MaxText/configs/base.yml \
    run_name=logdir \
    use_iota_embed=false \
    scan_layers=True \
    steps=31 \
    per_device_batch_size=1 \
    model_name=llama4-17b-16e \
    remat_policy=full \
    enable_checkpointing=false \
    logits_dot_in_fp32=false \
    base_output_directory=/opt/maxtext/local_train \
    dataset_type=synthetic \
    attention=cudnn_flash_te \
    max_target_length=8192 \
    sparse_matmul=True \
    megablox=false \
    enable_goodput_recording=false \
    monitor_goodput=false \
    hardware=gpu_multiprocess \
    dcn_fsdp_parallelism=2 \
    ici_fsdp_parallelism=1 \
    ici_expert_parallelism=8 \
    dcn_expert_parallelism=2 \
    ici_data_parallelism=1 \
    dcn_data_parallelism=1

Please note that the above example is for JAX distributed setting using 4 nodes (32 GPUs).

Checklist

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Please attach your GPU test into PR description.

@kocchop kocchop force-pushed the faysal/soften-ep-check branch from c0a1744 to 60c5b6e Compare November 6, 2025 03:50
@kocchop kocchop requested a review from RissyRan November 6, 2025 19:06
@copybara-service copybara-service bot merged commit 725d15c into AI-Hypercomputer:main Nov 7, 2025
63 of 66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants