-
Notifications
You must be signed in to change notification settings - Fork 432
[MoE][ddp] Enable distributed MoE calibration replacement #2449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,10 +15,10 @@ | |
| from abc import ABC | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from compressed_tensors.offload import is_distributed | ||
| from compressed_tensors.offload.dist_utils import is_distributed | ||
| from compressed_tensors.registry import RegistryMixin, standardize_lookup_name | ||
| from loguru import logger | ||
| from torch import distributed as dist | ||
| from tqdm import tqdm | ||
| from transformers import PreTrainedModel | ||
|
|
||
|
|
@@ -99,11 +99,33 @@ def moe_calibration_context( | |
| if _is_registered(class_name, MoECalibrationModule): | ||
| modules_to_replace.append((name, module, class_name)) | ||
|
|
||
| # Step 1.5: Verify all ranks have same number of modules (distributed mode) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also don't think this check is necessary |
||
| if is_distributed(): | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
|
|
||
| # Verify all ranks have same number of modules | ||
| num_modules = torch.tensor([len(modules_to_replace)], dtype=torch.long, device=next(model.parameters()).device) | ||
| all_counts = [torch.zeros_like(num_modules) for _ in range(world_size)] | ||
| dist.all_gather(all_counts, num_modules) | ||
|
|
||
| if not all(count.item() == num_modules.item() for count in all_counts): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes, the number of modules will be not evenly divisible by the number of ranks, so this check can be harmful.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just checking that each rank has the same number of modules total, not assigned I think |
||
| raise RuntimeError( | ||
| f"Rank {rank} found {num_modules.item()} MoE modules, but other " | ||
| f"ranks found different counts: {[c.item() for c in all_counts]}. " | ||
| "All ranks must have identical model structures." | ||
| ) | ||
|
|
||
| # Step 2: Replace modules with progress bar | ||
| if modules_to_replace: | ||
| logger.info(f"Found {len(modules_to_replace)} MoE modules to replace") | ||
| # Only rank 0 shows progress bar and logs | ||
| show_progress = not is_distributed() or dist.get_rank() == 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can use is_rank0, also probably simpler to just inline this into the if statement |
||
| if show_progress: | ||
| logger.info(f"Found {len(modules_to_replace)} MoE modules to replace") | ||
| for name, module, class_name in tqdm( | ||
| modules_to_replace, desc="Replacing MoE modules for calibration" | ||
| modules_to_replace, | ||
| desc="Replacing MoE modules for calibration", | ||
| disable=not show_progress, | ||
| ): | ||
| replacement = MoECalibrationModule.load_from_registry( | ||
| class_name, | ||
|
|
@@ -113,35 +135,45 @@ def moe_calibration_context( | |
| ) | ||
| model.set_submodule(name, replacement) | ||
| replaced[name] = (module, replacement) | ||
| if is_distributed(): | ||
| dist.barrier() | ||
|
|
||
| # Log what was replaced | ||
| # Synchronization barrier: all ranks complete replacement before calib | ||
| if is_distributed(): | ||
| dist.barrier() | ||
| logger.debug(f"Rank {dist.get_rank()}: Completed MoE module replacement") | ||
|
|
||
| # Log what was replaced (only rank 0 in distributed mode) | ||
| if replaced: | ||
| logger.info(f"Replaced {len(replaced)} MoE modules for calibration") | ||
| permanent_count = sum( | ||
| 1 for _, (_, repl) in replaced.items() if repl.is_permanent | ||
| ) | ||
| if permanent_count > 0: | ||
| logger.info( | ||
| f"{permanent_count}/{len(replaced)} modules will remain in " | ||
| "calibration form (permanent)" | ||
| ) | ||
| if permanent_count < len(replaced): | ||
| logger.info( | ||
| f"{len(replaced) - permanent_count}/{len(replaced)} modules will " | ||
| "be restored after calibration" | ||
| show_logs = not is_distributed() or dist.get_rank() == 0 | ||
| if show_logs: | ||
| logger.info(f"Replaced {len(replaced)} MoE modules for calibration") | ||
| permanent_count = sum( | ||
| 1 for _, (_, repl) in replaced.items() if repl.is_permanent | ||
| ) | ||
| if permanent_count > 0: | ||
| logger.info( | ||
| f"{permanent_count}/{len(replaced)} modules will remain in " | ||
| "calibration form (permanent)" | ||
| ) | ||
| if permanent_count < len(replaced): | ||
| logger.info( | ||
| f"{len(replaced) - permanent_count}/{len(replaced)} modules will " | ||
| "be restored after calibration" | ||
| ) | ||
|
|
||
| try: | ||
| yield | ||
| finally: | ||
| # Step 2: Restore non-permanent modules | ||
| # Step 3: Restore non-permanent modules | ||
| for name, (original, replacement) in replaced.items(): | ||
| if not replacement.is_permanent: | ||
| restored = replacement.restore(original) | ||
| model.set_submodule(name, restored) | ||
|
|
||
| # Synchronization barrier: ensure all ranks complete restoration | ||
| if is_distributed(): | ||
| dist.barrier() | ||
| logger.debug(f"Rank {dist.get_rank()}: Completed MoE module restoration") | ||
|
|
||
|
|
||
| def _is_registered(name: str, subclass: RegistryMixin): | ||
| return standardize_lookup_name(name) in subclass.registered_names() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| """Unit tests for MoE calibration context in single-rank mode.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config | ||
| from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( | ||
| DeepseekV3MoE as OriginalDeepseekV3MoE, | ||
| ) | ||
|
|
||
| from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE | ||
| from llmcompressor.modeling.moe_context import moe_calibration_context | ||
| from llmcompressor.utils.dev import skip_weights_download | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"]) | ||
| def test_moe_context_replacement(model_stub): | ||
| """Verify that MoE modules are correctly replaced and restored.""" | ||
| with skip_weights_download(): | ||
| model = AutoModelForCausalLM.from_pretrained(model_stub) | ||
|
|
||
| original_count = sum( | ||
| 1 for _, m in model.named_modules() if isinstance(m, OriginalDeepseekV3MoE) | ||
| ) | ||
| assert original_count > 0, "Model should have MoE modules" | ||
|
|
||
| with moe_calibration_context(model, calibrate_all_experts=True): | ||
| # Verify replacement | ||
| calibration_count = sum( | ||
| 1 | ||
| for _, m in model.named_modules() | ||
| if isinstance(m, CalibrationDeepseekV3MoE) | ||
| ) | ||
| assert calibration_count == original_count | ||
|
|
||
| # Verify permanent modules remain | ||
| final_count = sum( | ||
| 1 for _, m in model.named_modules() if isinstance(m, CalibrationDeepseekV3MoE) | ||
| ) | ||
| assert final_count == original_count | ||
|
|
||
|
|
||
| def test_moe_context_calibrate_flag(): | ||
| """Verify calibrate_all_experts flag is passed correctly.""" | ||
| config = DeepseekV3Config() | ||
| with torch.device("cpu"): | ||
| original = OriginalDeepseekV3MoE(config).eval() | ||
|
|
||
| class TestModel(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.moe = original | ||
| self.config = config | ||
|
|
||
| for flag_value in [True, False]: | ||
| model = TestModel() | ||
| with moe_calibration_context(model, calibrate_all_experts=flag_value): | ||
| for _, m in model.named_modules(): | ||
| if isinstance(m, CalibrationDeepseekV3MoE): | ||
| assert m.calibrate_all_experts is flag_value |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| """Integration tests for MoE calibration context in DDP mode. | ||
|
|
||
| Run with: torchrun --nproc_per_node=2 -m pytest | ||
| tests/llmcompressor/modeling/test_moe_context_ddp.py -v | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
| from compressed_tensors.offload import load_offloaded_model | ||
| from torch import distributed as dist | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE | ||
| from llmcompressor.modeling.moe_context import moe_calibration_context | ||
| from llmcompressor.utils.dev import skip_weights_download | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def ddp_environment(): | ||
| """Initialize DDP environment once for all tests.""" | ||
| if not dist.is_initialized(): | ||
| pytest.skip("DDP not initialized - run with torchrun") | ||
| yield | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"]) | ||
| def test_moe_context_ddp(ddp_environment, model_stub): | ||
| """Verify MoE replacement works correctly in DDP mode.""" | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
|
|
||
| with load_offloaded_model(): | ||
| with skip_weights_download(): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_stub, device_map="auto_offload" | ||
| ) | ||
|
|
||
| with moe_calibration_context(model, calibrate_all_experts=True): | ||
| # Count replaced modules | ||
| replaced_count = sum( | ||
| 1 | ||
| for _, m in model.named_modules() | ||
| if isinstance(m, CalibrationDeepseekV3MoE) | ||
| ) | ||
| assert replaced_count > 0, f"Rank {rank}: No modules replaced" | ||
|
|
||
| # Verify consistency across ranks | ||
| count_tensor = torch.tensor([replaced_count], dtype=torch.long, device=next(model.parameters()).device) | ||
| all_counts = [torch.zeros_like(count_tensor) for _ in range(world_size)] | ||
| dist.all_gather(all_counts, count_tensor) | ||
| assert all( | ||
| c.item() == replaced_count for c in all_counts | ||
| ), f"Rank {rank}: Inconsistent counts {[c.item() for c in all_counts]}" | ||
|
|
||
| # Verify permanent modules remain (DeepSeekV3 is permanent) | ||
| final_count = sum( | ||
| 1 for _, m in model.named_modules() if isinstance(m, CalibrationDeepseekV3MoE) | ||
| ) | ||
| assert final_count == replaced_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like you actually assign modules to ranks. It seems like right now, all ranks are still doing duplicate work.