diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index e14a5214..ef35dad5 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -1,6 +1,6 @@ # coding=utf-8 ''' -Description : +Description : Author : Boxin Zhang Version : 0.1.0 ''' @@ -8,7 +8,7 @@ # https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/modeling_deepseek.py # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. -# +# # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared @@ -31,6 +31,7 @@ from typing import List, Optional, Tuple, Union import torch +from ktransformers.util.torch_auto_backend import CUDA import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -145,7 +146,7 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): @@ -322,7 +323,7 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos()* self._mscale sin = emb.sin()* self._mscale - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): @@ -1112,7 +1113,7 @@ def _flash_attention_forward( cache_seqlens=position_ids, softmax_scale=softmax_scale, causal=causal, - ) + ) else: attn_output = flash_attn_func( query_states, @@ -1557,7 +1558,7 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, @@ -1629,7 +1630,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type == CUDA and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index 952eed7b..5a34938c 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -23,6 +23,7 @@ from typing import List, Optional, Tuple, Union import torch +from ktransformers.util.torch_auto_backend import CUDA import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -1587,7 +1588,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type == CUDA and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/ktransformers/models/modeling_llama.py b/ktransformers/models/modeling_llama.py index 5271ed5e..69ecb5bd 100644 --- a/ktransformers/models/modeling_llama.py +++ b/ktransformers/models/modeling_llama.py @@ -21,6 +21,7 @@ from typing import List, Optional, Tuple, Union import torch +from ktransformers.util.torch_auto_backend import CUDA import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -709,7 +710,7 @@ def forward( # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == CUDA and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -1220,7 +1221,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type == CUDA and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/ktransformers/models/modeling_mixtral.py b/ktransformers/models/modeling_mixtral.py index 87d8cf1a..9f48c7f7 100644 --- a/ktransformers/models/modeling_mixtral.py +++ b/ktransformers/models/modeling_mixtral.py @@ -1,6 +1,6 @@ # coding=utf-8 ''' -Description : +Description : Author : kkk1nak0 Date : 2024-07-29 02:58:57 Version : 1.0.0 @@ -8,7 +8,7 @@ LastEditTime : 2024-08-02 06:08:34 ''' -# Adapted from +# Adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. @@ -31,11 +31,12 @@ # limitations under the License. """PyTorch Mixtral model.""" -import inspect +import inspect import math from typing import List, Optional, Tuple, Union import torch +from ktransformers.util.torch_auto_backend import CUDA import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -201,7 +202,7 @@ def extra_repr(self): class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - + self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -544,7 +545,7 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - + def _flash_attention_forward( self, @@ -575,9 +576,9 @@ def _flash_attention_forward( position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout - + """ - + # Decide whether to use SWA or not by layer index. # if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: # use_sliding_windows = False @@ -633,7 +634,7 @@ def _flash_attention_forward( cache_seqlens=position_ids, softmax_scale=softmax_scale, causal=is_causal, - ) + ) else: attn_output = flash_attn_func( query_states, @@ -766,7 +767,7 @@ def forward( # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == CUDA and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -1323,7 +1324,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type == CUDA and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/ktransformers/models/modeling_qwen2_moe.py b/ktransformers/models/modeling_qwen2_moe.py index 4c66cefe..14761711 100644 --- a/ktransformers/models/modeling_qwen2_moe.py +++ b/ktransformers/models/modeling_qwen2_moe.py @@ -1,14 +1,14 @@ # coding=utf-8 ''' -Description : +Description : Author : Boxin Zhang Version : 0.1.0 -''' +''' # Adapted from # https://github.com/huggingface/transformers/blob/v4.42.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. -# +# # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared @@ -32,6 +32,7 @@ from typing import List, Optional, Tuple, Union import torch +from ktransformers.util.torch_auto_backend import CUDA import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -636,7 +637,7 @@ def _flash_attention_forward( cache_seqlens=position_ids, softmax_scale=softmax_scale, causal=causal, - ) + ) else: attn_output = flash_attn_func( query_states, @@ -766,7 +767,7 @@ def forward( # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == CUDA and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -1314,7 +1315,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type == CUDA and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index adc1c5f5..46315d7d 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -1,8 +1,8 @@ """ -Description : +Description : Author : Boxin Zhang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ from torch import nn @@ -27,6 +27,7 @@ from ktransformers.util.utils import InferenceState from transformers.configuration_utils import PretrainedConfig import torch +from ktransformers.util.torch_auto_backend import CUDA # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): @@ -37,8 +38,8 @@ def __init__( config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", - generate_device: str = "cuda", - prefill_device: str = "cuda", + generate_device: str = CUDA, + prefill_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__( @@ -67,8 +68,8 @@ def __init__( config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", - generate_device: str = "cuda", - prefill_device: str = "cuda", + generate_device: str = CUDA, + prefill_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__( @@ -76,7 +77,7 @@ def __init__( ) self.generate_device = generate_device self.prefill_device = prefill_device - + @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] @@ -91,7 +92,7 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def load(self): self._init( @@ -117,8 +118,8 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - generate_device: str = "cuda", - prefill_device: str = "cuda", + generate_device: str = CUDA, + prefill_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__( @@ -155,8 +156,8 @@ def __init__( config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", - generate_device: str = "cuda", - prefill_device: str = "cuda", + generate_device: str = CUDA, + prefill_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__( @@ -225,8 +226,8 @@ def __init__( config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", - generate_device: str = "cuda", - prefill_device: str = "cuda", + generate_device: str = CUDA, + prefill_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__( @@ -234,7 +235,7 @@ def __init__( ) self.generate_device = generate_device self.prefill_device = prefill_device - + def load(self): kwargs = { key: self.config.rope_scaling[key] @@ -270,7 +271,7 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos()* self._mscale sin = emb.sin()* self._mscale - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def _init( self, @@ -332,8 +333,8 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - prefill_device: str = "cuda", - generate_device: str = "cuda", + prefill_device: str = CUDA, + generate_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__( diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index a9bbea6f..cc5e77d2 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -1,10 +1,11 @@ ''' -Description : +Description : Author : Boxin Zhang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch +from ktransformers.util.torch_auto_backend import CUDA from torch import nn import warnings import torch.nn.functional as F @@ -46,8 +47,8 @@ def __init__(self, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - prefill_device: str = "cuda", - generate_device: str = "cuda", + prefill_device: str = CUDA, + generate_device: str = CUDA, chunck_size: int = 1000, absorb_for_prefill: bool = False, **kwargs): @@ -63,7 +64,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank) - + return self.q_absorb, self.out_absorb def forward_chunck( @@ -111,7 +112,7 @@ def forward_chunck( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - + # compressed_kv [bsz, q_len, self.kv_lora_rank] # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] k_pe = k_pe.transpose(1,2) @@ -122,7 +123,7 @@ def forward_chunck( ) # k_pe [pages, page_size, 1, self.qk_rope_head_dim] # compressed_kv [pages, page_size, 1, self.kv_lora_rank] - + q_absorb, out_absorb = self.get_absorbed() # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] @@ -136,9 +137,9 @@ def forward_chunck( #print(k_pe.shape) #print(q_nope.shape) #print(compressed_kv.shape) - + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale - + #attn_weights [bsz, self.num_heads, q_len, kv_seq_len] compressed_kv = compressed_kv.squeeze(1) """ @@ -166,10 +167,10 @@ def forward_chunck( attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) - + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) - - attn_output = torch.matmul(attn_output, out_absorb.mT) + + attn_output = torch.matmul(attn_output, out_absorb.mT) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( @@ -178,7 +179,7 @@ def forward_chunck( ) attn_output = attn_output.transpose(1, 2).contiguous() - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -225,11 +226,11 @@ def forward_linux_triton( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - + cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] - + # decode if q_len == 1: if past_key_value is not None: @@ -246,20 +247,20 @@ def forward_linux_triton( q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(1, 2) #assert q_nope.is_contiguous() - + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] query_states = torch.cat([q_nope, q_pe], dim=-1) - + query_states = query_states.squeeze(1) attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank] - + attn_logits = torch.empty( ( bsz, self.num_heads, 4, #num_kv_splits # follow vLLM, fix it TODO - self.kv_lora_rank + 1, + self.kv_lora_rank + 1, ), dtype=torch.float32, device = attn_output.device @@ -280,16 +281,16 @@ def forward_linux_triton( 4, #num_kv_splits # follow vLLM, fix it TODO self.softmax_scale, past_key_value.page_size) - + # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) attn_output = torch.matmul(attn_output, out_absorb.mT) attn_output = attn_output.transpose(1, 2) - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) - + #print("attn_output", torch.isnan(attn_output).any()) return attn_output, None, past_key_value else: @@ -317,7 +318,7 @@ def forward_linux_triton( key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) @@ -378,11 +379,11 @@ def forward_linux_flashinfer( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - + cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] - + # decode if q_len == 1 or self.absorb_for_prefill: if past_key_value is not None: @@ -401,7 +402,7 @@ def forward_linux_flashinfer( q_nope = q_nope.transpose(1, 2) q_nope = q_nope.contiguous() #assert q_nope.is_contiguous() - + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] q_nope.squeeze_(0) @@ -454,17 +455,17 @@ def forward_linux_flashinfer( ) attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank) """ - + # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank] - + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) - + return attn_output, None, past_key_value else: if past_key_value is not None: @@ -491,7 +492,7 @@ def forward_linux_flashinfer( key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) @@ -511,7 +512,7 @@ def forward_linux_flashinfer( ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - + def forward_windows( self, hidden_states: torch.Tensor, @@ -575,7 +576,7 @@ def forward_windows( attn_output = cur_output else: attn_output = torch.cat((attn_output, cur_output), dim=-2) - + return attn_output, None, past_key_value def forward( @@ -634,8 +635,8 @@ def __init__(self, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - prefill_device: str = "cuda", - generate_device: str = "cuda", + prefill_device: str = CUDA, + generate_device: str = CUDA, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, diff --git a/ktransformers/operators/base_operator.py b/ktransformers/operators/base_operator.py index 0fa2efd2..c244582b 100644 --- a/ktransformers/operators/base_operator.py +++ b/ktransformers/operators/base_operator.py @@ -1,23 +1,24 @@ ''' -Description : +Description : Author : Boxin Zhang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Any from torch import nn, Tensor +from ktransformers.util.torch_auto_backend import CUDA from ktransformers.util.custom_gguf import GGUFLoader from transformers.configuration_utils import PretrainedConfig import ktransformers.util.utils as utils class BaseInjectedModule(nn.Module): - + def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - prefill_device: str = "cuda", - generate_device: str = "cuda", + prefill_device: str = CUDA, + generate_device: str = CUDA, **kwargs): nn.Module.__init__(self) nn.Module.__setattr__(self, "orig_module", orig_module) @@ -27,10 +28,10 @@ def __init__(self, object.__setattr__(self, "prefill_device", prefill_device) object.__setattr__(self, "generate_device", generate_device) object.__setattr__(self, "device", generate_device) - + def __getattr__(self, name: str) -> Any: # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__, - # but __setattr__ in nn.Module call super().__setattr__ in that case, there may be some attribute set + # but __setattr__ in nn.Module call super().__setattr__ in that case, there may be some attribute set # but can't get using __getattr__, typically these attr is build in attr of the class, so class.attr does not # call __getattr__. # Example: @@ -54,10 +55,10 @@ def __setattr__(self, name: str, value: Tensor | nn.Module) -> None: elif hasattr(self, name): return object.__setattr__(self, name, value) return nn.Module.__getattr__(self, "orig_module").__setattr__(name, value) - + def forward(self, *args, **kwargs): return self.orig_module.forward(*args, **kwargs) - + def load(self): for name, child in self._modules.items(): utils.load_weights(child, self.gguf_loader, self.key+".") diff --git a/ktransformers/operators/dynamic_attention.py b/ktransformers/operators/dynamic_attention.py index 13a74b43..2000ad44 100644 --- a/ktransformers/operators/dynamic_attention.py +++ b/ktransformers/operators/dynamic_attention.py @@ -1,16 +1,17 @@ #!/usr/bin/env python # coding=utf-8 """ -Description : +Description : Author : Jianwei Dong Date : 2024-08-26 23:25:24 Version : 1.0.0 LastEditors : Jianwei Dong LastEditTime : 2024-08-26 23:25:24 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import torch +from ktransformers.util.torch_auto_backend import CUDA from transformers import AutoConfig import sys, os import logging @@ -238,7 +239,7 @@ def get_attn_score_one_block( n_rep = self.q_head_num // self.kv_head_num importance = self.cache_importance.view(-1, self.q_head_num) importance = importance.narrow(0, batch_idx * max_block_num + offset, width) - n_gqa_ = self.q_head_num // self.kv_head_num + n_gqa_ = self.q_head_num // self.kv_head_num for head_idx in range(self.q_head_num): key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1) qk = torch.einsum( @@ -258,7 +259,7 @@ def get_attn_score_one_block( qk = torch.nn.functional.softmax( qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32 ).to(torch.float16) - + qk = torch.sum(qk, dim=-2) importance[...,head_idx] += qk @@ -669,7 +670,7 @@ def apply( if layer_idx < self.dense_layer_num: self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True) self.cpu_infer.submit_with_cuda_stream( - torch.cuda.current_stream("cuda").cuda_stream, + torch.cuda.current_stream(CUDA).cuda_stream, self.local_thread.attn_with_kvcache( q_in=self.q_in_cpu, k_in=self.k_in_cpu, @@ -706,7 +707,7 @@ def apply( ) # print("submit_with_cuda_stream") self.cpu_infer.submit_with_cuda_stream( - torch.cuda.current_stream("cuda").cuda_stream, + torch.cuda.current_stream(CUDA).cuda_stream, self.local_thread.attn_with_kvcache( q_in=self.q_in_cpu, k_in=self.k_in_cpu, @@ -731,7 +732,7 @@ def apply( self.prefix_block_table, non_blocking=True ) self.cpu_infer.submit_with_cuda_stream( - torch.cuda.current_stream("cuda").cuda_stream, + torch.cuda.current_stream(CUDA).cuda_stream, self.local_thread.attn_with_kvcache( q_in=self.q_in_cpu, k_in=self.k_in_cpu, @@ -747,7 +748,7 @@ def apply( ), ) self.cpu_infer.sync_with_cuda_stream( - torch.cuda.current_stream("cuda").cuda_stream + torch.cuda.current_stream(CUDA).cuda_stream ) # print("submit_with_cuda_stream finished\n") self.output_cuda.copy_(self.output_cpu, non_blocking=True) diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 88960c70..478b5b45 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # coding=utf-8 ''' -Description : +Description : Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-25 11:25:24 Version : 0.1.0 -LastEditors : Azure +LastEditors : Azure LastEditTime : 2024-08-29 09:41:10 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Any, Union @@ -16,6 +16,7 @@ from torch import Tensor, nn import torch.nn.functional as F import torch +from ktransformers.util.torch_auto_backend import CUDA import sys, os from ktransformers.operators.base_operator import BaseInjectedModule from tqdm import tqdm @@ -39,13 +40,13 @@ # class Base(BaseInjectedModule, ABC): class KExpertsBase(ABC): - def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): + def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = CUDA, **kwargs): # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.key = key self.gguf_loader = gguf_loader self.config = config self.device = device - + @abstractmethod def forward(self, input_tensor, expert_ids, weights): pass @@ -53,7 +54,7 @@ def forward(self, input_tensor, expert_ids, weights): @abstractmethod def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False): pass - + @abstractmethod def unload(): pass @@ -83,7 +84,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info: - # for supporting Mixtral-8x7B-Instuct + # for supporting Mixtral-8x7B-Instuct gate = [] up = [] down = [] @@ -105,7 +106,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} return res - + def load_multi(self, key: str, keys: list[str], device: str = "cpu"): tensors = {} for k in keys: @@ -129,7 +130,7 @@ def __init__( n_routed_experts: int, orig_module: nn.Module = None, device: str = "cpu", - out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu. + out_device: str = CUDA, # this device mean which device the output should on. TODO: support cpu. **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) @@ -192,13 +193,13 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = N KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) - + def submit_for_one_decode(self, input_tensor, expert_ids, weights): KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr())) - + def sync_for_one_decode(self): self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) @@ -225,7 +226,7 @@ def forward(self, input_tensor, expert_ids, weights): self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr())) self.cpu_infer.sync() return output.to(device=object.__getattribute__(self, "out_device")) - + def unload(self): return @@ -253,7 +254,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item() up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item() down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item() - + elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") @@ -262,7 +263,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info: - # for supporting Mixtral-8x7B-Instuct + # for supporting Mixtral-8x7B-Instuct gate = [] up = [] down = [] @@ -283,7 +284,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} return res - + class KExpertsMarlin(KExpertsBase): expert_num: int loaded_experts_idx: list[int] @@ -294,7 +295,7 @@ def __init__( config: PretrainedConfig, n_routed_experts: int, orig_module: nn.Module = None, - device: str = "cuda", + device: str = CUDA, **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) @@ -329,7 +330,7 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device) gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device) down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device) - + self.up_projs[i].load(nn.Parameter(up_weights), device=device) self.gate_projs[i].load(nn.Parameter(gate_weights), device=device) self.down_projs[i].load(nn.Parameter(down_weights), device=device) @@ -344,7 +345,7 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) self.loaded_experts_idx.append(i) - return + return def unload(self): for i in self.loaded_experts_idx: @@ -378,7 +379,7 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T hidden_states_cpu = hidden_states_cpu.to(self.device) selected_experts_cpu = selected_experts_cpu.to(self.device) routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype) - + batch_sequence_length, hidden_dim = hidden_states_cpu.size() final_hidden_states = torch.zeros( @@ -405,9 +406,9 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states) - + return final_hidden_states.to(dtype=org_dtype, device=org_device) - + # untested, CUDA OOM class KExpertsTorch(KExpertsBase): expert_num: int @@ -448,7 +449,7 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device) gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device) down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device) - + self.up[i] = up_weights self.gate[i] = gate_weights self.down[i] = down_weights @@ -458,11 +459,11 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype) self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - + self.up = torch.stack(self.up, dim=0) self.gate = torch.stack(self.gate, dim=0) self.down = torch.stack(self.down, dim=0) - return + return def unload(self): if self.gate is not None: @@ -495,7 +496,7 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T hidden_states_cpu = hidden_states_cpu.to(self.device) selected_experts_cpu = selected_experts_cpu.to(self.device) routing_weights_cpu = routing_weights_cpu.to(self.device) - + batch_sequence_length, hidden_dim = hidden_states_cpu.size() final_hidden_states = torch.zeros( @@ -540,7 +541,7 @@ def __init__(self, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", - prefill_device:str = "cuda", + prefill_device:str = CUDA, prefill_op: str | None = "KExpertsTorch", generate_device: str = "cpu", generate_op: str | None = "KExpertsCPU", @@ -628,7 +629,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) shared_expert_output = self.shared_expert(hidden_states) @@ -637,7 +638,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: y += shared_expert_output y.resize_(*orig_shape) return y, router_logits - + hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu() selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu() routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu() @@ -666,7 +667,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: y += shared_expert_output y.resize_(*orig_shape) return y, router_logits - + @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) @@ -685,11 +686,11 @@ def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] return outs - + @torch.no_grad() # TODO may bugs here def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: - + batch_size, sequence_length, hidden_dim = orig_shape final_hidden_states = torch.zeros( @@ -724,7 +725,7 @@ def forward(self, hidden_states): sequence_length = orig_shape[1] topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: @@ -736,7 +737,7 @@ def forward(self, hidden_states): if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) - + if isinstance(self.experts, KExpertsBase): y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: @@ -816,14 +817,14 @@ def moe_infer(self, x, topk_ids, topk_weight): return final_out class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): - + def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - + # only for generate phase if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) @@ -836,7 +837,7 @@ def forward(self, hidden_states): if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) - + if isinstance(self.experts, KExpertsBase): y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: @@ -916,7 +917,7 @@ def moe_infer(self, x, topk_ids, topk_weight): return final_out class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ orig_shape = hidden_states.shape @@ -932,13 +933,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y.resize_(*orig_shape) return y, router_logits - + hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu() selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu() routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu() @@ -959,10 +960,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: y = self.moe_infer_simple( hidden_states_expert, selected_experts_expert, routing_weights_expert ).to(device=hidden_states.device) - + y.resize_(*orig_shape) return y, router_logits - + @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) @@ -981,11 +982,11 @@ def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] return outs - + @torch.no_grad() # TODO may bugs here def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: - + batch_size, sequence_length, hidden_dim = orig_shape final_hidden_states = torch.zeros( diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index a7028726..35554720 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -6,6 +6,7 @@ import torch import os from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +from ktransformers.util.torch_auto_backend import CUDA flashinfer_enabled = False @@ -13,7 +14,7 @@ import flashinfer flashinfer_enabled = True print("found flashinfer") - + except ImportError: print("flashinfer not found, use triton for linux") @@ -72,7 +73,7 @@ def __init__(self, max_batch_size, max_pages, use_cuda_graph = True, - device = "cuda", + device = CUDA, ): self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device) self.max_batch_size = max_batch_size @@ -101,7 +102,7 @@ def __init__(self, kv_len_arr=self.kv_len_arr_buf, ) self.need_plan = True - + def plan(self, qo_indptr, kv_indptr, @@ -124,7 +125,7 @@ def plan(self, if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf - + self.wrapper.plan( qo_indptr, kv_indptr, @@ -151,7 +152,7 @@ def get_instance(cls, device, *args, **kwargs)->MLAWrapper: if device not in cls.wrappers: cls.make_instance(device, *args, **kwargs) return cls.wrappers[device] - + @classmethod def make_instance(cls, device, *args, **kwargs): cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device) @@ -182,17 +183,17 @@ def plan_all(cls, qo_indptr, q_data_type, kv_data_type,) wrapper.need_plan = False - + @classmethod def need_plan_all(cls): for device, wrapper in cls.wrappers.items(): wrapper.need_plan = True - + @classmethod def reset_buffer(cls): for device, wrapper in cls.wrappers.items(): wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here. - + @classmethod def update_buffer(cls, max_pages): for device, wrapper in cls.wrappers.items(): @@ -205,7 +206,7 @@ def checksame(): flashinfer_folder = "./kv_cache_flashinfer" triton_folder = "./triton_output" triton_folder = "./kv_cache_triton" - + max_layer_id = 1 max_forward_id = 2 @@ -216,18 +217,18 @@ def checksame(): #file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt" #file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt" file_name = f"layer_{layer_id}.pt" - + flashinfer_path = os.path.join(flashinfer_folder, file_name) triton_path = os.path.join(triton_folder, file_name) - + if not os.path.exists(triton_path): print(f"{file_name} not exist in {triton_folder}") continue if not os.path.exists(flashinfer_path): print(f"{file_name} not exist in {flashinfer_folder}") continue - - + + flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]# triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)# try: @@ -236,7 +237,7 @@ def checksame(): print(e) if __name__ == "__main__": - + #checksame() #exit(0) @@ -244,24 +245,24 @@ def checksame(): max_pages = 64 page_size = 64 num_heads = 128 - + # warm-up kv_len = 4023 q_len = 1 - q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") - q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") - kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda") + q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device=CUDA) + q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device=CUDA) + kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device=CUDA) ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1) - + wrapper = MLAWrapperSingleton.get_instance( - "cuda", + CUDA, max_batch_size, max_pages, ) - - kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") - qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") + + kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device=CUDA) + qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=CUDA) wrapper.plan( qo_indptr, None, @@ -278,7 +279,7 @@ def checksame(): attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) print(attn_output.shape) - + graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) @@ -298,17 +299,17 @@ def checksame(): q_len = 1 kv_len = 126 file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt" - q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda") + q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device=CUDA) file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt" - q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda") + q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device=CUDA) q = torch.cat([q_nope, q_pe], dim=-1) - kv_cache = torch.load(kv_cache_path).to(device="cuda") + kv_cache = torch.load(kv_cache_path).to(device=CUDA) pages, page_size, _, head_dim = kv_cache.shape kv_cache = kv_cache.view(pages, page_size, head_dim) ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1) - - kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") - qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") + + kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device=CUDA) + qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=CUDA) wrapper.plan( None, None, @@ -322,7 +323,7 @@ def checksame(): torch.bfloat16, torch.bfloat16, ) - + q_nope_buf.copy_(q_nope) q_pe_buf.copy_(q_pe) kv_buf[:pages].copy_(kv_cache) @@ -347,21 +348,21 @@ def checksame(): 192 ** (-0.5) ) torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3) - + # ref_triton attn_logits = torch.empty( ( max_batch_size, num_heads, 4, #num_kv_splits # follow vLLM, fix it TODO - 512 + 1, + 512 + 1, ), dtype=torch.float32, - device = "cuda" + device = CUDA ) - + triton_ref = torch.zeros_like(q_nope) - page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda") + page_table = torch.arange(max_pages, dtype=torch.int32, device=CUDA) ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576) ckv = ckv.view(pages, page_size, 1, 512) decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref, @@ -372,9 +373,9 @@ def checksame(): page_size) torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3) - + #file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt" #ktrans_output = torch.load(file_name) #torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3) print("test past") - + diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index d9080939..9c3fae21 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -5,6 +5,7 @@ from torch import Tensor, nn import torch.nn.functional as F import torch +from ktransformers.util.torch_auto_backend import CUDA import sys, os from ktransformers.operators.base_operator import BaseInjectedModule @@ -24,12 +25,12 @@ # class Base(BaseInjectedModule, ABC): class KMoEGateBase(ABC): - def __init__(self, - key: str, - gguf_loader: GGUFLoader, - config: PretrainedConfig, - orig_module: nn.Module, - device: str = "cuda", + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = CUDA, **kwargs): # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__() @@ -38,7 +39,7 @@ def __init__(self, self.config = config self.device = device self.orig_module = orig_module - + @abstractmethod def forward(self, input_tensor, expert_ids, weights): pass @@ -46,7 +47,7 @@ def forward(self, input_tensor, expert_ids, weights): @abstractmethod def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False): pass - + @abstractmethod def unload(): pass @@ -69,7 +70,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): key = ".".join(key.split(".")[:-1]) if self.gguf_loader.safetensor_loader is not None: targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] - weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight") + weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight") e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias") weight_type = weight.dtype e_score_correction_bias_type = e_score_correction_bias.dtype @@ -85,7 +86,7 @@ def load_weights(self, override_key: str | None = None, device: str = "cpu"): raise ValueError(f"Experts {key} not found in gguf_loader") res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} return res - + def load_multi(self, key: str, keys: list[str], device: str = "cpu"): tensors = {} for k in keys: @@ -100,8 +101,8 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, - prefill_device: str = "cuda", - generate_device: str = "cuda", + prefill_device: str = CUDA, + generate_device: str = CUDA, **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) @@ -115,7 +116,7 @@ def forward(self, hidden_states) -> torch.Tensor: def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weights(device=device) - + if isinstance(w, dict): self.weight_type = w["weight_type"] self.e_score_correction_bias_type = w["e_score_correction_bias_type"] diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 103fc1a3..41bbd250 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -1,20 +1,21 @@ #!/usr/bin/env python # coding=utf-8 ''' -Description : +Description : Author : Azure-Tang, Boxin Zhang Date : 2024-07-25 11:25:24 Version : 0.1.0 -LastEditors : Azure +LastEditors : Azure LastEditTime : 2024-08-29 09:11:16 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import ctypes import torch +from ktransformers.util.torch_auto_backend import CUDA from torch import Tensor, nn -import KTransformersOps +import KTransformersOps from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.utils import InferenceState from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( @@ -44,7 +45,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, - device: str = "cuda", + device: str = CUDA, **kwargs, ): # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) @@ -84,7 +85,7 @@ def load_weight(self, override_key: str | None = None, device: str | None = None tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) - + elif key + ".weight" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map: tensors = self.load_multi(key, ["weight", "bias"], device=device) @@ -108,7 +109,7 @@ def load_multi(self, key: str, keys: list[str], device: str = "cpu"): return tensors @abstractmethod - def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"): + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = CUDA): pass @abstractmethod @@ -123,7 +124,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, - device: str = "cuda", + device: str = CUDA, **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) @@ -148,11 +149,11 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = if device is None: device = self.device if w is None: w = self.load_weight(device=device) # else: self.out_features = w.shape[0], self.in_features = w.shape[1] - + if isinstance(w, nn.Parameter): try: self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T - except: + except: self.weight = w.to(dtype=self.dtype).T self.has_bias = False elif isinstance(w, tuple): @@ -193,7 +194,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, - device: str = "cuda", + device: str = CUDA, block_size: int = 128, **kwargs, ): @@ -201,18 +202,18 @@ def __init__( self.has_bias = False self.dtype = torch.get_default_dtype() self.block_size = block_size - + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.to(self.device) - orig_dtype = x.dtype + orig_dtype = x.dtype x_quantized, scale_x = act_quant(x, self.block_size) y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv) return y.to(dtype=orig_dtype) - + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device - if w is None: - w = self.load_weight(device=device) + if w is None: + w = self.load_weight(device=device) ### TODO fit weight_inv format if isinstance(w, tuple): self.weight = w[0].to(device) @@ -223,14 +224,14 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) - + def unload(self): if self.weight is not None: self.weight = None if self.has_bias: self.bias = None - - + + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor @@ -243,7 +244,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, - device: str = "cuda", + device: str = CUDA, num_bits: int = 4, # 4-bit/8-bit is supported group_size: int = 64, # -1, 32, 64, 128 act_order: bool = False, @@ -265,7 +266,7 @@ def __init__( self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N #print(f"After padding: in_features={in_features}, out_features={out_features}") - + self.k = self.in_features self.n = self.out_features @@ -273,10 +274,10 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = if self.loaded: return if device is None: device = self.device assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" - + #if self.in_features * self.out_features: - if w is None: - w = self.load_weight(device=device) + if w is None: + w = self.load_weight(device=device) if isinstance(w, nn.Parameter): # pad weight @@ -293,7 +294,7 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = weight = weight.to(device) if self.has_bias: self.bias = self.bias.to(device) - + if self.padding: padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device) padded_weight[:self.orin_in_features, :self.orin_out_features] = weight @@ -368,7 +369,7 @@ def __init__( config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cpu", - out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu. + out_device: str = CUDA, # this device mean which device the output should on. TODO: support cpu. stride = 16, group_max_len = 1024, **kwargs, @@ -391,8 +392,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: KLinearCPUInfer.CPU_INFER.submit_with_cuda_stream( torch.cuda.current_stream().cuda_stream, self.linear.forward( - qlen, - self.input_tensor_cpu.data_ptr(), + qlen, + self.input_tensor_cpu.data_ptr(), self.output_cpu.data_ptr() ) ) @@ -410,8 +411,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = torch.empty(output_shape, device=x.device, dtype=x.dtype) KLinearCPUInfer.CPU_INFER.submit( self.linear.forward( - qlen, - x.data_ptr(), + qlen, + x.data_ptr(), output.data_ptr() ) ) @@ -428,13 +429,13 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = if self.bias is not None: self.has_bias = True self.bias = self.bias.to(device) - + weight_ptr = ctypes.addressof( ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30) self.linear = cpuinfer_ext.linear.Linear(config) - + if warmup: KLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up()) KLinearCPUInfer.CPU_INFER.sync() @@ -459,7 +460,7 @@ def unload(self): if self.w is not None: self.w = None if self.has_bias: - self.bias = None + self.bias = None LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, @@ -476,9 +477,9 @@ def __init__( config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", - generate_device: str = "cuda", + generate_device: str = CUDA, generate_op: str| None = "KLinearMarlin", - prefill_device: str = "cuda", + prefill_device: str = CUDA, prefill_op: str| None = "KLinearTorch", **kwargs, ): @@ -537,7 +538,7 @@ def unload(self): self.device = self.generate_linear.device def set_inference_mode(self, mode: InferenceState): - if not mode: + if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE) diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 57d4bea0..c2bd7325 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # coding=utf-8 """ -Description : +Description : Author : Azure-Tang Date : 2024-07-25 11:25:24 Version : 1.0.0 -LastEditors : Azure +LastEditors : Azure LastEditTime : 2024-08-27 07:29:04 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import inspect @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union import time import torch +from ktransformers.util.torch_auto_backend import CUDA import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -189,7 +190,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + device: str = CUDA, per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, @@ -281,7 +282,7 @@ def forward( if inputs_embeds is None: input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds.to("cuda") + inputs_embeds = inputs_embeds.to(CUDA) if cache_position is None: past_seen_tokens = ( @@ -427,7 +428,7 @@ def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState): ), "module should be nn.ModuleList of decoder layers" # TODO Support restore to original device, not only cuda - device = "cpu" if target == InferenceState.UNLOAD else "cuda" + device = "cpu" if target == InferenceState.UNLOAD else CUDA # attn layer.self_attn.q_proj.set_inference_mode(target) @@ -539,7 +540,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + device: str = CUDA, per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, @@ -625,7 +626,7 @@ def forward( if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) - + if inputs_embeds is None: org_device = input_ids.device # TODO move to embed_tokens's device, not hard code to cpu @@ -801,7 +802,7 @@ def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState): ), "module should be nn.ModuleList of decoder layers" # TODO Support restore to original device, not only cuda - device = "cpu" if target == InferenceState.UNLOAD else "cuda" + device = "cpu" if target == InferenceState.UNLOAD else CUDA # TODO Support DFS to auto use {to, set_inference_mode} according to the module type @@ -961,7 +962,7 @@ def __init__( gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, - device: str = "cuda", + device: str = CUDA, per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, @@ -985,7 +986,7 @@ def __init__( max_seq_len=self.long_context_config["max_seq_len"], block_size=self.long_context_config["block_size"], config=config, - device=torch.device("cuda"), + device=torch.device(CUDA), local_windows_len=self.long_context_config["local_windows_len"], topk=self.long_context_config["second_select_num"], threads_num=self.ext_config["cpu_infer"], @@ -1066,7 +1067,7 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], - device="cuda", + device=CUDA, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1094,7 +1095,7 @@ def forward( return_dict, ) elif q_len <= chunck_size: - inputs_embeds = inputs_embeds.to('cuda') + inputs_embeds = inputs_embeds.to(CUDA) output = self.forward_chunk( inputs_embeds, causal_mask, @@ -1120,7 +1121,7 @@ def forward( print(f'current prefill length: {cur_idx}') chunk_mask = None if inputs_embeds.device.type == 'cpu': - tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda") + tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to(CUDA) else: tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)] output_with_past = self.forward_chunk( @@ -1337,7 +1338,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type == CUDA and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 331e6cf9..46e3e4e1 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -1,11 +1,12 @@ ''' -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Mapping, List import torch +from ktransformers.util.torch_auto_backend import CUDA0 import yaml import re from torch import nn @@ -52,7 +53,7 @@ def del_meta(module:nn.Module): for name, child in module._modules.items(): del_meta(child) -def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"): +def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str = "", default_device: str = CUDA0): module_name = prefix[:-1] translated_name = translate_name_to_gguf(prefix)[:-1] #print("gen_optimize_config", prefix, module_name, translated_name) @@ -87,7 +88,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p if "recursive" in rule: recursive = bool(rule["recursive"]) break - + if module_name not in out_data: out_data[module_name]= { "class": "default", @@ -104,23 +105,23 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p if child is not None: child_prefix = prefix + name + "." gen_optimize_config(child, out_data, rule_list, child_prefix) - + def translate_model_config(model_config: PretrainedConfig): - # for supporting some special model + # for supporting some special model if model_config.model_type == "mixtral": model_config.moe_intermediate_size = model_config.intermediate_size - + return model_config -def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"): +def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = CUDA0): with open(rule_file, 'r', encoding='utf-8') as f: rule_list = yaml.load(f.read(), Loader=yaml.FullLoader) - + optimize_config = dict() gen_optimize_config(module, optimize_config, rule_list, default_device = default_device) - + model_config = translate_model_config(model_config) gguf_loader=GGUFLoader(gguf_path) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 1752a3c6..aeffc5c1 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -1,4 +1,5 @@ import torch +from ktransformers.util.torch_auto_backend import CUDA, CUDA0 import asyncio from transformers import AutoTokenizer, AutoConfig, GenerationConfig from ktransformers.server.backend.interfaces.transformers import ( @@ -20,6 +21,7 @@ warm_uped = False + class KTransformersThreadContext(TransformersThreadContext): pass @@ -28,7 +30,9 @@ class KTransformersInterface(TransformersInterface): def __init__(self, args: ConfigArgs = default_args): self.args = args torch.set_grad_enabled(False) - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) + self.tokenizer = AutoTokenizer.from_pretrained( + args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code + ) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) try: generation_config = GenerationConfig.from_pretrained(args.model_dir) @@ -39,7 +43,7 @@ def __init__(self, args: ConfigArgs = default_args): top_p=args.top_p, do_sample=True ) - + torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" @@ -71,7 +75,6 @@ def __init__(self, args: ConfigArgs = default_args): dtype=self.model.dtype, ) # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") - if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) @@ -83,7 +86,7 @@ def decode_one_tokens(self): device_map = self.model.gguf_loader.tensor_device_map torch_device = get_device("blk.0.self_attn", device_map) - torch_device = "cuda:0" if torch_device == "cuda" else torch_device + torch_device = CUDA0 if torch_device == CUDA else torch_device torch.cuda.set_device(torch_device) if warm_uped and self.args.use_cuda_graph: if not hasattr(self, "cuda_graph_runner"): @@ -107,10 +110,10 @@ def decode_one_tokens(self): torch.cuda.synchronize() logits = logits[0, -1, :] return self.logits_to_token(logits) - + if self.args.use_cuda_graph: warm_uped = True - + if self.use_static_cache: logits = self.model( self.current_ids.to(torch_device), @@ -125,8 +128,6 @@ def decode_one_tokens(self): return self.logits_to_token(logits) - - @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]): input_ids_length = input_ids.shape[-1] @@ -135,30 +136,30 @@ def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[f self.seq_length = input_ids_length return logger.debug(f"input_ids: {input_ids.shape}") - device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") - device = "cuda:0" if device == "cuda" else device + device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", CUDA0) + device = CUDA0 if device == CUDA else device if is_new: self.ever_generated_ids.clear() same_prefix = 0 flat_input_ids = input_ids.flatten() - if getattr(self, 'generated_ids', None) is None: + if getattr(self, "generated_ids", None) is None: self.generated_ids = torch.zeros( self.args.batch_size, input_ids.shape[-1] + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device, ) - self.seq_length = 1 - + self.seq_length = 1 + flat_prev_ids = self.generated_ids.flatten() for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): if flat_input_ids[i] == flat_prev_ids[i]: same_prefix += 1 else: break - + logger.debug(f"same prefix len: {same_prefix}") self.cache.remove_suffix(same_prefix) self.seq_length = same_prefix @@ -170,27 +171,25 @@ def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[f self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}") - + former_seq_length = self.seq_length self.seq_length += input_ids_length expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens) delta_length = expected_length - self.generated_ids.shape[-1] if delta_length > 0: - new_generate_ids = torch.zeros( - self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device - ) + new_generate_ids = torch.zeros(self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) else: logger.warning(f"seq_length bigger than cache_lens, killed") exit(0) - + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=device) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) if not (type(self) is TransformersInterface): input_ids = input_ids.to("cpu") - + def chunk_prefill(input_ids, cache_position): inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) torch.cuda.set_device(device) @@ -216,7 +215,7 @@ def chunk_prefill(input_ids, cache_position): self.cache.cur_idx=cache_position[chunk_start:chunk_end] logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end]) chunk_start += self.args.chunk_prefill_size - + if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() self.prepare_logits_wrapper(input_ids, device, temperature, top_p) @@ -225,14 +224,14 @@ def chunk_prefill(input_ids, cache_position): @property def active_cache_position(self): - device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", CUDA0) return torch.tensor([self.seq_length - 1], device=device) - + async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async with self._infer_lock: async for v in super().inference(local_messages, thread_id, temperature, top_p): yield v - + # return this inference raw usage yield RawUsage( tokenize_time = self.profiler.get_timer_sec('tokenize'), diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 332e398e..153e55fa 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -10,10 +10,11 @@ """ import os import shutil +from typing import Optional import yaml from ktransformers.server.config.singleton import Singleton -from typing import Optional +from ktransformers.util.torch_auto_backend import CUDA0, CUDA2 class Config(metaclass=Singleton): @@ -93,7 +94,7 @@ def __init__(self): # to make sure it consistent with previous version self.model_path: str = self.model_dir self.model_name: str = self.model.get("name", "") - self.model_device: str = self.model.get("device", "cuda:0") + self.model_device: str = self.model.get("device", CUDA0) self.gguf_path: Optional[str] = self.model.get("gguf_path", None) self.use_cuda_graph = self.model.get("use_cuda_graph", True) self.trust_remote_code = self.model.get("trust_remote_code", True) @@ -148,7 +149,7 @@ def __init__(self): self.amnesia = self.model.get("amnesia", False) self.batch_size = self.model.get("batch_size", 1) self.cache_lens = self.model.get("cache_lens", 4096) - self.device = self.model.get("device", "cuda:2") + self.device = self.model.get("device", CUDA2) # web config self.web: dict = cfg.get("web", {}) diff --git a/ktransformers/server/schemas/assistants/messages.py b/ktransformers/server/schemas/assistants/messages.py index b65ca7cd..bcda62ea 100644 --- a/ktransformers/server/schemas/assistants/messages.py +++ b/ktransformers/server/schemas/assistants/messages.py @@ -56,13 +56,13 @@ class Text(BaseModel): class TextObject(ContentObject): text: Text delta_index: int = Field(default=0,exclude=True) - special_tokens_on: bool = Field(default=False,exclude=True) - last_two: str= Field(default='',exclude=True) + special_tokens_on: bool = Field(default=False,exclude=True) + last_two: str = Field(default='',exclude=True) - def filter_append(self,text:str): + def filter_append(self,text:str): self.text.value+=text self.delta_index+=1 - return True + return True @@ -115,7 +115,7 @@ class Status(Enum): class MessageObject(MessageBase, ObjectWithCreatedTime): _encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None) - + def get_text_content(self) -> str: text_content = "" @@ -134,19 +134,19 @@ async def get_encoded_content(self,encode_fn:Callable): for f in self.get_attached_files(): logger.info(f'encoding file: {f.filename}') self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1) - yield None + yield None yield self._encoded_content def get_attached_files(self): - raise NotImplementedError # should be replaced + raise NotImplementedError # should be replaced def append_message_delta(self,text:str): - raise NotImplementedError # should be replaced - + raise NotImplementedError # should be replaced + def sync_db(self): # raise NotImplementedError # should be replaced sql_utils = SQLUtil() @@ -155,7 +155,7 @@ def sync_db(self): ) with sql_utils.get_db() as db: sql_utils.db_merge_commit(db, db_message) - + def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse: match event: @@ -164,7 +164,7 @@ def stream_response_with_event(self, event: MessageBase.Status) -> MessageStream case _: self.status = event return MessageStreamResponse(message=self, event=event) - + class MessageStreamResponse(BaseModel): message: MessageObject diff --git a/ktransformers/tests/torch_auto_backend_test.py b/ktransformers/tests/torch_auto_backend_test.py new file mode 100644 index 00000000..d059c796 --- /dev/null +++ b/ktransformers/tests/torch_auto_backend_test.py @@ -0,0 +1,14 @@ +import torch +from ktransformers.util.torch_auto_backend import CUDA, CUDA0 + +if __name__ == "__main__": + print(CUDA, CUDA0) + a = torch.tensor([1.2, 2.3], dtype=torch.float32, device=CUDA) + print(a) + b = torch.tensor([1.2, 2.3], dtype=torch.float32, device=CUDA0) + print(b) + print(torch.cuda.is_available()) + print(torch.cuda.device_count()) + print(torch.cuda.device(0)) + print(b.cuda()) + print(torch.cuda.current_stream(CUDA).cuda_stream) diff --git a/ktransformers/util/cuda_graph_runner.py b/ktransformers/util/cuda_graph_runner.py index b4b0adce..41c0f0c1 100644 --- a/ktransformers/util/cuda_graph_runner.py +++ b/ktransformers/util/cuda_graph_runner.py @@ -1,10 +1,11 @@ ''' -Description : +Description : Author : Boxin Zhang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch +from ktransformers.util.torch_auto_backend import CUDA, CUDA0 from typing import Dict class CUDAGraphRunner: @@ -32,13 +33,13 @@ def capture( self.model = model inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device) # torch.cuda.set_device can't set "cuda", must have a index - if main_device == "cuda": - main_device = "cuda:0" + if main_device == CUDA: + main_device = CUDA0 torch.cuda.set_device(main_device) self.main_device = main_device capture_stream = torch.cuda.Stream() with torch.cuda.graph(self.graph, stream = capture_stream): - logits=model(inputs_embeds=inputs_embeds, + logits=model(inputs_embeds=inputs_embeds, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, @@ -46,7 +47,7 @@ def capture( capture_stream.wait_stream(torch.cuda.current_stream()) torch.cuda.set_device(main_device) torch.cuda.set_stream(capture_stream) - if past_key_values != None: + if past_key_values != None: past_key_values.change_seq_length(-1) torch.cuda.synchronize(self.main_device) #self.graph.debug_dump("cuda_graph_hooked.dot") diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 84ada15a..b16d86a2 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 ''' -Description : +Description : Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-26 08:48:54 Version : 1.0.0 @@ -10,7 +10,7 @@ Adapted from https://github.com/99991/pygguf/blob/main/gguf.py Copyright (c) 2023-2024 The ggml authors Copyright (c) 2024 Thomas Germer -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf # GGUF specification @@ -24,6 +24,7 @@ import os from enum import IntEnum import torch +from ktransformers.util.torch_auto_backend import CUDA import KTransformersOps from .custom_loader import SafeTensorLoader import ctypes @@ -180,7 +181,7 @@ def __init__(self, gguf_path: str): gguf_path = os.path.dirname(gguf_path) self.safetensor_loader = None - + self.tensor_info = {} self.gguf_path = gguf_path self.tensor_file_map = {} @@ -207,7 +208,7 @@ def __init__(self, gguf_path: str): self.file_data_map[file_name] = np.memmap(file_name, mode = 'r') if not found_gguf: raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}") - + def load_gguf(self, f): f.seek(0) assert f.read(4) == b'GGUF' @@ -235,7 +236,7 @@ def load_gguf(self, f): block_size, type_size = GGML_QUANT_SIZES[ggml_type] n_bytes = n_elems * type_size // block_size np_dims = tuple(reversed(shape)) - + item_type: npt.DTypeLike if ggml_type == GGMLQuantizationType.F16: item_count = n_elems @@ -284,12 +285,12 @@ def load_gguf(self, f): offset = start + t["bad_offset"] offset += (alignment - offset % alignment) % alignment t["offset"] = offset - + for name in tensor_info: self.tensor_file_map[name] = f.name self.tensor_info.update(tensor_info) self.gguf_file_meta.update(info) - + def get_mmap_tensor(self, name): t = self.tensor_info[name] mmap_data = self.file_data_map[ self.tensor_file_map[name] ] @@ -299,7 +300,7 @@ def get_mmap_tensor(self, name): item_count = t["item_count"] itemsize = int(np.empty([], dtype = item_type).itemsize) return mmap_data[offset : offset + itemsize * item_count] - + def get_undequanted_tensor_and_ggml_type(self, name): t = self.tensor_info[name] data = self.get_mmap_tensor(name) @@ -307,7 +308,7 @@ def get_undequanted_tensor_and_ggml_type(self, name): data = torch.from_numpy(data) return data, ggml_type - def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor: + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = CUDA, target_dtype = torch.get_default_dtype())->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": print(f"loading expert {expert_id} of {name} with CPU") @@ -324,8 +325,8 @@ def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device block_size = GGML_BLOCK_SIZES[ggml_name] offset = expert_id * block_size * blocks_per_experts data = data[offset: offset + block_size * blocks_per_experts] - - if "cuda" in device.lower(): + + if CUDA in device.lower(): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) else: values = GGML_DEQUANTIZE[ggml_name](data) @@ -343,7 +344,7 @@ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)-> print(f"loading {name} with CPU") if target_dtype == None: target_dtype = torch.get_default_dtype() - + shape = t["shape"] ggml_type = t["ggml_type"] @@ -358,33 +359,33 @@ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)-> elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name] num_elements = int(np.prod(shape)) num_blocks = num_elements // elements_per_block - + blocks_per_iter = 16384 if num_blocks > blocks_per_iter: # dequant large tensor values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device) for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): blocks_begin = i * blocks_per_iter blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) - if "cuda" in device.lower(): + if CUDA in device.lower(): cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) else: cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) cur_values = torch.from_numpy(cur_values.copy()) - + cur_values = cur_values.view(-1, elements_per_block) if ggml_name == "BF16": cur_values = cur_values.view(torch.bfloat16) values[blocks_begin : blocks_end] = cur_values else: - if "cuda" in device.lower(): + if CUDA in device.lower(): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) else: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values) - + if ggml_name == "BF16": values = values.view(torch.bfloat16) - + values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: @@ -393,7 +394,7 @@ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)-> .swapaxes(1, 2) .reshape(values.shape)) elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: - n_head = self.gguf_file_meta['llama.attention.head_count_kv'] + n_head = self.gguf_file_meta['llama.attention.head_count_kv'] values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) .swapaxes(1, 2) .reshape(values.shape)) @@ -484,12 +485,12 @@ def dequantize_q2_k(data): return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) -def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q2_k_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q2_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) - # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) @@ -536,12 +537,12 @@ def dequantize_q3_k(data): (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) ], axis=1) -def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q3_k_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q3_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) - # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) @@ -568,12 +569,12 @@ def dequantize_q4_k(data): # Dequantize final weights using scales and offsets return factors * qs2 - offsets -def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q4_k_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q4_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) - # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) @@ -634,12 +635,12 @@ def dequantize_q5_k(data): d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, ], axis=1) -def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q5_k_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q5_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) - # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) @@ -690,10 +691,10 @@ def dequantize_q6_k(data): sc[:, 13] * q7[:, 16:], sc[:, 14] * q8[:, :16], sc[:, 15] * q8[:, 16:], - ], axis=1) + ], axis=1) # @torch.jit.script -def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q6_k_gpu(data: np.ndarray, device:str = CUDA, target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q6_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"] device = torch.device(device) @@ -734,7 +735,7 @@ def dequantize_iq4_xs(data): return y.flatten() -def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): +def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = CUDA, target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["IQ4_XS"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"] device = torch.device(device) @@ -758,7 +759,7 @@ def dequantize_q4_0(data): scales * ((qs >> 4).astype(np.int8) - 8), ], axis=1) -def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q4_0_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q5_0(data): @@ -782,7 +783,7 @@ def dequantize_q5_0(data): scales * x1, ], axis=1) -def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q5_0_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q8_0(data): @@ -794,10 +795,10 @@ def dequantize_q8_0(data): qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] return scales * qs -def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): +def dequantize_q8_0_gpu(data, device:str = CUDA, target_dtype = torch.get_default_dtype()): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 - + block_size = GGML_BLOCK_SIZES["Q8_0"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"] device = torch.device(device) @@ -865,12 +866,12 @@ def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()): def translate_name_to_gguf_mixtral(name): - + replacement_template = { "w1.weight": "ffn_gate", "w2.weight": "ffn_down", "w3.weight": "ffn_up" - } + } pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)") @@ -884,7 +885,7 @@ def replace_match(match): return match.group(0) new_name = re.sub(pattern, replace_match, name) - + return new_name def translate_name_to_gguf(name): @@ -894,7 +895,7 @@ def translate_name_to_gguf(name): name = name.replace("lm_head.", "output.") name = name.replace("model.embed_tokens.", "token_embd.") name = name.replace("model.norm.", "output_norm.") - + name = name.replace("model.layers.", "blk.") name = name.replace(".input_layernorm", ".attn_norm") name = name.replace(".mlp.down_proj", ".ffn_down") @@ -912,11 +913,11 @@ def translate_name_to_gguf(name): name = name.replace(".self_attn.q_a_proj", ".attn_q_a") name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm") name = name.replace(".self_attn.q_b_proj", ".attn_q_b") - + name = name.replace(".shared_expert.", ".shared_experts.") name = name.replace(".shared_expert_", ".shared_experts_") name = name.replace(".gate_up_proj.", ".up_proj") - + name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp") name = name.replace(".mlp.gate", ".ffn_gate_inp") name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp") @@ -927,10 +928,10 @@ def translate_name_to_gguf(name): name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") - + name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.") name = name.replace(".block_sparse_moe.experts", "") - + return name if __name__ == '__main__': diff --git a/ktransformers/util/torch_auto_backend.py b/ktransformers/util/torch_auto_backend.py new file mode 100644 index 00000000..7c861c99 --- /dev/null +++ b/ktransformers/util/torch_auto_backend.py @@ -0,0 +1,47 @@ +import sys +import torch +from torch.utils.cpp_extension import CUDA_HOME +try: + from torch_musa.utils.musa_extension import MUSA_HOME +except ImportError: + MUSA_HOME=None + +if CUDA_HOME is not None: + CUDA = "cuda" +elif MUSA_HOME is not None: + CUDA = "musa" + + torch.cuda = torch.musa + torch.cuda.CUDAGraph = torch.musa.MUSAGraph + + # **Monkey Patch `torch.Tensor.cuda()`** + def tensor_cuda(self, device=None, non_blocking=False, memory_format=None): + if device is None: + device = CUDA + elif isinstance(device, int): + device = f"{CUDA}:{device}" + return self.to(device, non_blocking=non_blocking, memory_format=memory_format) + + torch.Tensor.cuda = tensor_cuda + + # **Monkey Patch `torch.cuda.current_stream`** + original_musa_current_stream = torch.musa.current_stream + + def patch_stream_object(stream): + if not hasattr(stream, "cuda_stream"): + stream.cuda_stream = stream.musa_stream + return stream + + def patched_current_stream(device=None): + return patch_stream_object(original_musa_current_stream(device)) + + torch.cuda.current_stream = patched_current_stream + +else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + +CUDA0 = f"{CUDA}:0" +CUDA1 = f"{CUDA}:1" +CUDA2 = f"{CUDA}:2" + +print(f"Torch backend loaded: CUDA={CUDA}, CUDA0={CUDA0}") diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 6f3b0492..d43bde0e 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -1,12 +1,13 @@ #!/usr/bin/env python # coding=utf-8 ''' -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch +from ktransformers.util.torch_auto_backend import CUDA, CUDA0 from torch import nn import itertools import time @@ -48,7 +49,7 @@ def set_module(model, submodule_key, module): cur_mod[int(tokens[-1])] = module def set_param(module: nn.Module, name: str, weights: torch.Tensor): - + param=nn.parameter.Parameter(weights, requires_grad=False) if isinstance(module, nn.Linear) and len(weights.shape)==1: param.unsqueeze_(0) @@ -58,7 +59,7 @@ def get_device(gguf_module_key:str, device_map:dict): if gguf_module_key in device_map: return device_map[gguf_module_key]["generate_device"] else: - return "cuda" + return CUDA def get_all_used_cuda_device(device_map:dict): all_device_list = set() @@ -78,7 +79,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str for name, param in local_state.items(): key = prefix + name translated_key = translate_name_to_gguf(key) - + # TODO: Merge all loader. # I know this is ugly but lets do it for now. if gguf_loader.safetensor_loader is not None: @@ -87,7 +88,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str else: load_dequantized_tensor = gguf_loader.load_gguf_tensor tensor_file_map = gguf_loader.tensor_file_map - + if translated_key in tensor_file_map: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) @@ -99,7 +100,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str else: #print(load_config.tensor_file_map.keys()) raise Exception(f"can't find {translated_key} in GGUF file!") - + def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): @@ -118,12 +119,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud batch_size, seq_length = inputs.shape device_map = model.gguf_loader.tensor_device_map torch_device = get_device('blk.0.self_attn', device_map) - torch_device = "cuda:0" if torch_device == "cuda" else torch_device + torch_device = CUDA0 if torch_device == CUDA else torch_device inputs = inputs.to(torch_device) all_cuda_device = get_all_used_cuda_device(device_map) tokens = [] - + def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True): if cuda_graph_runner is None: use_cuda_graph = False @@ -151,7 +152,7 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position else: next_token = torch.argmax(next_token_scores, dim=-1) return next_token - + # TODO: use CUDA Graph for chunk prefill, may get small improvement def chunk_prefill(inputs, cache_position, past_key_values): if mode == "long_context": @@ -161,16 +162,16 @@ def chunk_prefill(inputs, cache_position, past_key_values): if use_flashinfer_mla: MLAWrapperSingleton.update_buffer(past_key_values.max_pages) MLAWrapperSingleton.need_plan_all() - + logits = model( inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True )[0][:,-1,:].unsqueeze(0).clone().to(torch_device) - + return logits - + torch.cuda.set_device(torch_device) with torch.no_grad(): - + stream = TextStreamer(tokenizer) if mode != 'long_context': past_key_values = StaticCache( @@ -178,7 +179,7 @@ def chunk_prefill(inputs, cache_position, past_key_values): ) else: past_key_values = None - + generation_config, model_kwargs = model._prepare_generation_config( None, do_sample=True # change this to modify generate config @@ -188,7 +189,7 @@ def chunk_prefill(inputs, cache_position, past_key_values): logits_warper = ( model._get_logits_warper(generation_config,device=inputs.device) ) - except: + except: logits_warper = ( model._get_logits_warper(generation_config) ) @@ -216,7 +217,7 @@ def chunk_prefill(inputs, cache_position, past_key_values): next_token = torch.argmax(next_token_scores, dim=-1) first_token_time = time.time() - start_time - + if use_flashinfer_mla: MLAWrapperSingleton.reset_buffer() @@ -231,9 +232,9 @@ def chunk_prefill(inputs, cache_position, past_key_values): cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32) position_ids = cache_position.unsqueeze(0) seq_length += 1 - + cuda_graph_runner = None - + start_time = time.time() for i in range(1, max_new_tokens): if use_flashinfer_mla: @@ -250,7 +251,7 @@ def chunk_prefill(inputs, cache_position, past_key_values): generated_ids[:, cache_position] = next_token.int() tokens.append(int(next_token)) seq_length += 1 - + if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': print(stream.end(), end="", flush=True) break @@ -258,7 +259,7 @@ def chunk_prefill(inputs, cache_position, past_key_values): print(stream.put(next_token.item()), end="", flush=True) cache_position += 1 position_ids = cache_position.unsqueeze(0) - + total_time = time.time() - start_time tokens_generated = len(tokens)