Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions diffsynth_engine/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@ def _parse_tuple(value: str) -> Tuple[int, int] | int:
raise ValueError(f"Cannot parse tuple: {value}, format should be '256,256' or '256'")


def _parse_attention_type(attn_type_str: str) -> AttentionType:
"""Convert string to AttentionType enum"""
return AttentionType[attn_type_str.upper()]


def _parse_attention_params(
attn_type: AttentionType,
attn_type: str,
sparge_topk: float | None = None,
) -> AttentionParams | None:
"""Parse attention parameters based on attention type"""
Expand All @@ -42,7 +37,6 @@ def parse_cli_args() -> Dict[str, Any]:

# Define choices
dtype_choices = ["float32", "float16", "bfloat16"]
attn_type_choices = [attn_type.name.lower() for attn_type in AttentionType]

# Model configuration group
model_group = parser.add_argument_group("Model Configuration")
Expand Down Expand Up @@ -107,7 +101,6 @@ def parse_cli_args() -> Dict[str, Any]:
"--attn-type",
type=str,
default="sdpa",
choices=attn_type_choices,
help="Attention type (default: sdpa)",
)
attn_group.add_argument(
Expand Down Expand Up @@ -178,7 +171,7 @@ def parse_cli_args() -> Dict[str, Any]:
args_dict["vae_tile_stride"] = _parse_tuple(args.vae_tile_stride)

# Attention configuration
attn_type = _parse_attention_type(args.attn_type)
attn_type = args.attn_type.lower()
args_dict["attn_type"] = attn_type
args_dict["attn_params"] = _parse_attention_params(attn_type, args.sparge_topk)

Expand Down
17 changes: 8 additions & 9 deletions diffsynth_engine/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from diffsynth_engine.layers.attention import AttentionType
from diffsynth_engine.layers.attention.ring import RING_ATTN_COMPATIBLE_TYPES
from diffsynth_engine.registry import get_attn_backend
from diffsynth_engine.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -37,7 +37,7 @@ class PipelineConfig:
vae_tile_stride: int | Tuple[int, int] = (192, 192)

# attention
attn_type: AttentionType = AttentionType.SDPA
attn_type: AttentionType | str = AttentionType.SDPA
attn_params: Optional[AttentionParams] = None

# parallelism
Expand All @@ -56,8 +56,9 @@ def from_dict(cls, args_dict: Dict[str, Any]) -> "PipelineConfig":
return cls(**filtered_dict)

def __post_init__(self):
self.attn_type = str(self.attn_type)
init_parallel_config(self)
validate_ring_attention_config(self)
validate_attn_config(self)


def init_parallel_config(config: PipelineConfig):
Expand Down Expand Up @@ -100,10 +101,8 @@ def init_parallel_config(config: PipelineConfig):
logger.warning("setting vae_tiled to True since use_vae_parallel is enabled")


def validate_ring_attention_config(config) -> None:
def validate_attn_config(config: PipelineConfig):
attn_backend = get_attn_backend(config.attn_type)
if config.sp_ring_degree is not None and config.sp_ring_degree > 1:
if config.attn_type not in RING_ATTN_COMPATIBLE_TYPES:
raise ValueError(
f"attention backend {config.attn_type} does not support ring attention "
f"(missing forward_with_lse). Use one of: {', '.join(str(t) for t in RING_ATTN_COMPATIBLE_TYPES)}"
)
if not attn_backend.supports_ring_attention():
raise ValueError(f"Attention backend {config.attn_type!r} does not support ring attention.")
2 changes: 1 addition & 1 deletion diffsynth_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.cuda import set_device

from diffsynth_engine.configs import PipelineConfig
from diffsynth_engine.pipelines.registry import (
from diffsynth_engine.registry import (
get_pipeline_class,
get_pipeline_class_name,
)
Expand Down
6 changes: 3 additions & 3 deletions diffsynth_engine/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from diffsynth_engine.layers.attention import AttentionMetadata, AttentionType
from diffsynth_engine.layers.attention import AttentionMetadata


@dataclass
class ForwardContext:
attn_metadata: Optional["AttentionMetadata"] = None
attn_type: Optional["AttentionType"] = None
attn_type: Optional[str] = None


_forward_context: ForwardContext | None = None
Expand Down Expand Up @@ -43,7 +43,7 @@ def override_forward_context(forward_context: Optional[ForwardContext] = None):
@contextmanager
def set_forward_context(
attn_metadata: Optional["AttentionMetadata"] = None,
attn_type: Optional["AttentionType"] = None,
attn_type: Optional[str] = None,
):
"""A context manager to that stores the current forward context."""
forward_context = ForwardContext(
Expand Down
2 changes: 0 additions & 2 deletions diffsynth_engine/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from .backends.abstract import AttentionMetadata, AttentionType
from .layer import LocalAttention, USPAttention
from .selector import get_attn_backend

__all__ = [
"AttentionType",
"AttentionMetadata",
"LocalAttention",
"USPAttention",
"get_attn_backend",
]
43 changes: 29 additions & 14 deletions diffsynth_engine/layers/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,25 @@

import torch

from diffsynth_engine.utils import logging

class AttentionType(enum.Enum):
SDPA = enum.auto()
FA2 = enum.auto()
FA3 = enum.auto()
FA3_FP8 = enum.auto()
FA4 = enum.auto()
AITER = enum.auto()
AITER_FP8 = enum.auto()
SAGE2 = enum.auto()
SAGE3 = enum.auto()
SPARGE = enum.auto()
logger = logging.get_logger(__name__)


class AttentionType(str, enum.Enum):
SDPA = "sdpa"
FA2 = "fa2"
FA3 = "fa3"
FA3_FP8 = "fa3_fp8"
FA4 = "fa4"
AITER = "aiter"
AITER_FP8 = "aiter_fp8"
SAGE2 = "sage2"
SAGE3 = "sage3"
SPARGE = "sparge"

def __str__(self) -> str:
return self.name.lower()
return self.value


class AttentionBackend(ABC):
Expand All @@ -37,7 +41,7 @@ def check_availability() -> None:

@staticmethod
@abstractmethod
def get_type() -> AttentionType:
def get_type() -> str:
raise NotImplementedError

@staticmethod
Expand All @@ -64,7 +68,18 @@ def get_supported_head_sizes() -> list[int]:
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
if (not supported_head_sizes) or head_size in supported_head_sizes:
return True

logger.error(
f"Attention backend {cls.get_type()!r} does not support head size {head_size}. "
f"Supported head sizes: {supported_head_sizes}"
)
return False

@classmethod
def supports_ring_attention(cls) -> bool:
return False


@dataclass
Expand Down
8 changes: 4 additions & 4 deletions diffsynth_engine/layers/attention/backends/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.AITER
def get_type() -> str:
return str(AttentionType.AITER)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand Down Expand Up @@ -92,8 +92,8 @@ def forward_with_lse(

class AiterFP8Backend(AiterBackend):
@staticmethod
def get_type() -> AttentionType:
return AttentionType.AITER_FP8
def get_type() -> str:
return str(AttentionType.AITER_FP8)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand Down
8 changes: 6 additions & 2 deletions diffsynth_engine/layers/attention/backends/flash_attn_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.FA2
def get_type() -> str:
return str(AttentionType.FA2)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand All @@ -40,6 +40,10 @@ def get_impl_cls() -> type["AttentionImpl"]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@classmethod
def supports_ring_attention(cls) -> bool:
return True


class FlashAttention2Impl(AttentionImpl):
def __init__(
Expand Down
12 changes: 8 additions & 4 deletions diffsynth_engine/layers/attention/backends/flash_attn_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.FA3
def get_type() -> str:
return str(AttentionType.FA3)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand All @@ -41,6 +41,10 @@ def get_impl_cls() -> type["AttentionImpl"]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@classmethod
def supports_ring_attention(cls) -> bool:
return True


class FlashAttention3Impl(AttentionImpl):
def __init__(
Expand Down Expand Up @@ -134,8 +138,8 @@ def forward_with_lse(

class FlashAttention3FP8Backend(FlashAttention3Backend):
@staticmethod
def get_type() -> AttentionType:
return AttentionType.FA3_FP8
def get_type() -> str:
return str(AttentionType.FA3_FP8)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand Down
8 changes: 6 additions & 2 deletions diffsynth_engine/layers/attention/backends/flash_attn_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.FA4
def get_type() -> str:
return str(AttentionType.FA4)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand All @@ -40,6 +40,10 @@ def get_impl_cls() -> type["AttentionImpl"]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@classmethod
def supports_ring_attention(cls) -> bool:
return True


class FlashAttention4Impl(AttentionImpl):
def __init__(
Expand Down
8 changes: 6 additions & 2 deletions diffsynth_engine/layers/attention/backends/sage_attn_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.SAGE2
def get_type() -> str:
return str(AttentionType.SAGE2)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand All @@ -39,6 +39,10 @@ def get_impl_cls() -> type["AttentionImpl"]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@classmethod
def supports_ring_attention(cls) -> bool:
return True


class SageAttention2Impl(AttentionImpl):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/layers/attention/backends/sage_attn_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.SAGE3
def get_type() -> str:
return str(AttentionType.SAGE3)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand Down
8 changes: 6 additions & 2 deletions diffsynth_engine/layers/attention/backends/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def check_availability() -> None:
pass

@staticmethod
def get_type() -> AttentionType:
return AttentionType.SDPA
def get_type() -> str:
return str(AttentionType.SDPA)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand All @@ -29,6 +29,10 @@ def get_impl_cls() -> type["AttentionImpl"]:
def get_supported_head_sizes() -> list[int]:
return []

@classmethod
def supports_ring_attention(cls) -> bool:
return True


class SDPAImpl(AttentionImpl):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/layers/attention/backends/sparge_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def check_availability() -> None:
raise RuntimeError(error_msg)

@staticmethod
def get_type() -> AttentionType:
return AttentionType.SPARGE
def get_type() -> str:
return str(AttentionType.SPARGE)

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
Expand Down
Loading