Skip to content

Commit d39be11

Browse files
committed
fix(test): enable TRT-RTX refit and engine cache tests
Now that #4181 removed the RTX-specific batch norm workaround that bypassed constant folding, the refit bug (#3752) is resolved — eps constants are no longer created as separate CONSTANT layers on RTX. Remove the RTX skip decorators from: - test_dynamo_compile_with_refittable_weight_stripped_engine - test_dynamo_compile_with_custom_engine_cache - test_dynamo_compile_change_input_shape Keep the RTX skip on test_caching_small_model, which fails a timing assertion (cached compilation is slower than uncached on RTX). Update the skip message to reflect the actual reason. Fix import ordering in test_weight_stripped_engine.py: tensorrt must be imported after torch_tensorrt so the tensorrt_rtx module alias is resolved correctly. Fixes #3752
1 parent 112e670 commit d39be11

2 files changed

Lines changed: 3 additions & 18 deletions

File tree

tests/py/dynamo/models/test_engine_cache.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
271271
@unittest.skipIf(
272272
not importlib.util.find_spec("torchvision"), "torchvision not installed"
273273
)
274-
@unittest.skipIf(
275-
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
276-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
277-
"There is bug in refit, so we skip the test for now",
278-
)
279274
def test_dynamo_compile_with_custom_engine_cache(self):
280275
model = models.resnet18(pretrained=True).eval().to("cuda")
281276

@@ -347,11 +342,6 @@ def test_dynamo_compile_with_custom_engine_cache(self):
347342
@unittest.skipIf(
348343
not importlib.util.find_spec("torchvision"), "torchvision not installed"
349344
)
350-
@unittest.skipIf(
351-
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
352-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
353-
"There is bug in refit, so we skip the test for now",
354-
)
355345
def test_dynamo_compile_change_input_shape(self):
356346
"""Runs compilation 3 times, the cache should miss each time"""
357347
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -673,8 +663,7 @@ def forward(self, c, d):
673663
)
674664
@unittest.skipIf(
675665
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
676-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
677-
"There is bug in refit, so we skip the test for now",
666+
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
678667
)
679668
def test_caching_small_model(self):
680669
from torch_tensorrt.dynamo._refit import refit_module_weights

tests/py/dynamo/models/test_weight_stripped_engine.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import shutil
55
import unittest
66

7-
import tensorrt as trt
87
import torch
98
import torch_tensorrt as torch_trt
109
from torch.testing._internal.common_utils import TestCase
@@ -13,6 +12,8 @@
1312
from torch_tensorrt.dynamo._refit import refit_module_weights
1413
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
1514

15+
import tensorrt as trt # isort: skip # must import after torch_tensorrt to resolve tensorrt_rtx alias
16+
1617
assertions = unittest.TestCase()
1718

1819
if importlib.util.find_spec("torchvision"):
@@ -277,11 +278,6 @@ def test_engine_caching_saves_weight_stripped_engine(self):
277278
not importlib.util.find_spec("torchvision"),
278279
"torchvision is not installed",
279280
)
280-
@unittest.skipIf(
281-
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
282-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
283-
"There is bug in refit, so we skip the test for now",
284-
)
285281
def test_dynamo_compile_with_refittable_weight_stripped_engine(self):
286282
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
287283
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)

0 commit comments

Comments
 (0)