-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathdeepseek_v2.py
More file actions
1917 lines (1756 loc) · 68.5 KB
/
deepseek_v2.py
File metadata and controls
1917 lines (1756 loc) · 68.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model."""
import logging
from typing import Optional, Tuple, Union
import torch
from aiter import (
QuantType,
cp_gather_indexer_k_quant_cache,
dtypes,
gemm_a8w8_blockscale_bpreshuffle,
get_hip_quant,
indexer_k_quant_and_cache,
top_k_per_row_decode,
top_k_per_row_prefill,
)
from aiter.dist.communication_op import tensor_model_parallel_all_reduce
from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
from aiter.ops.triton.fused_fp8_quant import (
fused_reduce_rms_fp8_group_quant,
fused_rms_fp8_group_quant,
)
from aiter.ops.triton.fused_mxfp4_quant import (
fused_reduce_rms_mxfp4_quant,
fused_rms_mxfp4_quant,
)
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits
from aiter.rotary_embedding import get_rope
from atom.config import (
CompilationLevel,
Config,
QuantizationConfig,
get_current_atom_config,
)
from atom.model_ops.activation import SiluAndMul
from atom.model_ops.attention_mla import MLAModules, is_rocm_aiter_fp4bmm_enabled
from atom.model_ops.base_attention import Attention
from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding
from atom.model_ops.layernorm import LayerNorm, RMSNorm
from atom.model_ops.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear,
use_triton_gemm,
)
from atom.model_ops.moe import FusedMoE
from atom.model_ops.topK import (
is_rocm_aiter_fuse_routed_scaling_factor,
is_rocm_aiter_fusion_shared_expert_enabled,
)
from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE, _has_module
from atom.models.utils import (
IntermediateTensors,
PPMissingLayer,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
should_ignore_layer,
)
from atom.utils import envs
from atom.utils.custom_register import direct_register_custom_op
from atom.utils.decorators import support_torch_compile
from atom.utils.forward_context import get_forward_context
from torch import nn
from transformers import PretrainedConfig
# from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8
logger = logging.getLogger("atom")
if use_triton_gemm():
try:
from aiter.ops.triton.gemm_a8w8_blockscale import (
gemm_a8w8_blockscale_preshuffle,
)
from aiter.ops.triton.gemm_a16w8_blockscale import (
gemm_a16w8_blockscale_preshuffle,
)
from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle
except ImportError as e:
logger.warning(
f"Triton GEMM kernels not available: {e}. Ensure AITER is up-to-date."
)
gemm_afp4wfp4_preshuffle = None
gemm_a16wfp4_preshuffle = None
gemm_a8w8_blockscale_preshuffle = None
gemm_a16w8_blockscale_preshuffle = None
ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION
ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION
ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION
def _fuse_rmsnorm_fp4_quant_fake(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
shuffle: bool = True,
scale_shuffle_padding: bool = True,
output_unquantized_inp1: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
m, n1 = x1.shape
n2 = x2.shape[1] if x2 is not None else 0
out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device)
scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE
scale_m = ((m + 255) // 256) * 256
scale_n = ((scale_n_valid + 7) // 8) * 8
out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device)
out2 = None
if x2 is not None:
out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device)
out_res1 = None
if res1 is not None:
out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device)
out1_unquantized = None
return out1_quantized, out1_bs, out1_unquantized, out2, out_res1
def _fused_rms_fp8_group_quant_fake(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
dtype_quant: torch.dtype = dtypes.fp8,
group_size: int = 128,
output_unquantized_inp1: bool = False,
transpose_scale: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
m, n1 = x1.shape
out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device)
out1_bs = torch.empty(
(m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device
)
if transpose_scale:
out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape)
out1_unquantized = None
if output_unquantized_inp1:
out1_unquantized = torch.empty_like(x1)
out2 = None
if x2 is not None:
_, n2 = x2.shape
out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device)
out_res1 = None
if res1 is not None:
out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device)
return out1_quantized, out1_bs, out1_unquantized, out2, out_res1
@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake)
def _fuse_rmsnorm_fp4_quant(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
shuffle: bool = True,
scale_shuffle_padding: bool = True,
output_unquantized_inp1: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
m = x1.shape[0]
shuffle_bool = shuffle and (m >= MXFP4_QUANT_BLOCK_SIZE)
(out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = (
fused_rms_mxfp4_quant(
x1=x1,
x1_weight=x1_weight,
x1_epsilon=x1_epsilon,
x2=x2,
x2_weight=x2_weight,
x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon,
res1=res1,
shuffle=shuffle_bool,
scale_shuffle_padding=scale_shuffle_padding,
output_unquantized_inp1=output_unquantized_inp1,
)
)
out1_unquantized = None
return out1_quantized, out1_bs, out1_unquantized, out2, out_res1
@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake)
def _fused_rms_fp8_group_quant(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
dtype_quant: torch.dtype = dtypes.fp8,
group_size: int = 128,
output_unquantized_inp1: bool = False,
transpose_scale: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
(out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = (
fused_rms_fp8_group_quant(
x1,
x1_weight,
x1_epsilon,
x2,
x2_weight,
x2_epsilon,
group_size,
dtype_quant,
res1,
output_unquantized_inp1,
transpose_scale,
)
)
return out1_quantized, out1_bs, out1_unquantized, out2, out_res1
def _fuse_rmsnorm_quant(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
dtype_quant: torch.dtype = dtypes.fp8,
shuffle: bool = True,
scale_shuffle_padding: bool = False,
group_size: int = 128,
output_unquantized_inp1: bool = False,
transpose_scale: bool = False,
):
if dtype_quant == dtypes.fp4x2:
out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = (
_fuse_rmsnorm_fp4_quant(
x1,
x1_weight,
x1_epsilon,
x2,
x2_weight,
x2_epsilon,
res1,
shuffle,
scale_shuffle_padding,
output_unquantized_inp1,
)
)
elif dtype_quant == dtypes.fp8:
out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = (
_fused_rms_fp8_group_quant(
x1,
x1_weight,
x1_epsilon,
x2,
x2_weight,
x2_epsilon,
res1,
dtype_quant,
group_size,
output_unquantized_inp1,
transpose_scale,
)
)
else:
raise ValueError(
f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}."
)
return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1
def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake(
hidden_states_quant: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
hidden_states_quant_scale: Optional[torch.Tensor] = None,
shuffle: Optional[bool] = True,
scale_shuffle_padding: Optional[bool] = True,
output_unquantized_inp1: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_quant.shape[0]
device = hidden_states_quant.device
q_c = torch.empty((M, q_lora_rank // 2), dtype=torch.uint8, device=device)
scale_n_valid = (q_lora_rank + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE
scale_m = ((M + 255) // 256) * 256
scale_n = ((scale_n_valid + 7) // 8) * 8
q_c_scale = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=device)
kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device)
k_pe = torch.empty(
(M, q_lora_rank + kv_lora_rank + qk_rope_head_dim),
dtype=torch.bfloat16,
device=device,
)[..., :qk_rope_head_dim]
return q_c, q_c_scale, kv_c_normed, k_pe
def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8_fake(
hidden_states_quant: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
hidden_states_quant_scale: Optional[torch.Tensor] = None,
output_unquantized_inp1: Optional[bool] = False,
transpose_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_quant.shape[0]
FP8_QUANT_BLOCK_SIZE = 128
device = hidden_states_quant.device
q_c = torch.empty((M, q_lora_rank), dtype=dtypes.fp8, device=device)
scale_n = (q_lora_rank + FP8_QUANT_BLOCK_SIZE - 1) // FP8_QUANT_BLOCK_SIZE
q_c_scale = torch.empty((M, scale_n), dtype=dtypes.fp8, device=device)
kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device)
k_pe = torch.empty(
(M, q_lora_rank + kv_lora_rank + qk_rope_head_dim),
dtype=torch.bfloat16,
device=device,
)[..., :qk_rope_head_dim]
return q_c, q_c_scale, kv_c_normed, k_pe
@torch_compile_guard(
gen_fake=_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake, mutates_args=[]
)
def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4(
hidden_states_quant: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
hidden_states_quant_scale: Optional[torch.Tensor] = None,
shuffle: Optional[bool] = True,
scale_shuffle_padding: Optional[bool] = True,
output_unquantized_inp1: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_quant.shape[0]
if hidden_states_quant_scale is None:
if M <= MXFP4_QUANT_BLOCK_SIZE:
qkv_lora = gemm_a16wfp4_preshuffle(
hidden_states_quant,
weight_qkv_a_proj.view(torch.uint8).view(
weight_qkv_a_proj.shape[0] // 16, -1
),
weight_scale_qkv_a_proj.view(torch.uint8).view(
weight_scale_qkv_a_proj.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1
),
prequant=True,
skip_reduce=True,
)
else:
quant_func = get_hip_quant(QuantType.per_1x32)
x, x_scale = quant_func(
hidden_states_quant,
quant_dtype=dtypes.fp4x2,
shuffle=(M >= MXFP4_QUANT_BLOCK_SIZE),
)
if M >= MXFP4_QUANT_BLOCK_SIZE:
x_scale = x_scale.view(torch.uint8).view(
x_scale.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1
)
else:
x_scale = x_scale[:M, ...].view(torch.uint8)
qkv_lora = gemm_afp4wfp4_preshuffle(
x.view(torch.uint8),
weight_qkv_a_proj.view(torch.uint8).view(
weight_qkv_a_proj.shape[0] // 16, -1
),
x_scale,
weight_scale_qkv_a_proj.view(torch.uint8).view(
weight_scale_qkv_a_proj.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1
),
skip_reduce=True,
)
else:
if M >= MXFP4_QUANT_BLOCK_SIZE:
hidden_states_quant_scale = hidden_states_quant_scale.view(
torch.uint8
).view(hidden_states_quant_scale.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1)
else:
hidden_states_quant_scale = hidden_states_quant_scale[:M, ...].view(
torch.uint8
)
qkv_lora = gemm_afp4wfp4_preshuffle(
hidden_states_quant.view(torch.uint8),
weight_qkv_a_proj.view(torch.uint8).view(
weight_qkv_a_proj.shape[0] // 16, -1
),
hidden_states_quant_scale,
weight_scale_qkv_a_proj.view(torch.uint8).view(
weight_scale_qkv_a_proj.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1
),
skip_reduce=True,
)
q_c, kv_c, k_pe = torch.split(
qkv_lora,
[q_lora_rank, kv_lora_rank, qk_rope_head_dim],
dim=-1,
)
shuffle_bool = shuffle and (M >= MXFP4_QUANT_BLOCK_SIZE)
k_pe_reduced = None
k_pe_reduced_out = None
if k_pe.dim() == 3:
device = hidden_states_quant.device
k_pe_reduced = k_pe
k_pe_reduced_out = torch.empty(
(M, q_lora_rank + kv_lora_rank + qk_rope_head_dim),
dtype=torch.bfloat16,
device=device,
)[..., :qk_rope_head_dim]
(q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = (
fused_reduce_rms_mxfp4_quant(
q_c,
q_a_layernorm_weight,
q_a_layernorm_variance_epsilon,
kv_c,
kv_a_layernorm_weight,
kv_a_layernorm_variance_epsilon,
k_pe_reduced,
res1=None,
shuffle=shuffle_bool,
scale_shuffle_padding=scale_shuffle_padding,
output_unquantized_inp1=output_unquantized_inp1,
dtype=torch.bfloat16,
out3=k_pe_reduced_out,
)
)
if k_pe_reduced_out is not None:
k_pe = k_pe_reduced_out
return q_c, q_c_scale, kv_c_normed, k_pe
@torch_compile_guard(
gen_fake=_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8_fake, mutates_args=[]
)
def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8(
hidden_states_quant: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
hidden_states_quant_scale: Optional[torch.Tensor] = None,
output_unquantized_inp1: Optional[bool] = False,
transpose_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_quant.shape[0]
if hidden_states_quant_scale is None:
if M <= 32:
qkv_lora = gemm_a16w8_blockscale_preshuffle(
hidden_states_quant,
weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1),
weight_scale_qkv_a_proj,
prequant=False,
skip_reduce=True,
)
else:
quant_func = get_hip_quant(QuantType.per_1x128)
x, x_scale = quant_func(
hidden_states_quant,
quant_dtype=dtypes.fp8,
transpose_scale=transpose_scale,
)
if M <= 128:
qkv_lora = gemm_a8w8_blockscale_preshuffle(
x,
weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1),
x_scale,
weight_scale_qkv_a_proj,
skip_reduce=True,
)
else:
qkv_lora = gemm_a8w8_blockscale_bpreshuffle(
x,
weight_qkv_a_proj,
x_scale,
weight_scale_qkv_a_proj,
torch.bfloat16,
)
else:
if M <= 128:
qkv_lora = gemm_a8w8_blockscale_preshuffle(
hidden_states_quant,
weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1),
hidden_states_quant_scale,
weight_scale_qkv_a_proj,
skip_reduce=True,
)
else:
qkv_lora = gemm_a8w8_blockscale_bpreshuffle(
hidden_states_quant,
weight_qkv_a_proj,
hidden_states_quant_scale,
weight_scale_qkv_a_proj,
torch.bfloat16,
)
q_c, kv_c, k_pe = torch.split(
qkv_lora,
[q_lora_rank, kv_lora_rank, qk_rope_head_dim],
dim=-1,
)
k_pe_reduced = None
k_pe_reduced_out = None
if k_pe.dim() == 3:
device = hidden_states_quant.device
k_pe_reduced = k_pe
k_pe_reduced_out = torch.empty(
(M, q_lora_rank + kv_lora_rank + qk_rope_head_dim),
dtype=torch.bfloat16,
device=device,
)[..., :qk_rope_head_dim]
(q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = (
fused_reduce_rms_fp8_group_quant(
q_c,
q_a_layernorm_weight,
q_a_layernorm_variance_epsilon,
kv_c,
kv_a_layernorm_weight,
kv_a_layernorm_variance_epsilon,
k_pe_reduced,
res1=None,
output_unquantized_inp1=output_unquantized_inp1,
dtype=torch.bfloat16,
out3=k_pe_reduced_out,
transpose_scale=transpose_scale,
)
)
if k_pe_reduced_out is not None:
k_pe = k_pe_reduced_out
return q_c, q_c_scale, kv_c_normed, k_pe
def _fuse_qkv_a_proj_reduce_rmsnorm_quant(
hidden_states_quant: torch.Tensor,
weight_qkv_a_proj: torch.Tensor,
weight_scale_qkv_a_proj: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
q_lora_rank: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
dtype_quant=dtypes.fp8,
hidden_states_quant_scale: Optional[torch.Tensor] = None,
shuffle: Optional[bool] = False,
scale_shuffle_padding: Optional[bool] = False,
group_size: Optional[int] = 128,
output_unquantized_inp1: Optional[bool] = False,
transpose_scale: Optional[bool] = False,
):
if dtype_quant == dtypes.fp4x2:
q_c, q_c_scale, kv_c_normed, k_pe = _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4(
hidden_states_quant,
weight_qkv_a_proj,
weight_scale_qkv_a_proj,
q_a_layernorm_weight,
q_a_layernorm_variance_epsilon,
kv_a_layernorm_weight,
kv_a_layernorm_variance_epsilon,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
hidden_states_quant_scale,
shuffle,
scale_shuffle_padding,
output_unquantized_inp1,
)
elif dtype_quant == dtypes.fp8:
q_c, q_c_scale, kv_c_normed, k_pe = _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8(
hidden_states_quant,
weight_qkv_a_proj,
weight_scale_qkv_a_proj,
q_a_layernorm_weight,
q_a_layernorm_variance_epsilon,
kv_a_layernorm_weight,
kv_a_layernorm_variance_epsilon,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
hidden_states_quant_scale,
output_unquantized_inp1,
transpose_scale,
)
else:
raise ValueError(
f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}."
)
# logger.info(f"{q_c.shape=}, {q_c_scale.shape=}, {kv_c_normed.shape=}, {k_pe.shape=}, {q_c.stride()=}, {q_c_scale.stride()=}, {kv_c_normed.stride()=}, {k_pe.stride()=}")
return q_c, q_c_scale, kv_c_normed, k_pe
class DeepseekV2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module):
# Using a single shared stream avoids exhausting GPU/HSA resources
_shared_alt_stream: Optional[torch.cuda.Stream] = None
@staticmethod
def _get_shared_stream() -> torch.cuda.Stream:
if DeepseekV2MoE._shared_alt_stream is None:
DeepseekV2MoE._shared_alt_stream = torch.cuda.Stream()
return DeepseekV2MoE._shared_alt_stream
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.reduce_results = reduce_results
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
self.gate = ReplicatedLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
# MoE gate normally remains unquantized, but may not declare as ignore layers in quantization_config
quant_config=None,
prefix=f"{prefix}.gate",
)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts)
)
else:
self.gate.e_score_correction_bias = None
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
config=config,
)
# Dual-stream support: when mori is enabled,
# parallelize shared expert and routed expert computation
self._use_dual_stream = False
self.alt_stream: Optional[torch.cuda.Stream] = None
if config.n_shared_experts is not None:
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
and _has_module("mori")
and get_current_atom_config().compilation_config.level
!= CompilationLevel.PIECEWISE
):
self._use_dual_stream = True
self.alt_stream = DeepseekV2MoE._get_shared_stream()
if not is_rocm_aiter_fusion_shared_expert_enabled():
intermediate_size = (
config.moe_intermediate_size * config.n_shared_experts
)
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
def _forward_dual_stream(
self,
hidden_states: torch.Tensor,
num_tokens: int,
hidden_dim: int,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
alt_stream = self.alt_stream
alt_stream.wait_stream(current_stream)
# Execute shared experts on current_stream
shared_output = self.shared_experts(hidden_states)
# Execute routed experts on alt_stream
with torch.cuda.stream(alt_stream):
router_logits = self.gate(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if not is_rocm_aiter_fuse_routed_scaling_factor():
final_hidden_states = (
final_hidden_states * self.routed_scaling_factor
)
else:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
current_stream.wait_stream(alt_stream)
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
final_hidden_states = final_hidden_states + shared_output * (
1.0 / self.routed_scaling_factor
)
if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
# Use dual-stream forward when mori is enabled
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self._use_dual_stream
and self.alt_stream is not None
and num_tokens > 0
and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self._forward_dual_stream(hidden_states, num_tokens, hidden_dim)
if (
self.n_shared_experts is not None
and not is_rocm_aiter_fusion_shared_expert_enabled()
):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if not is_rocm_aiter_fuse_routed_scaling_factor():
final_hidden_states = final_hidden_states * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output * (
1.0 / self.routed_scaling_factor
)
if self.tp_size > 1 and self.reduce_results:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
import math
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekV32IndexerCache(nn.Module):
def __init__(
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: str
):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor,
) -> torch.Tensor:
# careful! this will be None in dummy run
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
context = forward_context.context
slot_mapping = attn_metadata.slot_mapping
# Skip for dummy runs to avoid corrupting KV cache
if forward_context.context.is_dummy_run:
# dummy runner
return weights
num_decode_tokens = context.batch_size if not context.is_prefill else 0
indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
if context.is_prefill:
if attn_metadata.max_seqlen_k <= topk_indices_buffer.shape[1]:
return weights
prefill_metadata = attn_metadata
num_prefills = context.batch_size
total_seq_lens = hidden_states.shape[0]
k_fp8 = torch.empty(
[total_seq_lens, head_dim], device=k.device, dtype=dtypes.fp8
)
k_scale = torch.empty([total_seq_lens, 1], device=k.device, dtype=torch.float32)
if prefill_metadata.block_tables.shape[0] < num_prefills:
new_shape = (num_prefills, prefill_metadata.block_tables.shape[1])
prefill_metadata.block_tables = torch.full(
new_shape,
-1,
dtype=torch.long,
device=prefill_metadata.block_tables.device,
)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale.view(dtypes.fp8),
prefill_metadata.block_tables,
prefill_metadata.cu_seqlens_q,
# num_prefills,
)
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
num_tokens = hidden_states.shape[0]
logits = fp8_mqa_logits(
Q=q_fp8[num_decode_tokens:num_tokens],
KV=k_fp8,
kv_scales=k_scale,
weights=weights[num_decode_tokens:num_tokens],
cu_starts=cu_seqlen_ks,