Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ def _refit_single_trt_engine_with_gm(
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
weight_dtype = val.dtype
val_tensor = val.cuda()
trt_dtype = dtype._from(weight_dtype).to(trt.DataType)
torch_dtype = dtype._from(weight_dtype).to(torch.dtype)
Comment on lines +180 to +183
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a result of fix #1 in https://github.com/pytorch/TensorRT/pull/4192/changes#r3192971184 and from previous commit (0273726) all constant_mapping items are torch tensors and not np arrays.

constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def _save_weight_mapping(self) -> None:
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.numel() == 1}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path was exercised with

  • tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_dynamo_compile_with_custom_engine_cache
  • tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_torch_compile_with_default_disk_engine_cache (XFAILED)
    in the RTX path only. In the earlier diff Tensor.size is a method (and it was not called), so the comparison was always False and constant_mapping is always empty. The current PR fixes this to allow CI to pass.

net = self.ctx.net
for i in range(net.num_layers):
layer = net[i]
Expand Down
13 changes: 1 addition & 12 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
)
def test_dynamo_compile_with_custom_engine_cache(self):
model = models.resnet18(pretrained=True).eval().to("cuda")

Expand Down Expand Up @@ -342,11 +337,6 @@ def test_dynamo_compile_with_custom_engine_cache(self):
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
)
def test_dynamo_compile_change_input_shape(self):
"""Runs compilation 3 times, the cache should miss each time"""
model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -659,8 +649,7 @@ def forward(self, c, d):
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
Comment thread
lanluo-nvidia marked this conversation as resolved.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed this is fixed in 1.5. TRT-RTX doesn't cache refit graphs/kernels, so upon refitting, we are essentially recompiling kernels. Refit kernels are now generated AoT with v1.5, so caching behavior can be restored.

)
def test_caching_small_model(self):
from torch_tensorrt.dynamo._refit import refit_module_weights
Expand Down
6 changes: 3 additions & 3 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import shutil
import unittest

import tensorrt as trt
import torch
import torch_tensorrt as torch_trt
from torch.testing._internal.common_utils import TestCase
Expand All @@ -13,6 +12,8 @@
from torch_tensorrt.dynamo._refit import refit_module_weights
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

import tensorrt as trt # isort: skip # must import after torch_tensorrt to resolve tensorrt_rtx alias

assertions = unittest.TestCase()

if importlib.util.find_spec("torchvision"):
Expand Down Expand Up @@ -274,8 +275,7 @@ def test_engine_caching_saves_weight_stripped_engine(self):
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
)
def test_dynamo_compile_with_refittable_weight_stripped_engine(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down
Loading