Skip to content

Commit 1cb433d

Browse files
committed
upd
1 parent 979fe87 commit 1cb433d

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

vllm_omni/diffusion/cache/teacache/extractors.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -567,20 +567,21 @@ def postprocess(h):
567567

568568

569569
def extract_flux2_klein_context(
570-
module,
570+
module: nn.Module,
571571
hidden_states: torch.Tensor,
572-
encoder_hidden_states: torch.Tensor = None,
572+
encoder_hidden_states: torch.Tensor | None = None,
573573
timestep: torch.LongTensor = None,
574574
img_ids: torch.Tensor = None,
575575
txt_ids: torch.Tensor = None,
576-
guidance: torch.Tensor = None,
576+
guidance: torch.Tensor | None = None,
577577
joint_attention_kwargs: dict[str, Any] | None = None,
578578
**kwargs: Any,
579579
) -> CacheContext:
580580
"""
581581
Extract cache context for Flux2Klein model.
582582
583-
Only caches transformer_blocks output. single_transformer_blocks is always executed.
583+
Caches the full transformer output (including single_transformer_blocks).
584+
When cache is reused, single_transformer_blocks is skipped to achieve maximum speedup.
584585
585586
Args:
586587
module: Flux2Transformer2DModel instance
@@ -600,6 +601,9 @@ def extract_flux2_klein_context(
600601
if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
601602
raise ValueError("Module must have transformer_blocks")
602603

604+
# ============================================================================
605+
# PREPROCESSING (Flux2-specific)
606+
# ============================================================================
603607
dtype = hidden_states.dtype
604608

605609
num_txt_tokens = encoder_hidden_states.shape[1]
@@ -629,13 +633,19 @@ def extract_flux2_klein_context(
629633
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
630634
)
631635

636+
# ============================================================================
637+
# EXTRACT MODULATED INPUT (for cache decision)
638+
# ============================================================================
632639
block = module.transformer_blocks[0]
633640

634641
norm_hidden_states = block.norm1(hidden_states)
635-
norm_hidden_states = (1 + double_stream_mod_img[0][0]) * norm_hidden_states + double_stream_mod_img[0][1]
642+
norm_hidden_states = (1 + double_stream_mod_img[0][1]) * norm_hidden_states + double_stream_mod_img[0][0]
636643

637644
modulated_input = norm_hidden_states
638645

646+
# ============================================================================
647+
# DEFINE TRANSFORMER EXECUTION (Flux2-specific)
648+
# ============================================================================
639649
def run_flux2_transformer_blocks():
640650
h = hidden_states
641651
c = encoder_hidden_states
@@ -650,18 +660,6 @@ def run_flux2_transformer_blocks():
650660
)
651661
return (h, c)
652662

653-
def run_flux2_single_transformer_blocks(c, h):
654-
h_concat = torch.cat([c, h], dim=1)
655-
for block in module.single_transformer_blocks:
656-
h_concat = block(
657-
hidden_states=h_concat,
658-
encoder_hidden_states=None,
659-
temb_mod_params=single_stream_mod,
660-
image_rotary_emb=concat_rotary_emb,
661-
joint_attention_kwargs=joint_attention_kwargs,
662-
)
663-
return h_concat[:, num_txt_tokens:, ...]
664-
665663
def run_flux2_full_transformer_with_single(ori_h, ori_c):
666664
h = ori_h
667665
c = ori_c
@@ -686,6 +684,9 @@ def run_flux2_full_transformer_with_single(ori_h, ori_c):
686684
final_hidden_states = h_concat[:, num_txt_tokens:, ...]
687685
return final_hidden_states, c
688686

687+
# ============================================================================
688+
# DEFINE POSTPROCESSING (Flux2-specific)
689+
# ============================================================================
689690
return_dict = kwargs.get("return_dict", True)
690691

691692
def postprocess(h):
@@ -703,7 +704,6 @@ def postprocess(h):
703704
run_transformer_blocks=run_flux2_transformer_blocks,
704705
postprocess=postprocess,
705706
extra_states={
706-
"run_flux2_single_transformer_blocks": run_flux2_single_transformer_blocks,
707707
"run_flux2_full_transformer_with_single": run_flux2_full_transformer_with_single,
708708
},
709709
)

0 commit comments

Comments
 (0)