Skip to content

Commit e0fe707

Browse files
committed
feat: Add feature gate
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 649489d commit e0fe707

File tree

4 files changed

+71
-55
lines changed

4 files changed

+71
-55
lines changed

ktransformers/local_chat.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""
2-
Description :
2+
Description :
33
Author : Boxin Zhang, Azure-Tang
44
Version : 0.1.0
5-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66
"""
77

88
import os
9-
import platform
109
import sys
1110

1211
project_dir = os.path.dirname(os.path.dirname(__file__))
@@ -28,10 +27,9 @@
2827
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
2928
from ktransformers.models.modeling_llama import LlamaForCausalLM
3029
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
31-
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
30+
from ktransformers.util.utils import prefill_and_generate
31+
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_FLASHINFER
3232
from ktransformers.server.config.config import Config
33-
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
34-
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
3533

3634
custom_models = {
3735
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
@@ -110,7 +108,7 @@ def local_chat(
110108
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
111109
)
112110
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
113-
111+
114112
try:
115113
model.generation_config = GenerationConfig.from_pretrained(model_path)
116114
except Exception as e:
@@ -127,8 +125,7 @@ def local_chat(
127125
model.eval()
128126
logging.basicConfig(level=logging.INFO)
129127

130-
system = platform.system()
131-
if system == "Windows":
128+
if os.name == 'nt':
132129
os.system("cls")
133130
else:
134131
os.system("clear")
@@ -156,7 +153,7 @@ def local_chat(
156153
content = "Please write a piece of quicksort code in C++."
157154
elif os.path.isfile(content):
158155
content = open(content, "r").read()
159-
156+
160157
messages = [{"role": "user", "content": content}]
161158
input_tensor = tokenizer.apply_chat_template(
162159
messages, add_generation_prompt=True, return_tensors="pt"
@@ -169,8 +166,8 @@ def local_chat(
169166
if mode == 'long_context':
170167
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
171168
"please change max_seq_len in ~/.ktransformers/config.yaml"
172-
173-
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
169+
170+
if KTRANSFORMERS_USE_FLASHINFER and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM"):
174171
generated = prefill_and_generate(
175172
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
176173
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim

ktransformers/operators/attention.py

+34-35
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
'''
2-
Description :
2+
Description :
33
Author : Boxin Zhang
44
Version : 0.1.0
5-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66
'''
77
import torch
88
from torch import nn
@@ -16,17 +16,16 @@
1616
from typing import Optional, Tuple
1717
from ktransformers.operators.base_operator import BaseInjectedModule
1818
from ktransformers.util.custom_gguf import GGUFLoader
19-
from ktransformers.util.utils import get_compute_capability
19+
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE, KTRANSFORMERS_USE_FLASHINFER
2020
import logging
2121
from transformers.configuration_utils import PretrainedConfig
2222
from transformers.cache_utils import Cache
23-
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
2423

2524
try:
2625
from flash_attn import flash_attn_func
2726
except:
2827
pass
29-
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
28+
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
3029
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
3130
import os
3231
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
@@ -69,7 +68,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
6968
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
7069
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
7170
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
72-
71+
7372
return self.q_absorb, self.out_absorb
7473

7574
def forward_chunck(
@@ -117,7 +116,7 @@ def forward_chunck(
117116

118117
if past_key_value is not None:
119118
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
120-
119+
121120
# compressed_kv [bsz, q_len, self.kv_lora_rank]
122121
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
123122
k_pe = k_pe.transpose(1,2)
@@ -128,7 +127,7 @@ def forward_chunck(
128127
)
129128
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
130129
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
131-
130+
132131
q_absorb, out_absorb = self.get_absorbed()
133132

134133
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
@@ -142,9 +141,9 @@ def forward_chunck(
142141
#print(k_pe.shape)
143142
#print(q_nope.shape)
144143
#print(compressed_kv.shape)
145-
144+
146145
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
147-
146+
148147
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
149148
compressed_kv = compressed_kv.squeeze(1)
150149
"""
@@ -172,10 +171,10 @@ def forward_chunck(
172171
attn_weights = nn.functional.dropout(
173172
attn_weights, p=self.attention_dropout, training=self.training
174173
)
175-
174+
176175
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
177-
178-
attn_output = torch.matmul(attn_output, out_absorb.mT)
176+
177+
attn_output = torch.matmul(attn_output, out_absorb.mT)
179178

180179
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
181180
raise ValueError(
@@ -184,7 +183,7 @@ def forward_chunck(
184183
)
185184

186185
attn_output = attn_output.transpose(1, 2).contiguous()
187-
186+
188187
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
189188

190189
attn_output = self.o_proj(attn_output)
@@ -231,11 +230,11 @@ def forward_linux_triton(
231230
"with a layer index."
232231
)
233232
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
234-
233+
235234
cos, sin = self.rotary_emb(q_pe, position_ids)
236235
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
237236
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
238-
237+
239238
# decode
240239
if q_len == 1:
241240
if past_key_value is not None:
@@ -252,20 +251,20 @@ def forward_linux_triton(
252251
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
253252
q_nope = q_nope.transpose(1, 2)
254253
#assert q_nope.is_contiguous()
255-
254+
256255
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
257256
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
258257
query_states = torch.cat([q_nope, q_pe], dim=-1)
259-
258+
260259
query_states = query_states.squeeze(1)
261260
attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
262-
261+
263262
attn_logits = torch.empty(
264263
(
265264
bsz,
266265
self.num_heads,
267266
4, #num_kv_splits # follow vLLM, fix it TODO
268-
self.kv_lora_rank + 1,
267+
self.kv_lora_rank + 1,
269268
),
270269
dtype=torch.float32,
271270
device = attn_output.device
@@ -286,16 +285,16 @@ def forward_linux_triton(
286285
4, #num_kv_splits # follow vLLM, fix it TODO
287286
self.softmax_scale,
288287
past_key_value.page_size)
289-
288+
290289
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
291290
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
292291
attn_output = attn_output.transpose(1, 2)
293292
attn_output = torch.matmul(attn_output, out_absorb.mT)
294293
attn_output = attn_output.transpose(1, 2)
295-
294+
296295
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
297296
attn_output = self.o_proj(attn_output)
298-
297+
299298
#print("attn_output", torch.isnan(attn_output).any())
300299
return attn_output, None, past_key_value
301300
else:
@@ -323,7 +322,7 @@ def forward_linux_triton(
323322
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
324323
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
325324
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
326-
325+
327326
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
328327
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
329328

@@ -384,11 +383,11 @@ def forward_linux_flashinfer(
384383
"with a layer index."
385384
)
386385
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
387-
386+
388387
cos, sin = self.rotary_emb(q_pe, position_ids)
389388
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
390389
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
391-
390+
392391
# decode
393392
if q_len == 1 or self.absorb_for_prefill:
394393
if past_key_value is not None:
@@ -407,7 +406,7 @@ def forward_linux_flashinfer(
407406
q_nope = q_nope.transpose(1, 2)
408407
q_nope = q_nope.contiguous()
409408
#assert q_nope.is_contiguous()
410-
409+
411410
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
412411
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
413412
q_nope.squeeze_(0)
@@ -460,17 +459,17 @@ def forward_linux_flashinfer(
460459
)
461460
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
462461
"""
463-
462+
464463
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
465464
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
466465
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
467466
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
468467
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
469468
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
470-
469+
471470
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
472471
attn_output = self.o_proj(attn_output)
473-
472+
474473
return attn_output, None, past_key_value
475474
else:
476475
if past_key_value is not None:
@@ -497,7 +496,7 @@ def forward_linux_flashinfer(
497496
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
498497
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
499498
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
500-
499+
501500
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
502501
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
503502

@@ -517,7 +516,7 @@ def forward_linux_flashinfer(
517516
).contiguous()
518517
attn_output = self.o_proj(attn_output)
519518
return attn_output, None, past_key_value
520-
519+
521520
def forward_windows(
522521
self,
523522
hidden_states: torch.Tensor,
@@ -581,7 +580,7 @@ def forward_windows(
581580
attn_output = cur_output
582581
else:
583582
attn_output = torch.cat((attn_output, cur_output), dim=-2)
584-
583+
585584
return attn_output, None, past_key_value
586585

587586
def forward(
@@ -595,7 +594,7 @@ def forward(
595594
cache_position: Optional[torch.LongTensor] = None,
596595
**kwargs,
597596
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
598-
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
597+
if KTRANSFORMERS_USE_TORCH_NATIVE:
599598
return self.forward_windows(
600599
hidden_states,
601600
attention_mask,
@@ -607,7 +606,7 @@ def forward(
607606
**kwargs,
608607
)
609608
else:
610-
if flashinfer_enabled:
609+
if KTRANSFORMERS_USE_FLASHINFER:
611610
return self.forward_linux_flashinfer(
612611
hidden_states,
613612
attention_mask,

ktransformers/operators/models.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/usr/bin/env python
22
# coding=utf-8
33
"""
4-
Description :
4+
Description :
55
Author : Azure-Tang
66
Date : 2024-07-25 11:25:24
77
Version : 1.0.0
8-
LastEditors : Azure
8+
LastEditors : Azure
99
LastEditTime : 2024-08-27 07:29:04
10-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
10+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
1111
"""
1212

1313
import inspect
@@ -53,12 +53,12 @@
5353
DeepseekV2DecoderLayer,
5454
DeepseekV2MoE,
5555
)
56-
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
5756
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
5857
from ktransformers.models.configuration_llama import LlamaConfig
5958
from ktransformers.operators.base_operator import BaseInjectedModule
60-
from ktransformers.util.utils import InferenceState, get_compute_capability
59+
from ktransformers.util.utils import InferenceState
6160
from ktransformers.util.custom_gguf import GGUFLoader
61+
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE
6262
from transformers.configuration_utils import PretrainedConfig
6363
from ktransformers.models.modeling_llama import (
6464
LlamaDecoderLayer,
@@ -626,7 +626,7 @@ def forward(
626626
if use_legacy_cache:
627627
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
628628
past_key_values_length = past_key_values.get_usable_length(seq_length)
629-
629+
630630
if inputs_embeds is None:
631631
org_device = input_ids.device
632632
# TODO move to embed_tokens's device, not hard code to cpu
@@ -650,8 +650,7 @@ def forward(
650650
if per_layer_prefill_flag:
651651
causal_mask = None
652652
else:
653-
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
654-
# print("for Windows or GPU before ampere, use forward_windows")
653+
if KTRANSFORMERS_USE_TORCH_NATIVE:
655654
# only use mask in forward windows or can't flash attn
656655
causal_mask = self._update_causal_mask(
657656
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions

ktransformers/util/feature_gate.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
3+
from ktransformers.util.utils import get_compute_capability
4+
from ktransformers.util.vendors import device_manager, GPUVendor
5+
6+
# Feature gate default values
7+
KTRANSFORMERS_USE_TORCH_NATIVE = False
8+
KTRANSFORMERS_USE_FLASHINFER = False
9+
10+
if os.name == "nt" or get_compute_capability() < 8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
11+
print("Using torch native for Windows or Nvidia GPUs before Ampere.")
12+
KTRANSFORMERS_USE_TORCH_NATIVE = True
13+
14+
if not KTRANSFORMERS_USE_TORCH_NATIVE and flashinfer_enabled:
15+
print("Using FlashInfer for Nvidia GPUs after Ampere.")
16+
KTRANSFORMERS_USE_FLASHINFER = True
17+
18+
print(
19+
f"Feature gate initialized: KTRANSFORMERS_USE_TORCH_NATIVE={KTRANSFORMERS_USE_TORCH_NATIVE},"
20+
f" KTRANSFORMERS_USE_FLASHINFER={KTRANSFORMERS_USE_FLASHINFER}"
21+
)

0 commit comments

Comments
 (0)