Skip to content
56 changes: 39 additions & 17 deletions python/sglang/multimodal_gen/configs/models/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,29 @@

@dataclass
class AdapterArchConfig(ArchConfig):
hidden_size: int = 0
num_attention_heads: int = 0
num_channels_latents: int = 0


@dataclass
class AdapterConfig(ModelConfig[AdapterArchConfig]):
arch_config: AdapterArchConfig = field(default_factory=AdapterArchConfig)
_internal_config_fields = (
"_fsdp_shard_conditions",
"_compile_conditions",
"_supported_attention_backends",
"param_names_mapping",
"reverse_param_names_mapping",
"exclude_lora_layers",
"boundary_ratio",
)

# sglang-diffusion Adapter-specific parameters
prefix: str = ""
_fsdp_shard_conditions: list = field(default_factory=list)
_compile_conditions: list = field(default_factory=list)

# convert weights name from HF-format to SGLang-dit-format
param_names_mapping: dict = field(default_factory=dict)

# Reverse mapping for saving checkpoints: custom -> hf
reverse_param_names_mapping: dict = field(default_factory=dict)
_supported_attention_backends: set[AttentionBackendEnum] = field(
default_factory=lambda: {
Expand All @@ -29,25 +45,31 @@ class AdapterArchConfig(ArchConfig):
AttentionBackendEnum.SAGE_ATTN_3,
}
)

hidden_size: int = 0
num_attention_heads: int = 0
num_channels_latents: int = 0
exclude_lora_layers: list[str] = field(default_factory=list)
boundary_ratio: float | None = None

def __post_init__(self) -> None:
def refresh_model_config(self) -> None:
if hasattr(self.arch_config, "_fsdp_shard_conditions"):
self._fsdp_shard_conditions = list(self.arch_config._fsdp_shard_conditions)
if hasattr(self.arch_config, "_compile_conditions"):
self._compile_conditions = list(self.arch_config._compile_conditions)
if hasattr(self.arch_config, "param_names_mapping"):
self.param_names_mapping = dict(self.arch_config.param_names_mapping)
if hasattr(self.arch_config, "reverse_param_names_mapping"):
self.reverse_param_names_mapping = dict(
self.arch_config.reverse_param_names_mapping
)
if hasattr(self.arch_config, "_supported_attention_backends"):
self._supported_attention_backends = set(
self.arch_config._supported_attention_backends
)
if hasattr(self.arch_config, "exclude_lora_layers"):
self.exclude_lora_layers = list(self.arch_config.exclude_lora_layers)
if hasattr(self.arch_config, "boundary_ratio"):
self.boundary_ratio = self.arch_config.boundary_ratio
if not self._compile_conditions:
self._compile_conditions = self._fsdp_shard_conditions.copy()


@dataclass
class AdapterConfig(ModelConfig):
arch_config: AdapterArchConfig = field(default_factory=AdapterArchConfig)

# sglang-diffusion Adapter-specific parameters
prefix: str = ""

@staticmethod
def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
"""Add CLI arguments for AdapterConfig fields"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class LTX2ConnectorArchConfig(AdapterArchConfig):

@dataclass
class LTX2ConnectorConfig(AdapterConfig):

arch_config: AdapterArchConfig = field(default_factory=LTX2ConnectorArchConfig)
arch_config: LTX2ConnectorArchConfig = field(
default_factory=LTX2ConnectorArchConfig
)

prefix: str = "LTX2"
63 changes: 52 additions & 11 deletions python/sglang/multimodal_gen/configs/models/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
import dataclasses
from dataclasses import dataclass, field, fields
from typing import Any, Dict
from typing import Any, ClassVar, Dict, Generic, TypeVar

from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)

ArchConfigT = TypeVar("ArchConfigT", bound="ArchConfig")


# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model)
# 2. ArchConfig should be inherited & overridden by each model arch_config
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names
extra_attrs: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
self.refresh_derived_fields()

def refresh_derived_fields(self) -> None:
pass

def __getattr__(self, name: str):
d = object.__getattribute__(self, "__dict__")
extras = d.get("extra_attrs")
Expand All @@ -41,13 +47,17 @@ def __setattr__(self, key, value):


@dataclass
class ModelConfig:
class ModelConfig(Generic[ArchConfigT]):
# Every model config parameter can be categorized into either ArchConfig or everything else
# Diffuser/Transformer parameters
arch_config: ArchConfig = field(default_factory=ArchConfig)
arch_config: ArchConfigT = field(default_factory=ArchConfig) # type: ignore[arg-type]

# sglang-diffusion-specific parameters here
# i.e. STA, quantization, teacache
_internal_config_fields: ClassVar[tuple[str, ...]] = ()

def __post_init__(self) -> None:
self.refresh_derived_fields()

def __getattr__(self, name):
# Only called if 'name' is not found in ModelConfig directly
Expand All @@ -67,6 +77,27 @@ def __setstate__(self, state):
# Restore instance attributes from the unpickled state
self.__dict__.update(state)

def refresh_model_config(self) -> None:
pass

def refresh_derived_fields(self) -> None:
self.arch_config.refresh_derived_fields()
self.refresh_model_config()

@classmethod
def internal_config_fields(cls) -> set[str]:
internal_fields: set[str] = set()
for base in reversed(cls.__mro__):
internal_fields.update(getattr(base, "_internal_config_fields", ()))
return internal_fields

def to_user_dict(self) -> dict[str, Any]:
state = dataclasses.asdict(self)
state.pop("arch_config", None)
for field_name in self.internal_config_fields():
state.pop(field_name, None)
return state

# This should be used only when loading from transformers/diffusers
def update_model_arch(self, source_model_dict: dict[str, Any]) -> None:
"""
Expand All @@ -77,15 +108,26 @@ def update_model_arch(self, source_model_dict: dict[str, Any]) -> None:
for key, value in source_model_dict.items():
setattr(arch_config, key, value)

if hasattr(arch_config, "__post_init__"):
arch_config.__post_init__()
legacy_post_init = type(arch_config).__dict__.get("__post_init__")
if (
legacy_post_init is not None
and legacy_post_init is not ArchConfig.__post_init__
):
legacy_post_init(arch_config)
Comment on lines +111 to +116
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check type(arch_config).__dict__.get("__post_init__") is fragile because it only looks at the immediate class's dictionary. If a custom ArchConfig inherits its __post_init__ from a parent class (other than the base ArchConfig), it won't be detected here. A more robust check would be to compare the bound method's underlying function to the base implementation.

        if arch_config.__post_init__.__func__ is not ArchConfig.__post_init__:
            arch_config.__post_init__()
        else:
            arch_config.refresh_derived_fields()

else:
arch_config.refresh_derived_fields()
self.refresh_model_config()

def update_model_config(self, source_model_dict: dict[str, Any]) -> None:
assert (
"arch_config" not in source_model_dict
), "Source model config shouldn't contain arch_config."

valid_fields = {f.name for f in fields(self)}
valid_fields = {
f.name
for f in fields(self)
if f.name != "arch_config" and f.name not in self.internal_config_fields()
}

for key, value in source_model_dict.items():
if key in valid_fields:
Expand All @@ -96,5 +138,4 @@ def update_model_config(self, source_model_dict: dict[str, Any]) -> None:
)
raise AttributeError(f"Invalid field: {key}")

if hasattr(self, "__post_init__"):
self.__post_init__()
self.refresh_derived_fields()
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class MOVADualTowerArchConfig(DiTArchConfig):
pooled_adaln: bool = False
eps: float = 1e-6

def __post_init__(self):
super().__post_init__()
def refresh_derived_fields(self):
super().refresh_derived_fields()
self.hidden_size = self.visual_hidden_dim
self.num_attention_heads = self.visual_hidden_dim // self.head_dim


@dataclass
class MOVADualTowerConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=MOVADualTowerArchConfig)
arch_config: MOVADualTowerArchConfig = field(
default_factory=MOVADualTowerArchConfig
)
70 changes: 49 additions & 21 deletions python/sglang/multimodal_gen/configs/models/dits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,34 @@

@dataclass
class DiTArchConfig(ArchConfig):
hidden_size: int = 0
num_attention_heads: int = 0
num_channels_latents: int = 0


@dataclass
class DiTConfig(ModelConfig[DiTArchConfig]):
arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
_internal_config_fields = (
"_fsdp_shard_conditions",
"_compile_conditions",
"_supported_attention_backends",
"stacked_params_mapping",
"param_names_mapping",
"lora_param_names_mapping",
"reverse_param_names_mapping",
"exclude_lora_layers",
"boundary_ratio",
)

# sglang-diffusion DiT-specific parameters
prefix: str = ""
quant_config: QuantizationConfig | None = None
stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)
_fsdp_shard_conditions: list = field(default_factory=list)
_compile_conditions: list = field(default_factory=list)

# convert weights name from HF-format to SGLang-dit-format
param_names_mapping: dict = field(default_factory=dict)

# convert weights name from misc-format to HF-format
# usually applicable if the LoRA is trained with official repo implementation
lora_param_names_mapping: dict = field(default_factory=dict)

# Reverse mapping for saving checkpoints: custom -> hf
reverse_param_names_mapping: dict = field(default_factory=dict)
_supported_attention_backends: set[AttentionBackendEnum] = field(
default_factory=lambda: {
Expand All @@ -37,26 +54,37 @@ class DiTArchConfig(ArchConfig):
AttentionBackendEnum.SAGE_ATTN_3,
}
)

hidden_size: int = 0
num_attention_heads: int = 0
num_channels_latents: int = 0
exclude_lora_layers: list[str] = field(default_factory=list)
boundary_ratio: float | None = None

def __post_init__(self) -> None:
def refresh_model_config(self) -> None:
if hasattr(self.arch_config, "stacked_params_mapping"):
self.stacked_params_mapping = list(self.arch_config.stacked_params_mapping)
if hasattr(self.arch_config, "_fsdp_shard_conditions"):
self._fsdp_shard_conditions = list(self.arch_config._fsdp_shard_conditions)
if hasattr(self.arch_config, "_compile_conditions"):
self._compile_conditions = list(self.arch_config._compile_conditions)
if hasattr(self.arch_config, "param_names_mapping"):
self.param_names_mapping = dict(self.arch_config.param_names_mapping)
if hasattr(self.arch_config, "lora_param_names_mapping"):
self.lora_param_names_mapping = dict(
self.arch_config.lora_param_names_mapping
)
if hasattr(self.arch_config, "reverse_param_names_mapping"):
self.reverse_param_names_mapping = dict(
self.arch_config.reverse_param_names_mapping
)
if hasattr(self.arch_config, "_supported_attention_backends"):
self._supported_attention_backends = set(
self.arch_config._supported_attention_backends
)
if hasattr(self.arch_config, "exclude_lora_layers"):
self.exclude_lora_layers = list(self.arch_config.exclude_lora_layers)
if hasattr(self.arch_config, "boundary_ratio"):
self.boundary_ratio = self.arch_config.boundary_ratio
if not self._compile_conditions:
self._compile_conditions = self._fsdp_shard_conditions.copy()


@dataclass
class DiTConfig(ModelConfig):
arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)

# sglang-diffusion DiT-specific parameters
prefix: str = ""
quant_config: QuantizationConfig | None = None

@staticmethod
def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
"""Add CLI arguments for DiTConfig fields"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class ErnieImageArchConfig(DiTArchConfig):
default_factory=lambda: [_is_transformer_layer]
)

def __post_init__(self):
super().__post_init__()
def refresh_derived_fields(self):
super().refresh_derived_fields()
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels


@dataclass
class ErnieImageDitConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=ErnieImageArchConfig)
arch_config: ErnieImageArchConfig = field(default_factory=ErnieImageArchConfig)
prefix: str = "ernieimage"
7 changes: 3 additions & 4 deletions python/sglang/multimodal_gen/configs/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,15 @@ class FluxArchConfig(DiTArchConfig):
}
)

def __post_init__(self):
super().__post_init__()
def refresh_derived_fields(self):
super().refresh_derived_fields()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels


@dataclass
class FluxConfig(DiTConfig):

arch_config: DiTArchConfig = field(default_factory=FluxArchConfig)
arch_config: FluxArchConfig = field(default_factory=FluxArchConfig)

prefix: str = "Flux"
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class GlmImageArchConfig(DiTArchConfig):
}
)

def __post_init__(self):
super().__post_init__()
def refresh_derived_fields(self):
super().refresh_derived_fields()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels


@dataclass
class GlmImageDitConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=GlmImageArchConfig)
arch_config: GlmImageArchConfig = field(default_factory=GlmImageArchConfig)

prefix: str = "glmimage"
6 changes: 3 additions & 3 deletions python/sglang/multimodal_gen/configs/models/dits/helios.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ class HeliosArchConfig(DiTArchConfig):
is_amplify_history: bool = False
history_scale_mode: str = "per_head"

def __post_init__(self):
super().__post_init__()
def refresh_derived_fields(self):
super().refresh_derived_fields()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels


@dataclass
class HeliosConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=HeliosArchConfig)
arch_config: HeliosArchConfig = field(default_factory=HeliosArchConfig)

prefix: str = "Helios"
Loading
Loading