Skip to content

Conv Fails in Float16 but Passes in Float32 #694

@Mgluhovskoi

Description

@Mgluhovskoi

Description

When running a tp.Conv operation in float16, the build fails with:

MTRTException: failed to run pass pipeline
IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node [tensorrt.convolution] ...)

The same code executes successfully in float32.


Steps to Reproduce

  1. Set tp_dtype, np_dtype, and torch_dtype to either float16 or float32 in test_equiv_tripy_cosmos_patch_embed3d().
  2. Run:
    python conv_bug.py

Expected Behavior

The convolution should execute successfully in both float32 and float16.


Actual Behavior

  • float32 → Passes.
  • float16 → Fails during TensorRT compilation with missing convolution implementation error.

Full Error Trace

root@83bc71cc371d:/diffusers/tripy_autoencoder_kl_cosmos# python conv_bug.py 
Traceback (most recent call last):
  File "/diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py", line 179, in <module>
    test_equiv_tripy_cosmos_patch_embed3d()
  File "/diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py", line 175, in test_equiv_tripy_cosmos_patch_embed3d
    y_tripy = tripy_patch(x_tripy).eval().tolist()
  File "/usr/local/lib/python3.9/site-packages/nvtripy/utils/function_registry.py", line 383, in wrapper
    return self.find_overload(key, args, kwargs)(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/nvtripy/utils/function_registry.py", line 257, in __call__
    return self.func(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/nvtripy/frontend/tensor.py", line 263, in eval
    compiler.compile(mlir, trace=trace),
  File "/usr/local/lib/python3.9/site-packages/nvtripy/utils/utils.py", line 77, in wrapper
    result = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/nvtripy/backend/mlir/compiler.py", line 83, in compile
    map_error_to_user_code_and_raise(trace, exc, stderr.decode())
  File "/usr/local/lib/python3.9/site-packages/nvtripy/backend/mlir/utils.py", line 251, in map_error_to_user_code_and_raise
    raise_error(
  File "/usr/local/lib/python3.9/site-packages/nvtripy/common/exception.py", line 165, in raise_error
    raise TripyException(msg) from None
nvtripy.common.exception.TripyException: 

--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:175 in test_equiv_tripy_cosmos_patch_embed3d()
  175 |     y_tripy = tripy_patch(x_tripy).eval().tolist()
      | 
MTRTException: failed to run pass pipeline
    IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node [tensorrt.convolution] %t619,%t221;;<out>;;%t620.)
    (%t639) error: failed to translate function 'tensorrt_cluster' to a TensorRT engine

    This originated from the following operation:

    --> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:111 in _dwt()
      111 |             hidden_states = hidden_states / math.sqrt(8)
          |                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    --> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:119 in _haar()
      119 |             hidden_states = self._dwt(hidden_states, rescale=True)
          |                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
    --> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:135 in forward()
      135 |             return self._haar(hidden_states)
          |                    ^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
    --> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:175 in test_equiv_tripy_cosmos_patch_embed3d()
      175 |     y_tripy = tripy_patch(x_tripy).eval().tolist()
          | 
    --> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:179 in <module>()
      179 |     test_equiv_tripy_cosmos_patch_embed3d()
          |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

Standalone Reproduction Code

import numpy as np
import nvtripy as tp
import torch
import math

class TripyCosmosPatchEmbed3d(tp.Module):
    def __init__(self, patch_size: int = 1, patch_method: str = "haar", wavelets=None, groups: int = None, dtype=tp.float32):
        super().__init__()
        self.patch_size = patch_size
        self.patch_method = patch_method
        self.wavelets = wavelets
        self._arange = tp.arange(self.wavelets.shape[0], dtype=dtype)
        self.dtype = dtype
        # Note: groups is the initial number of input channels, but we don't pre-create convolutions
        # because the number of channels changes dynamically during _dwt steps

    def _create_conv_weights(self, num_channels: int):
        """Create convolution weights for the current number of channels."""
        n = self.wavelets.shape[0]
        
        # Create base weights
        hl = tp.flip(self.wavelets, 0)
        hl = tp.reshape(hl, (1, 1, n))
        hh = self.wavelets * ((-1) ** self._arange)
        hh = tp.reshape(hh, (1, 1, n))
        
        # Repeat for all channels
        hl = tp.repeat(hl, num_channels, dim=0)  # [num_channels, 1, n]
        hh = tp.repeat(hh, num_channels, dim=0)  # [num_channels, 1, n]
        
        # Create 5D weights for different convolution types
        hl_t = tp.unsqueeze(tp.unsqueeze(hl, 3), 4)  # [num_channels, 1, n, 1, 1] for temporal
        hh_t = tp.unsqueeze(tp.unsqueeze(hh, 3), 4)  # [num_channels, 1, n, 1, 1] for temporal
        hl_h = tp.unsqueeze(tp.unsqueeze(hl, 2), 4)  # [num_channels, 1, 1, n, 1] for height
        hh_h = tp.unsqueeze(tp.unsqueeze(hh, 2), 4)  # [num_channels, 1, 1, n, 1] for height
        hl_w = tp.unsqueeze(tp.unsqueeze(hl, 2), 3)  # [num_channels, 1, 1, 1, n] for width
        hh_w = tp.unsqueeze(tp.unsqueeze(hh, 2), 3)  # [num_channels, 1, 1, 1, n] for width
        
        return hl_t, hh_t, hl_h, hh_h, hl_w, hh_w

    def _create_conv(self, weight, stride, num_channels: int):
        """Create a convolution with the given weight and parameters."""
        kernel_dims = tuple(int(x) for x in weight.shape[2:])
        conv = tp.Conv(
            in_channels=num_channels,
            out_channels=num_channels,
            kernel_dims=kernel_dims,
            stride=stride,
            padding=[(0, 0), (0, 0), (0, 0)],
            groups=num_channels,
            bias=False,
            dtype=self.dtype,
        )
        conv.weight = weight
        return conv

    def _dwt(self, hidden_states, mode="constant", rescale=False):
        """Apply discrete wavelet transform."""
        num_channels = int(hidden_states.shape[1])
        n = self.wavelets.shape[0]
        
        # Create weights for current number of channels
        hl_t, hh_t, hl_h, hh_h, hl_w, hh_w = self._create_conv_weights(num_channels)
        
        # Create convolutions
        conv_xl = self._create_conv(hl_t, (2, 1, 1), num_channels)
        conv_xh = self._create_conv(hh_t, (2, 1, 1), num_channels)
        conv_xll = self._create_conv(hl_h, (1, 2, 1), num_channels)
        conv_xlh = self._create_conv(hh_h, (1, 2, 1), num_channels)
        conv_xhl = self._create_conv(hl_h, (1, 2, 1), num_channels)
        conv_xhh = self._create_conv(hh_h, (1, 2, 1), num_channels)
        conv_xlll = self._create_conv(hl_w, (1, 1, 2), num_channels)
        conv_xllh = self._create_conv(hh_w, (1, 1, 2), num_channels)
        conv_xlhl = self._create_conv(hl_w, (1, 1, 2), num_channels)
        conv_xlhh = self._create_conv(hh_w, (1, 1, 2), num_channels)
        conv_xhll = self._create_conv(hl_w, (1, 1, 2), num_channels)
        conv_xhlh = self._create_conv(hh_w, (1, 1, 2), num_channels)
        conv_xhhl = self._create_conv(hl_w, (1, 1, 2), num_channels)
        conv_xhhh = self._create_conv(hh_w, (1, 1, 2), num_channels)
        
        # Pad input
        pad = (max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1)
        hidden_states = tp.pad(hidden_states, [
            (0, 0),  # batch
            (0, 0),  # channel
            (pad[4], pad[5]),  # depth
            (pad[2], pad[3]),  # height
            (pad[0], pad[1]),  # width
        ], mode="constant", value=0.0)
        
        # Apply convolutions
        xl = conv_xl(hidden_states)
        xh = conv_xh(hidden_states)
        xll = conv_xll(xl)
        xlh = conv_xlh(xl)
        xhl = conv_xhl(xh)
        xhh = conv_xhh(xh)
        xlll = conv_xlll(xll)
        xllh = conv_xllh(xll)
        xlhl = conv_xlhl(xlh)
        xlhh = conv_xlhh(xlh)
        xhll = conv_xhll(xhl)
        xhlh = conv_xhlh(xhl)
        xhhl = conv_xhhl(xhh)
        xhhh = conv_xhhh(xhh)
        
        # Concatenate all filtered outputs
        hidden_states = tp.concatenate([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
        
        if rescale:
            hidden_states = hidden_states / math.sqrt(8)
        return hidden_states

    def _haar(self, hidden_states):
        """Apply Haar wavelet transform."""
        xi, xv = tp.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
        hidden_states = tp.concatenate([tp.repeat(xi, self.patch_size, dim=2), xv], dim=2)
        for _ in range(int(math.log2(self.patch_size))):
            hidden_states = self._dwt(hidden_states, rescale=True)
        return hidden_states
    
    def _arrange(self, hidden_states):
        """Apply rearrange patch embedding."""
        xi, xv = tp.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
        hidden_states = tp.concatenate([tp.repeat(xi, self.patch_size, dim=2), xv], dim=2)
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p = self.patch_size
        hidden_states = tp.reshape(hidden_states, (batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p))
        hidden_states = tp.permute(hidden_states, (0, 1, 3, 5, 7, 2, 4, 6))
        hidden_states = tp.reshape(hidden_states, (batch_size, num_channels * p * p * p, num_frames // p, height // p, width // p))
        return hidden_states

    def forward(self, hidden_states):
        if self.patch_method == "haar":
            return self._haar(hidden_states)
        elif self.patch_method == "rearrange":
            return self._arrange(hidden_states)
        else:
            raise ValueError(f"Unsupported patch method: {self.patch_method}")
        


_WAVELETS = {
    "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
    "rearrange": torch.tensor([1.0, 1.0]),
}

def test_equiv_tripy_cosmos_patch_embed3d():
    patch_size = 2
    patch_method = "haar"
    batch = 2
    channels = 3
    time = 4
    height = 8
    width = 8
    tp_dtype = tp.float32
    np_dtype = np.float32
    torch_dtype = torch.float32
    # Use the same wavelets as PyTorch
    wavelets_torch = _WAVELETS[patch_method].clone().float()
    wavelets_np = wavelets_torch.detach().cpu().numpy()
    wavelets_tripy = tp.Tensor(wavelets_np, dtype=tp_dtype)
    # Dummy input
    x_np = np.random.randn(batch, channels, time, height, width).astype(np_dtype)
    x_torch = torch.tensor(x_np, dtype=torch_dtype)
    x_tripy = tp.Tensor(x_np, dtype=tp_dtype)
    # Tripy module
    tripy_patch = TripyCosmosPatchEmbed3d(
        patch_size=patch_size,
        patch_method=patch_method,
        wavelets=wavelets_tripy,
        dtype=tp_dtype,
    )
    # Forward
    y_tripy = tripy_patch(x_tripy).eval().tolist()
    y_tripy = np.array(y_tripy)
    print("Test passed!")

if __name__ == "__main__":
    test_equiv_tripy_cosmos_patch_embed3d()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions