Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Feature: Enable codellama on Intel GPUs #90

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llama/__init__.py
Original file line number Diff line number Diff line change
@@ -4,3 +4,4 @@
from .generation import Llama
from .model import ModelArgs, Transformer
from .tokenizer import Tokenizer
from .xpu_utils import is_xpu_available, is_ccl_available
22 changes: 16 additions & 6 deletions llama/generation.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@

from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer
from xpu_utils import is_xpu_available, is_ccl_available


Role = Literal["system", "user", "assistant"]

@@ -65,14 +67,20 @@ def build(
model_parallel_size: Optional[int] = None,
) -> "Llama":
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if is_ccl_available() and is_xpu_available():
torch.distributed.init_process_group("ccl")
else:
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if is_xpu_available():
torch.xpu.set_device(local_rank)
else:
torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(1)
@@ -100,6 +108,8 @@ def build(
model_args.vocab_size = tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
elif is_xpu_available():
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
@@ -135,14 +145,14 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="xpu") if is_xpu_available() else torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="xpu") if is_xpu_available() else torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs = torch.zeros_like(tokens, dtype=torch.bfloat16) if is_xpu_available() else torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
stop_reached = torch.tensor([False] * bsz, device="cuda")
stop_reached = torch.tensor([False] * bsz, device="xpu") if is_xpu_available() else torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
44 changes: 32 additions & 12 deletions llama/model.py
Original file line number Diff line number Diff line change
@@ -8,13 +8,15 @@
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F

from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from torch import nn

from torch import nn
from xpu_utils import is_xpu_available

@dataclass
class ModelArgs:
@@ -125,23 +127,41 @@ def __init__(self, args: ModelArgs):
input_is_parallel=True,
init_method=lambda x: x,
)

self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
if is_xpu_available():
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to("xpu")
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).to("xpu")

else:
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()

def forward(
self,
67 changes: 67 additions & 0 deletions llama/xpu_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import importlib
import importlib.metadata
import os
import warnings
from functools import lru_cache

from packaging import version
from packaging.version import parse



def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

_torch_version = importlib.metadata.version("torch")
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
except importlib.metadata.PackageNotFoundError:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
warnings.warn(
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
return True


@lru_cache
def is_xpu_available(check_device=False):
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
if not is_ipex_available():
return False

import intel_extension_for_pytorch # noqa: F401

if check_device:
try:
# Will raise a RuntimeError if no XPU is found
_ = torch.xpu.device_count()
return torch.xpu.is_available()
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


def is_ccl_available():
ccl_version = "N/A"
try:
_is_ccl_available = (
importlib.util.find_spec("torch_ccl") is not None
or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
)

ccl_version = importlib.metadata.version("oneccl_bind_pt")
print(f"Detected oneccl_bind_pt version {ccl_version}")
except importlib.metadata.PackageNotFoundError:
_is_ccl_available = False
return False
return _is_ccl_available