diff --git a/vllm_omni/diffusion/cache/teacache/backend.py b/vllm_omni/diffusion/cache/teacache/backend.py index bf328d43d8..132dcce0d3 100644 --- a/vllm_omni/diffusion/cache/teacache/backend.py +++ b/vllm_omni/diffusion/cache/teacache/backend.py @@ -48,7 +48,29 @@ def forward_alias(self, *args, **kwargs): ) -CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache} +def enable_flux2_klein_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None: + """ + Enable TeaCache for Flux2 Klein model. + """ + teacache_config = TeaCacheConfig( + transformer_type="Flux2Klein", + rel_l1_thresh=config.rel_l1_thresh, + coefficients=config.coefficients, + ) + transformer = pipeline.transformer + + apply_teacache_hook(transformer, teacache_config) + + logger.info( + f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, " + f"transformer_class={teacache_config.transformer_type}" + ) + + +CUSTOM_TEACACHE_ENABLERS = { + "BagelPipeline": enable_bagel_teacache, + "Flux2KleinPipeline": enable_flux2_klein_teacache, +} class TeaCacheBackend(CacheBackend): diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index 30b1745f47..9fc7dc5b18 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -15,6 +15,15 @@ -3.82021401e00, 2.64230861e-01, ], + # Flux2 Klein transformer coefficients + # Same as FLUX.1 (similar dual-stream architecture) + "Flux2Klein": [ + 4.98651651e02, + -2.83781631e02, + 5.58554382e01, + -3.82021401e00, + 2.64230861e-01, + ], # Qwen-Image transformer coefficients from ComfyUI-TeaCache # Tuned specifically for Qwen's dual-stream transformer architecture # Used for all Qwen-Image Family pipelines, in general diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 7802979191..c1475fc1ac 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -566,6 +566,149 @@ def postprocess(h): ) +def extract_flux2_klein_context( + module: nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + **kwargs: Any, +) -> CacheContext: + """ + Extract cache context for Flux2Klein model. + + Caches the full transformer output (including single_transformer_blocks). + When cache is reused, single_transformer_blocks is skipped to achieve maximum speedup. + + Args: + module: Flux2Transformer2DModel instance + hidden_states: Input image hidden states tensor + encoder_hidden_states: Input text hidden states tensor + timestep: Current diffusion timestep + img_ids: Image position IDs for RoPE + txt_ids: Text position IDs for RoPE + guidance: Optional guidance scale for CFG + joint_attention_kwargs: Additional attention kwargs + + Returns: + CacheContext with all information needed for generic caching + """ + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError("Module must have transformer_blocks") + + # ============================================================================ + # PREPROCESSING (Flux2-specific) + # ============================================================================ + dtype = hidden_states.dtype + + num_txt_tokens = encoder_hidden_states.shape[1] + + timestep = timestep.to(dtype=dtype) * 1000 + if guidance is not None: + guidance = guidance.to(dtype=dtype) * 1000 + + temb = module.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = module.double_stream_modulation_img(temb) + double_stream_mod_txt = module.double_stream_modulation_txt(temb) + single_stream_mod = module.single_stream_modulation(temb)[0] + + hidden_states = module.x_embedder(hidden_states) + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + image_rotary_emb = module.pos_embed(img_ids) + text_rotary_emb = module.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # ============================================================================ + # EXTRACT MODULATED INPUT (for cache decision) + # ============================================================================ + block = module.transformer_blocks[0] + + norm_hidden_states = block.norm1(hidden_states) + norm_hidden_states = (1 + double_stream_mod_img[0][1]) * norm_hidden_states + double_stream_mod_img[0][0] + + modulated_input = norm_hidden_states + + # ============================================================================ + # DEFINE TRANSFORMER EXECUTION (Flux2-specific) + # ============================================================================ + def run_flux2_transformer_blocks(): + h = hidden_states + c = encoder_hidden_states + for block in module.transformer_blocks: + c, h = block( + hidden_states=h, + encoder_hidden_states=c, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + return (h, c) + + def run_flux2_full_transformer_with_single(ori_h, ori_c): + h = ori_h + c = ori_c + for block in module.transformer_blocks: + c, h = block( + hidden_states=h, + encoder_hidden_states=c, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + h_concat = torch.cat([c, h], dim=1) + for block in module.single_transformer_blocks: + h_concat = block( + hidden_states=h_concat, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + final_hidden_states = h_concat[:, num_txt_tokens:, ...] + return final_hidden_states, c + + # ============================================================================ + # DEFINE POSTPROCESSING (Flux2-specific) + # ============================================================================ + return_dict = kwargs.get("return_dict", True) + + def postprocess(h): + h = module.norm_out(h, temb) + h = module.proj_out(h) + if not return_dict: + return (h,) + return Transformer2DModelOutput(sample=h) + + return CacheContext( + modulated_input=modulated_input, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + run_transformer_blocks=run_flux2_transformer_blocks, + postprocess=postprocess, + extra_states={ + "run_flux2_full_transformer_with_single": run_flux2_full_transformer_with_single, + }, + ) + + # Registry for model-specific extractors # Key: Transformer class name # Value: extractor function with signature (module, *args, **kwargs) -> CacheContext @@ -576,6 +719,7 @@ def postprocess(h): "QwenImageTransformer2DModel": extract_qwen_context, "Bagel": extract_bagel_context, "ZImageTransformer2DModel": extract_zimage_context, + "Flux2Klein": extract_flux2_klein_context, # Future models: # "FluxTransformer2DModel": extract_flux_context, # "CogVideoXTransformer3DModel": extract_cogvideox_context, diff --git a/vllm_omni/diffusion/cache/teacache/hook.py b/vllm_omni/diffusion/cache/teacache/hook.py index 65f764c43b..d38148366e 100644 --- a/vllm_omni/diffusion/cache/teacache/hook.py +++ b/vllm_omni/diffusion/cache/teacache/hook.py @@ -157,20 +157,26 @@ def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any ctx.encoder_hidden_states.clone() if ctx.encoder_hidden_states is not None else None ) - # Run transformer blocks using model-specific callable - outputs = ctx.run_transformer_blocks() - - # Update context with outputs - ctx.hidden_states = outputs[0] - if len(outputs) > 1 and ctx.encoder_hidden_states is not None: - ctx.encoder_hidden_states = outputs[1] - - # Cache residuals for next timestep - state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach() - if ori_encoder_hidden_states is not None: - state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach() - - output = ctx.hidden_states + # Handle models with additional blocks (e.g., Flux2 single_transformer_blocks) + if getattr(ctx, "extra_states", None) and "run_flux2_full_transformer_with_single" in ctx.extra_states: + run_full = ctx.extra_states["run_flux2_full_transformer_with_single"] + ctx.hidden_states, ctx.encoder_hidden_states = run_full(ori_hidden_states, ori_encoder_hidden_states) + output = ctx.hidden_states + state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach() + else: + # Run transformer blocks using model-specific callable + outputs = ctx.run_transformer_blocks() + # Update context with outputs + ctx.hidden_states = outputs[0] + if len(outputs) > 1 and ctx.encoder_hidden_states is not None: + ctx.encoder_hidden_states = outputs[1] + + output = ctx.hidden_states + + # Cache residuals for next timestep + state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach() + if ori_encoder_hidden_states is not None: + state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach() # Update state state.previous_modulated_input = ctx.modulated_input.detach()