Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion vllm_omni/diffusion/cache/teacache/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,29 @@ def forward_alias(self, *args, **kwargs):
)


CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache}
def enable_flux2_klein_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None:
"""
Enable TeaCache for Flux2 Klein model.
"""
teacache_config = TeaCacheConfig(
transformer_type="Flux2Klein",
rel_l1_thresh=config.rel_l1_thresh,
coefficients=config.coefficients,
)
transformer = pipeline.transformer

apply_teacache_hook(transformer, teacache_config)

logger.info(
f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, "
f"transformer_class={teacache_config.transformer_type}"
)


CUSTOM_TEACACHE_ENABLERS = {
"BagelPipeline": enable_bagel_teacache,
"Flux2KleinPipeline": enable_flux2_klein_teacache,
}


class TeaCacheBackend(CacheBackend):
Expand Down
9 changes: 9 additions & 0 deletions vllm_omni/diffusion/cache/teacache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
-3.82021401e00,
2.64230861e-01,
],
# Flux2 Klein transformer coefficients
# Same as FLUX.1 (similar dual-stream architecture)
"Flux2Klein": [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
],
# Qwen-Image transformer coefficients from ComfyUI-TeaCache
# Tuned specifically for Qwen's dual-stream transformer architecture
# Used for all Qwen-Image Family pipelines, in general
Expand Down
144 changes: 144 additions & 0 deletions vllm_omni/diffusion/cache/teacache/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,149 @@ def postprocess(h):
)


def extract_flux2_klein_context(
module,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> CacheContext:
"""
Extract cache context for Flux2Klein model.

Only caches transformer_blocks output. single_transformer_blocks is always executed.

Args:
module: Flux2Transformer2DModel instance
hidden_states: Input image hidden states tensor
encoder_hidden_states: Input text hidden states tensor
timestep: Current diffusion timestep
img_ids: Image position IDs for RoPE
txt_ids: Text position IDs for RoPE
guidance: Optional guidance scale for CFG
joint_attention_kwargs: Additional attention kwargs

Returns:
CacheContext with all information needed for generic caching
"""
from diffusers.models.modeling_outputs import Transformer2DModelOutput

if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
raise ValueError("Module must have transformer_blocks")

dtype = hidden_states.dtype

num_txt_tokens = encoder_hidden_states.shape[1]

timestep = timestep.to(dtype=dtype) * 1000
if guidance is not None:
guidance = guidance.to(dtype=dtype) * 1000

temb = module.time_guidance_embed(timestep, guidance)

double_stream_mod_img = module.double_stream_modulation_img(temb)
double_stream_mod_txt = module.double_stream_modulation_txt(temb)
single_stream_mod = module.single_stream_modulation(temb)[0]

hidden_states = module.x_embedder(hidden_states)
encoder_hidden_states = module.context_embedder(encoder_hidden_states)

if img_ids.ndim == 3:
img_ids = img_ids[0]
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

image_rotary_emb = module.pos_embed(img_ids)
text_rotary_emb = module.pos_embed(txt_ids)
concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
)

block = module.transformer_blocks[0]

norm_hidden_states = block.norm1(hidden_states)
norm_hidden_states = (1 + double_stream_mod_img[0][0]) * norm_hidden_states + double_stream_mod_img[0][1]

modulated_input = norm_hidden_states

def run_flux2_transformer_blocks():
h = hidden_states
c = encoder_hidden_states
for block in module.transformer_blocks:
c, h = block(
hidden_states=h,
encoder_hidden_states=c,
temb_mod_params_img=double_stream_mod_img,
temb_mod_params_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
return (h, c)

def run_flux2_single_transformer_blocks(c, h):
h_concat = torch.cat([c, h], dim=1)
for block in module.single_transformer_blocks:
h_concat = block(
hidden_states=h_concat,
encoder_hidden_states=None,
temb_mod_params=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
return h_concat[:, num_txt_tokens:, ...]

def run_flux2_full_transformer_with_single(ori_h, ori_c):
h = ori_h
c = ori_c
for block in module.transformer_blocks:
c, h = block(
hidden_states=h,
encoder_hidden_states=c,
temb_mod_params_img=double_stream_mod_img,
temb_mod_params_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
h_concat = torch.cat([c, h], dim=1)
for block in module.single_transformer_blocks:
h_concat = block(
hidden_states=h_concat,
encoder_hidden_states=None,
temb_mod_params=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
final_hidden_states = h_concat[:, num_txt_tokens:, ...]
return final_hidden_states, c

return_dict = kwargs.get("return_dict", True)

def postprocess(h):
h = module.norm_out(h, temb)
h = module.proj_out(h)
if not return_dict:
return (h,)
return Transformer2DModelOutput(sample=h)

return CacheContext(
modulated_input=modulated_input,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
run_transformer_blocks=run_flux2_transformer_blocks,
postprocess=postprocess,
extra_states={
"run_flux2_single_transformer_blocks": run_flux2_single_transformer_blocks,
"run_flux2_full_transformer_with_single": run_flux2_full_transformer_with_single,
},
)


# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
Expand All @@ -576,6 +719,7 @@ def postprocess(h):
"QwenImageTransformer2DModel": extract_qwen_context,
"Bagel": extract_bagel_context,
"ZImageTransformer2DModel": extract_zimage_context,
"Flux2Klein": extract_flux2_klein_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
Expand Down
34 changes: 20 additions & 14 deletions vllm_omni/diffusion/cache/teacache/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,26 @@ def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any
ctx.encoder_hidden_states.clone() if ctx.encoder_hidden_states is not None else None
)

# Run transformer blocks using model-specific callable
outputs = ctx.run_transformer_blocks()

# Update context with outputs
ctx.hidden_states = outputs[0]
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
ctx.encoder_hidden_states = outputs[1]

# Cache residuals for next timestep
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
if ori_encoder_hidden_states is not None:
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()

output = ctx.hidden_states
# Handle models with additional blocks (e.g., Flux2 single_transformer_blocks)
if getattr(ctx, "extra_states", None) and "run_flux2_full_transformer_with_single" in ctx.extra_states:
run_full = ctx.extra_states["run_flux2_full_transformer_with_single"]
ctx.hidden_states, ctx.encoder_hidden_states = run_full(ori_hidden_states, ori_encoder_hidden_states)
output = ctx.hidden_states
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
else:
# Run transformer blocks using model-specific callable
outputs = ctx.run_transformer_blocks()
# Update context with outputs
ctx.hidden_states = outputs[0]
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
ctx.encoder_hidden_states = outputs[1]

output = ctx.hidden_states

# Cache residuals for next timestep
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
if ori_encoder_hidden_states is not None:
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()

# Update state
state.previous_modulated_input = ctx.modulated_input.detach()
Expand Down
Loading