-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Description
When running a tp.Conv
operation in float16
, the build fails with:
MTRTException: failed to run pass pipeline
IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node [tensorrt.convolution] ...)
The same code executes successfully in float32
.
Steps to Reproduce
- Set
tp_dtype
,np_dtype
, andtorch_dtype
to either float16 or float32 intest_equiv_tripy_cosmos_patch_embed3d()
. - Run:
python conv_bug.py
Expected Behavior
The convolution should execute successfully in both float32
and float16
.
Actual Behavior
- float32 → Passes.
- float16 → Fails during TensorRT compilation with missing convolution implementation error.
Full Error Trace
root@83bc71cc371d:/diffusers/tripy_autoencoder_kl_cosmos# python conv_bug.py
Traceback (most recent call last):
File "/diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py", line 179, in <module>
test_equiv_tripy_cosmos_patch_embed3d()
File "/diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py", line 175, in test_equiv_tripy_cosmos_patch_embed3d
y_tripy = tripy_patch(x_tripy).eval().tolist()
File "/usr/local/lib/python3.9/site-packages/nvtripy/utils/function_registry.py", line 383, in wrapper
return self.find_overload(key, args, kwargs)(*args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/nvtripy/utils/function_registry.py", line 257, in __call__
return self.func(*args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/nvtripy/frontend/tensor.py", line 263, in eval
compiler.compile(mlir, trace=trace),
File "/usr/local/lib/python3.9/site-packages/nvtripy/utils/utils.py", line 77, in wrapper
result = func(*args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/nvtripy/backend/mlir/compiler.py", line 83, in compile
map_error_to_user_code_and_raise(trace, exc, stderr.decode())
File "/usr/local/lib/python3.9/site-packages/nvtripy/backend/mlir/utils.py", line 251, in map_error_to_user_code_and_raise
raise_error(
File "/usr/local/lib/python3.9/site-packages/nvtripy/common/exception.py", line 165, in raise_error
raise TripyException(msg) from None
nvtripy.common.exception.TripyException:
--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:175 in test_equiv_tripy_cosmos_patch_embed3d()
175 | y_tripy = tripy_patch(x_tripy).eval().tolist()
|
MTRTException: failed to run pass pipeline
IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node [tensorrt.convolution] %t619,%t221;;<out>;;%t620.)
(%t639) error: failed to translate function 'tensorrt_cluster' to a TensorRT engine
This originated from the following operation:
--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:111 in _dwt()
111 | hidden_states = hidden_states / math.sqrt(8)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:119 in _haar()
119 | hidden_states = self._dwt(hidden_states, rescale=True)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:135 in forward()
135 | return self._haar(hidden_states)
| ^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:175 in test_equiv_tripy_cosmos_patch_embed3d()
175 | y_tripy = tripy_patch(x_tripy).eval().tolist()
|
--> /diffusers/tripy_autoencoder_kl_cosmos/conv_bug.py:179 in <module>()
179 | test_equiv_tripy_cosmos_patch_embed3d()
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
Standalone Reproduction Code
import numpy as np
import nvtripy as tp
import torch
import math
class TripyCosmosPatchEmbed3d(tp.Module):
def __init__(self, patch_size: int = 1, patch_method: str = "haar", wavelets=None, groups: int = None, dtype=tp.float32):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.wavelets = wavelets
self._arange = tp.arange(self.wavelets.shape[0], dtype=dtype)
self.dtype = dtype
# Note: groups is the initial number of input channels, but we don't pre-create convolutions
# because the number of channels changes dynamically during _dwt steps
def _create_conv_weights(self, num_channels: int):
"""Create convolution weights for the current number of channels."""
n = self.wavelets.shape[0]
# Create base weights
hl = tp.flip(self.wavelets, 0)
hl = tp.reshape(hl, (1, 1, n))
hh = self.wavelets * ((-1) ** self._arange)
hh = tp.reshape(hh, (1, 1, n))
# Repeat for all channels
hl = tp.repeat(hl, num_channels, dim=0) # [num_channels, 1, n]
hh = tp.repeat(hh, num_channels, dim=0) # [num_channels, 1, n]
# Create 5D weights for different convolution types
hl_t = tp.unsqueeze(tp.unsqueeze(hl, 3), 4) # [num_channels, 1, n, 1, 1] for temporal
hh_t = tp.unsqueeze(tp.unsqueeze(hh, 3), 4) # [num_channels, 1, n, 1, 1] for temporal
hl_h = tp.unsqueeze(tp.unsqueeze(hl, 2), 4) # [num_channels, 1, 1, n, 1] for height
hh_h = tp.unsqueeze(tp.unsqueeze(hh, 2), 4) # [num_channels, 1, 1, n, 1] for height
hl_w = tp.unsqueeze(tp.unsqueeze(hl, 2), 3) # [num_channels, 1, 1, 1, n] for width
hh_w = tp.unsqueeze(tp.unsqueeze(hh, 2), 3) # [num_channels, 1, 1, 1, n] for width
return hl_t, hh_t, hl_h, hh_h, hl_w, hh_w
def _create_conv(self, weight, stride, num_channels: int):
"""Create a convolution with the given weight and parameters."""
kernel_dims = tuple(int(x) for x in weight.shape[2:])
conv = tp.Conv(
in_channels=num_channels,
out_channels=num_channels,
kernel_dims=kernel_dims,
stride=stride,
padding=[(0, 0), (0, 0), (0, 0)],
groups=num_channels,
bias=False,
dtype=self.dtype,
)
conv.weight = weight
return conv
def _dwt(self, hidden_states, mode="constant", rescale=False):
"""Apply discrete wavelet transform."""
num_channels = int(hidden_states.shape[1])
n = self.wavelets.shape[0]
# Create weights for current number of channels
hl_t, hh_t, hl_h, hh_h, hl_w, hh_w = self._create_conv_weights(num_channels)
# Create convolutions
conv_xl = self._create_conv(hl_t, (2, 1, 1), num_channels)
conv_xh = self._create_conv(hh_t, (2, 1, 1), num_channels)
conv_xll = self._create_conv(hl_h, (1, 2, 1), num_channels)
conv_xlh = self._create_conv(hh_h, (1, 2, 1), num_channels)
conv_xhl = self._create_conv(hl_h, (1, 2, 1), num_channels)
conv_xhh = self._create_conv(hh_h, (1, 2, 1), num_channels)
conv_xlll = self._create_conv(hl_w, (1, 1, 2), num_channels)
conv_xllh = self._create_conv(hh_w, (1, 1, 2), num_channels)
conv_xlhl = self._create_conv(hl_w, (1, 1, 2), num_channels)
conv_xlhh = self._create_conv(hh_w, (1, 1, 2), num_channels)
conv_xhll = self._create_conv(hl_w, (1, 1, 2), num_channels)
conv_xhlh = self._create_conv(hh_w, (1, 1, 2), num_channels)
conv_xhhl = self._create_conv(hl_w, (1, 1, 2), num_channels)
conv_xhhh = self._create_conv(hh_w, (1, 1, 2), num_channels)
# Pad input
pad = (max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1)
hidden_states = tp.pad(hidden_states, [
(0, 0), # batch
(0, 0), # channel
(pad[4], pad[5]), # depth
(pad[2], pad[3]), # height
(pad[0], pad[1]), # width
], mode="constant", value=0.0)
# Apply convolutions
xl = conv_xl(hidden_states)
xh = conv_xh(hidden_states)
xll = conv_xll(xl)
xlh = conv_xlh(xl)
xhl = conv_xhl(xh)
xhh = conv_xhh(xh)
xlll = conv_xlll(xll)
xllh = conv_xllh(xll)
xlhl = conv_xlhl(xlh)
xlhh = conv_xlhh(xlh)
xhll = conv_xhll(xhl)
xhlh = conv_xhlh(xhl)
xhhl = conv_xhhl(xhh)
xhhh = conv_xhhh(xhh)
# Concatenate all filtered outputs
hidden_states = tp.concatenate([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
if rescale:
hidden_states = hidden_states / math.sqrt(8)
return hidden_states
def _haar(self, hidden_states):
"""Apply Haar wavelet transform."""
xi, xv = tp.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
hidden_states = tp.concatenate([tp.repeat(xi, self.patch_size, dim=2), xv], dim=2)
for _ in range(int(math.log2(self.patch_size))):
hidden_states = self._dwt(hidden_states, rescale=True)
return hidden_states
def _arrange(self, hidden_states):
"""Apply rearrange patch embedding."""
xi, xv = tp.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
hidden_states = tp.concatenate([tp.repeat(xi, self.patch_size, dim=2), xv], dim=2)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.patch_size
hidden_states = tp.reshape(hidden_states, (batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p))
hidden_states = tp.permute(hidden_states, (0, 1, 3, 5, 7, 2, 4, 6))
hidden_states = tp.reshape(hidden_states, (batch_size, num_channels * p * p * p, num_frames // p, height // p, width // p))
return hidden_states
def forward(self, hidden_states):
if self.patch_method == "haar":
return self._haar(hidden_states)
elif self.patch_method == "rearrange":
return self._arrange(hidden_states)
else:
raise ValueError(f"Unsupported patch method: {self.patch_method}")
_WAVELETS = {
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
"rearrange": torch.tensor([1.0, 1.0]),
}
def test_equiv_tripy_cosmos_patch_embed3d():
patch_size = 2
patch_method = "haar"
batch = 2
channels = 3
time = 4
height = 8
width = 8
tp_dtype = tp.float32
np_dtype = np.float32
torch_dtype = torch.float32
# Use the same wavelets as PyTorch
wavelets_torch = _WAVELETS[patch_method].clone().float()
wavelets_np = wavelets_torch.detach().cpu().numpy()
wavelets_tripy = tp.Tensor(wavelets_np, dtype=tp_dtype)
# Dummy input
x_np = np.random.randn(batch, channels, time, height, width).astype(np_dtype)
x_torch = torch.tensor(x_np, dtype=torch_dtype)
x_tripy = tp.Tensor(x_np, dtype=tp_dtype)
# Tripy module
tripy_patch = TripyCosmosPatchEmbed3d(
patch_size=patch_size,
patch_method=patch_method,
wavelets=wavelets_tripy,
dtype=tp_dtype,
)
# Forward
y_tripy = tripy_patch(x_tripy).eval().tolist()
y_tripy = np.array(y_tripy)
print("Test passed!")
if __name__ == "__main__":
test_equiv_tripy_cosmos_patch_embed3d()
Metadata
Metadata
Assignees
Labels
No labels