Skip to content

Commit ba28383

Browse files
author
宣源
committed
fix multi gpu
1 parent ecc7597 commit ba28383

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

models/qwenimage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ def forward(
572572
encoder_hidden_states_mask: torch.Tensor,
573573
temb: torch.Tensor,
574574
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
575+
transformer_options: dict = {},
575576
) -> Tuple[torch.Tensor, torch.Tensor]:
576577
"""
577578
Forward pass for the transformer block.
@@ -625,6 +626,7 @@ def forward(
625626
encoder_hidden_states=txt_modulated, # Text stream ("context")
626627
encoder_hidden_states_mask=encoder_hidden_states_mask,
627628
image_rotary_emb=image_rotary_emb,
629+
transformer_options=transformer_options,
628630
)
629631

630632
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided

0 commit comments

Comments
 (0)