1
1
'''
2
- Description :
2
+ Description :
3
3
Author : Boxin Zhang
4
4
Version : 0.1.0
5
- Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5
+ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
6
6
'''
7
7
import torch
8
8
from torch import nn
16
16
from typing import Optional , Tuple
17
17
from ktransformers .operators .base_operator import BaseInjectedModule
18
18
from ktransformers .util .custom_gguf import GGUFLoader
19
- from ktransformers .util .utils import get_compute_capability
19
+ from ktransformers .util .feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE , KTRANSFORMERS_USE_FLASHINFER
20
20
import logging
21
21
from transformers .configuration_utils import PretrainedConfig
22
22
from transformers .cache_utils import Cache
23
23
from flash_attn import flash_attn_func
24
24
from ktransformers .operators .triton_attention import decode_attention_fwd_grouped
25
- import os
26
25
from ktransformers .operators .flashinfer_wrapper import flashinfer_enabled
27
26
if flashinfer_enabled :
28
27
from ktransformers .operators .flashinfer_wrapper import MLAWrapperSingleton , attention_ref
@@ -63,7 +62,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
63
62
kv_b_proj = self .kv_b_proj .weight .view (self .num_heads , - 1 , self .kv_lora_rank )
64
63
self .q_absorb = kv_b_proj [:, :self .qk_nope_head_dim , :].view (self .num_heads , self .qk_nope_head_dim , self .kv_lora_rank )
65
64
self .out_absorb = kv_b_proj [:, self .qk_nope_head_dim :, :].view (self .num_heads , self .v_head_dim , self .kv_lora_rank )
66
-
65
+
67
66
return self .q_absorb , self .out_absorb
68
67
69
68
def forward_chunck (
@@ -111,7 +110,7 @@ def forward_chunck(
111
110
112
111
if past_key_value is not None :
113
112
cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position } # Specific to RoPE models
114
-
113
+
115
114
# compressed_kv [bsz, q_len, self.kv_lora_rank]
116
115
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
117
116
k_pe = k_pe .transpose (1 ,2 )
@@ -122,7 +121,7 @@ def forward_chunck(
122
121
)
123
122
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
124
123
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
125
-
124
+
126
125
q_absorb , out_absorb = self .get_absorbed ()
127
126
128
127
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
@@ -136,9 +135,9 @@ def forward_chunck(
136
135
#print(k_pe.shape)
137
136
#print(q_nope.shape)
138
137
#print(compressed_kv.shape)
139
-
138
+
140
139
attn_weights = (torch .matmul (q_pe , k_pe .mT ) + torch .matmul (q_nope , compressed_kv .mT )) * self .softmax_scale
141
-
140
+
142
141
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
143
142
compressed_kv = compressed_kv .squeeze (1 )
144
143
"""
@@ -166,10 +165,10 @@ def forward_chunck(
166
165
attn_weights = nn .functional .dropout (
167
166
attn_weights , p = self .attention_dropout , training = self .training
168
167
)
169
-
168
+
170
169
attn_output = torch .einsum ('bhql,blc->bhqc' , attn_weights , compressed_kv )
171
-
172
- attn_output = torch .matmul (attn_output , out_absorb .mT )
170
+
171
+ attn_output = torch .matmul (attn_output , out_absorb .mT )
173
172
174
173
if attn_output .size () != (bsz , self .num_heads , q_len , self .v_head_dim ):
175
174
raise ValueError (
@@ -178,7 +177,7 @@ def forward_chunck(
178
177
)
179
178
180
179
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
181
-
180
+
182
181
attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
183
182
184
183
attn_output = self .o_proj (attn_output )
@@ -225,11 +224,11 @@ def forward_linux_triton(
225
224
"with a layer index."
226
225
)
227
226
kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
228
-
227
+
229
228
cos , sin = self .rotary_emb (q_pe , position_ids )
230
229
q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , unsqueeze_dim = 2 )
231
230
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
232
-
231
+
233
232
# decode
234
233
if q_len == 1 :
235
234
if past_key_value is not None :
@@ -246,20 +245,20 @@ def forward_linux_triton(
246
245
q_nope = torch .matmul (q_nope , q_absorb ) # batched MM
247
246
q_nope = q_nope .transpose (1 , 2 )
248
247
#assert q_nope.is_contiguous()
249
-
248
+
250
249
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
251
250
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
252
251
query_states = torch .cat ([q_nope , q_pe ], dim = - 1 )
253
-
252
+
254
253
query_states = query_states .squeeze (1 )
255
254
attn_output = torch .zeros_like (q_nope ) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
256
-
255
+
257
256
attn_logits = torch .empty (
258
257
(
259
258
bsz ,
260
259
self .num_heads ,
261
260
4 , #num_kv_splits # follow vLLM, fix it TODO
262
- self .kv_lora_rank + 1 ,
261
+ self .kv_lora_rank + 1 ,
263
262
),
264
263
dtype = torch .float32 ,
265
264
device = attn_output .device
@@ -280,16 +279,16 @@ def forward_linux_triton(
280
279
4 , #num_kv_splits # follow vLLM, fix it TODO
281
280
self .softmax_scale ,
282
281
past_key_value .page_size )
283
-
282
+
284
283
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
285
284
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
286
285
attn_output = attn_output .transpose (1 , 2 )
287
286
attn_output = torch .matmul (attn_output , out_absorb .mT )
288
287
attn_output = attn_output .transpose (1 , 2 )
289
-
288
+
290
289
attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
291
290
attn_output = self .o_proj (attn_output )
292
-
291
+
293
292
#print("attn_output", torch.isnan(attn_output).any())
294
293
return attn_output , None , past_key_value
295
294
else :
@@ -317,7 +316,7 @@ def forward_linux_triton(
317
316
key_states = k_pe .new_empty (bsz , kv_seq_len , self .num_heads , self .q_head_dim )
318
317
key_states [:, :, :, :self .qk_nope_head_dim ] = k_nope
319
318
key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
320
-
319
+
321
320
value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
322
321
value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
323
322
@@ -378,11 +377,11 @@ def forward_linux_flashinfer(
378
377
"with a layer index."
379
378
)
380
379
kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
381
-
380
+
382
381
cos , sin = self .rotary_emb (q_pe , position_ids )
383
382
q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , unsqueeze_dim = 2 )
384
383
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
385
-
384
+
386
385
# decode
387
386
if q_len == 1 or self .absorb_for_prefill :
388
387
if past_key_value is not None :
@@ -401,7 +400,7 @@ def forward_linux_flashinfer(
401
400
q_nope = q_nope .transpose (1 , 2 )
402
401
q_nope = q_nope .contiguous ()
403
402
#assert q_nope.is_contiguous()
404
-
403
+
405
404
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
406
405
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
407
406
q_nope .squeeze_ (0 )
@@ -454,17 +453,17 @@ def forward_linux_flashinfer(
454
453
)
455
454
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
456
455
"""
457
-
456
+
458
457
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
459
458
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
460
459
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
461
460
attn_output = attn_output .transpose (1 , 2 ) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
462
461
attn_output = torch .matmul (attn_output , out_absorb .mT ) # [bsz, self.num_heads, q_len, self.v_head_dim]
463
462
attn_output = attn_output .transpose (1 , 2 ).contiguous () # [bsz, q_len, self.num_heads, self.kv_lora_rank]
464
-
463
+
465
464
attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim ) # [bsz, q_len, self.num_heads * self.v_head_dim]
466
465
attn_output = self .o_proj (attn_output )
467
-
466
+
468
467
return attn_output , None , past_key_value
469
468
else :
470
469
if past_key_value is not None :
@@ -491,7 +490,7 @@ def forward_linux_flashinfer(
491
490
key_states = k_pe .new_empty (bsz , kv_seq_len , self .num_heads , self .q_head_dim )
492
491
key_states [:, :, :, :self .qk_nope_head_dim ] = k_nope
493
492
key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
494
-
493
+
495
494
value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
496
495
value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
497
496
@@ -511,7 +510,7 @@ def forward_linux_flashinfer(
511
510
).contiguous ()
512
511
attn_output = self .o_proj (attn_output )
513
512
return attn_output , None , past_key_value
514
-
513
+
515
514
def forward_windows (
516
515
self ,
517
516
hidden_states : torch .Tensor ,
@@ -575,7 +574,7 @@ def forward_windows(
575
574
attn_output = cur_output
576
575
else :
577
576
attn_output = torch .cat ((attn_output , cur_output ), dim = - 2 )
578
-
577
+
579
578
return attn_output , None , past_key_value
580
579
581
580
def forward (
@@ -589,8 +588,7 @@ def forward(
589
588
cache_position : Optional [torch .LongTensor ] = None ,
590
589
** kwargs ,
591
590
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
592
- if os .name == 'nt' or get_compute_capability ()< 8 :
593
- print ("for Windows or GPU before ampere, use forward_windows" )
591
+ if KTRANSFORMERS_USE_TORCH_NATIVE :
594
592
return self .forward_windows (
595
593
hidden_states ,
596
594
attention_mask ,
@@ -602,7 +600,7 @@ def forward(
602
600
** kwargs ,
603
601
)
604
602
else :
605
- if flashinfer_enabled :
603
+ if KTRANSFORMERS_USE_FLASHINFER :
606
604
return self .forward_linux_flashinfer (
607
605
hidden_states ,
608
606
attention_mask ,
0 commit comments