Skip to content

Commit 0fb7a33

Browse files
committed
feat: Add feature gate
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent f46b3fd commit 0fb7a33

File tree

4 files changed

+66
-53
lines changed

4 files changed

+66
-53
lines changed

ktransformers/local_chat.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""
2-
Description :
2+
Description :
33
Author : Boxin Zhang, Azure-Tang
44
Version : 0.1.0
5-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66
"""
77

88
import os
9-
import platform
109
import sys
1110

1211
project_dir = os.path.dirname(os.path.dirname(__file__))
@@ -28,9 +27,9 @@
2827
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
2928
from ktransformers.models.modeling_llama import LlamaForCausalLM
3029
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
31-
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
30+
from ktransformers.util.utils import prefill_and_generate
31+
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_FLASHINFER
3232
from ktransformers.server.config.config import Config
33-
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
3433

3534
custom_models = {
3635
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
@@ -109,7 +108,7 @@ def local_chat(
109108
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
110109
)
111110
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
112-
111+
113112
try:
114113
model.generation_config = GenerationConfig.from_pretrained(model_path)
115114
except Exception as e:
@@ -126,8 +125,7 @@ def local_chat(
126125
model.eval()
127126
logging.basicConfig(level=logging.INFO)
128127

129-
system = platform.system()
130-
if system == "Windows":
128+
if os.name == 'nt':
131129
os.system("cls")
132130
else:
133131
os.system("clear")
@@ -155,7 +153,7 @@ def local_chat(
155153
content = "Please write a piece of quicksort code in C++."
156154
elif os.path.isfile(content):
157155
content = open(content, "r").read()
158-
156+
159157
messages = [{"role": "user", "content": content}]
160158
input_tensor = tokenizer.apply_chat_template(
161159
messages, add_generation_prompt=True, return_tensors="pt"
@@ -168,8 +166,8 @@ def local_chat(
168166
if mode == 'long_context':
169167
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
170168
"please change max_seq_len in ~/.ktransformers/config.yaml"
171-
172-
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
169+
170+
if KTRANSFORMERS_USE_FLASHINFER and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM"):
173171
generated = prefill_and_generate(
174172
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
175173
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim

ktransformers/operators/attention.py

+33-35
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
'''
2-
Description :
2+
Description :
33
Author : Boxin Zhang
44
Version : 0.1.0
5-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66
'''
77
import torch
88
from torch import nn
@@ -16,13 +16,12 @@
1616
from typing import Optional, Tuple
1717
from ktransformers.operators.base_operator import BaseInjectedModule
1818
from ktransformers.util.custom_gguf import GGUFLoader
19-
from ktransformers.util.utils import get_compute_capability
19+
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE, KTRANSFORMERS_USE_FLASHINFER
2020
import logging
2121
from transformers.configuration_utils import PretrainedConfig
2222
from transformers.cache_utils import Cache
2323
from flash_attn import flash_attn_func
2424
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
25-
import os
2625
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
2726
if flashinfer_enabled:
2827
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
@@ -63,7 +62,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
6362
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
6463
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
6564
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
66-
65+
6766
return self.q_absorb, self.out_absorb
6867

6968
def forward_chunck(
@@ -111,7 +110,7 @@ def forward_chunck(
111110

112111
if past_key_value is not None:
113112
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
114-
113+
115114
# compressed_kv [bsz, q_len, self.kv_lora_rank]
116115
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
117116
k_pe = k_pe.transpose(1,2)
@@ -122,7 +121,7 @@ def forward_chunck(
122121
)
123122
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
124123
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
125-
124+
126125
q_absorb, out_absorb = self.get_absorbed()
127126

128127
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
@@ -136,9 +135,9 @@ def forward_chunck(
136135
#print(k_pe.shape)
137136
#print(q_nope.shape)
138137
#print(compressed_kv.shape)
139-
138+
140139
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
141-
140+
142141
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
143142
compressed_kv = compressed_kv.squeeze(1)
144143
"""
@@ -166,10 +165,10 @@ def forward_chunck(
166165
attn_weights = nn.functional.dropout(
167166
attn_weights, p=self.attention_dropout, training=self.training
168167
)
169-
168+
170169
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
171-
172-
attn_output = torch.matmul(attn_output, out_absorb.mT)
170+
171+
attn_output = torch.matmul(attn_output, out_absorb.mT)
173172

174173
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
175174
raise ValueError(
@@ -178,7 +177,7 @@ def forward_chunck(
178177
)
179178

180179
attn_output = attn_output.transpose(1, 2).contiguous()
181-
180+
182181
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
183182

184183
attn_output = self.o_proj(attn_output)
@@ -225,11 +224,11 @@ def forward_linux_triton(
225224
"with a layer index."
226225
)
227226
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
228-
227+
229228
cos, sin = self.rotary_emb(q_pe, position_ids)
230229
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
231230
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
232-
231+
233232
# decode
234233
if q_len == 1:
235234
if past_key_value is not None:
@@ -246,20 +245,20 @@ def forward_linux_triton(
246245
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
247246
q_nope = q_nope.transpose(1, 2)
248247
#assert q_nope.is_contiguous()
249-
248+
250249
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
251250
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
252251
query_states = torch.cat([q_nope, q_pe], dim=-1)
253-
252+
254253
query_states = query_states.squeeze(1)
255254
attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
256-
255+
257256
attn_logits = torch.empty(
258257
(
259258
bsz,
260259
self.num_heads,
261260
4, #num_kv_splits # follow vLLM, fix it TODO
262-
self.kv_lora_rank + 1,
261+
self.kv_lora_rank + 1,
263262
),
264263
dtype=torch.float32,
265264
device = attn_output.device
@@ -280,16 +279,16 @@ def forward_linux_triton(
280279
4, #num_kv_splits # follow vLLM, fix it TODO
281280
self.softmax_scale,
282281
past_key_value.page_size)
283-
282+
284283
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
285284
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
286285
attn_output = attn_output.transpose(1, 2)
287286
attn_output = torch.matmul(attn_output, out_absorb.mT)
288287
attn_output = attn_output.transpose(1, 2)
289-
288+
290289
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
291290
attn_output = self.o_proj(attn_output)
292-
291+
293292
#print("attn_output", torch.isnan(attn_output).any())
294293
return attn_output, None, past_key_value
295294
else:
@@ -317,7 +316,7 @@ def forward_linux_triton(
317316
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
318317
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
319318
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
320-
319+
321320
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
322321
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
323322

@@ -378,11 +377,11 @@ def forward_linux_flashinfer(
378377
"with a layer index."
379378
)
380379
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
381-
380+
382381
cos, sin = self.rotary_emb(q_pe, position_ids)
383382
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
384383
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
385-
384+
386385
# decode
387386
if q_len == 1 or self.absorb_for_prefill:
388387
if past_key_value is not None:
@@ -401,7 +400,7 @@ def forward_linux_flashinfer(
401400
q_nope = q_nope.transpose(1, 2)
402401
q_nope = q_nope.contiguous()
403402
#assert q_nope.is_contiguous()
404-
403+
405404
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
406405
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
407406
q_nope.squeeze_(0)
@@ -454,17 +453,17 @@ def forward_linux_flashinfer(
454453
)
455454
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
456455
"""
457-
456+
458457
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
459458
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
460459
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
461460
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
462461
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
463462
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
464-
463+
465464
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
466465
attn_output = self.o_proj(attn_output)
467-
466+
468467
return attn_output, None, past_key_value
469468
else:
470469
if past_key_value is not None:
@@ -491,7 +490,7 @@ def forward_linux_flashinfer(
491490
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
492491
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
493492
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
494-
493+
495494
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
496495
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
497496

@@ -511,7 +510,7 @@ def forward_linux_flashinfer(
511510
).contiguous()
512511
attn_output = self.o_proj(attn_output)
513512
return attn_output, None, past_key_value
514-
513+
515514
def forward_windows(
516515
self,
517516
hidden_states: torch.Tensor,
@@ -575,7 +574,7 @@ def forward_windows(
575574
attn_output = cur_output
576575
else:
577576
attn_output = torch.cat((attn_output, cur_output), dim=-2)
578-
577+
579578
return attn_output, None, past_key_value
580579

581580
def forward(
@@ -589,8 +588,7 @@ def forward(
589588
cache_position: Optional[torch.LongTensor] = None,
590589
**kwargs,
591590
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
592-
if os.name == 'nt' or get_compute_capability()<8:
593-
print("for Windows or GPU before ampere, use forward_windows")
591+
if KTRANSFORMERS_USE_TORCH_NATIVE:
594592
return self.forward_windows(
595593
hidden_states,
596594
attention_mask,
@@ -602,7 +600,7 @@ def forward(
602600
**kwargs,
603601
)
604602
else:
605-
if flashinfer_enabled:
603+
if KTRANSFORMERS_USE_FLASHINFER:
606604
return self.forward_linux_flashinfer(
607605
hidden_states,
608606
attention_mask,

ktransformers/operators/models.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/usr/bin/env python
22
# coding=utf-8
33
"""
4-
Description :
4+
Description :
55
Author : Azure-Tang
66
Date : 2024-07-25 11:25:24
77
Version : 1.0.0
8-
LastEditors : Azure
8+
LastEditors : Azure
99
LastEditTime : 2024-08-27 07:29:04
10-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
10+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
1111
"""
1212

1313
import inspect
@@ -56,8 +56,9 @@
5656
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
5757
from ktransformers.models.configuration_llama import LlamaConfig
5858
from ktransformers.operators.base_operator import BaseInjectedModule
59-
from ktransformers.util.utils import InferenceState, get_compute_capability
59+
from ktransformers.util.utils import InferenceState
6060
from ktransformers.util.custom_gguf import GGUFLoader
61+
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE
6162
from transformers.configuration_utils import PretrainedConfig
6263
from ktransformers.models.modeling_llama import (
6364
LlamaDecoderLayer,
@@ -625,7 +626,7 @@ def forward(
625626
if use_legacy_cache:
626627
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
627628
past_key_values_length = past_key_values.get_usable_length(seq_length)
628-
629+
629630
if inputs_embeds is None:
630631
org_device = input_ids.device
631632
# TODO move to embed_tokens's device, not hard code to cpu
@@ -649,8 +650,7 @@ def forward(
649650
if per_layer_prefill_flag:
650651
causal_mask = None
651652
else:
652-
if os.name == 'nt' or get_compute_capability()<8:
653-
print("for Windows or GPU before ampere, use forward_windows")
653+
if KTRANSFORMERS_USE_TORCH_NATIVE:
654654
# only use mask in forward windows or can't flash attn
655655
causal_mask = self._update_causal_mask(
656656
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions

ktransformers/util/feature_gate.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
3+
from ktransformers.util.utils import get_compute_capability
4+
5+
# Feature gate default values
6+
KTRANSFORMERS_USE_TORCH_NATIVE = False
7+
KTRANSFORMERS_USE_FLASHINFER = False
8+
9+
if os.name == 'nt' or get_compute_capability() < 8:
10+
print("Using torch native for Windows or Nvidia GPUs before Ampere.")
11+
KTRANSFORMERS_USE_TORCH_NATIVE = True
12+
13+
if not KTRANSFORMERS_USE_TORCH_NATIVE and flashinfer_enabled:
14+
print("Using FlashInfer for Nvidia GPUs after Ampere.")
15+
KTRANSFORMERS_USE_FLASHINFER = True
16+
17+
print(f"Feature gate initialized: KTRANSFORMERS_USE_TORCH_NATIVE={KTRANSFORMERS_USE_TORCH_NATIVE}, KTRANSFORMERS_USE_FLASHINFER={KTRANSFORMERS_USE_FLASHINFER}")

0 commit comments

Comments
 (0)