-
Notifications
You must be signed in to change notification settings - Fork 126
Expand file tree
/
Copy pathdeepseek_v3.py
More file actions
1466 lines (1253 loc) · 56.1 KB
/
deepseek_v3.py
File metadata and controls
1466 lines (1253 loc) · 56.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from abc import abstractmethod
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Iterable, List, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from flax import nnx
from flax.typing import Sharding
from jax import lax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jaxtyping import Float
from vllm.config import VllmConfig
from tpu_inference import utils
from tpu_inference.distributed.jax_parallel_state import get_pp_group
from tpu_inference.kernels.quantized_matmul.util import quantize_tensor
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
ragged_paged_attention
from tpu_inference.layers.common.attention_interface import mla_attention
from tpu_inference.layers.common.moe import MoEBackend
from tpu_inference.layers.common.quantization import (dequantize_tensor,
quantize_kv)
from tpu_inference.layers.common.sharding import \
ShardingAxisNameBase as ShardingAxisName
from tpu_inference.layers.common.utils import cpu_mesh_context
from tpu_inference.layers.jax import JaxModule
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
from tpu_inference.layers.jax.base import _init_fn as init_fn
from tpu_inference.layers.jax.base import create_param, sharded_initializer
from tpu_inference.layers.jax.embed import JaxEmbed
from tpu_inference.layers.jax.layers import FlaxUtils
from tpu_inference.layers.jax.linear import JaxEinsum
from tpu_inference.layers.jax.moe.moe import JaxMoE
from tpu_inference.layers.jax.moe.utils import (get_expert_parallelism,
select_moe_backend)
from tpu_inference.layers.jax.norm import JaxRmsNorm
from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
from tpu_inference.layers.jax.quantization.configs import QuantizationConfig
from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.models.jax.utils.weight_utils import (JaxAutoWeightsLoader,
LoadableWithIterator,
shard_put)
KVCache = Tuple[jax.Array, jax.Array]
logger = init_logger(__name__)
def _weight_init(random_init: bool):
return sharded_initializer if random_init else nnx.initializers.uniform()
modeling_flax_utils = FlaxUtils()
# TODO: read these configs from HF config.
num_local_experts: int = 256
vocab_size: int = 129280
hidden_size: int = 7168
num_attention_heads: int = 128
num_key_value_heads: int = 128
ffw_intermediate_size: int = 18432
moe_intermediate_size: int = 2048
num_experts_per_token: int = 8
n_group: int = 8
interleave_moe_layer_step: int = 1 # Deepseek V3 has moe_layer_freq=1 in hf config.
hidden_act: str = "silu"
rms_norm_eps: float = 1e-06
routed_scaling_factor: float = 2.5
first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
num_shared_experts = 1
rope_theta = 10000
rope_scaling = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
}
q_lora_rank = 1536
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
v_head_dim = 128
expert_axis_name = ShardingAxisName.ATTN_DATA_EXPERT
@dataclass(kw_only=True)
class DeepseekV3BaseAttention(JaxModule):
"""
Base class containing shared logic for DeepSeek Attention mechanisms.
Handles initialization of common layers and defines skeleton forward pass.
"""
# Core configuration
hidden_size: int
num_attention_heads: int
num_key_value_heads: int
head_dim: int
rope: DeepseekScalingRotaryEmbedding
dtype: jnp.dtype
kv_cache_dtype: str
mesh: Mesh
# Attention-specific configuration
q_lora_rank: int
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
rms_norm_eps: float
# Sharding
rd_sharding: P = P()
q_da_sharding: P = P()
ap_sharding: P = P()
kv_da_sharding: P = P()
activation_attention_td: P = P()
activation_q_td: P = P()
query_tnh: P = P()
keyvalue_skh: P = P()
attn_o_tnh: P = P()
activation_attention_out_td: P = P()
# Weight initialization
random_init: bool = False
rope_mscale_all_dim: float = 1.0
# RNG for weight initialization
rngs: InitVar[nnx.Rngs]
quant_config: Optional[QuantizationConfig] = None
# Scales for Q/KV quantization (per-tensor)
_q_scale: float = 1
_k_scale: float = 1
_v_scale: float = 1
prefix: str = ""
def __post_init__(self, rngs: nnx.Rngs):
self.N = self.num_attention_heads
self.K = self.num_key_value_heads
self.D = self.hidden_size
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
if self.rope.scaling_factor <= 1.0:
yarn_mscale = 1.0
else:
yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
self.rope.scaling_factor) + 1.0
self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2
weight_init = _weight_init(self.random_init)
self.q_a_proj = JaxEinsum(
einsum_str="TD,DA->TA",
kernel_shape=(self.D, self.q_lora_rank),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.q_da_sharding),
prefix=self.prefix + ".q_a_proj",
)
self.q_b_proj = JaxEinsum(
einsum_str="TA,AP->TP",
kernel_shape=(self.q_lora_rank, self.N * self.qk_head_dim),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.ap_sharding),
prefix=self.prefix + ".q_b_proj")
self.kv_a_proj_with_mqa = JaxEinsum(
einsum_str="SD,DA->SA",
kernel_shape=(self.D, self.kv_lora_rank + self.qk_rope_head_dim),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init,
self.kv_da_sharding),
prefix=self.prefix + ".kv_a_proj_with_mqa")
self.o_proj = JaxEinsum(
einsum_str="TR,RD->TD",
kernel_shape=(self.N * self.v_head_dim, self.D),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.rd_sharding),
prefix=self.prefix + ".o_proj")
self.q_a_layernorm = JaxRmsNorm(self.q_lora_rank,
epsilon=self.rms_norm_eps,
scale_init=nnx.with_partitioning(
init_fn, (None, )),
param_dtype=self.dtype,
dtype=self.dtype,
quant_config=self.quant_config,
prefix=self.prefix + ".q_a_layernorm",
rngs=rngs)
self.kv_a_layernorm = JaxRmsNorm(
self.kv_lora_rank,
epsilon=self.rms_norm_eps,
scale_init=nnx.with_partitioning(init_fn, (None, )),
param_dtype=self.dtype,
dtype=self.dtype,
quant_config=self.quant_config,
prefix=self.prefix + ".kv_a_layernorm",
rngs=rngs)
self.kv_cache_quantized_dtype = None
if self.kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
self.kv_cache_dtype)
self.kv_b_proj = JaxEinsum(
einsum_str="SA,AL->SL",
kernel_shape=(self.kv_lora_rank,
self.N * (self.qk_nope_head_dim + self.v_head_dim)),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(init_fn, self.ap_sharding),
prefix=self.prefix + ".kv_b_proj",
)
@abstractmethod
def compute_q_projection(self, *args) -> jax.Array:
raise NotImplementedError
@abstractmethod
def compute_kv_projection(self, *args) -> Tuple[jax.Array, jax.Array]:
raise NotImplementedError
@abstractmethod
def compute_attention(self, *args) -> Tuple[KVCache, jax.Array]:
raise NotImplementedError
def process_output(self, outputs_TNH) -> jax.Array:
return outputs_TNH
def __call__(
self, x: jax.Array, kv_cache: KVCache,
attention_metadata: AttentionMetadata
) -> Tuple[KVCache, jax.Array]:
"""Performs the forward pass of the attention module. Expects that the
child class has implemented the `compute_q_projection`, `compute_kv_projection`,
and `compute_attention` methods.
Args:
x: The input tensor of shape `(batch_size, seq_len, d_model)`.
kv_cache: The key-value cache for storing past attention states.
attention_metadata: Metadata for attention, such as input positions.
Returns:
A tuple containing:
- The updated KV cache.
- The attention output tensor of shape
`(batch_size, seq_len, d_model)`.
"""
md = attention_metadata
x = jnp.asarray(x, self.dtype)
x_SD = lax.with_sharding_constraint(x, self.activation_attention_td)
x_q_TD = lax.with_sharding_constraint(x, self.activation_q_td)
with jax.named_scope("q_proj"):
q_data = self.compute_q_projection(x_q_TD, md.input_positions)
with jax.named_scope("kv_proj"):
kv_data = self.compute_kv_projection(x_SD, md.input_positions)
with jax.named_scope("attn_op"):
new_kv_cache, outputs_TNH = self.compute_attention(
q_data, kv_data, kv_cache, md)
outputs_TNH = self.process_output(outputs_TNH)
if outputs_TNH.shape[-1] != self.v_head_dim:
outputs_TNH = outputs_TNH[..., :self.v_head_dim]
with jax.named_scope("o_proj"):
outputs_TR = outputs_TNH.reshape(outputs_TNH.shape[0],
self.N * self.v_head_dim)
o_TD = self.o_proj(outputs_TR)
return new_kv_cache, o_TD
@dataclass(kw_only=True)
class DeepseekV3Attention(DeepseekV3BaseAttention):
"""Standard Multi-Head Attention (MHA) for DeepSeek models."""
def __post_init__(self, rngs: nnx.Rngs):
super().__post_init__(rngs)
weight_init = _weight_init(self.random_init)
self.kv_b_proj = JaxEinsum(
einsum_str="SA,AL->SL",
kernel_shape=(self.kv_lora_rank,
self.N * (self.qk_nope_head_dim + self.v_head_dim)),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.ap_sharding),
prefix=self.prefix + ".kv_b_proj",
)
def compute_q_projection(self, x_q_TD: jax.Array,
input_positions: jax.Array) -> jax.Array:
"""
Computes the query projection for MHA.
Args:
x_q_TD: The input tensor of shape `(tokens_query, d_model)`.
input_positions: The input positions tensor of shape `(padded_total_num_scheduled_tokens,)`.
Returns:
The query tensor of shape `(tokens_query, num_query_heads, head_dim)`.
"""
q_TA = self.q_a_proj(x_q_TD)
q_TA = self.q_a_layernorm(q_TA)
q_TP = self.q_b_proj(q_TA)
q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
q_rope_TNH = self.rope.apply_rope(input_positions, q_rope_TNH)
q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
return lax.with_sharding_constraint(q_TNH, self.query_tnh)
def compute_kv_projection(
self, x_SD: jax.Array,
input_positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""
Computes the key-value projection for MHA.
Args:
x_SD: The input tensor of shape `(tokens_kv, d_model)`.
input_positions: The input positions tensor of shape `(padded_total_num_scheduled_tokens,)`.
Returns:
Tuple of key-value tensors of shape `(tokens_kv, num_query_heads, d_model)`.
"""
kv_SA = self.kv_a_proj_with_mqa(x_SD)
k_rope_SH = kv_SA[..., self.kv_lora_rank:]
k_rope_SNH = k_rope_SH[..., None, :]
k_rope_SNH = self.rope.apply_rope(input_positions, k_rope_SNH)
assert k_rope_SNH.shape[1] == 1
k_rope_SNH = jnp.broadcast_to(
k_rope_SNH, (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
kv_SA = kv_SA[..., :self.kv_lora_rank]
kv_SA = self.kv_a_layernorm(kv_SA)
kv_SA = lax.with_sharding_constraint(kv_SA, self.keyvalue_skh)
kv_SL = self.kv_b_proj(kv_SA)
kv_nope_SNH = kv_SL.reshape(kv_SA.shape[0], self.N,
self.qk_nope_head_dim + self.v_head_dim)
k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
# Shard
k_SNH = lax.with_sharding_constraint(k_SNH, self.keyvalue_skh)
v_SNH = lax.with_sharding_constraint(v_SNH, self.keyvalue_skh)
return (k_SNH, v_SNH)
def compute_attention(self, q_data: jax.Array, kv_data: Tuple[jax.Array,
jax.Array],
kv_cache: KVCache,
md: AttentionMetadata) -> Tuple[jax.Array, KVCache]:
"""
Computes self-attention for MHA.
Args:
q_data: The query tensor of shape `(tokens_query, num_query_heads, head_dim)`.
kv_data: Tuple of key-value tensors of shape `(tokens_kv, num_query_heads, d_model)`.
kv_cache: KVCache object.
md: AttentionMetadata object.
Returns:
Tuple of output tensors of shape `(tokens_query, num_query_heads, head_dim)` and KVCache object.
"""
q_TNH = q_data
k_SNH, v_SNH = kv_data
multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
(0, multiple_of_128 - self.qk_head_dim)))
k_SNH = jnp.pad(k_SNH, ((0, 0), (0, 0),
(0, multiple_of_128 - self.qk_head_dim)))
v_SNH = jnp.pad(v_SNH, ((0, 0), (0, 0),
(0, multiple_of_128 - self.v_head_dim)))
q_scale = k_scale = v_scale = None
if self.kv_cache_quantized_dtype:
k_scale = self._k_scale
v_scale = self._v_scale
k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype, k_SNH,
v_SNH, k_scale, v_scale)
def _ragged_paged_attention(q, k, v, cache, seq_lens, block_tables,
starts, dist):
return ragged_paged_attention(q,
k,
v,
cache,
seq_lens,
block_tables,
starts,
dist,
sm_scale=self.scale,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale)
in_specs = (
self.query_tnh, # q
self.keyvalue_skh, # k
self.keyvalue_skh, # v
P(None, None, ShardingAxisName.ATTN_HEAD), # kv_cache
P(), # md.seq_lens: Replicated
P(), # page_indices_flat: Replicated
P(), # query_start_loc: Replicated
P(), # distribution: Replicated
)
out_specs = (self.attn_o_tnh, P(None, None,
ShardingAxisName.ATTN_HEAD))
output_TNH, kv_cache = jax.jit(
jax.shard_map(_ragged_paged_attention,
mesh=self.mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=False))(q_TNH, k_SNH, v_SNH, kv_cache,
md.seq_lens, md.block_tables,
md.query_start_loc,
md.request_distribution)
return kv_cache, output_TNH
class MLAEinsum(JaxEinsum):
"""Extending JaxEinsum to handle MLA.
This class is used for MLA, where:
1) the weight is split into k/v parts after loading, and
2) modify the MLA layer to set k/v weights
"""
def __init__(self,
mla_layer,
einsum_str: str,
kernel_shape: tuple[int, ...],
rngs,
bias_shape: Optional[tuple[int, ...]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs):
super().__init__(einsum_str,
kernel_shape,
rngs,
bias_shape=bias_shape,
quant_config=quant_config,
prefix=prefix,
**kwargs)
self.loaded = set()
self.mla_layer = mla_layer
self.quant_config = quant_config
def named_children(self):
# Override, otherwise "mla_layer" will be visited, causing infinite recursion.
yield from []
def load_weights(self, weights):
named_params = dict(self.named_parameters())
if len(self.loaded) >= 2:
raise ValueError(
f"Expect at most 2 params to load for kv_b_proj, already got {self.loaded}, still have {[name for name, _ in weights]} coming."
)
for name, weight in weights:
param = named_params[name]
weight_loader = getattr(param, "weight_loader")
weight_loader(param, weight)
self.loaded.add(name)
if len(self.loaded) != len(named_params):
return
assert self.quant_config is not None
# After loading, split the weights into k/v
with cpu_mesh_context():
dequantized_weight = dequantize_tensor(
self.weight,
self.weight_scale_inv,
(0, 1),
block_size=None,
).T
A, N, qk_nope_head_dim, v_head_dim = self.mla_layer.kv_lora_rank, self.mla_layer.N, self.mla_layer.qk_nope_head_dim, self.mla_layer.v_head_dim
if dequantized_weight.shape != (A, N *
(qk_nope_head_dim + v_head_dim)):
raise ValueError(
f"Unexpected weight shape after dequantization: {dequantized_weight.shape}, expected {(A, N * (qk_nope_head_dim + v_head_dim))=}"
)
dequantized_weight = dequantized_weight.reshape(
A, N, qk_nope_head_dim + v_head_dim)
k_ANH, v_ANH = jnp.split(dequantized_weight, [qk_nope_head_dim],
axis=-1)
k_ANH_weight, k_ANH_scale = quantize_tensor(k_ANH,
self.weight.dtype,
dim=-1)
v_ANH_weight, v_1NH_scale = quantize_tensor(v_ANH,
self.weight.dtype,
dim=0)
# As of writing, sharded_quantized_batched_matmul expects scale to be
# a different shape order than weight
k_N1A_scale = k_ANH_scale.transpose(1, 2, 0)
v_N1H_scale = v_1NH_scale.transpose(1, 0, 2)
mla_layer = self.mla_layer
setattr(
mla_layer, "k_up_proj",
JaxEinsum(
einsum_str="TNH,ANH->TNA",
kernel_shape=(A, N, qk_nope_head_dim),
rngs=nnx.Rngs(0),
prefix=mla_layer.prefix + ".k_up_proj",
quant_config=self.quant_config,
))
setattr(
mla_layer, "v_up_proj",
JaxEinsum(
einsum_str="TNA,ANH->TNH",
kernel_shape=(A, N, v_head_dim),
rngs=nnx.Rngs(0),
prefix=mla_layer.prefix + ".v_up_proj",
quant_config=self.quant_config,
))
# Cannot apply anh_sharding to scales, otherwise it complains about shape mismatch.
mla_layer.k_up_proj.weight.value = shard_put(
k_ANH_weight, self.mla_layer.anh_sharding)
mla_layer.k_up_proj.weight_scale_inv.value = shard_put(k_N1A_scale, ())
mla_layer.v_up_proj.weight.value = shard_put(
v_ANH_weight, self.mla_layer.anh_sharding)
mla_layer.v_up_proj.weight_scale_inv.value = shard_put(v_N1H_scale, ())
delattr(self, 'weight')
delattr(self, 'weight_scale_inv')
delattr(self, 'quant_method')
@dataclass(kw_only=True)
class DeepseekV3MLA(DeepseekV3BaseAttention):
"""Multi-Head Latent Attention (MLA) for DeepSeek V3."""
anh_sharding: Sharding = ()
def __post_init__(self, rngs: nnx.Rngs):
super().__post_init__(rngs)
weight_init = _weight_init(self.random_init)
self.kv_b_proj = MLAEinsum(
mla_layer=self,
einsum_str="SA,AL->SL",
kernel_shape=(self.kv_lora_rank,
self.N * (self.qk_nope_head_dim + self.v_head_dim)),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.ap_sharding),
prefix=self.prefix + ".kv_b_proj",
)
def compute_q_projection(
self, x_q_TD: jax.Array,
input_positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""
Computes the query projection for MLA.
Args:
x_q_TD: The input tensor of shape `(tokens_query, d_model)`.
input_positions: The input positions tensor of shape `(padded_total_num_scheduled_tokens,)`.
Returns:
A tuple of query tensor of shape `(tokens_query, num_query_heads, q_lora_rank)` and
rope tensor of shape `(tokens_query, num_query_heads, head_dim)`.
"""
q_TA = self.q_a_proj(x_q_TD)
q_TA = self.q_a_layernorm(q_TA)
q_TP = self.q_b_proj(q_TA)
q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
q_rope_TNH = self.rope.apply_rope(input_positions, q_rope_TNH)
q_TNA = self.k_up_proj(q_nope_TNH)
q_TNA = lax.with_sharding_constraint(q_TNA, self.query_tnh)
return (q_TNA, q_rope_TNH)
def compute_kv_projection(
self, x_SD: jax.Array,
input_positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""
Computes the key-value projection for MLA.
Args:
x_SD: The input tensor of shape `(tokens_kv, d_model)`.
input_positions: The input positions tensor of shape `(padded_total_num_scheduled_tokens,)`.
Returns:
A tuple of key-value tensor of shape `(tokens_kv, q_lora_rank)` and
rope tensor of shape `(tokens_kv, head_dim)`.
"""
kv_SA = self.kv_a_proj_with_mqa(x_SD)
k_rope_SH = kv_SA[..., self.kv_lora_rank:]
k_rope_SNH = k_rope_SH[..., None, :]
k_rope_SNH = self.rope.apply_rope(input_positions, k_rope_SNH)
assert k_rope_SNH.shape[1] == 1
k_rope_SH = k_rope_SNH[:, 0, :]
kv_SA = kv_SA[..., :self.kv_lora_rank]
kv_SA = self.kv_a_layernorm(kv_SA)
kv_SA = lax.with_sharding_constraint(kv_SA, self.keyvalue_skh)
return (kv_SA, k_rope_SH)
def compute_attention(self, q_data: Tuple[jax.Array, jax.Array],
kv_data: Tuple[jax.Array,
jax.Array], kv_cache: KVCache,
md: AttentionMetadata) -> Tuple[KVCache, jax.Array]:
"""
Computes the attention for MLA.
Args:
q_data: A tuple of query tensor of shape `(tokens_query, num_query_heads, q_lora_rank)` and
rope tensor of shape `(tokens_query, num_query_heads, head_dim)`.
kv_data: A tuple of key-value tensor of shape `(tokens_kv, q_lora_rank)` and
rope tensor of shape `(tokens_kv, head_dim)`.
kv_cache: The key-value cache.
md: The attention metadata.
Returns:
A tuple of key-value cache and output tensor of shape `(tokens_query, num_query_heads, q_lora_rank)`.
"""
q_TNA, q_rope_TNH = q_data
k_SA, k_rope_SH = kv_data
q_scale = k_scale = None
if self.kv_cache_quantized_dtype:
k_scale = self._k_scale
# TODO: May need to apply quantization separately for k_c & k_pe
k_SA, _ = quantize_kv(self.kv_cache_quantized_dtype,
k_SA,
value=None,
k_scale=k_scale)
k_rope_SH, _ = quantize_kv(self.kv_cache_quantized_dtype,
k_rope_SH,
value=None,
k_scale=k_scale)
return mla_attention(q_TNA,
q_rope_TNH,
k_SA,
k_rope_SH,
kv_cache,
md,
self.mesh,
self.num_attention_heads,
self.qk_nope_head_dim,
query_tnh_sharding=self.query_tnh,
keyvalue_skh_sharding=self.keyvalue_skh,
attn_o_tnh_sharding=self.attn_o_tnh,
q_scale=q_scale,
k_scale=k_scale,
v_scale=k_scale,
sm_scale=self.scale)
def process_output(self, outputs_TNA: jax.Array) -> jax.Array:
"""
Processes output for MLA specifically.
Args:
outputs_TNH: The output tensor of shape `(tokens_query, num_query_heads, q_lora_rank)`.
Returns:
The processed output tensor of shape `(tokens_query, num_query_heads, head_dim)`.
"""
# MLA Specific: Apply V-Up Projection after attention
# Outputs from MLA kernel are in latent space (TNA), project to TNH
outputs_TNH = self.v_up_proj(outputs_TNA)
return outputs_TNH
@dataclass(kw_only=True)
class DeepseekV3MLP(JaxModule):
"""A Gated Feed-Forward Network (FFN) layer.
This module consists of two linear projections (gating and up-projection),
an element-wise multiplication of the activated gating projection and the
up-projection, followed by a final downward projection.
Attributes:
sharding_cfg: The configuration for tensor sharding.
"""
dtype: jnp.dtype
hidden_act: str
hidden_size: int
intermediate_size: int
df_sharding: P = P()
fd_sharding: P = P()
activation_ffw_td: P = P()
random_init: bool = False
quant_config: Optional[QuantizationConfig] = None
rngs: InitVar[nnx.Rngs]
def __call__(self, x_TD):
"""Performs the forward pass of the FFW layer.
Args:
x_TD: The input tensor of shape either `(sequence, d_model)`
Returns:
The output tensor of shape `(batch, sequence, d_model)`.
"""
x_TD = jnp.asarray(x_TD, self.dtype)
x_TD = lax.with_sharding_constraint(x_TD, self.activation_ffw_td)
with jax.named_scope("wi_0"):
gating_TF = self.gate_proj(x_TD)
activated_gating_TF = modeling_flax_utils.ACT2FN[self.hidden_act](
gating_TF)
with jax.named_scope("wi_1"):
up_proj_TF = self.up_proj(x_TD)
fuse_TF = activated_gating_TF * up_proj_TF
with jax.named_scope("wo"):
output_TD = self.down_proj(fuse_TF)
return output_TD
def __post_init__(self, rngs: nnx.Rngs):
D = self.hidden_size
F = self.intermediate_size
weight_init = _weight_init(self.random_init)
self.gate_proj = JaxEinsum(
einsum_str="TD,DF->TF",
kernel_shape=(D, F),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.df_sharding),
)
self.up_proj = JaxEinsum(
einsum_str="TD,DF->TF",
kernel_shape=(D, F),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.df_sharding),
)
self.down_proj = JaxEinsum(
einsum_str="TF,FD->TD",
kernel_shape=(F, D),
rngs=rngs,
quant_config=self.quant_config,
param_dtype=self.dtype,
kernel_init=nnx.with_partitioning(weight_init, self.fd_sharding),
)
@dataclass(kw_only=True)
class SharedFusedMoe(JaxMoE):
"""
Corresponds to vLLM's SharedFusedMoe.
Handles the routed and shared experts + the relevant forward pass.
Reference here: https://github.com/vllm-project/vllm/blob/168ee03e1cbba2b962adbc704b16762b266be184/vllm/model_executor/layers/fused_moe/shared_fused_moe.py#L14
"""
shared_experts: Optional[DeepseekV3MLP] = None
routed_scaling_factor: float = 1.0
def __call__(self, x_TD: jax.Array) -> jax.Array:
# Compute Routed Experts
final_hidden_states = super().__call__(x_TD)
# (Maybe) Compute Shared Experts
if self.shared_experts is not None:
shared_output = self.shared_experts(x_TD)
final_hidden_states += shared_output
return final_hidden_states
class DeepseekV2Moe(JaxModule):
"""Jax implementation of Deepseek MoE layer
vllm ref. https://github.com/vllm-project/vllm/blob/168ee03e1cbba2b962adbc704b16762b266be184/vllm/model_executor/models/deepseek_v2.py#L225
"""
def __init__(self,
*,
mesh,
dtype,
num_expert_parallelism,
moe_backend,
quant_config,
scoring_func,
rng,
prefix: str = ""):
self.gate = DeepSeekV3Router(
hidden_size=hidden_size,
num_experts=num_local_experts,
num_experts_per_tok=num_experts_per_token,
n_groups=n_group,
topk_groups=4,
norm_topk_prob=True,
rngs=rng,
routed_scaling_factor=routed_scaling_factor,
dtype=dtype,
moe_backend=moe_backend,
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
ed_sharding=(None, None),
e_sharding=(None, ),
scoring_func=scoring_func,
quant_config=quant_config)
# shared experts
self.shared_experts = DeepseekV3MLP(
dtype=dtype,
hidden_act=hidden_act,
hidden_size=hidden_size,
intermediate_size=num_shared_experts * moe_intermediate_size,
rngs=rng,
activation_ffw_td=P(ShardingAxisName.MLP_DATA, None),
df_sharding=P(None, ShardingAxisName.ATTN_HEAD),
fd_sharding=P(ShardingAxisName.ATTN_HEAD, None),
quant_config=quant_config)
# routed experts
if moe_backend == MoEBackend.GMM_TP:
moe_activation_ffw_td = P(ShardingAxisName.MLP_DATA, None)
moe_activation_ffw_ted = P(ShardingAxisName.MLP_DATA, None,
ShardingAxisName.MOE_TENSOR)
moe_edf_sharding = P(None, ShardingAxisName.ATTN_DATA_EXPERT,
ShardingAxisName.MOE_TENSOR)
moe_efd_sharding = P(None, ShardingAxisName.MOE_TENSOR,
ShardingAxisName.ATTN_DATA_EXPERT)
else:
moe_activation_ffw_td = P(ShardingAxisName.MLP_DATA,
ShardingAxisName.MOE_TENSOR)
moe_activation_ffw_ted = P(ShardingAxisName.MLP_DATA, None,
ShardingAxisName.MOE_TENSOR)
moe_edf_sharding = P(ShardingAxisName.ATTN_DATA_EXPERT, None, None)
moe_efd_sharding = P(ShardingAxisName.ATTN_DATA_EXPERT, None, None)
self.experts = SharedFusedMoe(
dtype=dtype,
num_local_experts=num_local_experts,
apply_expert_weight_before_computation=False,
expert_axis_name=expert_axis_name,
num_expert_parallelism=num_expert_parallelism,
hidden_size=hidden_size,
intermediate_size_moe=moe_intermediate_size,
num_experts_per_tok=num_experts_per_token,
mesh=mesh,
hidden_act=hidden_act,
rngs=rng,
quant_config=quant_config,
activation_ffw_td=moe_activation_ffw_td,
activation_ffw_ted=moe_activation_ffw_ted,
edf_sharding=moe_edf_sharding,
efd_sharding=moe_efd_sharding,
moe_backend=moe_backend,
qwix_quantized_weight_dtype=None,
# It's abnormal prefix here because we are using dataclass for SharedFusedMoe and JaxMoe.
# The proper way is to change both to normal class, set prefix=prefix+".mlp" here,
# then in __init__, pass prefix+".experts" to super().__init__.
prefix=f"{prefix}.experts",
router=self.gate,
shared_experts=self.shared_experts,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor)
def __call__(self, x_TD: jax.Array):
return self.experts(x_TD)
class DeepseekV3DecoderLayer(JaxModule):
"""
Implementats the DecoderLayer for DeepseekV3.
"""
def __init__(
self,
input_layernorm: JaxRmsNorm,
post_attention_layernorm: JaxRmsNorm,
self_attn: Union[DeepseekV3Attention, DeepseekV3MLA],
# MLP can be either the Dense MLP (for first k layers) or SharedFusedMoe
mlp: nnx.Module | SharedFusedMoe | DeepseekV3MLP,
prefix: str = ""):
self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm
self.self_attn = self_attn
self.mlp = mlp
def __call__(
self, x_TD: jax.Array, *, kv_cache: List[jax.Array],
attention_metadata: AttentionMetadata
) -> Tuple[List[jax.Array], jax.Array]:
# Run Self-Attention
residual = x_TD
hidden_states = self.input_layernorm(x_TD)
new_cache, attn_output = self.self_attn(hidden_states, kv_cache,
attention_metadata)
hidden_states = residual + attn_output
# Run MLP/MoE
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
# Residual
hidden_states = residual + mlp_output
return new_cache, hidden_states
class DeepSeekV3Router(JaxEinsum):
"""Router module for Mixture-of-Experts (MoE) layers.
This module determines which experts each token should be routed to based on the input.
"""
def __init__(
self,
hidden_size: int,
num_experts: int,
num_experts_per_tok: int,
n_groups: int,
topk_groups: int,
norm_topk_prob: bool,
routed_scaling_factor,
dtype: jnp.dtype,
rngs: nnx.Rngs,
# Sharding Attributes
activation_ffw_td: P = P(),
ed_sharding: P = P(),
e_sharding: P = P(),
random_init: bool = False,
quant_config: Optional[QuantizationConfig] = None,
router_bias_dtype: jnp.dtype = jnp.float32,
scoring_func: str = "sigmoid",
moe_backend: MoEBackend = MoEBackend.DENSE_MAT):
self.hidden_size = hidden_size
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.n_groups = n_groups
self.topk_groups = topk_groups
self.norm_topk_prob = norm_topk_prob
self.routed_scaling_factor = routed_scaling_factor