diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 1a5f55f9..65e739dc 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -1,12 +1,11 @@ """ -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import os -import platform import sys project_dir = os.path.dirname(os.path.dirname(__file__)) @@ -28,10 +27,9 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM -from ktransformers.util.utils import prefill_and_generate, get_compute_capability +from ktransformers.util.utils import prefill_and_generate +from ktransformers.util.feature_gate import KTRANSFORMERS_USE_FLASHINFER from ktransformers.server.config.config import Config -from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled -from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, @@ -110,7 +108,7 @@ def local_chat( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) - + try: model.generation_config = GenerationConfig.from_pretrained(model_path) except Exception as e: @@ -127,8 +125,7 @@ def local_chat( model.eval() logging.basicConfig(level=logging.INFO) - system = platform.system() - if system == "Windows": + if os.name == 'nt': os.system("cls") else: os.system("clear") @@ -156,7 +153,7 @@ def local_chat( content = "Please write a piece of quicksort code in C++." elif os.path.isfile(content): content = open(content, "r").read() - + messages = [{"role": "user", "content": content}] input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" @@ -169,8 +166,8 @@ def local_chat( if mode == 'long_context': assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - - if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: + + if KTRANSFORMERS_USE_FLASHINFER and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM"): generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index db65f343..b362c381 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -1,8 +1,8 @@ ''' -Description : +Description : Author : Boxin Zhang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch from torch import nn @@ -16,17 +16,16 @@ from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader -from ktransformers.util.utils import get_compute_capability +from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE, KTRANSFORMERS_USE_FLASHINFER import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache -from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor try: from flash_attn import flash_attn_func except: pass -from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +from ktransformers.operators.triton_attention import decode_attention_fwd_grouped from ktransformers.operators.triton_attention_prefill import context_attention_fwd import os from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled @@ -69,7 +68,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank) - + return self.q_absorb, self.out_absorb def forward_chunck( @@ -117,7 +116,7 @@ def forward_chunck( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - + # compressed_kv [bsz, q_len, self.kv_lora_rank] # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] k_pe = k_pe.transpose(1,2) @@ -128,7 +127,7 @@ def forward_chunck( ) # k_pe [pages, page_size, 1, self.qk_rope_head_dim] # compressed_kv [pages, page_size, 1, self.kv_lora_rank] - + q_absorb, out_absorb = self.get_absorbed() # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] @@ -142,9 +141,9 @@ def forward_chunck( #print(k_pe.shape) #print(q_nope.shape) #print(compressed_kv.shape) - + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale - + #attn_weights [bsz, self.num_heads, q_len, kv_seq_len] compressed_kv = compressed_kv.squeeze(1) """ @@ -172,10 +171,10 @@ def forward_chunck( attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) - + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) - - attn_output = torch.matmul(attn_output, out_absorb.mT) + + attn_output = torch.matmul(attn_output, out_absorb.mT) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( @@ -184,7 +183,7 @@ def forward_chunck( ) attn_output = attn_output.transpose(1, 2).contiguous() - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -231,11 +230,11 @@ def forward_linux_triton( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - + cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] - + # decode if q_len == 1: if past_key_value is not None: @@ -252,20 +251,20 @@ def forward_linux_triton( q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(1, 2) #assert q_nope.is_contiguous() - + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] query_states = torch.cat([q_nope, q_pe], dim=-1) - + query_states = query_states.squeeze(1) attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank] - + attn_logits = torch.empty( ( bsz, self.num_heads, 4, #num_kv_splits # follow vLLM, fix it TODO - self.kv_lora_rank + 1, + self.kv_lora_rank + 1, ), dtype=torch.float32, device = attn_output.device @@ -286,16 +285,16 @@ def forward_linux_triton( 4, #num_kv_splits # follow vLLM, fix it TODO self.softmax_scale, past_key_value.page_size) - + # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) attn_output = torch.matmul(attn_output, out_absorb.mT) attn_output = attn_output.transpose(1, 2) - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - + #print("attn_output", torch.isnan(attn_output).any()) return attn_output, None, past_key_value else: @@ -323,7 +322,7 @@ def forward_linux_triton( key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) @@ -384,11 +383,11 @@ def forward_linux_flashinfer( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - + cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] - + # decode if q_len == 1 or self.absorb_for_prefill: if past_key_value is not None: @@ -407,7 +406,7 @@ def forward_linux_flashinfer( q_nope = q_nope.transpose(1, 2) q_nope = q_nope.contiguous() #assert q_nope.is_contiguous() - + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] q_nope.squeeze_(0) @@ -460,17 +459,17 @@ def forward_linux_flashinfer( ) attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank) """ - + # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank] - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) - + return attn_output, None, past_key_value else: if past_key_value is not None: @@ -497,7 +496,7 @@ def forward_linux_flashinfer( key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) @@ -517,7 +516,7 @@ def forward_linux_flashinfer( ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - + def forward_windows( self, hidden_states: torch.Tensor, @@ -581,7 +580,7 @@ def forward_windows( attn_output = cur_output else: attn_output = torch.cat((attn_output, cur_output), dim=-2) - + return attn_output, None, past_key_value def forward( @@ -595,7 +594,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: + if KTRANSFORMERS_USE_TORCH_NATIVE: return self.forward_windows( hidden_states, attention_mask, @@ -607,7 +606,7 @@ def forward( **kwargs, ) else: - if flashinfer_enabled: + if KTRANSFORMERS_USE_FLASHINFER: return self.forward_linux_flashinfer( hidden_states, attention_mask, diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index bbac29a3..e8aa3220 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # coding=utf-8 """ -Description : +Description : Author : Azure-Tang Date : 2024-07-25 11:25:24 Version : 1.0.0 -LastEditors : Azure +LastEditors : Azure LastEditTime : 2024-08-27 07:29:04 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import inspect @@ -53,12 +53,12 @@ DeepseekV2DecoderLayer, DeepseekV2MoE, ) -from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.utils import InferenceState, get_compute_capability +from ktransformers.util.utils import InferenceState from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_llama import ( LlamaDecoderLayer, @@ -626,7 +626,7 @@ def forward( if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) - + if inputs_embeds is None: org_device = input_ids.device # TODO move to embed_tokens's device, not hard code to cpu @@ -650,8 +650,7 @@ def forward( if per_layer_prefill_flag: causal_mask = None else: - if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: - # print("for Windows or GPU before ampere, use forward_windows") + if KTRANSFORMERS_USE_TORCH_NATIVE: # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/ktransformers/util/feature_gate.py b/ktransformers/util/feature_gate.py new file mode 100644 index 00000000..07df2970 --- /dev/null +++ b/ktransformers/util/feature_gate.py @@ -0,0 +1,21 @@ +import os +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled +from ktransformers.util.utils import get_compute_capability +from ktransformers.util.vendors import device_manager, GPUVendor + +# Feature gate default values +KTRANSFORMERS_USE_TORCH_NATIVE = False +KTRANSFORMERS_USE_FLASHINFER = False + +if os.name == "nt" or get_compute_capability() < 8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: + print("Using torch native for Windows or Nvidia GPUs before Ampere.") + KTRANSFORMERS_USE_TORCH_NATIVE = True + +if not KTRANSFORMERS_USE_TORCH_NATIVE and flashinfer_enabled: + print("Using FlashInfer for Nvidia GPUs after Ampere.") + KTRANSFORMERS_USE_FLASHINFER = True + +print( + f"Feature gate initialized: KTRANSFORMERS_USE_TORCH_NATIVE={KTRANSFORMERS_USE_TORCH_NATIVE}," + f" KTRANSFORMERS_USE_FLASHINFER={KTRANSFORMERS_USE_FLASHINFER}" +)