Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
176 changes: 176 additions & 0 deletions tests/torch_compile/e2e/v1/attention/test_attention_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""E2E compile tests for attention operations on RBLN NPU.

Compiles pure-PyTorch attention models via torch.compile(backend="rbln"),
runs on NPU hardware, and compares output with host-computed reference.

This verifies that the RBLN compiler correctly handles attention-pattern
computations (matmul-scale-mask-softmax-matmul) across different head
configurations that correspond to various TP sizes.
"""

import pytest
import torch

# RBLN NPU accumulates in FP16; expect ~5% relative error vs FP32 host reference
NPU_ATOL = 5e-2
NPU_RTOL = 5e-2

# TP head configs: (n_kv_heads, n_groups, head_dim) simulating TP=1,2,4
TP_CONFIGS = [
pytest.param(1, 4, 64, id="tp1-kv1-g4-d64"),
pytest.param(2, 2, 64, id="tp2-kv2-g2-d64"),
pytest.param(4, 1, 64, id="tp4-kv4-g1-d64"),
]


@pytest.fixture(autouse=True)
def reset_dynamo():
torch._dynamo.reset()
yield
torch._dynamo.reset()


# ---------------------------------------------------------------------------
# Attention models using pure PyTorch ops (compilable to NPU)
# ---------------------------------------------------------------------------

class ScaledDotProductAttention(torch.nn.Module):
"""Standard attention: Q @ K^T -> softmax -> @ V."""

def forward(self, q, k, v):
attn_weights = torch.matmul(q, k.transpose(-2, -1))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return torch.matmul(attn_weights, v)


class MaskedAttention(torch.nn.Module):
"""Attention with explicit mask: Q @ K^T + mask -> softmax -> @ V."""

def forward(self, q, k, v, mask):
attn_weights = torch.matmul(q, k.transpose(-2, -1))
attn_weights = attn_weights + mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return torch.matmul(attn_weights, v)


class GroupedQueryAttention(torch.nn.Module):
"""GQA: q=[B,H,G,L,D], k/v=[B,H,1,S,D] -> broadcast k/v to G groups."""

def forward(self, q, k, v):
n_groups = q.shape[2]
k = k.expand(-1, -1, n_groups, -1, -1)
v = v.expand(-1, -1, n_groups, -1, -1)
attn_weights = torch.matmul(q, k.transpose(-2, -1))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return torch.matmul(attn_weights, v)


class CausalAttention(torch.nn.Module):
"""Causal attention with triangular mask."""

def forward(self, q, k, v):
seq_len = q.shape[-2]
attn_weights = torch.matmul(q, k.transpose(-2, -1))
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=q.device), diagonal=1
)
attn_weights = attn_weights + causal_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return torch.matmul(attn_weights, v)


# ---------------------------------------------------------------------------
# E2E compile tests
# ---------------------------------------------------------------------------

class TestScaledDotProductAttentionCompile:

@pytest.mark.parametrize("n_kv_heads,n_groups,head_dim", TP_CONFIGS)
def test_sdpa_matches_host(self, n_kv_heads, n_groups, head_dim):
torch.manual_seed(42)
n_heads = n_kv_heads * n_groups
seq_len = 8

q = torch.randn(1, n_heads, seq_len, head_dim)
k = torch.randn(1, n_heads, seq_len, head_dim)
v = torch.randn(1, n_heads, seq_len, head_dim)

model = ScaledDotProductAttention()
expected = model(q, k, v)

compiled = torch.compile(model, backend="rbln", dynamic=False)
npu_output = compiled(q, k, v)

assert torch.allclose(npu_output, expected, atol=NPU_ATOL, rtol=NPU_RTOL), (
f"Max diff: {(npu_output - expected).abs().max().item()}"
)


class TestMaskedAttentionCompile:

@pytest.mark.parametrize("n_kv_heads,n_groups,head_dim", TP_CONFIGS)
def test_masked_attention_matches_host(self, n_kv_heads, n_groups, head_dim):
torch.manual_seed(42)
n_heads = n_kv_heads * n_groups
seq_len = 8

q = torch.randn(1, n_heads, seq_len, head_dim)
k = torch.randn(1, n_heads, seq_len, head_dim)
v = torch.randn(1, n_heads, seq_len, head_dim)
mask = torch.triu(torch.full((seq_len, seq_len), float("-inf")), diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0)

model = MaskedAttention()
expected = model(q, k, v, mask)

compiled = torch.compile(model, backend="rbln", dynamic=False)
npu_output = compiled(q, k, v, mask)

assert torch.allclose(npu_output, expected, atol=NPU_ATOL, rtol=NPU_RTOL), (
f"Max diff: {(npu_output - expected).abs().max().item()}"
)


class TestGroupedQueryAttentionCompile:

@pytest.mark.parametrize("n_kv_heads,n_groups,head_dim", TP_CONFIGS)
def test_gqa_matches_host(self, n_kv_heads, n_groups, head_dim):
torch.manual_seed(42)
seq_len = 8

q = torch.randn(1, n_kv_heads, n_groups, seq_len, head_dim)
k = torch.randn(1, n_kv_heads, 1, seq_len, head_dim)
v = torch.randn(1, n_kv_heads, 1, seq_len, head_dim)

model = GroupedQueryAttention()
expected = model(q, k, v)

compiled = torch.compile(model, backend="rbln", dynamic=False)
npu_output = compiled(q, k, v)

assert torch.allclose(npu_output, expected, atol=NPU_ATOL, rtol=NPU_RTOL), (
f"Max diff: {(npu_output - expected).abs().max().item()}"
)


class TestCausalAttentionCompile:

@pytest.mark.parametrize("n_kv_heads,n_groups,head_dim", TP_CONFIGS)
def test_causal_attention_matches_host(self, n_kv_heads, n_groups, head_dim):
torch.manual_seed(42)
n_heads = n_kv_heads * n_groups
seq_len = 8

q = torch.randn(1, n_heads, seq_len, head_dim)
k = torch.randn(1, n_heads, seq_len, head_dim)
v = torch.randn(1, n_heads, seq_len, head_dim)

model = CausalAttention()
expected = model(q, k, v)

compiled = torch.compile(model, backend="rbln", dynamic=False)
npu_output = compiled(q, k, v)

assert torch.allclose(npu_output, expected, atol=NPU_ATOL, rtol=NPU_RTOL), (
f"Max diff: {(npu_output - expected).abs().max().item()}"
)
Empty file.
26 changes: 26 additions & 0 deletions tests/torch_compile/unit/triton_kernels/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

import vllm_rbln

QUERY_SHAPE = (2, 1, 4, 3, 32)
KV_SHAPE = (2, 1, 1, 3, 32)
KV_CACHE_SHAPE = (2, 2, 1, 1, 4, 32)

ALL_OPS = [
"attention_naive_prefill",
"attention_naive_decode",
"causal_attention_naive_prefill",
"causal_attention_naive_decode",
"flash_attention_naive_prefill",
"flash_attention_naive_decode",
"flash_causal_attention_naive_prefill",
"flash_causal_attention_naive_decode",
"sliding_window_attention_naive_prefill",
"sliding_window_attention_naive_decode",
]


@pytest.fixture(autouse=True)
def register_triton_ops(monkeypatch):
monkeypatch.setattr(vllm_rbln.envs, "VLLM_RBLN_USE_VLLM_MODEL", True, raising=False)
vllm_rbln.register_ops()
106 changes: 106 additions & 0 deletions tests/torch_compile/unit/triton_kernels/test_fake_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytest
import torch
from torch._subclasses.fake_tensor import FakeTensorMode

from .conftest import KV_CACHE_SHAPE, KV_SHAPE, QUERY_SHAPE


def _meta(shape, dtype=torch.float16):
return torch.empty(shape, dtype=dtype, device="meta")


def _build_attention_meta():
return (
_meta(QUERY_SHAPE),
_meta(KV_SHAPE),
_meta(KV_SHAPE),
_meta(KV_CACHE_SHAPE),
_meta((1, 1, 1, QUERY_SHAPE[-2], 5)),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
)


def _build_causal_attention_meta():
return (
_meta(QUERY_SHAPE),
_meta(KV_SHAPE),
_meta(KV_SHAPE),
_meta(KV_CACHE_SHAPE),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
)


def _build_flash_attention_meta():
return (
_meta(QUERY_SHAPE),
_meta(KV_SHAPE),
_meta(KV_SHAPE),
_meta(KV_CACHE_SHAPE),
_meta((1, 1, 1, QUERY_SHAPE[-2], 5)),
_meta((), dtype=torch.float16),
_meta((2, 1), dtype=torch.int16),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
)


def _build_flash_causal_attention_meta():
return (
_meta(QUERY_SHAPE),
_meta(KV_SHAPE),
_meta(KV_SHAPE),
_meta(KV_CACHE_SHAPE),
_meta((), dtype=torch.float16),
_meta((2, 2), dtype=torch.int16),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
)


def _build_sliding_window_attention_meta():
return (
_meta(QUERY_SHAPE),
_meta(KV_SHAPE),
_meta(KV_SHAPE),
_meta(KV_CACHE_SHAPE),
_meta((2, 1), dtype=torch.int16),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
_meta((2, 1), dtype=torch.int16),
_meta((), dtype=torch.float16),
)


FAKE_OP_SPECS = [
("attention_naive_prefill", _build_attention_meta),
("attention_naive_decode", _build_attention_meta),
("causal_attention_naive_prefill", _build_causal_attention_meta),
("causal_attention_naive_decode", _build_causal_attention_meta),
("flash_attention_naive_prefill", _build_flash_attention_meta),
("flash_attention_naive_decode", _build_flash_attention_meta),
("flash_causal_attention_naive_prefill", _build_flash_causal_attention_meta),
("flash_causal_attention_naive_decode", _build_flash_causal_attention_meta),
("sliding_window_attention_naive_prefill", _build_sliding_window_attention_meta),
("sliding_window_attention_naive_decode", _build_sliding_window_attention_meta),
]


@pytest.mark.parametrize(
("op_name", "build_meta"),
FAKE_OP_SPECS,
ids=[s[0] for s in FAKE_OP_SPECS],
)
def test_fake_op_returns_correct_shape_and_dtype(op_name, build_meta):
"""All fake ops must return a tensor matching the query's shape and dtype."""
args = build_meta()
with FakeTensorMode(allow_non_fake_inputs=True):
op = getattr(torch.ops.rbln_triton_ops, op_name)
result = op(*args)
assert result.shape == QUERY_SHAPE
assert result.dtype == args[0].dtype
12 changes: 12 additions & 0 deletions tests/torch_compile/unit/triton_kernels/test_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest
import torch

from .conftest import ALL_OPS


@pytest.mark.parametrize("op_name", ALL_OPS)
def test_triton_op_is_registered_in_torch_ops(op_name):
ns = torch.ops.rbln_triton_ops
assert hasattr(ns, op_name), f"{op_name} not found in rbln_triton_ops"
op = getattr(ns, op_name)
assert hasattr(op, "default"), f"{op_name}.default not found"
Loading
Loading