forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexperts.py
More file actions
1002 lines (901 loc) · 43.5 KB
/
experts.py
File metadata and controls
1002 lines (901 loc) · 43.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import copy
import logging
from copy import deepcopy
from functools import partial
from math import ceil
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.core import tensor_parallel
from megatron.core.activations import squared_relu
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
LocalNonpersistentObject,
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
)
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.fusions.fused_bias_geglu import quick_gelu, weighted_bias_quick_geglu_impl
from megatron.core.fusions.fused_bias_swiglu import weighted_bias_swiglu_impl
from megatron.core.fusions.fused_weighted_squared_relu import weighted_squared_relu_impl
from megatron.core.jit import jit_fuser
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.utils import divide
from miles_megatron_plugins.true_on_policy.contracts import resolve_true_on_policy_runtime_policy
from megatron.core.transformer.mlp import MLP, MLPSubmodules, apply_swiglu_sharded_factory
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.moe.moe_utils import (
ProcessGroupCollection,
get_align_size_for_quantization,
)
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
ensure_metadata_has_dp_cp_group,
make_sharded_object_for_checkpoint,
sharded_state_dict_default,
)
try:
import transformer_engine as te # pylint: disable=unused-import
from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
HAVE_TE = True
except ImportError:
HAVE_TE = False
logger = logging.getLogger(__name__)
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using GroupedGEMM.
Executes multiple experts in parallel to maximize computational efficiency.
"""
# TODO(M4): breaking api, switched from pass in tp_group to pass in pg_collection.
def __init__(
self,
num_local_experts: int,
config: TransformerConfig,
pg_collection: Optional[ProcessGroupCollection] = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
assert (
config.moe_latent_size is None
), "MoE latent projection not supported in GroupedMLP yet."
self.expert_parallel = config.expert_model_parallel_size > 1
true_on_policy = resolve_true_on_policy_runtime_policy(config)
self.cast_grouped_gemm_input_to_weight_dtype = true_on_policy.use_sglang_backend
if self.config.gated_linear_unit:
if self.config.activation_func not in (F.silu, F.gelu):
raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
@jit_fuser
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
self.activation_recompute = (
self.config.recompute_granularity == 'selective'
and "moe_act" in self.config.recompute_modules
)
if self.activation_recompute and (self.config.fp8 or self.config.fp4):
raise ValueError(
"moe_act recompute for fp8 or fp4 cannot work with the legacy GroupedMLP."
)
@jit_fuser
def activation_func_with_probs(x, probs):
dtype = x.dtype
res = self.activation_func(x) * probs
return res.to(dtype)
self.activation_func_with_probs = activation_func_with_probs
self.ep_group = pg_collection.ep
# use pg_collection.expt_tp_group as tensor parallel group in this module.
self.tp_group = pg_collection.expt_tp
# use pg_collection.expt_dp_group as data parallel group in this module.
self.dp_group = pg_collection.expt_dp
# How many feature each rank holds for fc1 and fc2, respectively.
tp_size = self.tp_group.size()
tp_rank = self.tp_group.rank()
fc1_output_size = self.config.moe_ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.moe_ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition, self.config.hidden_size, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
rank=tp_rank,
world_size=tp_size,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
rank=tp_rank,
world_size=tp_size,
)
else:
# Ensure TP attrs are set even when not initializing
set_tensor_model_parallel_attributes(
tensor=self.weight1, is_parallel=True, dim=1, stride=1
)
set_tensor_model_parallel_attributes(
tensor=self.weight2, is_parallel=True, dim=0, stride=1
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1, config.init_method, partition_dim=1, is_expert=True
)
_initialize_affine_weight_gpu(
self.weight2, config.output_layer_init_method, partition_dim=0, is_expert=True
)
else:
# Ensure TP attrs are set even when not initializing
set_tensor_model_parallel_attributes(
tensor=self.weight1, is_parallel=True, dim=1, stride=1
)
set_tensor_model_parallel_attributes(
tensor=self.weight2, is_parallel=True, dim=0, stride=1
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def remove_extra_states_check(self, incompatible_keys):
"""
Remove _extra_state from unexpected keys.
These keys are for dist ckpt compatibility with SequentialMLP.
"""
keys = deepcopy(incompatible_keys.unexpected_keys)
for key in keys:
if '_extra_state' in key:
incompatible_keys.unexpected_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def forward(
self,
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
):
"""Forward step of the GroupedMLP."""
assert self.config.bf16, "Currently GroupedMLP for MoE only supports bf16."
if self.activation_recompute:
self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput()
if self.config.moe_apply_probs_on_input:
assert (
self.config.moe_router_topk == 1
), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1."
original_dtype = permuted_local_hidden_states.dtype
permuted_local_hidden_states = (
permuted_probs.unsqueeze(-1) * permuted_local_hidden_states
)
permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype)
# Probs already applied, so reset to 1.
permuted_probs = torch.ones_like(permuted_probs)
if (
self.cast_grouped_gemm_input_to_weight_dtype
and permuted_local_hidden_states.dtype != self.weight1.dtype
):
permuted_local_hidden_states = permuted_local_hidden_states.to(self.weight1.dtype)
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False
)
if self.activation_recompute:
intermediate_parallel = self.activation_checkpoint.checkpoint(
self.activation_func_with_probs, fc1_output, permuted_probs.unsqueeze(-1)
)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
self.activation_checkpoint.discard_output_and_register_recompute(fc2_output)
else:
intermediate_parallel = self.activation_func_with_probs(
fc1_output, permuted_probs.unsqueeze(-1)
)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure params of experts still have gradients even given zero tokens.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
if self.activation_recompute:
h = self.activation_checkpoint.checkpoint(
self.activation_func_with_probs, h, permuted_probs.unsqueeze(-1)
)
fc2_output = torch.matmul(h, w2)
self.activation_checkpoint.discard_output_and_register_recompute(fc2_output)
else:
h = self.activation_func_with_probs(h, permuted_probs.unsqueeze(-1))
fc2_output = torch.matmul(h, w2)
return fc2_output, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
Maps local expert to global experts.
The sharded_state_dict for the weight parts are compatible with the SequentialMLP,
whereas the optimizer states are not due to the limitation from weight transposing.
That is, for finetuning scenario, the checkpoint is compatible with the SequentialMLP.
When `singleton_local_shards` metadata flag is True, experts are broken down into
separate tensors and stored under separate global keys. Additionally, similarly to MLP,
layers with GLU activations are broken down into separate `w` and `v` tensors.
"""
singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
sharded_state_dict = {}
ep_size = self.ep_group.size()
ep_rank = self.ep_group.rank()
tp_size = self.tp_group.size()
tp_rank = self.tp_group.rank()
dp_rank = self.dp_group.rank()
num_global_experts = ep_size * self.num_local_experts
local_expert_indices_offset = ep_rank * self.num_local_experts
prepend_axis_num = len(sharded_offsets)
replica_id = (0, 0, dp_rank)
local_ffn_dim_size = (
self.weight2.numel() // self.num_local_experts // self.config.hidden_size
)
def _break_into_individual_experts(
experts_ten: torch.Tensor,
key: str,
tp_offset: Tuple[int, int, int],
replica_id: ReplicaId,
):
"""Breaks experts into individual tensors and stores them under separate global keys"""
experts_state = []
assert len(experts_ten) == self.num_local_experts, (
experts_ten.shape,
self.num_local_experts,
)
for local_expert_idx, expert_ten in enumerate(experts_ten):
global_expert_idx = local_expert_indices_offset + local_expert_idx
expert_key = key.replace(
f'{prefix}experts.', f'{prefix}experts.{global_expert_idx}.'
)
experts_state.append(
ShardedTensor.from_rank_offsets(
expert_key,
expert_ten.contiguous(),
*sharded_offsets,
tp_offset,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
)
)
return experts_state
@torch.no_grad()
def sh_ten_build_fn(
key: str,
t: torch.Tensor,
replica_id: ReplicaId,
flattened_range: Optional[slice],
tp_axis: int,
with_glu: bool,
):
# TODO: write a generic implementation to cover both cases with and without GLU
if tp_axis == 1:
# weight1
if with_glu:
last_dim_size = local_ffn_dim_size * 2
else:
last_dim_size = local_ffn_dim_size
real_shape = (self.num_local_experts, self.config.hidden_size, last_dim_size)
elif tp_axis == 0:
# weight2
real_shape = (self.num_local_experts, local_ffn_dim_size, self.config.hidden_size)
assert with_glu == False
else:
raise ValueError("tp_axis should be 0 or 1.")
if flattened_range is None:
# weights
t = t.view(real_shape).transpose(-1, -2)
# change tp_axis due to the transposing
tp_axis = 1 - tp_axis
if with_glu:
assert tp_axis == 0, tp_axis
if singleton_local_shards:
w_tensor, v_tensor = torch.chunk(t, 2, -2)
w_key = f'{key}_w'
v_key = f'{key}_v'
sub_states = {
'singleton_local_shards': LocalNonpersistentObject(True),
'data': {
'w': _break_into_individual_experts(
w_tensor,
w_key,
(prepend_axis_num, tp_rank, tp_size),
replica_id,
),
'v': _break_into_individual_experts(
v_tensor,
v_key,
(prepend_axis_num, tp_rank, tp_size),
replica_id,
),
},
}
else:
local_tensors = torch.chunk(t, 2, -2)
sub_states = [
ShardedTensor.from_rank_offsets(
key,
local_tensors[0].contiguous(),
*sharded_offsets,
(prepend_axis_num, ep_rank, ep_size),
(prepend_axis_num + 1, tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
local_tensors[1].contiguous(),
*sharded_offsets,
(prepend_axis_num, ep_rank, ep_size),
(prepend_axis_num + 1, tp_size + tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
else:
if singleton_local_shards:
sub_states = {
'singleton_local_shards': LocalNonpersistentObject(True),
'data': _break_into_individual_experts(
t, key, (prepend_axis_num + tp_axis, tp_rank, tp_size), replica_id
),
}
else:
sub_states = ShardedTensor.from_rank_offsets(
key,
t.contiguous(),
*sharded_offsets,
(prepend_axis_num, ep_rank, ep_size),
(prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
)
return sub_states # pylint: disable=possibly-used-before-assignment
@torch.no_grad()
def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool):
if tp_axis == 1:
# weight1
weight_shape = (self.config.hidden_size, -1)
elif tp_axis == 0:
# weight2
weight_shape = (-1, self.config.hidden_size)
assert with_glu == False
else:
raise ValueError("tp_axis should be 0 or 1.")
if isinstance(sub_state_dict, dict):
assert sub_state_dict['singleton_local_shards']
if with_glu:
assert isinstance(sub_state_dict['data'], dict)
sub_state_dict = torch.cat(
(
torch.stack(sub_state_dict['data']['w']),
torch.stack(sub_state_dict['data']['v']),
),
dim=-2,
)
else:
assert isinstance(sub_state_dict['data'], list)
sub_state_dict = torch.stack(sub_state_dict['data'])
else:
if with_glu:
sub_state_dict = torch.cat(sub_state_dict, -2)
return sub_state_dict.transpose(-1, -2).reshape(weight_shape)
state_dict = self.state_dict(prefix='', keep_vars=True)
for name, tensor in state_dict.items():
if name == 'weight1':
tp_axis = 1
with_glu = self.config.gated_linear_unit
wkey = f'{prefix}experts.linear_fc1.weight'
else:
tp_axis = 0
with_glu = False
wkey = f'{prefix}experts.linear_fc2.weight'
this_replica_id = list(copy.deepcopy(replica_id))
sharded_state_dict[f'{prefix}{name}'] = ShardedTensorFactory(
wkey,
tensor,
partial(sh_ten_build_fn, tp_axis=tp_axis, with_glu=with_glu),
partial(sh_ten_merge_fn, tp_axis=tp_axis, with_glu=with_glu),
tuple(this_replica_id),
)
replica_id = (0, tp_rank, dp_rank)
# Add fake _extra_state to be compatible with SequentialMLP
for expert_local_idx in range(self.num_local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
if singleton_local_shards:
expert_sharded_offsets = sharded_offsets
else:
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
for mod in ['linear_fc1', 'linear_fc2']:
if singleton_local_shards:
expert_key = f'{prefix}experts.{expert_global_idx}.{mod}._extra_state'
else:
expert_key = f'{prefix}experts.{mod}._extra_state'
sharded_state_dict[f'{prefix}expert{expert_global_idx}.{mod}._extra_state'] = (
make_sharded_object_for_checkpoint(
None, expert_key, expert_sharded_offsets, replica_id
)
)
return sharded_state_dict
def backward_dw(self):
"""Performs backward pass for weight gradients in Experts.
Empty implementation for compatibility with SequentialMLP and TEGroupedMLP.
"""
pass
class TEGroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using TE's GroupedLinear.
Executes multiple experts in parallel to maximize computational efficiency.
"""
# TODO(M4): breaking api, switched from pass in tp_group to pass in pg_collection.
def __init__(
self,
num_local_experts,
config: TransformerConfig,
submodules: MLPSubmodules,
pg_collection: Optional[ProcessGroupCollection] = None,
):
super().__init__(config=config)
self.num_local_experts = num_local_experts
self.input_size = self.config.hidden_size
assert not (
self.config.add_bias_linear and config.bias_dropout_fusion
), "bias_dropout_fusion is not supported in TEGroupedMLP when add_bias_linear=True"
self.ep_group = pg_collection.ep
self.tp_group = pg_collection.expt_tp
# Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.moe_ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.num_local_experts,
self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=True,
tp_comm_buffer_name='fc1',
pg_collection=pg_collection,
)
if self.config.use_te_activation_func and not (submodules.activation_func is None):
self.activation_func = build_module(submodules.activation_func, config=self.config)
else:
self.activation_func = self.config.activation_func
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.num_local_experts,
self.config.moe_ffn_hidden_size,
(
self.config.hidden_size
if self.config.moe_latent_size is None
else self.config.moe_latent_size
),
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name='fc2',
pg_collection=pg_collection,
)
self.offload_expert_fc1 = (
self.config.fine_grained_activation_offloading
and "expert_fc1" in self.config.offload_modules
)
self.offload_moe_act = (
self.config.fine_grained_activation_offloading
and "moe_act" in self.config.offload_modules
)
self.activation_recompute = (
self.config.recompute_granularity == 'selective'
and "moe_act" in self.config.recompute_modules
)
if self.activation_recompute and (self.config.fp8 or self.config.fp4):
from megatron.core.extensions.transformer_engine import set_save_original_input
set_save_original_input(self.linear_fc2)
# This is to avoid the CPU overhead of multiple d2h copies
if self.offload_expert_fc1:
from megatron.core.extensions.transformer_engine import set_save_original_input
set_save_original_input(self.linear_fc1)
if self.config.fp8 or self.config.fp4:
assert HAVE_TE, "FP8 and FP4 requires TE."
self.quantization_padding = Fp8Padding(self.num_local_experts)
self.quantization_unpadding = Fp8Unpadding(self.num_local_experts)
@staticmethod
def _apply_bias(intermediate_parallel, bias_parallel, tokens_per_expert, permuted_probs):
if bias_parallel is None:
return intermediate_parallel
shape = intermediate_parallel.shape
return (
torch.cat(
[
t + b * p
for t, b, p in zip(
torch.split(intermediate_parallel.view(-1, shape[-1]), tokens_per_expert),
bias_parallel,
torch.split(permuted_probs, tokens_per_expert),
)
]
)
.view(shape)
.to(intermediate_parallel.dtype)
)
def forward(
self,
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward of TEGroupedMLP
Args:
permuted_local_hidden_states (torch.Tensor): The permuted input hidden states of the
local experts.
tokens_per_expert (torch.Tensor): The number of tokens per expert.
permuted_probs (torch.Tensor): The permuted probs of each token produced by the router.
Return:
output (torch.Tensor): The output of the local experts.
"""
tokens_per_expert = tokens_per_expert.tolist()
if self.config.fp8 or self.config.fp4:
actual_tokens_per_expert = tokens_per_expert
permuted_local_hidden_states, tokens_per_expert = self.quantization_padding(
permuted_local_hidden_states, tokens_per_expert
)
permuted_probs, _ = self.quantization_padding(
permuted_probs.unsqueeze(-1), actual_tokens_per_expert
)
else:
permuted_probs = permuted_probs.unsqueeze(-1)
if self.config.moe_apply_probs_on_input:
assert (
self.config.moe_router_topk == 1
), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1."
original_dtype = permuted_local_hidden_states.dtype
permuted_local_hidden_states = permuted_probs * permuted_local_hidden_states
permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype)
# Probs already applied, so reset to 1.
permuted_probs = torch.ones_like(permuted_probs)
with off_interface(
self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1"
) as permuted_local_hidden_states:
fc1_output, bias_parallel = self.linear_fc1(
permuted_local_hidden_states, tokens_per_expert
)
if self.offload_expert_fc1:
fc1_output = off_interface.group_commit(
fc1_output,
name="expert_fc1",
forced_released_tensors=[permuted_local_hidden_states],
)
def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs):
if self.config.use_te_activation_func:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
if permuted_probs is not None:
original_dtype = intermediate_parallel.dtype
intermediate_parallel = intermediate_parallel * permuted_probs
intermediate_parallel = intermediate_parallel.to(original_dtype)
elif self.config.bias_activation_fusion:
if self.activation_func == F.silu and self.config.gated_linear_unit:
# dtype is handled inside the fused kernel
intermediate_parallel = weighted_bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
permuted_probs,
self.config.activation_func_fp8_input_store,
)
elif self.activation_func == quick_gelu and self.config.gated_linear_unit:
intermediate_parallel = weighted_bias_quick_geglu_impl(
intermediate_parallel,
bias_parallel,
permuted_probs,
self.config.activation_func_fp8_input_store,
self.config.glu_linear_offset,
self.config.activation_func_clamp_value,
)
else:
raise ValueError(
"Only support fusion of swiglu and quick_gelu in TEGroupedMLP."
)
elif (
self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu
):
assert bias_parallel is None
intermediate_parallel = weighted_squared_relu_impl(
intermediate_parallel, permuted_probs
)
else:
if self.config.gated_linear_unit:
def glu(x):
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if (val := self.config.activation_func_clamp_value) is not None:
x_glu = x_glu.clamp(min=None, max=val)
x_linear = x_linear.clamp(min=-val, max=val)
return self.config.activation_func(x_glu) * (
x_linear + self.config.glu_linear_offset
)
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
original_dtype = intermediate_parallel.dtype
intermediate_parallel = intermediate_parallel * permuted_probs
intermediate_parallel = intermediate_parallel.to(original_dtype)
return intermediate_parallel
if self.activation_recompute:
self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput()
with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output:
bias_act_output = self.activation_checkpoint.checkpoint(
bias_act_func, fc1_output, bias_parallel, permuted_probs
)
else:
with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output:
bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs)
output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert)
if self.activation_recompute:
self.activation_checkpoint.discard_output_and_register_recompute(output)
# Delay the offload of the moe act until after the linear_fc2 has been computed
# to make sure the fc1_output is reloaded to GPU before recomputing moe_act.
if self.offload_moe_act:
output = off_interface.group_commit(
output, name="moe_act", forced_released_tensors=[fc1_output]
)
output = self._apply_bias(output, output_bias, tokens_per_expert, permuted_probs)
# upad and concat the output
if self.config.fp8 or self.config.fp4:
output = self.quantization_unpadding(output, actual_tokens_per_expert)
output_bias = None
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Maps local expert to global experts.
The sharded state dict is interchangable with SequentialMLP's.
"""
# Guard for cases metadata is not provided
metadata = ensure_metadata_has_dp_cp_group(metadata)
singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = sharded_state_dict_default(
module, f'{name}.', sharded_offsets, metadata, tp_group=self.tp_group
)
if name == 'linear_fc1' and self.config.gated_linear_unit:
num_global_experts = self.ep_group.size() * self.num_local_experts
local_expert_indices_offset = self.ep_group.rank() * self.num_local_experts
ep_axis = len(sharded_offsets)
for i in range(self.num_local_experts):
if singleton_local_shards:
new_sharded_offsets = sharded_offsets
else:
new_sharded_offsets = (
*sharded_offsets,
(ep_axis, local_expert_indices_offset + i, num_global_experts),
)
for k in (f'{name}.weight{i}', f'{name}.bias{i}'):
if k in sub_sd:
sub_sd[k] = apply_swiglu_sharded_factory(
sub_sd[k], new_sharded_offsets, singleton_local_shards
)
if singleton_local_shards:
replace_prefix_for_sharding(sub_sd, '', f'{prefix}experts.')
else:
# Add prefix here to match sequential's keys
replace_prefix_for_sharding(sub_sd, f'{name}.', f'{prefix}experts.{name}.')
sharded_state_dict.update({f"{prefix}{k}": v for k, v in sub_sd.items()})
return sharded_state_dict
def backward_dw(self):
"""Performs backward pass for weight gradients in TEGroupedMLP.
This method executes the backward pass for weight gradients by calling
backward_dw() on the linear layers in reverse order (fc2 followed by fc1).
If an error occurs during execution, it is caught and re-raised with a
descriptive message.
"""
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
# TODO(M4): breaking api, switched from pass in tp_group to pass in pg_collection.
def __init__(
self,
num_local_experts,
config: TransformerConfig,
submodules: MLPSubmodules,
pg_collection: Optional[ProcessGroupCollection] = None,
):
if config.moe_ffn_hidden_size == config.ffn_hidden_size:
super().__init__(config=config)
else:
# Local SequentialMLP can still be used here by overriding the ffn_hidden_size
# with a deepcopied config.
sequential_mlp_config = deepcopy(config)
sequential_mlp_config.ffn_hidden_size = config.moe_ffn_hidden_size
super().__init__(config=sequential_mlp_config)
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
self.ep_group = pg_collection.ep
self.tp_group = pg_collection.expt_tp
# use pg_collection.expt_dp_group as data parallel group in this module.
# TODO (Hepteract): expt_dp wont be needed here once distributed checkpoint is refactored
self.dp_group = pg_collection.expt_dp
for _ in range(self.num_local_experts):
expert = MLP(
self.config,
submodules,
ffn_hidden_size=self.config.moe_ffn_hidden_size,
is_expert=True,
tp_group=pg_collection.expt_tp,
)
self.local_experts.append(expert)
def _pad_tensor_for_quantization(self, hidden, probs):
"""Padding tensor shape to multiples of 16/32."""
actual_num_tokens = hidden.shape[0]
divisor = get_align_size_for_quantization(self.config)
padded_num_tokens = ceil(actual_num_tokens / divisor) * divisor - actual_num_tokens
if padded_num_tokens > 0:
pad_tensor = torch.zeros(
padded_num_tokens, hidden.shape[1], dtype=hidden.dtype, device=hidden.device
)
hidden = torch.cat((hidden, pad_tensor), dim=0)
pad_probs = torch.zeros(padded_num_tokens, dtype=probs.dtype, device=probs.device)
probs = torch.cat((probs, pad_probs), dim=0)
return hidden, probs
def forward(
self,
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
):
"""Forward step of the SequentialMLP."""
if self.config.moe_apply_probs_on_input:
assert (
self.config.moe_router_topk == 1
), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1."
original_dtype = permuted_local_hidden_states.dtype
permuted_local_hidden_states = (
permuted_probs.unsqueeze(-1) * permuted_local_hidden_states
)
permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype)
# Probs already applied, so reset to 1.
permuted_probs = torch.ones_like(permuted_probs)
if self.num_local_experts == 1:
if self.config.fp8 or self.config.fp4:
hidden, probs = self._pad_tensor_for_quantization(
permuted_local_hidden_states, permuted_probs
)
output, output_bias = self.local_experts[0](hidden, probs)
output = output[: permuted_local_hidden_states.shape[0]]
else:
output, output_bias = self.local_experts[0](
permuted_local_hidden_states, permuted_probs
)
return output, output_bias
else:
tokens_per_expert = tokens_per_expert.tolist()
tokens_list = torch.split(permuted_local_hidden_states, tokens_per_expert)
probs_list = torch.split(permuted_probs, tokens_per_expert)
output_local_list = []
for expert, tokens, probs in zip(self.local_experts, tokens_list, probs_list):
if self.config.fp8 or self.config.fp4:
hidden, probs = self._pad_tensor_for_quantization(tokens, probs)
output, output_bias = expert(hidden, probs)
output = output[: tokens.shape[0]]
else:
output, output_bias = expert(tokens, probs)
output_local_list.append(output)
output_local = torch.cat(output_local_list, dim=0)
output_bias_local = None
# Note: if bias is enabled on experts, it is already added to the output at this point
return output_local, output_bias_local
def backward_dw(self):
"""Backward pass for weight gradients in SequentialMLP."""
for expert in self.local_experts:
expert.backward_dw()
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Maps local expert to global experts."""
# Guard for cases metadata is not provided
metadata = ensure_metadata_has_dp_cp_group(metadata)
sharded_state_dict = {}
num_global_experts = self.ep_group.size() * self.num_local_experts
local_expert_indices_offset = self.ep_group.rank() * self.num_local_experts
singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
for expert_local_idx, expert in enumerate(self.local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_state_dict_prefix = f'{prefix}local_experts.{expert_local_idx}.'
if singleton_local_shards:
expert_sharded_prefix = f'{prefix}experts.{expert_global_idx}.'
expert_sharded_offsets = sharded_offsets
else:
expert_sharded_prefix = f'{prefix}experts.'
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
expert_state_dict = expert.sharded_state_dict(
expert_state_dict_prefix, expert_sharded_offsets, metadata
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(
expert_state_dict, expert_state_dict_prefix, expert_sharded_prefix
)
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in expert_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (*replica_id[:2], self.dp_group.rank())