Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ The following table shows which models are currently supported by each accelerat
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ |
| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ❌ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ✅ | ❌ | ❌ | ❌ |

### VideoGen

Expand Down
71 changes: 71 additions & 0 deletions vllm_omni/diffusion/cache/cache_dit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,76 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool

return refresh_cache_context

def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for FLUX.2-klein-4B pipeline.

Args:
pipeline: The FLUX.2-klein-4B pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
"""
# Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
db_cache_config.Fn_compute_blocks = 2

calibrator = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

# Build ParamsModifier for transformer
modifier = ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator,
)

logger.info(
f"Enabling cache-dit on Flux transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)

# Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
blocks=[
pipeline.transformer.transformer_blocks,
pipeline.transformer.single_transformer_blocks,
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
params_modifiers=[modifier],
)
),
cache_config=db_cache_config,
)

def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.

Args:
pipeline: The FLUX.2-klein-4B pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)

return refresh_cache_context

def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.
Expand Down Expand Up @@ -859,6 +929,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"Wan22I2VPipeline": enable_cache_for_wan22,
"Wan22TI2VPipeline": enable_cache_for_wan22,
"FluxPipeline": enable_cache_for_flux,
"Flux2KleinPipeline": enable_cache_for_flux2_klein,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
Expand Down
Loading