Skip to content

Conversation

@Bluear7878
Copy link
Contributor

Feat: Implement FBCache (First Block Cache) Optimization for SDXL Pipeline

Overview

This PR introduces the FBCache (First Block Cache) optimization to the Stable Diffusion XL (SDXL) inference pipeline, accelerating inference by skipping redundant computations.

How FBCache Works

FBCache uses the U-Net's first down-sampling block (First Down Block) as this proxy.
The implementation of this forward pass references the design of the Hugging Face diffusers library to maintain consistency with its pipeline architecture.

Performance Improvement (Acceleration)

With this optimization, inference speed for a 50-step SDXL generation on an NVIDIA A5000 GPU was significantly improved.

  • Before (No Cache): $\approx$ 7 seconds

    • Attached Image:no cache(7s) sdxl
  • After (FBCache): $\approx$ 2 seconds

    • Attached Image: cache(2s) sdxl

This represents an approximate 3.5x inference speedup. This result demonstrates that the cache effectively skips the majority of U-Net computations in the later timesteps. The attached images show identical quality output, confirming the optimization's effectiveness.

@tonera
Copy link

tonera commented Dec 16, 2025

First: This is very useful and awesome, but I have a few suggestions for improvement:

That said, I ran into a few problems that are easy to miss during development but can bite later:

  1. Global side-effect / class pollution
  • The original apply_cache_on_pipe() approach patched pipe.__class__.__call__, which affects all instances of that pipeline class in the same process.
  • Even if you usually only run one pipeline instance, this is a hidden coupling and makes debugging much harder if a second pipeline exists.

Fix: make the patch instance-isolated by creating a per-instance subclass and patching only that subclass’ __call__.

def apply_cache_on_pipe(pipe: DiffusionPipeline, *, residual_diff_threshold=0.12, verbose=False):
    # Wrap pipeline __call__ with cache context (instance-isolated).
    # NOTE: special methods like __call__ are resolved on the *type*, not the instance dict,
    # so to avoid patching the shared class, we create a per-instance subclass and patch that.
    if not getattr(pipe, "_fbcache_call_isolated", False):
        base_cls = pipe.__class__
        original_call = base_cls.__call__

        patched_cls = type(f"{base_cls.__name__}FBCachePatched_{id(pipe)}", (base_cls,), {})

        @functools.wraps(original_call)
        def new_call(self, *args, **kwargs):
            with cache_context(create_cache_context()):
                return original_call(self, *args, **kwargs)

        patched_cls.__call__ = new_call
        pipe.__class__ = patched_cls
        pipe._fbcache_call_isolated = True

    pipe._is_cached = True
    apply_cache_on_unet(pipe.unet, residual_diff_threshold=residual_diff_threshold, verbose=verbose)
    return pipe
  1. down_intrablock_additional_residuals is often a tuple, but the code mutated it with .pop(0)
  • Diffusers frequently passes tuples here.
  • Calling .pop(0) on a tuple will crash on T2I-Adapter / some residual paths.

Fix: normalize to a list early (only when adapter mode is active), then safely pop.

is_adapter = down_intrablock_additional_residuals is not None
# diffusers may pass tuples here; we mutate via pop(0), so ensure it's a list.
if is_adapter and not isinstance(down_intrablock_additional_residuals, list):
    down_intrablock_additional_residuals = list(down_intrablock_additional_residuals)
  1. Python-side overhead in residual accumulation
  • On cache misses, the original implementation used repeated tuple concatenation for down_block_res_samples and ControlNet residual merges, which creates lots of temporary tuples.
  • This doesn’t change correctness, but it wastes CPU time and allocations.

Fix: use a list + extend() and convert to tuple only at the call boundary where Diffusers expects tuples.

down_block_res_samples_base = [sample]
# ...
if not can_use_cache:
    # ...
    down_block_res_samples = list(down_block_res_samples_base)
    down_block_res_samples.extend(first_block_res_samples)
    # ...
    down_block_res_samples.extend(res_samples)

    if is_controlnet:
        down_block_res_samples = [r + a for r, a in zip(down_block_res_samples, down_block_additional_residuals)]

    # ...
    res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
    down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
    res_samples_tuple = tuple(res_samples)
    # pass res_hidden_states_tuple=res_samples_tuple
  1. Robustness in cache-hit path
  • If the cache says “hit” but the final_output buffer isn’t available, returning None would break downstream.
  • Device/dtype mismatches in cached tensors can also crash inside normalization/linear layers.

Fix: add a safe fallback to “miss” if final_output is missing, and enforce device/dtype alignment when returning cached output.

if can_use_cache:
    # ...
    sample = get_buffer("final_output")
    # Safety: if cache buffer isn't ready, fall back to full compute.
    if sample is None:
        can_use_cache = False
        if self.verbose:
            print("[SDXL] Cache buffer missing; fallback to cache-miss path.")
    if can_use_cache and isinstance(sample, torch.Tensor):
        if sample.device != target_device:
            sample = sample.to(device=target_device)
        if sample.dtype != target_dtype:
            sample = sample.to(dtype=target_dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants