-
Notifications
You must be signed in to change notification settings - Fork 45
Expand file tree
/
Copy pathabstract.py
More file actions
112 lines (86 loc) · 2.66 KB
/
Copy pathabstract.py
File metadata and controls
112 lines (86 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Adapted from https://github.com/vllm-project/vllm
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
import torch
class AttentionType(enum.Enum):
SDPA = enum.auto()
FA2 = enum.auto()
FA3 = enum.auto()
FA3_FP8 = enum.auto()
FA4 = enum.auto()
AITER = enum.auto()
AITER_FP8 = enum.auto()
SAGE2 = enum.auto()
SAGE3 = enum.auto()
SPARGE = enum.auto()
MINDIE = enum.auto()
def __str__(self) -> str:
return self.name.lower()
class AttentionBackend(ABC):
"""Abstract class for diffusion attention backends."""
@staticmethod
@abstractmethod
def check_availability() -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_type() -> AttentionType:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_supported_head_sizes() -> list[int]:
"""Get the list of supported head sizes for this backend."""
raise NotImplementedError
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
@dataclass
class AttentionMetadata:
pass
T = TypeVar("T", bound=AttentionMetadata)
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self) -> None:
raise NotImplementedError
@abstractmethod
def build(self, **kwargs) -> AttentionMetadata:
raise NotImplementedError
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float | None = None,
causal: bool = False,
num_kv_heads: int | None = None,
**extra_impl_args,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
raise NotImplementedError