Skip to content

feat: Add feature gate #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions ktransformers/local_chat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""
Description :
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""

import os
import platform
import sys

project_dir = os.path.dirname(os.path.dirname(__file__))
Expand All @@ -28,10 +27,9 @@
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.util.utils import prefill_and_generate
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_FLASHINFER
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor

custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
Expand Down Expand Up @@ -110,7 +108,7 @@ def local_chat(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)

try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except Exception as e:
Expand All @@ -127,8 +125,7 @@ def local_chat(
model.eval()
logging.basicConfig(level=logging.INFO)

system = platform.system()
if system == "Windows":
if os.name == 'nt':
os.system("cls")
else:
os.system("clear")
Expand Down Expand Up @@ -156,7 +153,7 @@ def local_chat(
content = "Please write a piece of quicksort code in C++."
elif os.path.isfile(content):
content = open(content, "r").read()

messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
Expand All @@ -169,8 +166,8 @@ def local_chat(
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:

if KTRANSFORMERS_USE_FLASHINFER and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM"):
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
Expand Down
69 changes: 34 additions & 35 deletions ktransformers/operators/attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
'''
Description :
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import torch
from torch import nn
Expand All @@ -16,17 +16,16 @@
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import get_compute_capability
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE, KTRANSFORMERS_USE_FLASHINFER
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor

try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
Expand Down Expand Up @@ -69,7 +68,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)

return self.q_absorb, self.out_absorb

def forward_chunck(
Expand Down Expand Up @@ -117,7 +116,7 @@ def forward_chunck(

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models

# compressed_kv [bsz, q_len, self.kv_lora_rank]
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
k_pe = k_pe.transpose(1,2)
Expand All @@ -128,7 +127,7 @@ def forward_chunck(
)
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]

q_absorb, out_absorb = self.get_absorbed()

# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
Expand All @@ -142,9 +141,9 @@ def forward_chunck(
#print(k_pe.shape)
#print(q_nope.shape)
#print(compressed_kv.shape)

attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale

#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
compressed_kv = compressed_kv.squeeze(1)
"""
Expand Down Expand Up @@ -172,10 +171,10 @@ def forward_chunck(
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)

attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.matmul(attn_output, out_absorb.mT)

attn_output = torch.matmul(attn_output, out_absorb.mT)

if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError(
Expand All @@ -184,7 +183,7 @@ def forward_chunck(
)

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)

attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -231,11 +230,11 @@ def forward_linux_triton(
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]

# decode
if q_len == 1:
if past_key_value is not None:
Expand All @@ -252,20 +251,20 @@ def forward_linux_triton(
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(1, 2)
#assert q_nope.is_contiguous()

# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
query_states = torch.cat([q_nope, q_pe], dim=-1)

query_states = query_states.squeeze(1)
attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]

attn_logits = torch.empty(
(
bsz,
self.num_heads,
4, #num_kv_splits # follow vLLM, fix it TODO
self.kv_lora_rank + 1,
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device = attn_output.device
Expand All @@ -286,16 +285,16 @@ def forward_linux_triton(
4, #num_kv_splits # follow vLLM, fix it TODO
self.softmax_scale,
past_key_value.page_size)

# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2)
attn_output = torch.matmul(attn_output, out_absorb.mT)
attn_output = attn_output.transpose(1, 2)

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)

#print("attn_output", torch.isnan(attn_output).any())
return attn_output, None, past_key_value
else:
Expand Down Expand Up @@ -323,7 +322,7 @@ def forward_linux_triton(
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)

value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)

Expand Down Expand Up @@ -384,11 +383,11 @@ def forward_linux_flashinfer(
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]

# decode
if q_len == 1 or self.absorb_for_prefill:
if past_key_value is not None:
Expand All @@ -407,7 +406,7 @@ def forward_linux_flashinfer(
q_nope = q_nope.transpose(1, 2)
q_nope = q_nope.contiguous()
#assert q_nope.is_contiguous()

# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope.squeeze_(0)
Expand Down Expand Up @@ -460,17 +459,17 @@ def forward_linux_flashinfer(
)
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
"""

# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
else:
if past_key_value is not None:
Expand All @@ -497,7 +496,7 @@ def forward_linux_flashinfer(
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)

value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)

Expand All @@ -517,7 +516,7 @@ def forward_linux_flashinfer(
).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

def forward_windows(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -581,7 +580,7 @@ def forward_windows(
attn_output = cur_output
else:
attn_output = torch.cat((attn_output, cur_output), dim=-2)

return attn_output, None, past_key_value

def forward(
Expand All @@ -595,7 +594,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
if KTRANSFORMERS_USE_TORCH_NATIVE:
return self.forward_windows(
hidden_states,
attention_mask,
Expand All @@ -607,7 +606,7 @@ def forward(
**kwargs,
)
else:
if flashinfer_enabled:
if KTRANSFORMERS_USE_FLASHINFER:
return self.forward_linux_flashinfer(
hidden_states,
attention_mask,
Expand Down
15 changes: 7 additions & 8 deletions ktransformers/operators/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Description :
Author : Azure-Tang
Date : 2024-07-25 11:25:24
Version : 1.0.0
LastEditors : Azure
LastEditors : Azure
LastEditTime : 2024-08-27 07:29:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""

import inspect
Expand Down Expand Up @@ -53,12 +53,12 @@
DeepseekV2DecoderLayer,
DeepseekV2MoE,
)
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
LlamaDecoderLayer,
Expand Down Expand Up @@ -626,7 +626,7 @@ def forward(
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if inputs_embeds is None:
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
Expand All @@ -650,8 +650,7 @@ def forward(
if per_layer_prefill_flag:
causal_mask = None
else:
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
# print("for Windows or GPU before ampere, use forward_windows")
if KTRANSFORMERS_USE_TORCH_NATIVE:
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
Expand Down
21 changes: 21 additions & 0 deletions ktransformers/util/feature_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.utils import get_compute_capability
from ktransformers.util.vendors import device_manager, GPUVendor

# Feature gate default values
KTRANSFORMERS_USE_TORCH_NATIVE = False
KTRANSFORMERS_USE_FLASHINFER = False

if os.name == "nt" or get_compute_capability() < 8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
print("Using torch native for Windows or Nvidia GPUs before Ampere.")
KTRANSFORMERS_USE_TORCH_NATIVE = True

if not KTRANSFORMERS_USE_TORCH_NATIVE and flashinfer_enabled:
print("Using FlashInfer for Nvidia GPUs after Ampere.")
KTRANSFORMERS_USE_FLASHINFER = True

print(
f"Feature gate initialized: KTRANSFORMERS_USE_TORCH_NATIVE={KTRANSFORMERS_USE_TORCH_NATIVE},"
f" KTRANSFORMERS_USE_FLASHINFER={KTRANSFORMERS_USE_FLASHINFER}"
)