Skip to content

Commit 979fe87

Browse files
committed
[Feat] support TeaCache for Flux2 klein
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 6c19f3e commit 979fe87

File tree

4 files changed

+196
-15
lines changed

4 files changed

+196
-15
lines changed

vllm_omni/diffusion/cache/teacache/backend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,29 @@ def forward_alias(self, *args, **kwargs):
4848
)
4949

5050

51-
CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache}
51+
def enable_flux2_klein_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None:
52+
"""
53+
Enable TeaCache for Flux2 Klein model.
54+
"""
55+
teacache_config = TeaCacheConfig(
56+
transformer_type="Flux2Klein",
57+
rel_l1_thresh=config.rel_l1_thresh,
58+
coefficients=config.coefficients,
59+
)
60+
transformer = pipeline.transformer
61+
62+
apply_teacache_hook(transformer, teacache_config)
63+
64+
logger.info(
65+
f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, "
66+
f"transformer_class={teacache_config.transformer_type}"
67+
)
68+
69+
70+
CUSTOM_TEACACHE_ENABLERS = {
71+
"BagelPipeline": enable_bagel_teacache,
72+
"Flux2KleinPipeline": enable_flux2_klein_teacache,
73+
}
5274

5375

5476
class TeaCacheBackend(CacheBackend):

vllm_omni/diffusion/cache/teacache/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
-3.82021401e00,
1616
2.64230861e-01,
1717
],
18+
# Flux2 Klein transformer coefficients
19+
# Same as FLUX.1 (similar dual-stream architecture)
20+
"Flux2Klein": [
21+
4.98651651e02,
22+
-2.83781631e02,
23+
5.58554382e01,
24+
-3.82021401e00,
25+
2.64230861e-01,
26+
],
1827
# Qwen-Image transformer coefficients from ComfyUI-TeaCache
1928
# Tuned specifically for Qwen's dual-stream transformer architecture
2029
# Used for all Qwen-Image Family pipelines, in general

vllm_omni/diffusion/cache/teacache/extractors.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,149 @@ def postprocess(h):
566566
)
567567

568568

569+
def extract_flux2_klein_context(
570+
module,
571+
hidden_states: torch.Tensor,
572+
encoder_hidden_states: torch.Tensor = None,
573+
timestep: torch.LongTensor = None,
574+
img_ids: torch.Tensor = None,
575+
txt_ids: torch.Tensor = None,
576+
guidance: torch.Tensor = None,
577+
joint_attention_kwargs: dict[str, Any] | None = None,
578+
**kwargs: Any,
579+
) -> CacheContext:
580+
"""
581+
Extract cache context for Flux2Klein model.
582+
583+
Only caches transformer_blocks output. single_transformer_blocks is always executed.
584+
585+
Args:
586+
module: Flux2Transformer2DModel instance
587+
hidden_states: Input image hidden states tensor
588+
encoder_hidden_states: Input text hidden states tensor
589+
timestep: Current diffusion timestep
590+
img_ids: Image position IDs for RoPE
591+
txt_ids: Text position IDs for RoPE
592+
guidance: Optional guidance scale for CFG
593+
joint_attention_kwargs: Additional attention kwargs
594+
595+
Returns:
596+
CacheContext with all information needed for generic caching
597+
"""
598+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
599+
600+
if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
601+
raise ValueError("Module must have transformer_blocks")
602+
603+
dtype = hidden_states.dtype
604+
605+
num_txt_tokens = encoder_hidden_states.shape[1]
606+
607+
timestep = timestep.to(dtype=dtype) * 1000
608+
if guidance is not None:
609+
guidance = guidance.to(dtype=dtype) * 1000
610+
611+
temb = module.time_guidance_embed(timestep, guidance)
612+
613+
double_stream_mod_img = module.double_stream_modulation_img(temb)
614+
double_stream_mod_txt = module.double_stream_modulation_txt(temb)
615+
single_stream_mod = module.single_stream_modulation(temb)[0]
616+
617+
hidden_states = module.x_embedder(hidden_states)
618+
encoder_hidden_states = module.context_embedder(encoder_hidden_states)
619+
620+
if img_ids.ndim == 3:
621+
img_ids = img_ids[0]
622+
if txt_ids.ndim == 3:
623+
txt_ids = txt_ids[0]
624+
625+
image_rotary_emb = module.pos_embed(img_ids)
626+
text_rotary_emb = module.pos_embed(txt_ids)
627+
concat_rotary_emb = (
628+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
629+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
630+
)
631+
632+
block = module.transformer_blocks[0]
633+
634+
norm_hidden_states = block.norm1(hidden_states)
635+
norm_hidden_states = (1 + double_stream_mod_img[0][0]) * norm_hidden_states + double_stream_mod_img[0][1]
636+
637+
modulated_input = norm_hidden_states
638+
639+
def run_flux2_transformer_blocks():
640+
h = hidden_states
641+
c = encoder_hidden_states
642+
for block in module.transformer_blocks:
643+
c, h = block(
644+
hidden_states=h,
645+
encoder_hidden_states=c,
646+
temb_mod_params_img=double_stream_mod_img,
647+
temb_mod_params_txt=double_stream_mod_txt,
648+
image_rotary_emb=concat_rotary_emb,
649+
joint_attention_kwargs=joint_attention_kwargs,
650+
)
651+
return (h, c)
652+
653+
def run_flux2_single_transformer_blocks(c, h):
654+
h_concat = torch.cat([c, h], dim=1)
655+
for block in module.single_transformer_blocks:
656+
h_concat = block(
657+
hidden_states=h_concat,
658+
encoder_hidden_states=None,
659+
temb_mod_params=single_stream_mod,
660+
image_rotary_emb=concat_rotary_emb,
661+
joint_attention_kwargs=joint_attention_kwargs,
662+
)
663+
return h_concat[:, num_txt_tokens:, ...]
664+
665+
def run_flux2_full_transformer_with_single(ori_h, ori_c):
666+
h = ori_h
667+
c = ori_c
668+
for block in module.transformer_blocks:
669+
c, h = block(
670+
hidden_states=h,
671+
encoder_hidden_states=c,
672+
temb_mod_params_img=double_stream_mod_img,
673+
temb_mod_params_txt=double_stream_mod_txt,
674+
image_rotary_emb=concat_rotary_emb,
675+
joint_attention_kwargs=joint_attention_kwargs,
676+
)
677+
h_concat = torch.cat([c, h], dim=1)
678+
for block in module.single_transformer_blocks:
679+
h_concat = block(
680+
hidden_states=h_concat,
681+
encoder_hidden_states=None,
682+
temb_mod_params=single_stream_mod,
683+
image_rotary_emb=concat_rotary_emb,
684+
joint_attention_kwargs=joint_attention_kwargs,
685+
)
686+
final_hidden_states = h_concat[:, num_txt_tokens:, ...]
687+
return final_hidden_states, c
688+
689+
return_dict = kwargs.get("return_dict", True)
690+
691+
def postprocess(h):
692+
h = module.norm_out(h, temb)
693+
h = module.proj_out(h)
694+
if not return_dict:
695+
return (h,)
696+
return Transformer2DModelOutput(sample=h)
697+
698+
return CacheContext(
699+
modulated_input=modulated_input,
700+
hidden_states=hidden_states,
701+
encoder_hidden_states=encoder_hidden_states,
702+
temb=temb,
703+
run_transformer_blocks=run_flux2_transformer_blocks,
704+
postprocess=postprocess,
705+
extra_states={
706+
"run_flux2_single_transformer_blocks": run_flux2_single_transformer_blocks,
707+
"run_flux2_full_transformer_with_single": run_flux2_full_transformer_with_single,
708+
},
709+
)
710+
711+
569712
# Registry for model-specific extractors
570713
# Key: Transformer class name
571714
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -576,6 +719,7 @@ def postprocess(h):
576719
"QwenImageTransformer2DModel": extract_qwen_context,
577720
"Bagel": extract_bagel_context,
578721
"ZImageTransformer2DModel": extract_zimage_context,
722+
"Flux2Klein": extract_flux2_klein_context,
579723
# Future models:
580724
# "FluxTransformer2DModel": extract_flux_context,
581725
# "CogVideoXTransformer3DModel": extract_cogvideox_context,

vllm_omni/diffusion/cache/teacache/hook.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,26 @@ def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any
157157
ctx.encoder_hidden_states.clone() if ctx.encoder_hidden_states is not None else None
158158
)
159159

160-
# Run transformer blocks using model-specific callable
161-
outputs = ctx.run_transformer_blocks()
162-
163-
# Update context with outputs
164-
ctx.hidden_states = outputs[0]
165-
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
166-
ctx.encoder_hidden_states = outputs[1]
167-
168-
# Cache residuals for next timestep
169-
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
170-
if ori_encoder_hidden_states is not None:
171-
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()
172-
173-
output = ctx.hidden_states
160+
# Handle models with additional blocks (e.g., Flux2 single_transformer_blocks)
161+
if getattr(ctx, "extra_states", None) and "run_flux2_full_transformer_with_single" in ctx.extra_states:
162+
run_full = ctx.extra_states["run_flux2_full_transformer_with_single"]
163+
ctx.hidden_states, ctx.encoder_hidden_states = run_full(ori_hidden_states, ori_encoder_hidden_states)
164+
output = ctx.hidden_states
165+
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
166+
else:
167+
# Run transformer blocks using model-specific callable
168+
outputs = ctx.run_transformer_blocks()
169+
# Update context with outputs
170+
ctx.hidden_states = outputs[0]
171+
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
172+
ctx.encoder_hidden_states = outputs[1]
173+
174+
output = ctx.hidden_states
175+
176+
# Cache residuals for next timestep
177+
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
178+
if ori_encoder_hidden_states is not None:
179+
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()
174180

175181
# Update state
176182
state.previous_modulated_input = ctx.modulated_input.detach()

0 commit comments

Comments
 (0)