Skip to content

Deepseek r1 #3211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 19, 2025
5 changes: 4 additions & 1 deletion Dockerfile_gaudi
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor
ENV ATTENTION=default
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0

# Text Generation Inference base env
ENV HF_HOME=/data \
Expand Down Expand Up @@ -95,7 +97,8 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
Expand Down
10 changes: 9 additions & 1 deletion backends/gaudi/server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class Dtype(str, Enum):
bloat16 = "bfloat16"


class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"


@app.command()
def serve(
model_id: str,
Expand All @@ -34,6 +39,7 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
Expand Down Expand Up @@ -93,7 +99,8 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}")
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",
Expand Down Expand Up @@ -175,6 +182,7 @@ def terminate_handler(sig, frame):
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear

from text_generation_server.layers.lora import (
LoraLinear,
Expand All @@ -27,6 +28,7 @@
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@
SUPPORTS_WINDOWING,
attention,
paged_attention,
paged_attention_mla,
)


# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales
from .kv_cache import KVCache, get_kv_scales, KVCompressCache

__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
Expand Down
110 changes: 96 additions & 14 deletions backends/gaudi/server/text_generation_server/layers/attention/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,61 @@
SUPPORTS_WINDOWING = False


def fetch_from_cache(cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)]
else:
return cache.index_select(0, blocks)
class FP8Matmul(torch.nn.Module):

def __init__(self, scale_other):
super().__init__()
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other

def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]

def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)

def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output


class FetchFromCache(torch.nn.Module):

def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv

def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out


def attention(
Expand Down Expand Up @@ -67,6 +117,7 @@ def paged_attention(
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
Expand All @@ -76,19 +127,50 @@ def paged_attention(
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache,
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)


__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]
def paged_attention_mla(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)


__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")

self.kv_cache = (
torch.zeros(
Expand Down Expand Up @@ -101,22 +103,92 @@ def store(
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
kv_scales.key_scale,
kv_scales.value_scale,
)


class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""

kv_cache: torch.Tensor

def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")

self.kv_cache = torch.zeros(
(num_blocks, BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)

@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype

@property
def key(self):
"""Get the key cache."""

return self.kv_cache

@property
def value(self):
"""Get the value cache."""

return self.kv_cache

def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support

block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)


def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)

Expand Down
Loading