Skip to content

Commit 8025273

Browse files
committed
[Feat] support TeaCache for FLUX1.dev
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 6c19f3e commit 8025273

File tree

3 files changed

+180
-15
lines changed

3 files changed

+180
-15
lines changed

vllm_omni/diffusion/cache/teacache/backend.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,26 @@ def forward_alias(self, *args, **kwargs):
4848
)
4949

5050

51-
CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache}
51+
def enable_flux_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None:
52+
"""
53+
Enable TeaCache for Flux (Flux1) model.
54+
"""
55+
teacache_config = TeaCacheConfig(
56+
transformer_type="FluxTransformer2DModel",
57+
rel_l1_thresh=config.rel_l1_thresh,
58+
coefficients=config.coefficients,
59+
)
60+
transformer = pipeline.transformer
61+
62+
apply_teacache_hook(transformer, teacache_config)
63+
64+
logger.info(
65+
f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, "
66+
f"transformer_class={teacache_config.transformer_type}"
67+
)
68+
69+
70+
CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache, "FluxPipeline": enable_flux_teacache}
5271

5372

5473
class TeaCacheBackend(CacheBackend):

vllm_omni/diffusion/cache/teacache/extractors.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
import torch.nn as nn
22+
from diffusers.utils import is_torch_npu_available
2223

2324
from vllm_omni.diffusion.forward_context import get_forward_context
2425

@@ -566,6 +567,144 @@ def postprocess(h):
566567
)
567568

568569

570+
def extract_flux_context(
571+
module: nn.Module,
572+
hidden_states: torch.Tensor,
573+
encoder_hidden_states: torch.Tensor = None,
574+
pooled_projections: torch.Tensor = None,
575+
timestep: torch.LongTensor = None,
576+
img_ids: torch.Tensor = None,
577+
txt_ids: torch.Tensor = None,
578+
guidance: torch.Tensor | None = None,
579+
joint_attention_kwargs: dict[str, Any] | None = None,
580+
**kwargs: Any,
581+
) -> CacheContext:
582+
"""
583+
Extract cache context for Flux1-dev model.
584+
585+
Only caches transformer_blocks output. single_transformer_blocks is always executed.
586+
587+
Args:
588+
module: FluxTransformer2DModel instance
589+
hidden_states: Input image hidden states tensor
590+
encoder_hidden_states: Input text hidden states tensor
591+
pooled_projections: Pooled text embeddings
592+
timestep: Current diffusion timestep
593+
img_ids: Image position IDs for RoPE
594+
txt_ids: Text position IDs for RoPE
595+
guidance: Optional guidance scale for CFG
596+
joint_attention_kwargs: Additional attention kwargs
597+
598+
Returns:
599+
CacheContext with all information needed for generic caching
600+
"""
601+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
602+
603+
if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
604+
raise ValueError("Module must have transformer_blocks")
605+
606+
# ============================================================================
607+
# PREPROCESSING (Flux-specific)
608+
# ============================================================================
609+
dtype = hidden_states.dtype
610+
device = hidden_states.device
611+
timestep = timestep.to(device=device, dtype=dtype) * 1000
612+
if guidance is not None:
613+
guidance = guidance.to(device=device, dtype=dtype) * 1000
614+
615+
temb = (
616+
module.time_text_embed(timestep, pooled_projections)
617+
if guidance is None
618+
else module.time_text_embed(timestep, guidance, pooled_projections)
619+
)
620+
621+
hidden_states = module.x_embedder(hidden_states)
622+
encoder_hidden_states = module.context_embedder(encoder_hidden_states)
623+
624+
if txt_ids.ndim == 3:
625+
txt_ids = txt_ids[0]
626+
if img_ids.ndim == 3:
627+
img_ids = img_ids[0]
628+
629+
ids = torch.cat((txt_ids, img_ids), dim=0)
630+
if is_torch_npu_available():
631+
freqs_cos, freqs_sin = module.pos_embed(ids.cpu())
632+
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
633+
else:
634+
image_rotary_emb = module.pos_embed(ids)
635+
636+
# ============================================================================
637+
# EXTRACT MODULATED INPUT (for cache decision)
638+
# ============================================================================
639+
block = module.transformer_blocks[0]
640+
norm_output = block.norm1(hidden_states, emb=temb)
641+
if isinstance(norm_output, tuple):
642+
norm_hidden_states = norm_output[0]
643+
else:
644+
norm_hidden_states = norm_output
645+
modulated_input = norm_hidden_states
646+
647+
# ============================================================================
648+
# DEFINE TRANSFORMER EXECUTION (Flux-specific)
649+
# ============================================================================
650+
def run_flux_transformer_blocks():
651+
h = hidden_states
652+
c = encoder_hidden_states
653+
for block in module.transformer_blocks:
654+
c, h = block(
655+
hidden_states=h,
656+
encoder_hidden_states=c,
657+
temb=temb,
658+
image_rotary_emb=image_rotary_emb,
659+
joint_attention_kwargs=joint_attention_kwargs,
660+
)
661+
return (h, c)
662+
663+
def run_flux_full_transformer_with_single(ori_h, ori_c):
664+
h = ori_h
665+
c = ori_c
666+
for block in module.transformer_blocks:
667+
c, h = block(
668+
hidden_states=h,
669+
encoder_hidden_states=c,
670+
temb=temb,
671+
image_rotary_emb=image_rotary_emb,
672+
joint_attention_kwargs=joint_attention_kwargs,
673+
)
674+
for block in module.single_transformer_blocks:
675+
c, h = block(
676+
hidden_states=h,
677+
encoder_hidden_states=c,
678+
temb=temb,
679+
image_rotary_emb=image_rotary_emb,
680+
joint_attention_kwargs=joint_attention_kwargs,
681+
)
682+
return h, c
683+
684+
# ============================================================================
685+
# DEFINE POSTPROCESSING (Flux-specific)
686+
# ============================================================================
687+
def postprocess(h):
688+
h = module.norm_out(h, temb)
689+
h = module.proj_out(h)
690+
return Transformer2DModelOutput(sample=h)
691+
692+
# ============================================================================
693+
# RETURN CONTEXT
694+
# ============================================================================
695+
return CacheContext(
696+
modulated_input=modulated_input,
697+
hidden_states=hidden_states,
698+
encoder_hidden_states=encoder_hidden_states,
699+
temb=temb,
700+
run_transformer_blocks=run_flux_transformer_blocks,
701+
postprocess=postprocess,
702+
extra_states={
703+
"run_flux_full_transformer_with_single": run_flux_full_transformer_with_single,
704+
},
705+
)
706+
707+
569708
# Registry for model-specific extractors
570709
# Key: Transformer class name
571710
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -576,6 +715,7 @@ def postprocess(h):
576715
"QwenImageTransformer2DModel": extract_qwen_context,
577716
"Bagel": extract_bagel_context,
578717
"ZImageTransformer2DModel": extract_zimage_context,
718+
"FluxTransformer2DModel": extract_flux_context,
579719
# Future models:
580720
# "FluxTransformer2DModel": extract_flux_context,
581721
# "CogVideoXTransformer3DModel": extract_cogvideox_context,

vllm_omni/diffusion/cache/teacache/hook.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,26 @@ def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any
157157
ctx.encoder_hidden_states.clone() if ctx.encoder_hidden_states is not None else None
158158
)
159159

160-
# Run transformer blocks using model-specific callable
161-
outputs = ctx.run_transformer_blocks()
162-
163-
# Update context with outputs
164-
ctx.hidden_states = outputs[0]
165-
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
166-
ctx.encoder_hidden_states = outputs[1]
167-
168-
# Cache residuals for next timestep
169-
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
170-
if ori_encoder_hidden_states is not None:
171-
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()
172-
173-
output = ctx.hidden_states
160+
# Handle models with additional blocks
161+
if getattr(ctx, "extra_states", None) and "run_flux_full_transformer_with_single" in ctx.extra_states:
162+
run_full = ctx.extra_states["run_flux_full_transformer_with_single"]
163+
ctx.hidden_states, ctx.encoder_hidden_states = run_full(ori_hidden_states, ori_encoder_hidden_states)
164+
output = ctx.hidden_states
165+
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
166+
else:
167+
# Run transformer blocks using model-specific callable
168+
outputs = ctx.run_transformer_blocks()
169+
# Update context with outputs
170+
ctx.hidden_states = outputs[0]
171+
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
172+
ctx.encoder_hidden_states = outputs[1]
173+
174+
output = ctx.hidden_states
175+
176+
# Cache residuals for next timestep
177+
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
178+
if ori_encoder_hidden_states is not None:
179+
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()
174180

175181
# Update state
176182
state.previous_modulated_input = ctx.modulated_input.detach()

0 commit comments

Comments
 (0)