-
Notifications
You must be signed in to change notification settings - Fork 220
[tx] Implement expert parallelism #842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request implements expert parallelism for Mixture-of-Experts (MoE) layers, which is a significant enhancement for model scaling. The changes are well-structured, introducing a clean expert_parallel_dispatch_combine utility that encapsulates the parallelism logic using jax.shard_map. The modifications to LoRAExpert and Qwen3Experts to support this are logical and correctly integrated. The necessary configuration and device mesh updates are also included. Overall, this is a solid implementation. I've provided a few suggestions to improve code readability and maintainability.
| lora_A_reshaped = self.lora_A.value.reshape(num_flattened_groups, self.in_features, self.max_lora_rank) | ||
| lora_B_reshaped = self.lora_B.value.reshape(num_flattened_groups, self.max_lora_rank, self.out_features) | ||
| # Reshape LoRA weights in expert-first order (already local shards) | ||
| lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tried without the transpose and just putting experts first in the weight tensors when initializing them -- the code is more complicated and also slower, e.g. 54s step time vs 40s with sl_loop.py and
uv run --extra gpu --extra tinker -m tx.tinker.api --base-model Qwen/Qwen3-30B-A3B --backend-config
'{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'
Though this is somewhat surprising and there might be more optimization potential in the future, but for now it is best to keep it as simple as possible.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces expert parallelism for Mixture-of-Experts (MoE) models, a significant enhancement for training large-scale models. The implementation correctly uses jax.shard_map to distribute expert computations across devices, with a new shard_map_ep utility function and corresponding updates to model partitioning specs. The changes also account for LoRA weights in the expert-parallel setup. My review focuses on improving robustness and usability. I've suggested adding an assertion to prevent silent errors from incorrect expert distribution, refactoring duplicated test setup code into a pytest fixture for better maintainability, and enhancing the standalone training script to support the new parallelism feature.
| def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights): | ||
| # Calculate local offset for this shard | ||
| ep_rank = jax.lax.axis_index("ep") | ||
| experts_per_rank = num_experts // jax.lax.axis_size("ep") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This integer division assumes that num_experts is evenly divisible by the number of devices in the 'ep' mesh axis. If it's not, this could lead to an incorrect number of experts being assigned per rank, causing silent errors or incorrect model behavior. It would be much safer to add an assertion to validate this assumption, ideally during model initialization.
For example, you could add this check in Qwen3Experts.__init__:
ep_size = get_abstract_mesh().shape.get("ep", 1)
assert self.config.num_experts % ep_size == 0, f"Number of experts ({self.config.num_experts}) must be divisible by expert parallel size ({ep_size})."
skyrl-tx/tests/models/test_qwen3.py
Outdated
| if not jax._src.xla_bridge.backends_are_initialized(): | ||
| jax.config.update("jax_num_cpu_devices", ep * tp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This device initialization logic is duplicated in test_qwen3_moe_layer_lora (lines 117-118). To improve maintainability and reduce redundancy, consider extracting this setup into a shared pytest fixture. The fixture could automatically run for tests that require ep and tp parameters, making the test suite cleaner and easier to manage.
This PR implements expert parallelism with
jax.shard_map. The tokens are already replicated on each rank from the previous layer, so we just filter the tokens that should be processed by each local expert using the group_offset feature implemented in #860.Here are some performance results, they were generated on 8xH1100 using
(or replacing
expert_parallel_sizewithtensor_parallel_sizefor the TP case).First with the new PR and EP:
Then with the new PR and TP (much slower):
And last with TP on current main, to make sure there is no performance regression: