1
1
'''
2
- Description :
2
+ Description :
3
3
Author : Boxin Zhang
4
4
Version : 0.1.0
5
- Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5
+ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
6
6
'''
7
7
import torch
8
8
from torch import nn
16
16
from typing import Optional , Tuple
17
17
from ktransformers .operators .base_operator import BaseInjectedModule
18
18
from ktransformers .util .custom_gguf import GGUFLoader
19
- from ktransformers .util .utils import get_compute_capability
19
+ from ktransformers .util .feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE , KTRANSFORMERS_USE_FLASHINFER
20
20
import logging
21
21
from transformers .configuration_utils import PretrainedConfig
22
22
from transformers .cache_utils import Cache
23
- from ktransformers .util .vendors import device_manager , get_device , to_device , GPUVendor
24
23
25
24
try :
26
25
from flash_attn import flash_attn_func
27
26
except :
28
27
pass
29
- from ktransformers .operators .triton_attention import decode_attention_fwd_grouped
28
+ from ktransformers .operators .triton_attention import decode_attention_fwd_grouped
30
29
from ktransformers .operators .triton_attention_prefill import context_attention_fwd
31
30
import os
32
31
from ktransformers .operators .flashinfer_wrapper import flashinfer_enabled
@@ -69,7 +68,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
69
68
kv_b_proj = self .kv_b_proj .weight .view (self .num_heads , - 1 , self .kv_lora_rank )
70
69
self .q_absorb = kv_b_proj [:, :self .qk_nope_head_dim , :].view (self .num_heads , self .qk_nope_head_dim , self .kv_lora_rank )
71
70
self .out_absorb = kv_b_proj [:, self .qk_nope_head_dim :, :].view (self .num_heads , self .v_head_dim , self .kv_lora_rank )
72
-
71
+
73
72
return self .q_absorb , self .out_absorb
74
73
75
74
def forward_chunck (
@@ -117,7 +116,7 @@ def forward_chunck(
117
116
118
117
if past_key_value is not None :
119
118
cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position } # Specific to RoPE models
120
-
119
+
121
120
# compressed_kv [bsz, q_len, self.kv_lora_rank]
122
121
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
123
122
k_pe = k_pe .transpose (1 ,2 )
@@ -128,7 +127,7 @@ def forward_chunck(
128
127
)
129
128
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
130
129
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
131
-
130
+
132
131
q_absorb , out_absorb = self .get_absorbed ()
133
132
134
133
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
@@ -142,9 +141,9 @@ def forward_chunck(
142
141
#print(k_pe.shape)
143
142
#print(q_nope.shape)
144
143
#print(compressed_kv.shape)
145
-
144
+
146
145
attn_weights = (torch .matmul (q_pe , k_pe .mT ) + torch .matmul (q_nope , compressed_kv .mT )) * self .softmax_scale
147
-
146
+
148
147
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
149
148
compressed_kv = compressed_kv .squeeze (1 )
150
149
"""
@@ -172,10 +171,10 @@ def forward_chunck(
172
171
attn_weights = nn .functional .dropout (
173
172
attn_weights , p = self .attention_dropout , training = self .training
174
173
)
175
-
174
+
176
175
attn_output = torch .einsum ('bhql,blc->bhqc' , attn_weights , compressed_kv )
177
-
178
- attn_output = torch .matmul (attn_output , out_absorb .mT )
176
+
177
+ attn_output = torch .matmul (attn_output , out_absorb .mT )
179
178
180
179
if attn_output .size () != (bsz , self .num_heads , q_len , self .v_head_dim ):
181
180
raise ValueError (
@@ -184,7 +183,7 @@ def forward_chunck(
184
183
)
185
184
186
185
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
187
-
186
+
188
187
attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
189
188
190
189
attn_output = self .o_proj (attn_output )
@@ -231,11 +230,11 @@ def forward_linux_triton(
231
230
"with a layer index."
232
231
)
233
232
kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
234
-
233
+
235
234
cos , sin = self .rotary_emb (q_pe , position_ids )
236
235
q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , unsqueeze_dim = 2 )
237
236
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
238
-
237
+
239
238
# decode
240
239
if q_len == 1 :
241
240
if past_key_value is not None :
@@ -252,20 +251,20 @@ def forward_linux_triton(
252
251
q_nope = torch .matmul (q_nope , q_absorb ) # batched MM
253
252
q_nope = q_nope .transpose (1 , 2 )
254
253
#assert q_nope.is_contiguous()
255
-
254
+
256
255
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
257
256
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
258
257
query_states = torch .cat ([q_nope , q_pe ], dim = - 1 )
259
-
258
+
260
259
query_states = query_states .squeeze (1 )
261
260
attn_output = torch .zeros_like (q_nope ) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
262
-
261
+
263
262
attn_logits = torch .empty (
264
263
(
265
264
bsz ,
266
265
self .num_heads ,
267
266
4 , #num_kv_splits # follow vLLM, fix it TODO
268
- self .kv_lora_rank + 1 ,
267
+ self .kv_lora_rank + 1 ,
269
268
),
270
269
dtype = torch .float32 ,
271
270
device = attn_output .device
@@ -286,16 +285,16 @@ def forward_linux_triton(
286
285
4 , #num_kv_splits # follow vLLM, fix it TODO
287
286
self .softmax_scale ,
288
287
past_key_value .page_size )
289
-
288
+
290
289
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
291
290
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
292
291
attn_output = attn_output .transpose (1 , 2 )
293
292
attn_output = torch .matmul (attn_output , out_absorb .mT )
294
293
attn_output = attn_output .transpose (1 , 2 )
295
-
294
+
296
295
attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
297
296
attn_output = self .o_proj (attn_output )
298
-
297
+
299
298
#print("attn_output", torch.isnan(attn_output).any())
300
299
return attn_output , None , past_key_value
301
300
else :
@@ -323,7 +322,7 @@ def forward_linux_triton(
323
322
key_states = k_pe .new_empty (bsz , kv_seq_len , self .num_heads , self .q_head_dim )
324
323
key_states [:, :, :, :self .qk_nope_head_dim ] = k_nope
325
324
key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
326
-
325
+
327
326
value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
328
327
value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
329
328
@@ -384,11 +383,11 @@ def forward_linux_flashinfer(
384
383
"with a layer index."
385
384
)
386
385
kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
387
-
386
+
388
387
cos , sin = self .rotary_emb (q_pe , position_ids )
389
388
q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , unsqueeze_dim = 2 )
390
389
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
391
-
390
+
392
391
# decode
393
392
if q_len == 1 or self .absorb_for_prefill :
394
393
if past_key_value is not None :
@@ -407,7 +406,7 @@ def forward_linux_flashinfer(
407
406
q_nope = q_nope .transpose (1 , 2 )
408
407
q_nope = q_nope .contiguous ()
409
408
#assert q_nope.is_contiguous()
410
-
409
+
411
410
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
412
411
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
413
412
q_nope .squeeze_ (0 )
@@ -460,17 +459,17 @@ def forward_linux_flashinfer(
460
459
)
461
460
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
462
461
"""
463
-
462
+
464
463
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
465
464
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
466
465
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
467
466
attn_output = attn_output .transpose (1 , 2 ) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
468
467
attn_output = torch .matmul (attn_output , out_absorb .mT ) # [bsz, self.num_heads, q_len, self.v_head_dim]
469
468
attn_output = attn_output .transpose (1 , 2 ).contiguous () # [bsz, q_len, self.num_heads, self.kv_lora_rank]
470
-
469
+
471
470
attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim ) # [bsz, q_len, self.num_heads * self.v_head_dim]
472
471
attn_output = self .o_proj (attn_output )
473
-
472
+
474
473
return attn_output , None , past_key_value
475
474
else :
476
475
if past_key_value is not None :
@@ -497,7 +496,7 @@ def forward_linux_flashinfer(
497
496
key_states = k_pe .new_empty (bsz , kv_seq_len , self .num_heads , self .q_head_dim )
498
497
key_states [:, :, :, :self .qk_nope_head_dim ] = k_nope
499
498
key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
500
-
499
+
501
500
value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
502
501
value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
503
502
@@ -517,7 +516,7 @@ def forward_linux_flashinfer(
517
516
).contiguous ()
518
517
attn_output = self .o_proj (attn_output )
519
518
return attn_output , None , past_key_value
520
-
519
+
521
520
def forward_windows (
522
521
self ,
523
522
hidden_states : torch .Tensor ,
@@ -581,7 +580,7 @@ def forward_windows(
581
580
attn_output = cur_output
582
581
else :
583
582
attn_output = torch .cat ((attn_output , cur_output ), dim = - 2 )
584
-
583
+
585
584
return attn_output , None , past_key_value
586
585
587
586
def forward (
@@ -595,7 +594,7 @@ def forward(
595
594
cache_position : Optional [torch .LongTensor ] = None ,
596
595
** kwargs ,
597
596
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
598
- if os . name == 'nt' or get_compute_capability () < 8 or device_manager . gpu_vendor != GPUVendor . NVIDIA :
597
+ if KTRANSFORMERS_USE_TORCH_NATIVE :
599
598
return self .forward_windows (
600
599
hidden_states ,
601
600
attention_mask ,
@@ -607,7 +606,7 @@ def forward(
607
606
** kwargs ,
608
607
)
609
608
else :
610
- if flashinfer_enabled :
609
+ if KTRANSFORMERS_USE_FLASHINFER :
611
610
return self .forward_linux_flashinfer (
612
611
hidden_states ,
613
612
attention_mask ,
0 commit comments