Skip to content

Commit d017ed7

Browse files
committed
fix(test): enable TRT-RTX refit and engine cache tests
Now that #4181 removed the RTX-specific batch norm workaround that bypassed constant folding, the refit bug (#3752) is resolved — eps constants are no longer created as separate CONSTANT layers on RTX. Remove the RTX skip decorators from: - test_dynamo_compile_with_refittable_weight_stripped_engine - test_dynamo_compile_with_custom_engine_cache - test_dynamo_compile_change_input_shape Keep the RTX skip on test_caching_small_model, which fails a timing assertion (cached compilation is slower than uncached on RTX). Update the skip message to reflect the actual reason. Fix import ordering in test_weight_stripped_engine.py: tensorrt must be imported after torch_tensorrt so the tensorrt_rtx module alias is resolved correctly. Fixes #3752
1 parent 2233edb commit d017ed7

2 files changed

Lines changed: 3 additions & 18 deletions

File tree

tests/py/dynamo/models/test_engine_cache.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
268268
@unittest.skipIf(
269269
not importlib.util.find_spec("torchvision"), "torchvision not installed"
270270
)
271-
@unittest.skipIf(
272-
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
273-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
274-
"There is bug in refit, so we skip the test for now",
275-
)
276271
def test_dynamo_compile_with_custom_engine_cache(self):
277272
model = models.resnet18(pretrained=True).eval().to("cuda")
278273

@@ -342,11 +337,6 @@ def test_dynamo_compile_with_custom_engine_cache(self):
342337
@unittest.skipIf(
343338
not importlib.util.find_spec("torchvision"), "torchvision not installed"
344339
)
345-
@unittest.skipIf(
346-
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
347-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
348-
"There is bug in refit, so we skip the test for now",
349-
)
350340
def test_dynamo_compile_change_input_shape(self):
351341
"""Runs compilation 3 times, the cache should miss each time"""
352342
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -659,8 +649,7 @@ def forward(self, c, d):
659649
)
660650
@unittest.skipIf(
661651
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
662-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
663-
"There is bug in refit, so we skip the test for now",
652+
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
664653
)
665654
def test_caching_small_model(self):
666655
from torch_tensorrt.dynamo._refit import refit_module_weights

tests/py/dynamo/models/test_weight_stripped_engine.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import shutil
55
import unittest
66

7-
import tensorrt as trt
87
import torch
98
import torch_tensorrt as torch_trt
109
from torch.testing._internal.common_utils import TestCase
@@ -13,6 +12,8 @@
1312
from torch_tensorrt.dynamo._refit import refit_module_weights
1413
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
1514

15+
import tensorrt as trt # isort: skip # must import after torch_tensorrt to resolve tensorrt_rtx alias
16+
1617
assertions = unittest.TestCase()
1718

1819
if importlib.util.find_spec("torchvision"):
@@ -272,11 +273,6 @@ def test_engine_caching_saves_weight_stripped_engine(self):
272273
not importlib.util.find_spec("torchvision"),
273274
"torchvision is not installed",
274275
)
275-
@unittest.skipIf(
276-
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
277-
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
278-
"There is bug in refit, so we skip the test for now",
279-
)
280276
def test_dynamo_compile_with_refittable_weight_stripped_engine(self):
281277
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
282278
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)

0 commit comments

Comments
 (0)