forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrouter.py
More file actions
707 lines (622 loc) · 28.6 KB
/
router.py
File metadata and controls
707 lines (622 loc) · 28.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from abc import ABC, abstractmethod
from typing import Optional, Union
import torch
from megatron.core.jit import jit_fuser
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
ProcessGroupCollection,
apply_biased_logits,
apply_random_logits,
apply_router_token_dropping,
compute_routing_scores_for_aux_loss,
get_tokens_per_expert_and_token_count,
router_gating_linear,
save_to_aux_losses_tracker,
sinkhorn,
switch_load_balancing_loss_func,
topk_routing_with_score_function,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig
try:
from miles_megatron_plugins.true_on_policy.contracts import (
resolve_true_on_policy_runtime_policy,
)
_HAS_TRUE_ON_POLICY = True
except ImportError:
_HAS_TRUE_ON_POLICY = False
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(
self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None
) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations.
"""
super().__init__(config)
self.config = config
self.num_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
self.layer_number = None
self.is_mtp = False
self.tp_group = pg_collection.tp
self.cp_group = pg_collection.cp
self.tp_cp_group = pg_collection.tp_cp
self.tp_dp_cp_group = pg_collection.tp_dp_cp
# Initialize the gate weights.
# TODO: Add support for GPU initialization, which requires updating the golden values.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
)
if self.config.add_bias_linear:
self.bias = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts), dtype=torch.float32)
)
else:
self.bias = None
# If calculate per token loss, we need to scale up moe aux loss by the number of tokens.
# So we need to know if the model is configured to calculate per token loss.
self.calculate_per_token_loss = self.config.calculate_per_token_loss
self.reset_parameters()
def reset_parameters(self):
"""Reset the router parameters."""
if self.config.perform_initialization:
self.config.init_method(self.weight)
if self.bias is not None:
self.config.init_method(self.bias)
self.weight.data = self.weight.data.to(dtype=self.config.params_dtype)
setattr(self.weight, 'sequence_parallel', self.config.sequence_parallel)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype=self.config.params_dtype)
setattr(self.bias, 'sequence_parallel', self.config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
if self.weight.device.type == 'cpu':
# move weights to GPU
self.weight.data = self.weight.data.to(device=torch.cuda.current_device())
if self.bias is not None and self.bias.device.type == 'cpu':
self.bias.data = self.bias.data.to(device=torch.cuda.current_device())
# When the true-on-policy MoE contract is active, cast the post-norm
# activation back to parameter dtype for the router projection so the
# softmax/top-k numerics match SGLang's routed expert path.
if (
_HAS_TRUE_ON_POLICY
and resolve_true_on_policy_runtime_policy(self.config).deterministic_moe_routing
):
router_dtype = self.config.params_dtype
else:
router_dtype = input.dtype
if self.config.moe_router_dtype == 'fp32':
router_dtype = torch.float32
elif self.config.moe_router_dtype == 'fp64':
router_dtype = torch.float64
logits = router_gating_linear(input, self.weight, self.bias, router_dtype)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
probabilities and mapping.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
def set_layer_number(self, layer_number: int):
"""Set the layer number for the router."""
self.layer_number = layer_number
class TopKRouter(Router):
"""Route each token to the top-k experts.
The workflow of TopKRouter is as follows:
(1) Calculate the logits by the router gating network.
(2) Calculate the routing probabilities and map for top-k selection with score function.
(3) [Optional] Apply token dropping to top-k expert selection.
(4) [Optional] Apply the auxiliary load balancing loss for the given scores and routing map.
Naming convention:
logits: The output logits by the router gating network.
scores: The scores after score function used to select the experts and calculate aux loss.
probs: The topk weights used to combined the experts' outputs.
routing_map: The masked routing map between tokens and experts.
"""
def __init__(
self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None
) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations.
"""
super().__init__(config=config, pg_collection=pg_collection)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.score_function = self.config.moe_router_score_function
self.input_jitter = None
self.enable_expert_bias = self.config.moe_router_enable_expert_bias
if self.enable_expert_bias:
self.register_buffer(
'local_tokens_per_expert',
torch.zeros(
self.config.num_moe_experts,
dtype=torch.float32,
device=torch.cuda.current_device(),
),
persistent=False,
)
self.register_buffer(
'expert_bias',
torch.zeros(
self.config.num_moe_experts,
dtype=torch.float32,
device=torch.cuda.current_device(),
),
)
else:
self.local_tokens_per_expert = None
self.expert_bias = None
# Initialize global tokens per expert for global aux loss
if self.get_aux_loss_coeff("global_aux_loss") > 0:
self.register_buffer(
'global_tokens_per_expert',
torch.zeros(
self.config.num_moe_experts,
dtype=torch.float32,
device=torch.cuda.current_device(),
),
persistent=False,
)
self.register_buffer(
'ga_steps',
torch.tensor(0, dtype=torch.float32, device=torch.cuda.current_device()),
persistent=False,
)
else:
self.global_tokens_per_expert = None
self.ga_steps = None
try:
from miles.utils.replay_base import routing_replay_manager
routing_replay_manager.register_to_module(self, "routing_replay")
except ImportError:
pass
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
We keep it in float32 to avoid routing errors when updating the expert_bias.
"""
if hasattr(self, 'expert_bias') and self.expert_bias is not None:
if self.expert_bias.dtype != torch.float32:
self.expert_bias.data = self.expert_bias.data.to(torch.float32)
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
probabilities and mask.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
else:
logits = _sinkhorn_activation(logits)
_, indices = torch.topk(logits, k=self.topk, dim=1)
map = torch.zeros_like(logits).int().scatter(1, indices, 1).bool()
scores = logits * map
return scores, map
def get_aux_loss_coeff(self, aux_loss_type: str) -> float:
"""Return the aux loss coeff for the given auxiliary loss type.
If the auxiliary loss type is not found, return 0.0.
"""
if isinstance(self.routing_type, str):
if self.routing_type == aux_loss_type:
return self.config.moe_aux_loss_coeff
if isinstance(self.routing_type, list):
try:
idx = self.routing_type.index(aux_loss_type)
return self.config.moe_aux_loss_coeff[idx]
except ValueError:
return 0.0
return 0.0
def is_aux_loss_enabled(self) -> bool:
"""Check if the auxiliary loss is enabled."""
for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]:
if self.get_aux_loss_coeff(aux_loss_type) > 0:
return True
return False
def _apply_aux_loss(
self,
probs: torch.Tensor,
scores_for_aux_loss: torch.Tensor,
routing_map: torch.Tensor,
with_padding_mask: bool = False,
):
"""Apply the auxiliary loss for the given scores and routing map."""
aux_loss_coeff = self.get_aux_loss_coeff("aux_loss")
if aux_loss_coeff == 0:
return probs
global_tokens_per_expert, local_num_tokens, total_num_tokens = (
get_tokens_per_expert_and_token_count(
routing_map=routing_map,
reduce_group=self.tp_cp_group,
topk=self.topk,
with_padding_mask=with_padding_mask,
)
)
aux_loss = switch_load_balancing_loss_func(
probs=scores_for_aux_loss,
tokens_per_expert=global_tokens_per_expert,
total_num_tokens=total_num_tokens,
topk=self.topk,
num_experts=self.config.num_moe_experts,
moe_aux_loss_coeff=aux_loss_coeff,
fused=self.config.moe_router_fusion,
)
probs = self.attach_and_log_load_balancing_loss(
probs,
aux_loss_coeff,
aux_loss,
"load_balancing_loss",
self.tp_cp_group,
valid_token_count=local_num_tokens,
)
return probs
def _apply_seq_aux_loss(
self,
probs: torch.Tensor,
scores_for_aux_loss: torch.Tensor,
routing_map: torch.Tensor,
seq_length: int,
bsz: int,
with_padding_mask: bool = False,
):
"""Apply the sequence-level auxiliary loss for the given scores and routing map.
To calculate the sequence-level aux loss, we reshape the batch_size dimension to
experts dimension. The resulted loss by switch_load_balancing_loss_func is equal
to the sum of aux loss for each sequence in the batch. And then we divide the aux
loss by the batch size to get averaged aux loss.
"""
seq_aux_loss_coeff = self.get_aux_loss_coeff("seq_aux_loss")
if seq_aux_loss_coeff == 0:
return probs
scores_for_aux_loss = scores_for_aux_loss.reshape(seq_length, -1)
routing_map = routing_map.reshape(seq_length, -1)
global_tokens_per_expert, local_num_tokens, total_num_tokens = (
get_tokens_per_expert_and_token_count(
routing_map=routing_map,
reduce_group=self.tp_cp_group,
with_padding_mask=with_padding_mask,
topk=self.topk * bsz,
)
)
aux_loss = (
switch_load_balancing_loss_func(
probs=scores_for_aux_loss,
tokens_per_expert=global_tokens_per_expert,
total_num_tokens=total_num_tokens,
topk=self.topk,
num_experts=self.config.num_moe_experts,
moe_aux_loss_coeff=seq_aux_loss_coeff,
fused=self.config.moe_router_fusion,
)
/ bsz
)
probs = self.attach_and_log_load_balancing_loss(
probs,
seq_aux_loss_coeff,
aux_loss,
"seq_load_balancing_loss",
self.tp_cp_group,
valid_token_count=local_num_tokens,
)
return probs
def _apply_global_aux_loss(
self,
probs: torch.Tensor,
scores_for_aux_loss: torch.Tensor,
routing_map: torch.Tensor,
with_padding_mask: bool = False,
):
"""Apply the global auxiliary loss for the given scores and routing map."""
global_aux_loss_coeff = self.get_aux_loss_coeff("global_aux_loss")
if global_aux_loss_coeff == 0:
return probs
# Use unified function to compute tokens_per_expert and num_tokens
global_tokens_per_expert, local_num_tokens, total_num_tokens = (
get_tokens_per_expert_and_token_count(
routing_map=routing_map,
reduce_group=self.tp_dp_cp_group,
with_padding_mask=with_padding_mask,
topk=self.topk,
)
)
self.global_tokens_per_expert += global_tokens_per_expert
self.ga_steps += 1
averated_tokens_per_expert = self.global_tokens_per_expert / self.ga_steps
global_aux_loss = switch_load_balancing_loss_func(
probs=scores_for_aux_loss,
tokens_per_expert=averated_tokens_per_expert,
total_num_tokens=total_num_tokens,
topk=self.topk,
num_experts=self.config.num_moe_experts,
moe_aux_loss_coeff=global_aux_loss_coeff,
fused=self.config.moe_router_fusion,
)
probs = self.attach_and_log_load_balancing_loss(
probs,
global_aux_loss_coeff,
global_aux_loss,
"global_load_balancing_loss",
self.tp_dp_cp_group,
reduce_group_has_dp=True,
valid_token_count=local_num_tokens,
)
return probs
def attach_and_log_load_balancing_loss(
self,
activation: torch.Tensor,
aux_loss_coeff: float,
aux_loss: torch.Tensor,
aux_loss_name: str,
reduce_group: torch.distributed.ProcessGroup,
reduce_group_has_dp: bool = False,
valid_token_count: Optional[Union[int, torch.Tensor]] = None,
):
"""Attach aux loss function to activation and add to logging.
Args:
activation (torch.Tensor): Activation tensor to attach the aux loss to.
aux_loss_coeff (float): Coefficient for the aux loss.
aux_loss (torch.Tensor): Computed aux loss.
aux_loss_name (str): Name of the aux loss for logging.
reduce_group (torch.distributed.ProcessGroup): Process group for reduction.
reduce_group_has_dp (bool): Whether the reduce group has data parallel ranks.
Set this to True if the reduce group has data parallel ranks. This flag is used to
ensure the correct reduction in aux loss tracking.
valid_token_count (int or torch.Tensor, optional): Number of valid tokens excluding
padding tokens. Can be a Python int or a torch.Tensor (typically 0-d tensor).
If None, uses activation.shape[0]. Defaults to None.
"""
# TODO (zijiey): fix the per_layer_logging for MTP, currently it will incorrectly
# add the aux loss logging value to other layer's since it is difficult to get the
# correct layer_number for MTP. It does not affect the correctness of the calculation
# results and the reduced load_balancing_loss logging value.
num_layers = self.config.num_layers
if self.config.mtp_num_layers is not None:
num_layers += self.config.mtp_num_layers
save_to_aux_losses_tracker(
aux_loss_name,
aux_loss / aux_loss_coeff,
self.layer_number,
num_layers,
reduce_group=reduce_group,
reduce_group_has_dp=reduce_group_has_dp,
)
if self.calculate_per_token_loss:
# Scale the aux_loss by the number of tokens.
# The expected final scaling for aux_loss gradients is 1/(num_micro_batches * dp_size).
# After commit 02648000, Megatron started using the number of total tokens to scale
# gradients under the argument of calculate_per_token_loss,
# which scales both the main_loss gradient and aux_loss gradient by
# 1/(num_local_tokens * dp_size * num_micro_batches) in finalize_model_grads function.
# To correct this scaling, we need to scale the aux_loss by num_local_tokens here.
# Use valid_token_count (excluding padding) if provided, otherwise use total tokens.
num_tokens = valid_token_count if valid_token_count is not None else activation.shape[0]
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss * num_tokens)
else:
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits, padding_mask: Optional[torch.Tensor] = None):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape in [num_tokens]. True for valid tokens,
False for padding tokens. Defaults to None.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None and self.training and torch.is_grad_enabled():
# Skip Z loss calculations when using torch.no_grad() or checkpointing.
moe_z_loss_coeff = self.config.moe_z_loss_coeff / self.tp_cp_group.size()
z_loss = z_loss_func(logits, moe_z_loss_coeff, padding_mask=padding_mask)
if self.calculate_per_token_loss:
# The expected final scaling for z_loss gradients is
# 1/(num_micro_batches * dp_size).
# After commit 02648000, Megatron started using the number of total tokens
# to scale gradients under the argument of calculate_per_token_loss,
# which scales both the main_loss gradient and z_loss gradient by
# 1/(num_local_tokens * dp_size * num_micro_batches) in finalize_model_grads().
# To correct this scaling, we need to scale the z_loss by num_local_tokens here.
# Count valid tokens: sum of inverted mask (False -> True = valid)
num_tokens = (~padding_mask).sum() if padding_mask is not None else logits.shape[0]
logits = MoEAuxLossAutoScaler.apply(logits, z_loss * num_tokens)
else:
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
num_layers = self.config.num_layers
if self.config.mtp_num_layers is not None:
num_layers += self.config.mtp_num_layers
save_to_aux_losses_tracker(
"z_loss", z_loss / moe_z_loss_coeff, self.layer_number, num_layers
)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, dtype=input.dtype, device=input.device),
torch.tensor(1.0 + eps, dtype=input.dtype, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
@jit_fuser
def _apply_expert_bias(
self, routing_map: torch.Tensor, padding_mask: Optional[torch.Tensor] = None
):
"""
Update expert bias and tokens_per_expert
Prevent extra local tokens accumulation on evaluation or activation recomputation
"""
if self.enable_expert_bias and torch.is_grad_enabled():
with torch.no_grad():
if padding_mask is not None:
routing_map = routing_map & (~padding_mask)
self.local_tokens_per_expert += routing_map.sum(dim=0)
def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape [seq_length, bsz]. True for valid tokens,
False for padding tokens. Defaults to None.
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment,
with shape [num_tokens, num_experts].
"""
seq_length, bsz = logits.shape[:2]
logits = logits.view(-1, self.config.num_moe_experts)
# Flatten padding_mask to [num_tokens] if provided
if padding_mask is not None:
padding_mask = padding_mask.reshape(-1)
# Apply Z-Loss
logits = self.apply_z_loss(logits, padding_mask=padding_mask)
if _HAS_TRUE_ON_POLICY:
_top = resolve_true_on_policy_runtime_policy(self.config)
_topk_tiebreak = _top.moe_topk_tiebreak if _top.deterministic_moe_routing else None
else:
_topk_tiebreak = None
# Calculate probs and routing_map for token dispatching
if self.routing_type == "sinkhorn":
probs, routing_map = self.sinkhorn_load_balancing(logits)
else:
probs, routing_map = topk_routing_with_score_function(
logits,
self.topk,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
score_function=self.score_function,
expert_bias=self.expert_bias,
fused=self.config.moe_router_fusion,
is_mtp=self.is_mtp,
topk_tiebreak=_topk_tiebreak,
)
# Apply token dropping to probs and routing_map.
if self.config.moe_expert_capacity_factor is not None:
probs, routing_map = apply_router_token_dropping(
probs,
routing_map,
router_topk=self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
drop_policy=self.config.moe_token_drop_policy,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
)
# Apply each aux loss type and attach aux loss autograd function to probs
if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled():
# Calculate scores and routing_map for aux loss
routing_map_for_aux_loss, scores_for_aux_loss = compute_routing_scores_for_aux_loss(
logits,
self.topk,
self.score_function,
fused=self.config.moe_router_fusion,
padding_mask=padding_mask,
topk_tiebreak=_topk_tiebreak,
)
probs = self._apply_aux_loss(
probs,
scores_for_aux_loss,
routing_map_for_aux_loss,
with_padding_mask=padding_mask is not None,
)
probs = self._apply_seq_aux_loss(
probs,
scores_for_aux_loss,
routing_map_for_aux_loss,
seq_length,
bsz,
with_padding_mask=padding_mask is not None,
)
probs = self._apply_global_aux_loss(
probs,
scores_for_aux_loss,
routing_map_for_aux_loss,
with_padding_mask=padding_mask is not None,
)
# Optionally apply expert bias
self._apply_expert_bias(routing_map, padding_mask=padding_mask)
return probs, routing_map
def reset_global_aux_loss_tracker(self):
"""Reset the global aux loss tracker."""
if self.global_tokens_per_expert is not None:
self.global_tokens_per_expert.zero_()
self.ga_steps.zero_()
def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape [seq_length, bsz]. True for valid tokens,
False for padding tokens. Defaults to None.
"""
self._maintain_float32_expert_bias()
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
if self.config.moe_router_force_load_balancing:
# Apply force load balancing with random logits for benchmark
logits = apply_random_logits(logits)
if self.config.moe_router_force_biased is not None:
# Apply biased logits with shared random bias across all ranks
logits = apply_biased_logits(
logits, self.config.moe_router_force_biased, self.layer_number
)
probs, routing_map = self.routing(logits, padding_mask=padding_mask)
return probs, routing_map
def _load_from_state_dict(self, *args, **kwargs):
"""Load the state dict of the router."""
self._maintain_float32_expert_bias() # switch to float32 before loading
return super()._load_from_state_dict(*args, **kwargs)
def _save_to_state_dict(self, *args, **kwargs):
"""Save the state dict of the router."""
self._maintain_float32_expert_bias() # switch to float32 before saving
return super()._save_to_state_dict(*args, **kwargs)