@@ -566,6 +566,149 @@ def postprocess(h):
566566 )
567567
568568
569+ def extract_flux2_klein_context (
570+ module ,
571+ hidden_states : torch .Tensor ,
572+ encoder_hidden_states : torch .Tensor = None ,
573+ timestep : torch .LongTensor = None ,
574+ img_ids : torch .Tensor = None ,
575+ txt_ids : torch .Tensor = None ,
576+ guidance : torch .Tensor = None ,
577+ joint_attention_kwargs : dict [str , Any ] | None = None ,
578+ ** kwargs : Any ,
579+ ) -> CacheContext :
580+ """
581+ Extract cache context for Flux2Klein model.
582+
583+ Only caches transformer_blocks output. single_transformer_blocks is always executed.
584+
585+ Args:
586+ module: Flux2Transformer2DModel instance
587+ hidden_states: Input image hidden states tensor
588+ encoder_hidden_states: Input text hidden states tensor
589+ timestep: Current diffusion timestep
590+ img_ids: Image position IDs for RoPE
591+ txt_ids: Text position IDs for RoPE
592+ guidance: Optional guidance scale for CFG
593+ joint_attention_kwargs: Additional attention kwargs
594+
595+ Returns:
596+ CacheContext with all information needed for generic caching
597+ """
598+ from diffusers .models .modeling_outputs import Transformer2DModelOutput
599+
600+ if not hasattr (module , "transformer_blocks" ) or len (module .transformer_blocks ) == 0 :
601+ raise ValueError ("Module must have transformer_blocks" )
602+
603+ dtype = hidden_states .dtype
604+
605+ num_txt_tokens = encoder_hidden_states .shape [1 ]
606+
607+ timestep = timestep .to (dtype = dtype ) * 1000
608+ if guidance is not None :
609+ guidance = guidance .to (dtype = dtype ) * 1000
610+
611+ temb = module .time_guidance_embed (timestep , guidance )
612+
613+ double_stream_mod_img = module .double_stream_modulation_img (temb )
614+ double_stream_mod_txt = module .double_stream_modulation_txt (temb )
615+ single_stream_mod = module .single_stream_modulation (temb )[0 ]
616+
617+ hidden_states = module .x_embedder (hidden_states )
618+ encoder_hidden_states = module .context_embedder (encoder_hidden_states )
619+
620+ if img_ids .ndim == 3 :
621+ img_ids = img_ids [0 ]
622+ if txt_ids .ndim == 3 :
623+ txt_ids = txt_ids [0 ]
624+
625+ image_rotary_emb = module .pos_embed (img_ids )
626+ text_rotary_emb = module .pos_embed (txt_ids )
627+ concat_rotary_emb = (
628+ torch .cat ([text_rotary_emb [0 ], image_rotary_emb [0 ]], dim = 0 ),
629+ torch .cat ([text_rotary_emb [1 ], image_rotary_emb [1 ]], dim = 0 ),
630+ )
631+
632+ block = module .transformer_blocks [0 ]
633+
634+ norm_hidden_states = block .norm1 (hidden_states )
635+ norm_hidden_states = (1 + double_stream_mod_img [0 ][0 ]) * norm_hidden_states + double_stream_mod_img [0 ][1 ]
636+
637+ modulated_input = norm_hidden_states
638+
639+ def run_flux2_transformer_blocks ():
640+ h = hidden_states
641+ c = encoder_hidden_states
642+ for block in module .transformer_blocks :
643+ c , h = block (
644+ hidden_states = h ,
645+ encoder_hidden_states = c ,
646+ temb_mod_params_img = double_stream_mod_img ,
647+ temb_mod_params_txt = double_stream_mod_txt ,
648+ image_rotary_emb = concat_rotary_emb ,
649+ joint_attention_kwargs = joint_attention_kwargs ,
650+ )
651+ return (h , c )
652+
653+ def run_flux2_single_transformer_blocks (c , h ):
654+ h_concat = torch .cat ([c , h ], dim = 1 )
655+ for block in module .single_transformer_blocks :
656+ h_concat = block (
657+ hidden_states = h_concat ,
658+ encoder_hidden_states = None ,
659+ temb_mod_params = single_stream_mod ,
660+ image_rotary_emb = concat_rotary_emb ,
661+ joint_attention_kwargs = joint_attention_kwargs ,
662+ )
663+ return h_concat [:, num_txt_tokens :, ...]
664+
665+ def run_flux2_full_transformer_with_single (ori_h , ori_c ):
666+ h = ori_h
667+ c = ori_c
668+ for block in module .transformer_blocks :
669+ c , h = block (
670+ hidden_states = h ,
671+ encoder_hidden_states = c ,
672+ temb_mod_params_img = double_stream_mod_img ,
673+ temb_mod_params_txt = double_stream_mod_txt ,
674+ image_rotary_emb = concat_rotary_emb ,
675+ joint_attention_kwargs = joint_attention_kwargs ,
676+ )
677+ h_concat = torch .cat ([c , h ], dim = 1 )
678+ for block in module .single_transformer_blocks :
679+ h_concat = block (
680+ hidden_states = h_concat ,
681+ encoder_hidden_states = None ,
682+ temb_mod_params = single_stream_mod ,
683+ image_rotary_emb = concat_rotary_emb ,
684+ joint_attention_kwargs = joint_attention_kwargs ,
685+ )
686+ final_hidden_states = h_concat [:, num_txt_tokens :, ...]
687+ return final_hidden_states , c
688+
689+ return_dict = kwargs .get ("return_dict" , True )
690+
691+ def postprocess (h ):
692+ h = module .norm_out (h , temb )
693+ h = module .proj_out (h )
694+ if not return_dict :
695+ return (h ,)
696+ return Transformer2DModelOutput (sample = h )
697+
698+ return CacheContext (
699+ modulated_input = modulated_input ,
700+ hidden_states = hidden_states ,
701+ encoder_hidden_states = encoder_hidden_states ,
702+ temb = temb ,
703+ run_transformer_blocks = run_flux2_transformer_blocks ,
704+ postprocess = postprocess ,
705+ extra_states = {
706+ "run_flux2_single_transformer_blocks" : run_flux2_single_transformer_blocks ,
707+ "run_flux2_full_transformer_with_single" : run_flux2_full_transformer_with_single ,
708+ },
709+ )
710+
711+
569712# Registry for model-specific extractors
570713# Key: Transformer class name
571714# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -576,6 +719,7 @@ def postprocess(h):
576719 "QwenImageTransformer2DModel" : extract_qwen_context ,
577720 "Bagel" : extract_bagel_context ,
578721 "ZImageTransformer2DModel" : extract_zimage_context ,
722+ "Flux2Klein" : extract_flux2_klein_context ,
579723 # Future models:
580724 # "FluxTransformer2DModel": extract_flux_context,
581725 # "CogVideoXTransformer3DModel": extract_cogvideox_context,
0 commit comments