diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 859f8c0a22..27c691f556 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -51,6 +51,7 @@ The following table shows which models are currently supported by each accelerat | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ✅ | ❌ | ❌ | ❌ | ### VideoGen diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f014d073e8..8f985b60c9 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -346,6 +346,82 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for FLUX.2-klein-4B pipeline. + + Args: + pipeline: The FLUX.2-klein-4B pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + Returns: + A refresh function that can be called with a new ``num_inference_steps`` + to update the cache context for the pipeline. + """ + # Build DBCacheConfig for transformer + db_cache_config = _build_db_cache_config(cache_config) + db_cache_config.Fn_compute_blocks = 2 + + calibrator = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) + + logger.info( + f"Enabling cache-dit on Flux transformer with BlockAdapter: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit using BlockAdapter for transformer + cache_dit.enable_cache( + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + params_modifiers=[modifier], + ) + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The FLUX.2-klein-4B pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + Fn_compute_blocks=db_cache_config.Fn_compute_blocks, + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -859,6 +935,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "Wan22I2VPipeline": enable_cache_for_wan22, "Wan22TI2VPipeline": enable_cache_for_wan22, "FluxPipeline": enable_cache_for_flux, + "Flux2KleinPipeline": enable_cache_for_flux2_klein, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3,