From 656f2b142f63b0f3ca22f129d950e400751e8823 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Thu, 5 Feb 2026 10:20:15 +0800 Subject: [PATCH 1/3] [feature]: support flux2.klein cache_dit Signed-off-by: wuzhongjian --- docs/user_guide/diffusion_acceleration.md | 1 + .../diffusion/cache/cache_dit_backend.py | 71 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 859f8c0a22..27c691f556 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -51,6 +51,7 @@ The following table shows which models are currently supported by each accelerat | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ✅ | ❌ | ❌ | ❌ | ### VideoGen diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f014d073e8..d71e3a46de 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -345,6 +345,76 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for FLUX.2-klein-4B pipeline. + + Args: + pipeline: The FLUX.2-klein-4B pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + """ + # Build DBCacheConfig for transformer + db_cache_config = _build_db_cache_config(cache_config) + db_cache_config.Fn_compute_blocks = 2 + + calibrator = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) + + logger.info( + f"Enabling cache-dit on Flux transformer with BlockAdapter: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit using BlockAdapter for transformer + cache_dit.enable_cache( + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + params_modifiers=[modifier], + ) + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The FLUX.2-klein-4B pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -859,6 +929,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "Wan22I2VPipeline": enable_cache_for_wan22, "Wan22TI2VPipeline": enable_cache_for_wan22, "FluxPipeline": enable_cache_for_flux, + "Flux2KleinPipeline": enable_cache_for_flux2_klein, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3, From 6e04607e73cf6fa9b521ef9d725379cd7e25f6a2 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Thu, 5 Feb 2026 10:41:24 +0800 Subject: [PATCH 2/3] [feature]: support flux2.klein cache_dit Signed-off-by: wuzhongjian --- vllm_omni/diffusion/cache/cache_dit_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index d71e3a46de..6822bbcb75 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -345,6 +345,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context + def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for FLUX.2-klein-4B pipeline. @@ -409,6 +410,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_inference_steps, ), + Fn_compute_blocks=db_cache_config.Fn_compute_blocks, steps_computation_policy=cache_config.scm_steps_policy, ), verbose=verbose, @@ -416,6 +418,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. From b6268560625ec7451f345c5460c7327e1c5b68fe Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Fri, 6 Feb 2026 14:26:40 +0800 Subject: [PATCH 3/3] [feature]: support flux2.klein cache_dit Signed-off-by: wuzhongjian --- vllm_omni/diffusion/cache/cache_dit_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 6822bbcb75..8f985b60c9 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -352,6 +352,9 @@ def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[ Args: pipeline: The FLUX.2-klein-4B pipeline instance. cache_config: DiffusionCacheConfig instance with cache configuration. + Returns: + A refresh function that can be called with a new ``num_inference_steps`` + to update the cache context for the pipeline. """ # Build DBCacheConfig for transformer db_cache_config = _build_db_cache_config(cache_config)