diff --git a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml index 056c15294a..5ef1591ecd 100644 --- a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml +++ b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml @@ -44,6 +44,7 @@ policy: empty_unused_memory_level: 2 enabled: true activation_checkpointing: true + moe_grouped_gemm: true tensor_model_parallel_size: 8 expert_model_parallel_size: 32 pipeline_model_parallel_size: 8 diff --git a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml index c8a9487175..8405629c66 100644 --- a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml @@ -24,6 +24,7 @@ policy: pipeline_model_parallel_size: 16 expert_model_parallel_size: 16 activation_checkpointing: true + moe_grouped_gemm: true num_layers_in_first_pipeline_stage: 3 num_layers_in_last_pipeline_stage: 2 apply_rope_fusion: false diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml index aecdabba73..1228db60bb 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml @@ -29,6 +29,7 @@ policy: context_parallel_size: 2 expert_model_parallel_size: 16 activation_checkpointing: true + moe_grouped_gemm: true num_layers_in_first_pipeline_stage: 11 num_layers_in_last_pipeline_stage: 11 defer_fp32_logits: true diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml index 21b9746f4b..41fd83ec21 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml @@ -21,6 +21,7 @@ policy: pipeline_model_parallel_size: 1 expert_model_parallel_size: 16 sequence_parallel: false + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml index 3b4f22ffbd..a8e130f853 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml @@ -24,6 +24,7 @@ policy: expert_model_parallel_size: 8 sequence_parallel: true context_parallel_size: 8 + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml index 795764d3ee..7a03394402 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml @@ -21,6 +21,7 @@ policy: pipeline_model_parallel_size: 1 expert_model_parallel_size: 8 sequence_parallel: false + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index fc5c6c44fa..4c99bacb0b 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -432,6 +432,9 @@ def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"] + if "moe_grouped_gemm" in config["megatron_cfg"]: + model_cfg.moe_grouped_gemm = config["megatron_cfg"]["moe_grouped_gemm"] + def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None: if "mtp_num_layers" in config["megatron_cfg"]: diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index ec4c9e66bb..63af536819 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -236,6 +236,10 @@ class MegatronConfig(TypedDict): moe_token_dispatcher_type: str # Can be used only with 'alltoall' token dispatcher moe_shared_expert_overlap: bool + # Enable grouped GEMM for MoE experts via CUTLASS. Significant throughput + # gain when multiple experts are assigned per rank (num_local_experts > 1). + # Requires TE >= 1.11.0 for FP8 and Ampere (sm_80) or newer. + moe_grouped_gemm: NotRequired[bool] peft: NotRequired[MegatronPeftConfig | MegatronPeftConfigDisabled] optimizer: MegatronOptimizerConfig scheduler: MegatronSchedulerConfig