1919
2020import torch
2121import torch .nn as nn
22+ from diffusers .utils import is_torch_npu_available
2223
2324from vllm_omni .diffusion .forward_context import get_forward_context
2425
@@ -566,6 +567,144 @@ def postprocess(h):
566567 )
567568
568569
570+ def extract_flux_context (
571+ module : nn .Module ,
572+ hidden_states : torch .Tensor ,
573+ encoder_hidden_states : torch .Tensor = None ,
574+ pooled_projections : torch .Tensor = None ,
575+ timestep : torch .LongTensor = None ,
576+ img_ids : torch .Tensor = None ,
577+ txt_ids : torch .Tensor = None ,
578+ guidance : torch .Tensor | None = None ,
579+ joint_attention_kwargs : dict [str , Any ] | None = None ,
580+ ** kwargs : Any ,
581+ ) -> CacheContext :
582+ """
583+ Extract cache context for Flux1-dev model.
584+
585+ Only caches transformer_blocks output. single_transformer_blocks is always executed.
586+
587+ Args:
588+ module: FluxTransformer2DModel instance
589+ hidden_states: Input image hidden states tensor
590+ encoder_hidden_states: Input text hidden states tensor
591+ pooled_projections: Pooled text embeddings
592+ timestep: Current diffusion timestep
593+ img_ids: Image position IDs for RoPE
594+ txt_ids: Text position IDs for RoPE
595+ guidance: Optional guidance scale for CFG
596+ joint_attention_kwargs: Additional attention kwargs
597+
598+ Returns:
599+ CacheContext with all information needed for generic caching
600+ """
601+ from diffusers .models .modeling_outputs import Transformer2DModelOutput
602+
603+ if not hasattr (module , "transformer_blocks" ) or len (module .transformer_blocks ) == 0 :
604+ raise ValueError ("Module must have transformer_blocks" )
605+
606+ # ============================================================================
607+ # PREPROCESSING (Flux-specific)
608+ # ============================================================================
609+ dtype = hidden_states .dtype
610+ device = hidden_states .device
611+ timestep = timestep .to (device = device , dtype = dtype ) * 1000
612+ if guidance is not None :
613+ guidance = guidance .to (device = device , dtype = dtype ) * 1000
614+
615+ temb = (
616+ module .time_text_embed (timestep , pooled_projections )
617+ if guidance is None
618+ else module .time_text_embed (timestep , guidance , pooled_projections )
619+ )
620+
621+ hidden_states = module .x_embedder (hidden_states )
622+ encoder_hidden_states = module .context_embedder (encoder_hidden_states )
623+
624+ if txt_ids .ndim == 3 :
625+ txt_ids = txt_ids [0 ]
626+ if img_ids .ndim == 3 :
627+ img_ids = img_ids [0 ]
628+
629+ ids = torch .cat ((txt_ids , img_ids ), dim = 0 )
630+ if is_torch_npu_available ():
631+ freqs_cos , freqs_sin = module .pos_embed (ids .cpu ())
632+ image_rotary_emb = (freqs_cos .npu (), freqs_sin .npu ())
633+ else :
634+ image_rotary_emb = module .pos_embed (ids )
635+
636+ # ============================================================================
637+ # EXTRACT MODULATED INPUT (for cache decision)
638+ # ============================================================================
639+ block = module .transformer_blocks [0 ]
640+ norm_output = block .norm1 (hidden_states , emb = temb )
641+ if isinstance (norm_output , tuple ):
642+ norm_hidden_states = norm_output [0 ]
643+ else :
644+ norm_hidden_states = norm_output
645+ modulated_input = norm_hidden_states
646+
647+ # ============================================================================
648+ # DEFINE TRANSFORMER EXECUTION (Flux-specific)
649+ # ============================================================================
650+ def run_flux_transformer_blocks ():
651+ h = hidden_states
652+ c = encoder_hidden_states
653+ for block in module .transformer_blocks :
654+ c , h = block (
655+ hidden_states = h ,
656+ encoder_hidden_states = c ,
657+ temb = temb ,
658+ image_rotary_emb = image_rotary_emb ,
659+ joint_attention_kwargs = joint_attention_kwargs ,
660+ )
661+ return (h , c )
662+
663+ def run_flux_full_transformer_with_single (ori_h , ori_c ):
664+ h = ori_h
665+ c = ori_c
666+ for block in module .transformer_blocks :
667+ c , h = block (
668+ hidden_states = h ,
669+ encoder_hidden_states = c ,
670+ temb = temb ,
671+ image_rotary_emb = image_rotary_emb ,
672+ joint_attention_kwargs = joint_attention_kwargs ,
673+ )
674+ for block in module .single_transformer_blocks :
675+ c , h = block (
676+ hidden_states = h ,
677+ encoder_hidden_states = c ,
678+ temb = temb ,
679+ image_rotary_emb = image_rotary_emb ,
680+ joint_attention_kwargs = joint_attention_kwargs ,
681+ )
682+ return h , c
683+
684+ # ============================================================================
685+ # DEFINE POSTPROCESSING (Flux-specific)
686+ # ============================================================================
687+ def postprocess (h ):
688+ h = module .norm_out (h , temb )
689+ h = module .proj_out (h )
690+ return Transformer2DModelOutput (sample = h )
691+
692+ # ============================================================================
693+ # RETURN CONTEXT
694+ # ============================================================================
695+ return CacheContext (
696+ modulated_input = modulated_input ,
697+ hidden_states = hidden_states ,
698+ encoder_hidden_states = encoder_hidden_states ,
699+ temb = temb ,
700+ run_transformer_blocks = run_flux_transformer_blocks ,
701+ postprocess = postprocess ,
702+ extra_states = {
703+ "run_flux_full_transformer_with_single" : run_flux_full_transformer_with_single ,
704+ },
705+ )
706+
707+
569708# Registry for model-specific extractors
570709# Key: Transformer class name
571710# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -576,6 +715,7 @@ def postprocess(h):
576715 "QwenImageTransformer2DModel" : extract_qwen_context ,
577716 "Bagel" : extract_bagel_context ,
578717 "ZImageTransformer2DModel" : extract_zimage_context ,
718+ "FluxTransformer2DModel" : extract_flux_context ,
579719 # Future models:
580720 # "FluxTransformer2DModel": extract_flux_context,
581721 # "CogVideoXTransformer3DModel": extract_cogvideox_context,
0 commit comments