Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/kernels/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes

Expand Down Expand Up @@ -442,7 +443,7 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)


@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
@pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
scale = float(1.0 / (head_size**0.5))
Expand Down
78 changes: 67 additions & 11 deletions tests/kernels/attention/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
"""
Test:

* Tests for MultiHeadAttention layer
* Tests for MMEncoderAttention layer
"""

import itertools
from unittest.mock import patch

import pytest
import torch

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
Expand Down Expand Up @@ -42,35 +43,31 @@ def test_mha_attn_platform(device: str):

if device == "cpu":
with (
patch("vllm.attention.layer.current_platform", CpuPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip":
with (
patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN

# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
):
attn = MultiHeadAttention(16, 72, scale=1)
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN


Expand All @@ -94,6 +91,10 @@ def ref_attention(

BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
VAR_SEQ_LENS = [
[2, 2],
[2, 3, 4],
]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)
Expand All @@ -151,3 +152,58 @@ def test_mha_attn_forward(
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output)


@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
var_seq_len: list[int],
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)

q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
cu_seqlens = torch.tensor(
[0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
)
scale = 1.0 / head_size**0.5
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(
q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
)

assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

ref_output = []
for q_i, k_i, v_i in zip(
torch.split(q, var_seq_len, dim=1),
torch.split(k, var_seq_len, dim=1),
torch.split(v, var_seq_len, dim=1),
):
output_i = ref_attention(
q_i,
k_i,
v_i,
scale=scale,
)
ref_output.append(output_i)
ref_output = torch.cat(ref_output, dim=1)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
6 changes: 3 additions & 3 deletions tests/v1/tpu/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
Test:

* Tests for MultiHeadAttention layer
* Tests for MMEncoderAttention layer
"""

import pytest
Expand All @@ -12,7 +12,7 @@
import torch_xla.core
import torch_xla.core.xla_model

from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform

Expand Down Expand Up @@ -69,7 +69,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)
Expand Down
132 changes: 0 additions & 132 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""

import functools
from typing import cast

import torch
import torch.nn as nn
import torch.nn.functional as F

import vllm.envs as envs
from vllm.attention.backends.abstract import (
Expand All @@ -16,13 +14,10 @@
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
Expand All @@ -36,7 +31,6 @@
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
Expand Down Expand Up @@ -412,132 +406,6 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
)


class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix

assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()

# Determine the attention backend
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend

self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)

self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)

self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}

self.fa_version = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)

logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)

query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
cu_seqlens_k = torch.arange(
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
)

out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention

out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} backend yet."
)

return out.reshape(bsz, q_len, -1)


class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.

Expand Down
Loading