-
Notifications
You must be signed in to change notification settings - Fork 797
Open
Description
Encountering a TypeError during LTX-Video inference where FP8Linear.forward() receives 5 positional arguments but only accepts 2-4. This appears to be related to the q8_kernels integration with the diffusers pipeline.
Error Details
TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given
Full Stack Trace
/workspace/LTX-Video/env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Padded dimensions: 960x1280x81
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 55.41it/s]
/workspace/LTX-Video/env/lib/python3.11/site-packages/torch/functional.py:554: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4314.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
0%| | 0/7 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 8
5 WIDTH = 1280
6 NUM_FRAMES = 81
----> 8 infer(
9 InferenceConfig(
10 pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
11 prompt=PROMPT,
12 height=HEIGHT,
13 width=WIDTH,
14 num_frames=NUM_FRAMES,
15 output_path="/workspace/output.mp4",
16 )
17 )
File /workspace/LTX-Video/ltx_video/inference.py:569, in infer(config)
560 sample = {
561 "prompt": config.prompt,
562 "prompt_attention_mask": None,
563 "negative_prompt": config.negative_prompt,
564 "negative_prompt_attention_mask": None,
565 }
567 generator = torch.Generator(device=device).manual_seed(config.seed)
--> 569 images = pipeline(
570 **pipeline_config,
571 skip_layer_strategy=skip_layer_strategy,
572 generator=generator,
573 output_type="pt",
574 callback_on_step_end=None,
575 height=height_padded,
576 width=width_padded,
577 num_frames=num_frames_padded,
578 frame_rate=config.frame_rate,
579 **sample,
580 media_items=media_item,
581 conditioning_items=conditioning_items,
582 is_video=True,
583 vae_per_channel_normalize=True,
584 image_cond_noise_scale=config.image_cond_noise_scale,
585 mixed_precision=(precision == "mixed_precision"),
586 offload_to_cpu=offload_to_cpu,
587 device=device,
588 enhance_prompt=enhance_prompt,
589 ).images
File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1865, in LTXMultiScalePipeline.__call__(self, downscale_factor, first_pass, second_pass, *args, **kwargs)
1863 kwargs["height"] = downscaled_height
1864 kwargs.update(**first_pass)
-> 1865 result = self.video_pipeline(*args, **kwargs)
1866 latents = result.images
1868 upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1206, in LTXVideoPipeline.__call__(self, height, width, num_frames, frame_rate, prompt, negative_prompt, num_inference_steps, skip_initial_inference_steps, skip_final_inference_steps, timesteps, guidance_scale, cfg_star_rescale, skip_layer_strategy, skip_block_list, stg_scale, rescaling_scale, guidance_timesteps, num_images_per_prompt, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, callback_on_step_end, conditioning_items, decode_timestep, decode_noise_scale, mixed_precision, offload_to_cpu, enhance_prompt, text_encoder_max_tokens, stochastic_sampling, media_items, tone_map_compression_ratio, **kwargs)
1204 # predict noise model_output
1205 with context_manager:
-> 1206 noise_pred = self.transformer(
1207 latent_model_input.to(self.transformer.dtype),
1208 indices_grid=fractional_coords,
1209 encoder_hidden_states=prompt_embeds_batch[indices].to(
1210 self.transformer.dtype
1211 ),
1212 encoder_attention_mask=prompt_attention_mask_batch[indices],
1213 timestep=current_timestep,
1214 skip_layer_mask=skip_layer_mask,
1215 skip_layer_strategy=skip_layer_strategy,
1216 return_dict=False,
1217 )[0]
1219 # perform guidance
1220 if do_spatio_temporal_guidance:
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
File /workspace/LTX-Video/ltx_video/models/transformers/transformer3d.py:478, in Transformer3DModel.forward(self, hidden_states, indices_grid, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, skip_layer_mask, skip_layer_strategy, return_dict)
459 hidden_states = torch.utils.checkpoint.checkpoint(
460 create_custom_forward(block),
461 hidden_states,
(...) 475 **ckpt_kwargs,
476 )
477 else:
--> 478 hidden_states = block(
479 hidden_states,
480 freqs_cis=freqs_cis,
481 attention_mask=attention_mask,
482 encoder_hidden_states=encoder_hidden_states,
483 encoder_attention_mask=encoder_attention_mask,
484 timestep=timestep,
485 cross_attention_kwargs=cross_attention_kwargs,
486 class_labels=class_labels,
487 skip_layer_mask=(
488 skip_layer_mask[block_idx]
489 if skip_layer_mask is not None
490 else None
491 ),
492 skip_layer_strategy=skip_layer_strategy,
493 )
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:380, in create_forwards.<locals>.fused_forward(self, hidden_states, freqs_cis, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, sharding_mesh, skip_layer_mask, skip_layer_strategy)
375 # 1. Prepare GLIGEN inputs
376 cross_attention_kwargs = (
377 cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
378 )
--> 380 attn_output = self.attn1(
381 norm_hidden_states,
382 norm_hidden_states_scales,
383 freqs_cis=freqs_cis,
384 encoder_hidden_states=(
385 encoder_hidden_states if self.only_cross_attention else None
386 ),
387 attention_mask=attention_mask,
388 skip_layer_mask=skip_layer_mask,
389 skip_layer_strategy=skip_layer_strategy,
390 **cross_attention_kwargs,
391 )
392 if gate_msa is not None:
393 attn_output = gate_msa * attn_output
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:67, in attn_forward(self, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, skip_layer_mask, skip_layer_strategy, **cross_attention_kwargs)
59 logger.warning(
60 f"cross_attention_kwargs {unused_kwargs} are not expected by"
61 f" {self.processor.__class__.__name__} and will be ignored."
62 )
63 cross_attention_kwargs = {
64 k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
65 }
---> 67 return self.processor(
68 self,
69 hidden_states,
70 hidden_states_scales,
71 freqs_cis=freqs_cis,
72 encoder_hidden_states=encoder_hidden_states,
73 attention_mask=attention_mask,
74 skip_layer_mask=skip_layer_mask,
75 skip_layer_strategy=skip_layer_strategy,
76 **cross_attention_kwargs,
77 )
File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:158, in create_attn_processor.<locals>.AttnProcessor3_0.__call__(self, attn, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, temb, skip_layer_mask, skip_layer_strategy, *args, **kwargs)
155 else: # if no context provided do self-attention
156 is_self_attention = True
--> 158 query = attn.to_q(
159 hidden_states, hidden_states_scales, False, torch.bfloat16
160 )
161 query = rms_norm_rope(
162 query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight
163 )
165 key = attn.to_k(
166 hidden_states, hidden_states_scales, False, torch.bfloat16
167 )
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were givenReproduction Steps
- Set up LTX-Video environment with FP8 distilled model
- Configure inference with the following parameters:
PROMPT = [your_prompt_here] HEIGHT = 960 WIDTH = 1280 NUM_FRAMES = 81 infer( InferenceConfig( pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml", prompt=PROMPT, height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, output_path="/workspace/output.mp4", ) )
- Execute inference - error occurs during transformer forward pass
Root Cause Analysis
The error occurs in the q8_kernels integration where attn.to_q() is called with 4 arguments:
query = attn.to_q(
hidden_states, hidden_states_scales, False, torch.bfloat16
)However, the FP8Linear.forward() method only accepts 2-4 positional arguments, suggesting a signature mismatch between the expected interface and the actual implementation.
Environment Information
- Model Config:
configs/ltxv-13b-0.9.8-distilled-fp8.yaml - Python: 3.11
- PyTorch: 2.7.1+cu126
- q8_kernels: q8_kernels @ file:///workspace/LTX-Video/LTX-Video-Q8-Kernels/dist/q8_kernels-0.0.5-cp311-cp311-linux_x86_64.whl#sha256=54e6ab58fa9acccfd60316e876e0807ced33df91929f59415d9b57a893f0acb7
- Platform: Linux workspace environment
Metadata
Metadata
Assignees
Labels
No labels