-
Notifications
You must be signed in to change notification settings - Fork 394
fix(test): enable TRT-RTX refit and engine cache tests #4192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -486,7 +486,7 @@ def _save_weight_mapping(self) -> None: | |
| sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} | ||
| weight_name_map: dict[str, Any] = {} | ||
| weight_refit_map = self.ctx.weight_refit_map | ||
| constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} | ||
| constant_mapping = {k: v for k, v in weight_refit_map.items() if v.numel() == 1} | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This path was exercised with
|
||
| net = self.ctx.net | ||
| for i in range(net.num_layers): | ||
| layer = net[i] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -268,11 +268,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): | |
| @unittest.skipIf( | ||
| not importlib.util.find_spec("torchvision"), "torchvision not installed" | ||
| ) | ||
| @unittest.skipIf( | ||
| torch_trt.ENABLED_FEATURES.tensorrt_rtx, | ||
| # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 | ||
| "There is bug in refit, so we skip the test for now", | ||
| ) | ||
| def test_dynamo_compile_with_custom_engine_cache(self): | ||
| model = models.resnet18(pretrained=True).eval().to("cuda") | ||
|
|
||
|
|
@@ -342,11 +337,6 @@ def test_dynamo_compile_with_custom_engine_cache(self): | |
| @unittest.skipIf( | ||
| not importlib.util.find_spec("torchvision"), "torchvision not installed" | ||
| ) | ||
| @unittest.skipIf( | ||
| torch_trt.ENABLED_FEATURES.tensorrt_rtx, | ||
| # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 | ||
| "There is bug in refit, so we skip the test for now", | ||
| ) | ||
| def test_dynamo_compile_change_input_shape(self): | ||
| """Runs compilation 3 times, the cache should miss each time""" | ||
| model = models.resnet18(pretrained=True).eval().to("cuda") | ||
|
|
@@ -659,8 +649,7 @@ def forward(self, c, d): | |
| ) | ||
| @unittest.skipIf( | ||
| torch_trt.ENABLED_FEATURES.tensorrt_rtx, | ||
| # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 | ||
| "There is bug in refit, so we skip the test for now", | ||
| "Engine caching compilation time assertion is unreliable with TensorRT-RTX", | ||
|
lanluo-nvidia marked this conversation as resolved.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirmed this is fixed in 1.5. TRT-RTX doesn't cache refit graphs/kernels, so upon refitting, we are essentially recompiling kernels. Refit kernels are now generated AoT with v1.5, so caching behavior can be restored. |
||
| ) | ||
| def test_caching_small_model(self): | ||
| from torch_tensorrt.dynamo._refit import refit_module_weights | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a result of fix #1 in https://github.com/pytorch/TensorRT/pull/4192/changes#r3192971184 and from previous commit (0273726) all constant_mapping items are torch tensors and not np arrays.