-
Notifications
You must be signed in to change notification settings - Fork 439
Expand file tree
/
Copy pathmoe_context.py
More file actions
179 lines (148 loc) · 6.56 KB
/
moe_context.py
File metadata and controls
179 lines (148 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
Simplified interface for MoE model calibration.
MoE (Mixture of Experts) models route tokens to different expert networks.
During calibration for quantization/compression, we need to ensure ALL experts
see data, not just the ones selected by the router. This module provides the
infrastructure to temporarily modify MoE modules for proper calibration.
Key components:
- MoECalibrationModule: Abstract base class for calibration modules
- moe_calibration_context: Context manager that applies calibration to a model
"""
import contextlib
from abc import ABC
import torch
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
from loguru import logger
from torch import distributed as dist
from tqdm import tqdm
from transformers import PreTrainedModel
__all__ = [
"MoECalibrationModule",
"moe_calibration_context",
]
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
"""
Abstract base class for MoE calibration modules.
Calibration modules replace original MoE modules during the calibration
phase to ensure all experts receive data for proper quantization statistics.
Subclasses must:
1. Implement `__init__()` with signature:
(self, original, config, calibrate_all_experts=True)
2. Set `is_permanent` to indicate if module should stay in calibration form
3. Optionally implement `restore()` if is_permanent=False
"""
is_permanent: bool = False
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
"""
Restore the original module structure.
Only needed if is_permanent=False. For permanent modules, this is a no-op.
Returns:
The original module (or self if permanent)
"""
if self.is_permanent:
return self
raise NotImplementedError(
f"{self.__class__.__name__} has is_permanent=False but doesn't "
"implement restore()"
)
@contextlib.contextmanager
def moe_calibration_context(
model: PreTrainedModel,
calibrate_all_experts: bool = True,
):
"""
Context manager that applies MoE calibration to a model.
This scans all modules in the model and replaces any MoE modules with their
calibration equivalents. After the context exits, non-permanent modules are
restored to their original form.
The model is modified in-place, so the same model object should be used
within the context.
Args:
model: The model to apply MoE calibration to (modified in-place)
calibrate_all_experts: If True, all experts see all tokens during calibration.
If False, use normal routing (useful for some techniques)
Example:
with moe_calibration_context(model):
# Run calibration - all experts will see data
for batch in dataloader:
model(**batch)
# Model is now restored (unless permanent)
"""
replaced = {}
# Step 1: Collect all MoE modules that need replacement
logger.debug("Entering MoE calibration context")
modules_to_replace = []
for name, module in model.named_modules():
class_name = module.__class__.__name__
if _is_registered(class_name, MoECalibrationModule):
modules_to_replace.append((name, module, class_name))
# Step 1.5: Verify all ranks have same number of modules (distributed mode)
if is_distributed():
rank = dist.get_rank()
world_size = dist.get_world_size()
# Verify all ranks have same number of modules
num_modules = torch.tensor([len(modules_to_replace)], dtype=torch.long)
all_counts = [torch.zeros_like(num_modules) for _ in range(world_size)]
dist.all_gather(all_counts, num_modules)
if not all(count.item() == num_modules.item() for count in all_counts):
raise RuntimeError(
f"Rank {rank} found {num_modules.item()} MoE modules, but other "
f"ranks found different counts: {[c.item() for c in all_counts]}. "
"All ranks must have identical model structures."
)
# Step 2: Replace modules with progress bar
if modules_to_replace:
# Only rank 0 shows progress bar and logs
show_progress = not is_distributed() or dist.get_rank() == 0
if show_progress:
logger.info(f"Found {len(modules_to_replace)} MoE modules to replace")
for name, module, class_name in tqdm(
modules_to_replace,
desc="Replacing MoE modules for calibration",
disable=not show_progress,
):
replacement = MoECalibrationModule.load_from_registry(
class_name,
original=module,
config=model.config,
calibrate_all_experts=calibrate_all_experts,
)
model.set_submodule(name, replacement)
replaced[name] = (module, replacement)
# Synchronization barrier: all ranks complete replacement before calib
if is_distributed():
dist.barrier()
logger.debug(f"Rank {dist.get_rank()}: Completed MoE module replacement")
# Log what was replaced (only rank 0 in distributed mode)
if replaced:
show_logs = not is_distributed() or dist.get_rank() == 0
if show_logs:
logger.info(f"Replaced {len(replaced)} MoE modules for calibration")
permanent_count = sum(
1 for _, (_, repl) in replaced.items() if repl.is_permanent
)
if permanent_count > 0:
logger.info(
f"{permanent_count}/{len(replaced)} modules will remain in "
"calibration form (permanent)"
)
if permanent_count < len(replaced):
logger.info(
f"{len(replaced) - permanent_count}/{len(replaced)} modules will "
"be restored after calibration"
)
try:
yield
finally:
# Step 3: Restore non-permanent modules
for name, (original, replacement) in replaced.items():
if not replacement.is_permanent:
restored = replacement.restore(original)
model.set_submodule(name, restored)
# Synchronization barrier: ensure all ranks complete restoration
if is_distributed():
dist.barrier()
logger.debug(f"Rank {dist.get_rank()}: Completed MoE module restoration")
def _is_registered(name: str, subclass: RegistryMixin):
return standardize_lookup_name(name) in subclass.registered_names()