Skip to content

Conversation

@pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Jan 5, 2026

This PR implements expert parallelism with jax.shard_map. The tokens are already replicated on each rank from the previous layer, so we just filter the tokens that should be processed by each local expert using the group_offset feature implemented in #860.

Here are some performance results, they were generated on 8xH1100 using

uv run --extra gpu --extra tinker -m tx.tinker.api     --base-model Qwen/Qwen3-30B-A3B     --backend-config
 '{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

(or replacing expert_parallel_size with tensor_parallel_size for the TP case).

First with the new PR and EP:

uv run --with wandb --with tinker==0.3.0 sl_loop.py     base_url=http://localhost:8000     model_name=Qwen/Qwen3-30B-A3B lora_rank=1 max_length=512
WARNING: WANDB_API_KEY environment variable not set. Skipping W&B logging. 
tinker_cookbook.utils.ml_log:475 [INFO] Logging to: /tmp/tinker-examples/sl-loop
__main__:49 [INFO] Using renderer: qwen3
__main__:52 [INFO] Loading dataset...
__main__:58 [INFO] Train batches: 74
tinker.lib.public_interfaces.service_client:61 [INFO] ServiceClient initialized for session session_4906965b
tinker_cookbook.checkpoint_utils:19 [INFO] No checkpoints found at /tmp/tinker-examples/sl-loop/checkpoints.jsonl
tinker_cookbook.checkpoint_utils:48 [INFO] No checkpoints found with key state_path in /tmp/tinker-examples/sl-loop
tinker.lib.public_interfaces.service_client:126 [INFO] Creating TrainingClient for model_id='model_ecb81c50'
__main__:78 [INFO] Training for 74 steps
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 0                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000100   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 34993      │
│ progress                       │ 0.000000   │
│ time_total                     │ 271.401147 │
│ train_mean_nll                 │ 2.839665   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 1                    
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric                         ┃ Value     ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ learning_rate                  │ 0.000099  │
│ num_sequences                  │ 128       │
│ num_tokens                     │ 32341     │
│ progress                       │ 0.013514  │
│ time_total                     │ 47.546760 │
│ train_mean_nll                 │ 2.636626  │
└────────────────────────────────┴───────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 2                    
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric                         ┃ Value     ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ learning_rate                  │ 0.000097  │
│ num_sequences                  │ 128       │
│ num_tokens                     │ 32905     │
│ progress                       │ 0.027027  │
│ time_total                     │ 40.470125 │
│ train_mean_nll                 │ 2.335052  │
└────────────────────────────────┴───────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 3                    
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric                         ┃ Value     ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ learning_rate                  │ 0.000096  │
│ num_sequences                  │ 128       │
│ num_tokens                     │ 35807     │
│ progress                       │ 0.040541  │
│ time_total                     │ 41.397586 │
│ train_mean_nll                 │ 2.240090  │
└────────────────────────────────┴───────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl

Then with the new PR and TP (much slower):

uv run --with wandb --with tinker==0.3.0 sl_loop.py     base_url=http://localhost:8000     model_name=Qwen/Qwen3-30B-A3B lora_rank=1 max_length=512
WARNING: WANDB_API_KEY environment variable not set. Skipping W&B logging. 
tinker_cookbook.utils.ml_log:475 [INFO] Logging to: /tmp/tinker-examples/sl-loop
__main__:49 [INFO] Using renderer: qwen3
__main__:52 [INFO] Loading dataset...
__main__:58 [INFO] Train batches: 74
tinker.lib.public_interfaces.service_client:61 [INFO] ServiceClient initialized for session session_0c5ed6e0
tinker_cookbook.checkpoint_utils:19 [INFO] No checkpoints found at /tmp/tinker-examples/sl-loop/checkpoints.jsonl
tinker_cookbook.checkpoint_utils:48 [INFO] No checkpoints found with key state_path in /tmp/tinker-examples/sl-loop
tinker.lib.public_interfaces.service_client:126 [INFO] Creating TrainingClient for model_id='model_c1342141'
__main__:78 [INFO] Training for 74 steps
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 0                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000100   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 34993      │
│ progress                       │ 0.000000   │
│ time_total                     │ 352.432744 │
│ train_mean_nll                 │ 2.835570   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 1                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000099   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 32341      │
│ progress                       │ 0.013514   │
│ time_total                     │ 121.366442 │
│ train_mean_nll                 │ 2.630018   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 2                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000097   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 32905      │
│ progress                       │ 0.027027   │
│ time_total                     │ 113.349039 │
│ train_mean_nll                 │ 2.326924   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 3                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000096   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 35807      │
│ progress                       │ 0.040541   │
│ time_total                     │ 113.365831 │
│ train_mean_nll                 │ 2.238438   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 4                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000095   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 35795      │
│ progress                       │ 0.054054   │
│ time_total                     │ 113.418735 │
│ train_mean_nll                 │ 2.176662   │
└────────────────────────────────┴────────────┘

And last with TP on current main, to make sure there is no performance regression:

uv run --with wandb --with tinker==0.3.0 sl_loop.py     base_url=http://localhost:8000     model_name=Qwen/Qwen3-30B-A3B lora_rank=1 max_length=512
WARNING: WANDB_API_KEY environment variable not set. Skipping W&B logging. 
tinker_cookbook.utils.ml_log:475 [INFO] Logging to: /tmp/tinker-examples/sl-loop
__main__:49 [INFO] Using renderer: qwen3
__main__:52 [INFO] Loading dataset...
__main__:58 [INFO] Train batches: 74
tinker.lib.public_interfaces.service_client:61 [INFO] ServiceClient initialized for session session_5b10a9a0
tinker_cookbook.checkpoint_utils:19 [INFO] No checkpoints found at /tmp/tinker-examples/sl-loop/checkpoints.jsonl
tinker_cookbook.checkpoint_utils:48 [INFO] No checkpoints found with key state_path in /tmp/tinker-examples/sl-loop
tinker.lib.public_interfaces.service_client:126 [INFO] Creating TrainingClient for model_id='model_bc6d7add'
__main__:78 [INFO] Training for 74 steps
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 0                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000100   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 34993      │
│ progress                       │ 0.000000   │
│ time_total                     │ 419.397852 │
│ train_mean_nll                 │ 2.835570   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 1                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000099   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 32341      │
│ progress                       │ 0.013514   │
│ time_total                     │ 235.424837 │
│ train_mean_nll                 │ 2.642184   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 2                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000097   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 32905      │
│ progress                       │ 0.027027   │
│ time_total                     │ 227.473726 │
│ train_mean_nll                 │ 2.337117   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 3                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000096   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 35807      │
│ progress                       │ 0.040541   │
│ time_total                     │ 227.534498 │
│ train_mean_nll                 │ 2.240172   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:143 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:195 [INFO] 
                    Step 4                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000095   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 35795      │
│ progress                       │ 0.054054   │
│ time_total                     │ 228.449649 │
│ train_mean_nll                 │ 2.175244   │
└────────────────────────────────┴────────────┘

@pcmoritz pcmoritz added the tx label Jan 5, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements expert parallelism for Mixture-of-Experts (MoE) layers, which is a significant enhancement for model scaling. The changes are well-structured, introducing a clean expert_parallel_dispatch_combine utility that encapsulates the parallelism logic using jax.shard_map. The modifications to LoRAExpert and Qwen3Experts to support this are logical and correctly integrated. The necessary configuration and device mesh updates are also included. Overall, this is a solid implementation. I've provided a few suggestions to improve code readability and maintainability.

lora_A_reshaped = self.lora_A.value.reshape(num_flattened_groups, self.in_features, self.max_lora_rank)
lora_B_reshaped = self.lora_B.value.reshape(num_flattened_groups, self.max_lora_rank, self.out_features)
# Reshape LoRA weights in expert-first order (already local shards)
lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried without the transpose and just putting experts first in the weight tensors when initializing them -- the code is more complicated and also slower, e.g. 54s step time vs 40s with sl_loop.py and

uv run --extra gpu --extra tinker -m tx.tinker.api     --base-model Qwen/Qwen3-30B-A3B     --backend-config
 '{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

Though this is somewhat surprising and there might be more optimization potential in the future, but for now it is best to keep it as simple as possible.

@pcmoritz
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces expert parallelism for Mixture-of-Experts (MoE) models, a significant enhancement for training large-scale models. The implementation correctly uses jax.shard_map to distribute expert computations across devices, with a new shard_map_ep utility function and corresponding updates to model partitioning specs. The changes also account for LoRA weights in the expert-parallel setup. My review focuses on improving robustness and usability. I've suggested adding an assertion to prevent silent errors from incorrect expert distribution, refactoring duplicated test setup code into a pytest fixture for better maintainability, and enhancing the standalone training script to support the new parallelism feature.

def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights):
# Calculate local offset for this shard
ep_rank = jax.lax.axis_index("ep")
experts_per_rank = num_experts // jax.lax.axis_size("ep")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This integer division assumes that num_experts is evenly divisible by the number of devices in the 'ep' mesh axis. If it's not, this could lead to an incorrect number of experts being assigned per rank, causing silent errors or incorrect model behavior. It would be much safer to add an assertion to validate this assumption, ideally during model initialization.

For example, you could add this check in Qwen3Experts.__init__:

ep_size = get_abstract_mesh().shape.get("ep", 1)
assert self.config.num_experts % ep_size == 0, f"Number of experts ({self.config.num_experts}) must be divisible by expert parallel size ({ep_size})."

Comment on lines 69 to 70
if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update("jax_num_cpu_devices", ep * tp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This device initialization logic is duplicated in test_qwen3_moe_layer_lora (lines 117-118). To improve maintainability and reduce redundancy, consider extracting this setup into a shared pytest fixture. The fixture could automatically run for tests that require ep and tp parameters, making the test suite cleaner and easier to manage.

@pcmoritz pcmoritz merged commit a6ba2fc into NovaSky-AI:main Jan 12, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant