Skip to content

[Hardware] Enable XPU Device on Intel GPU #651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <queue>
#include <thread>
#include <vector>
#include <stdexcept>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
Expand Down Expand Up @@ -62,10 +63,14 @@ class CPUInfer {
}

void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA)
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
#else
throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma");
#endif
}

static void sync_(void* cpu_infer_ptr) {
Expand All @@ -74,12 +79,16 @@ class CPUInfer {
}

void sync_with_cuda_stream(intptr_t user_cuda_stream) {
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA)
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
#else
throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma");
#endif
}

public:
Backend* backend_;
TaskQueue* task_queue_;
};

#endif
#endif
4 changes: 3 additions & 1 deletion ktransformers/ktransformers_ext/ext_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA)
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
Expand Down
9 changes: 7 additions & 2 deletions ktransformers/local_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,17 @@ def local_chat(
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
device: str = "cuda",
chunk_prefill_size: int = 8192
):

torch.set_grad_enabled(False)

Config().cpu_infer = cpu_infer

if device != "cuda":
Warning("cuda graph is only supported on cuda device, please set use_cuda_graph to False")
use_cuda_graph = False

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
Expand Down Expand Up @@ -108,7 +113,7 @@ def local_chat(
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config, default_device=device)

try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
Expand Down Expand Up @@ -176,7 +181,7 @@ def local_chat(
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
)


Expand Down
11 changes: 6 additions & 5 deletions ktransformers/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
from flash_attn import flash_attn_func
if not torch.xpu.is_available():
from flash_attn import flash_attn_func
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
Expand Down Expand Up @@ -512,7 +513,8 @@ def forward_linux_flashinfer(
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

def forward_windows(

def forward_native(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -589,9 +591,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
return self.forward_windows(
if os.name == 'nt' or torch.xpu.is_available() or get_compute_capability()<8:
return self.forward_native(
hidden_states,
attention_mask,
position_ids,
Expand Down
4 changes: 3 additions & 1 deletion ktransformers/operators/dynamic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
logger = logging.getLogger("dynamic_attention")
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
from flash_attn import flash_attn_func, flash_attn_with_kvcache

if torch.cuda.is_available():
from flash_attn import flash_attn_func, flash_attn_with_kvcache


import math
Expand Down
4 changes: 2 additions & 2 deletions ktransformers/operators/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def sync_for_one_decode(self):
def forward(self, input_tensor, expert_ids, weights):
# generate, capture and run cuda graph
# print(expert_ids)
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
if input_tensor.size(0)==1 and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
Expand Down Expand Up @@ -725,7 +725,7 @@ def forward(self, hidden_states):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and torch.cuda.is_available() and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
Expand Down
22 changes: 14 additions & 8 deletions ktransformers/operators/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@
import ctypes
import torch
from torch import Tensor, nn
import KTransformersOps

if torch.cuda.is_available():
import KTransformersOps

from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
)

if not torch.xpu.is_available():
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
)

from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
Expand Down
3 changes: 1 addition & 2 deletions ktransformers/operators/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,7 @@ def forward(
if per_layer_prefill_flag:
causal_mask = None
else:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or torch.xpu.is_available() or get_compute_capability()<8:
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
Expand Down
11 changes: 9 additions & 2 deletions ktransformers/optimize/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
gen_optimize_config(child, out_data, rule_list, child_prefix)
gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)


def translate_model_config(model_config: PretrainedConfig):
Expand All @@ -115,6 +115,10 @@ def translate_model_config(model_config: PretrainedConfig):


def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
if 'cuda' in default_device:
default_device = "cuda:0"
elif 'xpu' in default_device:
default_device = "xpu:0"
with open(rule_file, 'r', encoding='utf-8') as f:
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)

Expand All @@ -131,4 +135,7 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader
del_meta(module)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
56 changes: 56 additions & 0 deletions ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-XPU.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "xpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "xpu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "xpu"
prefill_device: "xpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
6 changes: 3 additions & 3 deletions ktransformers/util/custom_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import os
from enum import IntEnum
import torch
import KTransformersOps
if not torch.xpu.is_available():
import KTransformersOps
from .custom_loader import SafeTensorLoader
import ctypes
import math
Expand Down Expand Up @@ -380,12 +381,11 @@ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
values = torch.from_numpy(values).to(device)

if ggml_name == "BF16":
values = values.view(torch.bfloat16)


values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
Expand Down
5 changes: 3 additions & 2 deletions ktransformers/util/custom_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import os
from enum import IntEnum
import torch
import KTransformersOps
if not torch.xpu.is_available():
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from safetensors.torch import save_file
Expand Down Expand Up @@ -83,4 +84,4 @@ def load_dequantized_tensor(self, key:str, device: str="cpu"):
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
tensor = weight_dequant(tensor, weight_scale_inv)
return tensor.to(device)
return tensor.to(device)
Loading