Skip to content

[BUG]FP8Linear.forward() argument mismatch in LTX-Video inference #231

@ighoshsubho

Description

@ighoshsubho

Encountering a TypeError during LTX-Video inference where FP8Linear.forward() receives 5 positional arguments but only accepts 2-4. This appears to be related to the q8_kernels integration with the diffusers pipeline.

Error Details

TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given

Full Stack Trace

/workspace/LTX-Video/env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Padded dimensions: 960x1280x81
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 55.41it/s]
/workspace/LTX-Video/env/lib/python3.11/site-packages/torch/functional.py:554: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4314.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  0%|          | 0/7 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 8
      5 WIDTH = 1280
      6 NUM_FRAMES = 81
----> 8 infer(
      9     InferenceConfig(
     10         pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
     11         prompt=PROMPT,
     12         height=HEIGHT,
     13         width=WIDTH,
     14         num_frames=NUM_FRAMES,
     15         output_path="/workspace/output.mp4",
     16     )
     17 )

File /workspace/LTX-Video/ltx_video/inference.py:569, in infer(config)
    560 sample = {
    561     "prompt": config.prompt,
    562     "prompt_attention_mask": None,
    563     "negative_prompt": config.negative_prompt,
    564     "negative_prompt_attention_mask": None,
    565 }
    567 generator = torch.Generator(device=device).manual_seed(config.seed)
--> 569 images = pipeline(
    570     **pipeline_config,
    571     skip_layer_strategy=skip_layer_strategy,
    572     generator=generator,
    573     output_type="pt",
    574     callback_on_step_end=None,
    575     height=height_padded,
    576     width=width_padded,
    577     num_frames=num_frames_padded,
    578     frame_rate=config.frame_rate,
    579     **sample,
    580     media_items=media_item,
    581     conditioning_items=conditioning_items,
    582     is_video=True,
    583     vae_per_channel_normalize=True,
    584     image_cond_noise_scale=config.image_cond_noise_scale,
    585     mixed_precision=(precision == "mixed_precision"),
    586     offload_to_cpu=offload_to_cpu,
    587     device=device,
    588     enhance_prompt=enhance_prompt,
    589 ).images

File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1865, in LTXMultiScalePipeline.__call__(self, downscale_factor, first_pass, second_pass, *args, **kwargs)
   1863 kwargs["height"] = downscaled_height
   1864 kwargs.update(**first_pass)
-> 1865 result = self.video_pipeline(*args, **kwargs)
   1866 latents = result.images
   1868 upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1206, in LTXVideoPipeline.__call__(self, height, width, num_frames, frame_rate, prompt, negative_prompt, num_inference_steps, skip_initial_inference_steps, skip_final_inference_steps, timesteps, guidance_scale, cfg_star_rescale, skip_layer_strategy, skip_block_list, stg_scale, rescaling_scale, guidance_timesteps, num_images_per_prompt, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, callback_on_step_end, conditioning_items, decode_timestep, decode_noise_scale, mixed_precision, offload_to_cpu, enhance_prompt, text_encoder_max_tokens, stochastic_sampling, media_items, tone_map_compression_ratio, **kwargs)
   1204 # predict noise model_output
   1205 with context_manager:
-> 1206     noise_pred = self.transformer(
   1207         latent_model_input.to(self.transformer.dtype),
   1208         indices_grid=fractional_coords,
   1209         encoder_hidden_states=prompt_embeds_batch[indices].to(
   1210             self.transformer.dtype
   1211         ),
   1212         encoder_attention_mask=prompt_attention_mask_batch[indices],
   1213         timestep=current_timestep,
   1214         skip_layer_mask=skip_layer_mask,
   1215         skip_layer_strategy=skip_layer_strategy,
   1216         return_dict=False,
   1217     )[0]
   1219 # perform guidance
   1220 if do_spatio_temporal_guidance:

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

File /workspace/LTX-Video/ltx_video/models/transformers/transformer3d.py:478, in Transformer3DModel.forward(self, hidden_states, indices_grid, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, skip_layer_mask, skip_layer_strategy, return_dict)
    459         hidden_states = torch.utils.checkpoint.checkpoint(
    460             create_custom_forward(block),
    461             hidden_states,
   (...)    475             **ckpt_kwargs,
    476         )
    477     else:
--> 478         hidden_states = block(
    479             hidden_states,
    480             freqs_cis=freqs_cis,
    481             attention_mask=attention_mask,
    482             encoder_hidden_states=encoder_hidden_states,
    483             encoder_attention_mask=encoder_attention_mask,
    484             timestep=timestep,
    485             cross_attention_kwargs=cross_attention_kwargs,
    486             class_labels=class_labels,
    487             skip_layer_mask=(
    488                 skip_layer_mask[block_idx]
    489                 if skip_layer_mask is not None
    490                 else None
    491             ),
    492             skip_layer_strategy=skip_layer_strategy,
    493         )

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:380, in create_forwards.<locals>.fused_forward(self, hidden_states, freqs_cis, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, sharding_mesh, skip_layer_mask, skip_layer_strategy)
    375 # 1. Prepare GLIGEN inputs
    376 cross_attention_kwargs = (
    377     cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
    378 )
--> 380 attn_output = self.attn1(
    381     norm_hidden_states,
    382     norm_hidden_states_scales,
    383     freqs_cis=freqs_cis,
    384     encoder_hidden_states=(
    385         encoder_hidden_states if self.only_cross_attention else None
    386     ),
    387     attention_mask=attention_mask,
    388     skip_layer_mask=skip_layer_mask,
    389     skip_layer_strategy=skip_layer_strategy,
    390     **cross_attention_kwargs,
    391 )
    392 if gate_msa is not None:
    393     attn_output = gate_msa * attn_output

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:67, in attn_forward(self, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, skip_layer_mask, skip_layer_strategy, **cross_attention_kwargs)
     59     logger.warning(
     60         f"cross_attention_kwargs {unused_kwargs} are not expected by"
     61         f" {self.processor.__class__.__name__} and will be ignored."
     62     )
     63 cross_attention_kwargs = {
     64     k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
     65 }
---> 67 return self.processor(
     68     self,
     69     hidden_states,
     70     hidden_states_scales,
     71     freqs_cis=freqs_cis,
     72     encoder_hidden_states=encoder_hidden_states,
     73     attention_mask=attention_mask,
     74     skip_layer_mask=skip_layer_mask,
     75     skip_layer_strategy=skip_layer_strategy,
     76     **cross_attention_kwargs,
     77 )

File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:158, in create_attn_processor.<locals>.AttnProcessor3_0.__call__(self, attn, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, temb, skip_layer_mask, skip_layer_strategy, *args, **kwargs)
    155 else:  # if no context provided do self-attention
    156     is_self_attention = True
--> 158     query = attn.to_q(
    159         hidden_states, hidden_states_scales, False, torch.bfloat16
    160     )
    161     query = rms_norm_rope(
    162         query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight
    163     )
    165     key = attn.to_k(
    166         hidden_states, hidden_states_scales, False, torch.bfloat16
    167     )

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given

Reproduction Steps

  1. Set up LTX-Video environment with FP8 distilled model
  2. Configure inference with the following parameters:
    PROMPT = [your_prompt_here]
    HEIGHT = 960
    WIDTH = 1280
    NUM_FRAMES = 81
    
    infer(
        InferenceConfig(
            pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
            prompt=PROMPT,
            height=HEIGHT,
            width=WIDTH,
            num_frames=NUM_FRAMES,
            output_path="/workspace/output.mp4",
        )
    )
  3. Execute inference - error occurs during transformer forward pass

Root Cause Analysis

The error occurs in the q8_kernels integration where attn.to_q() is called with 4 arguments:

query = attn.to_q(
    hidden_states, hidden_states_scales, False, torch.bfloat16
)

However, the FP8Linear.forward() method only accepts 2-4 positional arguments, suggesting a signature mismatch between the expected interface and the actual implementation.

Environment Information

  • Model Config: configs/ltxv-13b-0.9.8-distilled-fp8.yaml
  • Python: 3.11
  • PyTorch: 2.7.1+cu126
  • q8_kernels: q8_kernels @ file:///workspace/LTX-Video/LTX-Video-Q8-Kernels/dist/q8_kernels-0.0.5-cp311-cp311-linux_x86_64.whl#sha256=54e6ab58fa9acccfd60316e876e0807ced33df91929f59415d9b57a893f0acb7
  • Platform: Linux workspace environment

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions