@@ -567,20 +567,21 @@ def postprocess(h):
567567
568568
569569def extract_flux2_klein_context (
570- module ,
570+ module : nn . Module ,
571571 hidden_states : torch .Tensor ,
572- encoder_hidden_states : torch .Tensor = None ,
572+ encoder_hidden_states : torch .Tensor | None = None ,
573573 timestep : torch .LongTensor = None ,
574574 img_ids : torch .Tensor = None ,
575575 txt_ids : torch .Tensor = None ,
576- guidance : torch .Tensor = None ,
576+ guidance : torch .Tensor | None = None ,
577577 joint_attention_kwargs : dict [str , Any ] | None = None ,
578578 ** kwargs : Any ,
579579) -> CacheContext :
580580 """
581581 Extract cache context for Flux2Klein model.
582582
583- Only caches transformer_blocks output. single_transformer_blocks is always executed.
583+ Caches the full transformer output (including single_transformer_blocks).
584+ When cache is reused, single_transformer_blocks is skipped to achieve maximum speedup.
584585
585586 Args:
586587 module: Flux2Transformer2DModel instance
@@ -600,6 +601,9 @@ def extract_flux2_klein_context(
600601 if not hasattr (module , "transformer_blocks" ) or len (module .transformer_blocks ) == 0 :
601602 raise ValueError ("Module must have transformer_blocks" )
602603
604+ # ============================================================================
605+ # PREPROCESSING (Flux2-specific)
606+ # ============================================================================
603607 dtype = hidden_states .dtype
604608
605609 num_txt_tokens = encoder_hidden_states .shape [1 ]
@@ -629,13 +633,19 @@ def extract_flux2_klein_context(
629633 torch .cat ([text_rotary_emb [1 ], image_rotary_emb [1 ]], dim = 0 ),
630634 )
631635
636+ # ============================================================================
637+ # EXTRACT MODULATED INPUT (for cache decision)
638+ # ============================================================================
632639 block = module .transformer_blocks [0 ]
633640
634641 norm_hidden_states = block .norm1 (hidden_states )
635- norm_hidden_states = (1 + double_stream_mod_img [0 ][0 ]) * norm_hidden_states + double_stream_mod_img [0 ][1 ]
642+ norm_hidden_states = (1 + double_stream_mod_img [0 ][1 ]) * norm_hidden_states + double_stream_mod_img [0 ][0 ]
636643
637644 modulated_input = norm_hidden_states
638645
646+ # ============================================================================
647+ # DEFINE TRANSFORMER EXECUTION (Flux2-specific)
648+ # ============================================================================
639649 def run_flux2_transformer_blocks ():
640650 h = hidden_states
641651 c = encoder_hidden_states
@@ -650,18 +660,6 @@ def run_flux2_transformer_blocks():
650660 )
651661 return (h , c )
652662
653- def run_flux2_single_transformer_blocks (c , h ):
654- h_concat = torch .cat ([c , h ], dim = 1 )
655- for block in module .single_transformer_blocks :
656- h_concat = block (
657- hidden_states = h_concat ,
658- encoder_hidden_states = None ,
659- temb_mod_params = single_stream_mod ,
660- image_rotary_emb = concat_rotary_emb ,
661- joint_attention_kwargs = joint_attention_kwargs ,
662- )
663- return h_concat [:, num_txt_tokens :, ...]
664-
665663 def run_flux2_full_transformer_with_single (ori_h , ori_c ):
666664 h = ori_h
667665 c = ori_c
@@ -686,6 +684,9 @@ def run_flux2_full_transformer_with_single(ori_h, ori_c):
686684 final_hidden_states = h_concat [:, num_txt_tokens :, ...]
687685 return final_hidden_states , c
688686
687+ # ============================================================================
688+ # DEFINE POSTPROCESSING (Flux2-specific)
689+ # ============================================================================
689690 return_dict = kwargs .get ("return_dict" , True )
690691
691692 def postprocess (h ):
@@ -703,7 +704,6 @@ def postprocess(h):
703704 run_transformer_blocks = run_flux2_transformer_blocks ,
704705 postprocess = postprocess ,
705706 extra_states = {
706- "run_flux2_single_transformer_blocks" : run_flux2_single_transformer_blocks ,
707707 "run_flux2_full_transformer_with_single" : run_flux2_full_transformer_with_single ,
708708 },
709709 )
0 commit comments