Skip to content

Commit 837f796

Browse files
committed
fix bug1
Signed-off-by: David Chen <530634352@qq.com>
1 parent f03a26a commit 837f796

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import os
1010
from collections.abc import Iterable
11+
from contextlib import nullcontext
1112
from typing import Any
1213

1314
import numpy as np
@@ -615,8 +616,14 @@ def interrupt(self):
615616
def _is_cfg_parallel_enabled(self, do_true_cfg: bool) -> bool:
616617
return do_true_cfg and get_classifier_free_guidance_world_size() > 1
617618

619+
def _transformer_cache_context(self, context_name: str):
620+
cache_context = getattr(self.transformer, "cache_context", None)
621+
if callable(cache_context):
622+
return cache_context(context_name)
623+
return nullcontext()
624+
618625
def _predict_noise_av(self, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
619-
with self.transformer.cache_context("cond_uncond"):
626+
with self._transformer_cache_context("cond_uncond"):
620627
noise_pred_video, noise_pred_audio = self.transformer(**kwargs)
621628
return noise_pred_video, noise_pred_audio
622629

@@ -1069,7 +1076,7 @@ def forward(
10691076

10701077
timestep = t.expand(latent_model_input.shape[0])
10711078

1072-
with self.transformer.cache_context("cond_uncond"):
1079+
with self._transformer_cache_context("cond_uncond"):
10731080
noise_pred_video, noise_pred_audio = self.transformer(
10741081
hidden_states=latent_model_input,
10751082
audio_hidden_states=audio_latent_model_input,

vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def forward(
603603
timestep = t.expand(latent_model_input.shape[0])
604604
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
605605

606-
with self.transformer.cache_context("cond_uncond"):
606+
with self._transformer_cache_context("cond_uncond"):
607607
noise_pred_video, noise_pred_audio = self.transformer(
608608
hidden_states=latent_model_input,
609609
audio_hidden_states=audio_latent_model_input,

0 commit comments

Comments
 (0)