Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions src/llmcompressor/modeling/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from abc import ABC

import torch
import torch.distributed as dist
from compressed_tensors.offload import is_distributed
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
from loguru import logger
from torch import distributed as dist
from tqdm import tqdm
from transformers import PreTrainedModel

Expand Down Expand Up @@ -99,11 +99,33 @@ def moe_calibration_context(
if _is_registered(class_name, MoECalibrationModule):
modules_to_replace.append((name, module, class_name))

# Step 1.5: Verify all ranks have same number of modules (distributed mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like you actually assign modules to ranks. It seems like right now, all ranks are still doing duplicate work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't think this check is necessary

if is_distributed():
rank = dist.get_rank()
world_size = dist.get_world_size()

# Verify all ranks have same number of modules
num_modules = torch.tensor([len(modules_to_replace)], dtype=torch.long, device=next(model.parameters()).device)
all_counts = [torch.zeros_like(num_modules) for _ in range(world_size)]
dist.all_gather(all_counts, num_modules)

if not all(count.item() == num_modules.item() for count in all_counts):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes, the number of modules will be not evenly divisible by the number of ranks, so this check can be harmful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just checking that each rank has the same number of modules total, not assigned I think

raise RuntimeError(
f"Rank {rank} found {num_modules.item()} MoE modules, but other "
f"ranks found different counts: {[c.item() for c in all_counts]}. "
"All ranks must have identical model structures."
)

# Step 2: Replace modules with progress bar
if modules_to_replace:
logger.info(f"Found {len(modules_to_replace)} MoE modules to replace")
# Only rank 0 shows progress bar and logs
show_progress = not is_distributed() or dist.get_rank() == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use is_rank0, also probably simpler to just inline this into the if statement

if show_progress:
logger.info(f"Found {len(modules_to_replace)} MoE modules to replace")
for name, module, class_name in tqdm(
modules_to_replace, desc="Replacing MoE modules for calibration"
modules_to_replace,
desc="Replacing MoE modules for calibration",
disable=not show_progress,
):
replacement = MoECalibrationModule.load_from_registry(
class_name,
Expand All @@ -113,35 +135,45 @@ def moe_calibration_context(
)
model.set_submodule(name, replacement)
replaced[name] = (module, replacement)
if is_distributed():
dist.barrier()

# Log what was replaced
# Synchronization barrier: all ranks complete replacement before calib
if is_distributed():
dist.barrier()
logger.debug(f"Rank {dist.get_rank()}: Completed MoE module replacement")

# Log what was replaced (only rank 0 in distributed mode)
if replaced:
logger.info(f"Replaced {len(replaced)} MoE modules for calibration")
permanent_count = sum(
1 for _, (_, repl) in replaced.items() if repl.is_permanent
)
if permanent_count > 0:
logger.info(
f"{permanent_count}/{len(replaced)} modules will remain in "
"calibration form (permanent)"
)
if permanent_count < len(replaced):
logger.info(
f"{len(replaced) - permanent_count}/{len(replaced)} modules will "
"be restored after calibration"
show_logs = not is_distributed() or dist.get_rank() == 0
if show_logs:
logger.info(f"Replaced {len(replaced)} MoE modules for calibration")
permanent_count = sum(
1 for _, (_, repl) in replaced.items() if repl.is_permanent
)
if permanent_count > 0:
logger.info(
f"{permanent_count}/{len(replaced)} modules will remain in "
"calibration form (permanent)"
)
if permanent_count < len(replaced):
logger.info(
f"{len(replaced) - permanent_count}/{len(replaced)} modules will "
"be restored after calibration"
)

try:
yield
finally:
# Step 2: Restore non-permanent modules
# Step 3: Restore non-permanent modules
for name, (original, replacement) in replaced.items():
if not replacement.is_permanent:
restored = replacement.restore(original)
model.set_submodule(name, restored)

# Synchronization barrier: ensure all ranks complete restoration
if is_distributed():
dist.barrier()
logger.debug(f"Rank {dist.get_rank()}: Completed MoE module restoration")


def _is_registered(name: str, subclass: RegistryMixin):
return standardize_lookup_name(name) in subclass.registered_names()
60 changes: 60 additions & 0 deletions tests/llmcompressor/modeling/test_moe_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Unit tests for MoE calibration context in single-rank mode."""

import pytest
import torch
from transformers import AutoModelForCausalLM
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3MoE as OriginalDeepseekV3MoE,
)

from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.utils.dev import skip_weights_download


@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"])
def test_moe_context_replacement(model_stub):
"""Verify that MoE modules are correctly replaced and restored."""
with skip_weights_download():
model = AutoModelForCausalLM.from_pretrained(model_stub)

original_count = sum(
1 for _, m in model.named_modules() if isinstance(m, OriginalDeepseekV3MoE)
)
assert original_count > 0, "Model should have MoE modules"

with moe_calibration_context(model, calibrate_all_experts=True):
# Verify replacement
calibration_count = sum(
1
for _, m in model.named_modules()
if isinstance(m, CalibrationDeepseekV3MoE)
)
assert calibration_count == original_count

# Verify permanent modules remain
final_count = sum(
1 for _, m in model.named_modules() if isinstance(m, CalibrationDeepseekV3MoE)
)
assert final_count == original_count


def test_moe_context_calibrate_flag():
"""Verify calibrate_all_experts flag is passed correctly."""
config = DeepseekV3Config()
with torch.device("cpu"):
original = OriginalDeepseekV3MoE(config).eval()

class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.moe = original
self.config = config

for flag_value in [True, False]:
model = TestModel()
with moe_calibration_context(model, calibrate_all_experts=flag_value):
for _, m in model.named_modules():
if isinstance(m, CalibrationDeepseekV3MoE):
assert m.calibrate_all_experts is flag_value
59 changes: 59 additions & 0 deletions tests/llmcompressor/modeling/test_moe_context_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Integration tests for MoE calibration context in DDP mode.

Run with: torchrun --nproc_per_node=2 -m pytest
tests/llmcompressor/modeling/test_moe_context_ddp.py -v
"""

import pytest
import torch
from compressed_tensors.offload import load_offloaded_model
from torch import distributed as dist
from transformers import AutoModelForCausalLM

from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.utils.dev import skip_weights_download


@pytest.fixture(scope="module")
def ddp_environment():
"""Initialize DDP environment once for all tests."""
if not dist.is_initialized():
pytest.skip("DDP not initialized - run with torchrun")
yield


@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"])
def test_moe_context_ddp(ddp_environment, model_stub):
"""Verify MoE replacement works correctly in DDP mode."""
rank = dist.get_rank()
world_size = dist.get_world_size()

with load_offloaded_model():
with skip_weights_download():
model = AutoModelForCausalLM.from_pretrained(
model_stub, device_map="auto_offload"
)

with moe_calibration_context(model, calibrate_all_experts=True):
# Count replaced modules
replaced_count = sum(
1
for _, m in model.named_modules()
if isinstance(m, CalibrationDeepseekV3MoE)
)
assert replaced_count > 0, f"Rank {rank}: No modules replaced"

# Verify consistency across ranks
count_tensor = torch.tensor([replaced_count], dtype=torch.long, device=next(model.parameters()).device)
all_counts = [torch.zeros_like(count_tensor) for _ in range(world_size)]
dist.all_gather(all_counts, count_tensor)
assert all(
c.item() == replaced_count for c in all_counts
), f"Rank {rank}: Inconsistent counts {[c.item() for c in all_counts]}"

# Verify permanent modules remain (DeepSeekV3 is permanent)
final_count = sum(
1 for _, m in model.named_modules() if isinstance(m, CalibrationDeepseekV3MoE)
)
assert final_count == replaced_count
Loading