diff --git a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h index d0f0c603..c75672d6 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h @@ -17,6 +17,7 @@ #include #include #include +#include #ifdef KTRANSFORMERS_USE_CUDA #include "vendors/cuda.h" #elif KTRANSFORMERS_USE_MUSA @@ -62,10 +63,14 @@ class CPUInfer { } void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { - void (*func)(void*) = (void (*)(void*))params.first; - void* args = (void*)params.second; - *((CPUInfer**)args) = this; - cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); + #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) + void (*func)(void*) = (void (*)(void*))params.first; + void* args = (void*)params.second; + *((CPUInfer**)args) = this; + cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); + #else + throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma"); + #endif } static void sync_(void* cpu_infer_ptr) { @@ -74,7 +79,11 @@ class CPUInfer { } void sync_with_cuda_stream(intptr_t user_cuda_stream) { - cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); + #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) + cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); + #else + throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma"); + #endif } public: @@ -82,4 +91,4 @@ class CPUInfer { TaskQueue* task_queue_; }; -#endif \ No newline at end of file +#endif diff --git a/ktransformers/ktransformers_ext/ext_bindings.cpp b/ktransformers/ktransformers_ext/ext_bindings.cpp index 902d4271..08d39564 100644 --- a/ktransformers/ktransformers_ext/ext_bindings.cpp +++ b/ktransformers/ktransformers_ext/ext_bindings.cpp @@ -9,7 +9,9 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" -#include "device_launch_parameters.h" +#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) + #include "device_launch_parameters.h" +#endif #include "llamafile/flags.h" #include "operators/kvcache/kvcache.h" #include "operators/llamafile/linear.h" diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 4acaf86a..a1ea8ecf 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -62,12 +62,17 @@ def local_chat( prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, + device: str = "cuda", chunk_prefill_size: int = 8192 ): torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer + + if device != "cuda": + Warning("cuda graph is only supported on cuda device, please set use_cuda_graph to False") + use_cuda_graph = False tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) @@ -108,7 +113,7 @@ def local_chat( gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) - optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) + optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device) try: model.generation_config = GenerationConfig.from_pretrained(model_path) @@ -176,7 +181,7 @@ def local_chat( ) else: generated = prefill_and_generate( - model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, + model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, ) diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 35c80932..8ffc397e 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -20,7 +20,8 @@ import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache -from flash_attn import flash_attn_func +if not torch.xpu.is_available(): + from flash_attn import flash_attn_func from ktransformers.operators.triton_attention import decode_attention_fwd_grouped import os from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled @@ -512,7 +513,8 @@ def forward_linux_flashinfer( attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - def forward_windows( + + def forward_native( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -589,9 +591,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt' or get_compute_capability()<8: - print("for Windows or GPU before ampere, use forward_windows") - return self.forward_windows( + if os.name == 'nt' or torch.xpu.is_available() or get_compute_capability()<8: + return self.forward_native( hidden_states, attention_mask, position_ids, diff --git a/ktransformers/operators/dynamic_attention.py b/ktransformers/operators/dynamic_attention.py index 13a74b43..051744bf 100644 --- a/ktransformers/operators/dynamic_attention.py +++ b/ktransformers/operators/dynamic_attention.py @@ -17,7 +17,9 @@ logger = logging.getLogger("dynamic_attention") sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend") from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache -from flash_attn import flash_attn_func, flash_attn_with_kvcache + +if torch.cuda.is_available(): + from flash_attn import flash_attn_func, flash_attn_with_kvcache import math diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 88960c70..d1d7d893 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -207,7 +207,7 @@ def sync_for_one_decode(self): def forward(self, input_tensor, expert_ids, weights): # generate, capture and run cuda graph # print(expert_ids) - if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing(): + if input_tensor.size(0)==1 and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible #print("capturing experts") KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) @@ -725,7 +725,7 @@ def forward(self, hidden_states): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + if sequence_length == 1 and torch.cuda.is_available() and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 96d35787..8fc563d1 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -14,16 +14,22 @@ import ctypes import torch from torch import Tensor, nn -import KTransformersOps + +if torch.cuda.is_available(): + import KTransformersOps + from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.utils import InferenceState -from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( - MarlinWorkspace, - marlin_quantize, - GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MIN_THREAD_K, - GPTQ_MARLIN_MAX_PARALLEL, -) + +if not torch.xpu.is_available(): + from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( + MarlinWorkspace, + marlin_quantize, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MIN_THREAD_K, + GPTQ_MARLIN_MAX_PARALLEL, + ) + from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 57d4bea0..c05b7c9f 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -649,8 +649,7 @@ def forward( if per_layer_prefill_flag: causal_mask = None else: - if os.name == 'nt' or get_compute_capability()<8: - print("for Windows or GPU before ampere, use forward_windows") + if os.name == 'nt' or torch.xpu.is_available() or get_compute_capability()<8: # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 331e6cf9..ab1bc089 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + "." - gen_optimize_config(child, out_data, rule_list, child_prefix) + gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device) def translate_model_config(model_config: PretrainedConfig): @@ -115,6 +115,10 @@ def translate_model_config(model_config: PretrainedConfig): def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"): + if 'cuda' in default_device: + default_device = "cuda:0" + elif 'xpu' in default_device: + default_device = "xpu:0" with open(rule_file, 'r', encoding='utf-8') as f: rule_list = yaml.load(f.read(), Loader=yaml.FullLoader) @@ -131,4 +135,7 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo load_weights(module, gguf_loader) module.gguf_loader = gguf_loader del_meta(module) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-XPU.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-XPU.yaml new file mode 100644 index 00000000..f832b172 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-XPU.yaml @@ -0,0 +1,56 @@ +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "xpu" + prefill_device: "xpu" + generate_op: "KLinearTorch" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "xpu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "xpu" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "xpu" + prefill_device: "xpu" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 84ada15a..b48bba54 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -24,7 +24,8 @@ import os from enum import IntEnum import torch -import KTransformersOps +if not torch.xpu.is_available(): + import KTransformersOps from .custom_loader import SafeTensorLoader import ctypes import math @@ -380,12 +381,11 @@ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)-> values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) else: values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) + values = torch.from_numpy(values).to(device) if ggml_name == "BF16": values = values.view(torch.bfloat16) - values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count'] diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index ecc09a0a..e20b79fd 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -7,7 +7,8 @@ import os from enum import IntEnum import torch -import KTransformersOps +if not torch.xpu.is_available(): + import KTransformersOps from safetensors import safe_open from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from safetensors.torch import save_file @@ -83,4 +84,4 @@ def load_dequantized_tensor(self, key:str, device: str="cpu"): if key[:-7] + ".weight_scale_inv" in self.tensor_file_map: weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device) tensor = weight_dequant(tensor, weight_scale_inv) - return tensor.to(device) \ No newline at end of file + return tensor.to(device) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index ec862c1c..97f0b29a 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -92,7 +92,11 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225" + elif torch.xpu.is_available(): + torch.xpu.empty_cache() + weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights @@ -109,6 +113,17 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): else: module.load() +def sync_all_device(all_device_list): + for device in all_device_list: + if "cuda" in device.lower(): + torch.cuda.synchronize(device) + elif "xpu" in device.lower(): + torch.xpu.synchronize(device) + else: + raise RuntimeError("The device {} is not available".format(device)) + +torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"} + def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False, num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): @@ -118,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud batch_size, seq_length = inputs.shape device_map = model.gguf_loader.tensor_device_map torch_device = get_device('blk.0.self_attn', device_map) - torch_device = "cuda:0" if torch_device == "cuda" else torch_device + torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device inputs = inputs.to(torch_device) all_cuda_device = get_all_used_cuda_device(device_map) @@ -131,7 +146,12 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position logits = cuda_graph_runner(cur_token, position_ids, cache_position) else: # custom_stream = torch.cuda.Stream() - torch.cuda.set_device(torch_device) + if torch.cuda.is_available(): + torch.cuda.set_device(torch_device) + elif torch.xpu.is_available(): + torch.xpu.set_device(torch_device) + else: + RuntimeError("The device: {torch_device} is not available") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device) # with torch.cuda.stream(custom_stream): logits=model(inputs_embeds=inputs_embeds, @@ -141,8 +161,7 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position return_dict=False, use_cache=True)[0] if past_key_values != None: past_key_values.change_seq_length(1) - for device in all_cuda_device: - torch.cuda.synchronize(device) + sync_all_device(all_cuda_device) #print(logits) next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: @@ -151,7 +170,6 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position else: next_token = torch.argmax(next_token_scores, dim=-1) return next_token - # TODO: use CUDA Graph for chunk prefill, may get small improvement def chunk_prefill(inputs, cache_position, past_key_values): if mode == "long_context": @@ -167,8 +185,12 @@ def chunk_prefill(inputs, cache_position, past_key_values): )[0][:,-1,:].unsqueeze(0).clone().to(torch_device) return logits - - torch.cuda.set_device(torch_device) + if torch.cuda.is_available(): + torch.cuda.set_device(torch_device) + elif torch.xpu.is_available(): + torch.cuda.set_device(torch_device) + else: + RuntimeError("The device: {torch_device} is not available") with torch.no_grad(): stream = TextStreamer(tokenizer) diff --git a/setup.py b/setup.py index ea154828..3c6eea44 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ except ImportError: MUSA_HOME=None +KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available() + class CpuInstructInfo: CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") FANCY = "FANCY" @@ -151,9 +153,12 @@ def get_package_version(self, full_version=False): backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}" elif MUSA_HOME is not None: backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" + elif torch.xpu.is_available(): + backend_version = f"xpu" else: raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" + if full_version: return package_version if not VersionInfo.FORCE_BUILD: @@ -247,8 +252,10 @@ def build_extension(self, ext) -> None: cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"] elif MUSA_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] + elif KTRANSFORMERS_BUILD_XPU: + cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON"] else: - raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + raise ValueError("Unsupported backend: CUDA_HOME or MUSA_HOME are not set.") build_args = [] if "CMAKE_ARGS" in os.environ: @@ -367,14 +374,26 @@ def build_extension(self, ext) -> None: ] } ) +elif torch.xpu.is_available():#XPUExtension is not available now. + pass else: raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") -setup( - version=VersionInfo().get_package_version(), - cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, - ext_modules=[ - CMakeExtension("cpuinfer_ext"), - ops_module, - ] -) +if torch.xpu.is_available(): + setup( + version=VersionInfo().get_package_version(), + cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, + ext_modules=[ + CMakeExtension("cpuinfer_ext"), + ] + ) +else: + setup( + version=VersionInfo().get_package_version(), + cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, + ext_modules=[ + CMakeExtension("cpuinfer_ext"), + ops_module, + ] + ) +