2323
2424from ...configuration_utils import ConfigMixin , register_to_config
2525from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
26- from ...utils import apply_lora_scale , deprecate , logging
26+ from ...utils import apply_lora_scale , logging
2727from ...utils .torch_utils import lru_cache_unless_export , maybe_allow_in_graph
2828from .._modeling_parallel import ContextParallelInput , ContextParallelOutput
2929from ..attention import AttentionMixin , FeedForward
@@ -241,38 +241,21 @@ def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.T
241241 def forward (
242242 self ,
243243 video_fhw : tuple [int , int , int , list [tuple [int , int , int ]]],
244- txt_seq_lens : list [int ] | None = None ,
245244 device : torch .device = None ,
246245 max_txt_seq_len : int | torch .Tensor | None = None ,
247246 ) -> tuple [torch .Tensor , torch .Tensor ]:
248247 """
249248 Args:
250249 video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`):
251250 A list of 3 integers [frame, height, width] representing the shape of the video.
252- txt_seq_lens (`list[int]`, *optional*, **Deprecated**):
253- Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
254251 device: (`torch.device`, *optional*):
255252 The device on which to perform the RoPE computation.
256253 max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
257254 The maximum text sequence length for RoPE computation. This should match the encoder hidden states
258255 sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
259256 """
260- # Handle deprecated txt_seq_lens parameter
261- if txt_seq_lens is not None :
262- deprecate (
263- "txt_seq_lens" ,
264- "0.39.0" ,
265- "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
266- "Please use `max_txt_seq_len` instead. "
267- "The new parameter accepts a single int or tensor value representing the maximum text sequence length." ,
268- standard_warn = False ,
269- )
270- if max_txt_seq_len is None :
271- # Use max of txt_seq_lens for backward compatibility
272- max_txt_seq_len = max (txt_seq_lens ) if isinstance (txt_seq_lens , list ) else txt_seq_lens
273-
274257 if max_txt_seq_len is None :
275- raise ValueError ("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided." )
258+ raise ValueError ("`max_txt_seq_len` must be provided." )
276259
277260 # Validate batch inference with variable-sized images
278261 if isinstance (video_fhw , list ) and len (video_fhw ) > 1 :
@@ -855,7 +838,6 @@ def forward(
855838 encoder_hidden_states_mask : torch .Tensor = None ,
856839 timestep : torch .LongTensor = None ,
857840 img_shapes : list [tuple [int , int , int ]] | None = None ,
858- txt_seq_lens : list [int ] | None = None ,
859841 guidance : torch .Tensor = None , # TODO: this should probably be removed
860842 attention_kwargs : dict [str , Any ] | None = None ,
861843 controlnet_block_samples = None ,
@@ -878,9 +860,6 @@ def forward(
878860 Used to indicate denoising step.
879861 img_shapes (`list[tuple[int, int, int]]`, *optional*):
880862 Image shapes for RoPE computation.
881- txt_seq_lens (`list[int]`, *optional*, **Deprecated**):
882- Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
883- used to compute RoPE sequence length.
884863 guidance (`torch.Tensor`, *optional*):
885864 Guidance tensor for conditional generation.
886865 attention_kwargs (`dict`, *optional*):
@@ -897,16 +876,6 @@ def forward(
897876 If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
898877 `tuple` where the first element is the sample tensor.
899878 """
900- if txt_seq_lens is not None :
901- deprecate (
902- "txt_seq_lens" ,
903- "0.39.0" ,
904- "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
905- "Please use `encoder_hidden_states_mask` instead. "
906- "The mask-based approach is more flexible and supports variable-length sequences." ,
907- standard_warn = False ,
908- )
909-
910879 hidden_states = self .img_in (hidden_states )
911880
912881 timestep = timestep .to (hidden_states .dtype )
0 commit comments