fix(test): enable TRT-RTX refit and engine cache tests#4192
Open
tp5uiuc wants to merge 4 commits intopytorch:mainfrom
Open
fix(test): enable TRT-RTX refit and engine cache tests#4192tp5uiuc wants to merge 4 commits intopytorch:mainfrom
tp5uiuc wants to merge 4 commits intopytorch:mainfrom
Conversation
tp5uiuc
commented
Apr 16, 2026
d39be11 to
fe174e7
Compare
f3e7ccd to
048d538
Compare
cehongwang
approved these changes
Apr 23, 2026
a5fd05c to
c6497b4
Compare
tp5uiuc
commented
Apr 29, 2026
| torch_trt.ENABLED_FEATURES.tensorrt_rtx, | ||
| # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 | ||
| "There is bug in refit, so we skip the test for now", | ||
| "Engine caching compilation time assertion is unreliable with TensorRT-RTX", |
Contributor
Author
There was a problem hiding this comment.
Confirmed this is fixed in 1.5. TRT-RTX doesn't cache refit graphs/kernels, so upon refitting, we are essentially recompiling kernels. Refit kernels are now generated AoT with v1.5, so caching behavior can be restored.
Now that pytorch#4181 removed the RTX-specific batch norm workaround that bypassed constant folding, the refit bug (pytorch#3752) is resolved — eps constants are no longer created as separate CONSTANT layers on RTX. Remove the RTX skip decorators from: - test_dynamo_compile_with_refittable_weight_stripped_engine - test_dynamo_compile_with_custom_engine_cache - test_dynamo_compile_change_input_shape Keep the RTX skip on test_caching_small_model, which fails a timing assertion (cached compilation is slower than uncached on RTX). Update the skip message to reflect the actual reason. Fix import ordering in test_weight_stripped_engine.py: tensorrt must be imported after torch_tensorrt so the tensorrt_rtx module alias is resolved correctly. Fixes pytorch#3752
Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
c6497b4 to
ce60a5d
Compare
The fast-refit path on TensorRT-RTX was failing with "Fast refit failed on TensorRT-RTX: N of N engine weight(s) had no entry in weight_name_map" for any model containing scalar constants (e.g. batch-norm eps), because `weight_refit_map` values are torch.Tensor (since pytorch#3573) but two consumer call sites still used the old np.ndarray API: * _TRTInterpreter._construct_refit_mapping filtered scalars with `v.size == 1`. `Tensor.size` is a method, so the comparison was always False and `constant_mapping` was always empty -- scalar constants never reached the cached `weight_name_map["constant_mapping"]`. Fixed by switching to `v.numel() == 1`. * _refit_single_trt_engine_with_gm rehydrated those values via `torch.from_numpy(val).cuda()`, which raises TypeError on a Tensor. Fixed by using `val.cuda()` directly and renaming the local from `np_weight_type` to `weight_dtype` to reflect the actual type. With both fixes, the engine-cache hit + fast-refit path now covers scalar constants on TRT-RTX without falling back to GraphModule.forward; the formerly-skipped refit tests pass.
tp5uiuc
commented
May 6, 2026
| weight_name_map: dict[str, Any] = {} | ||
| weight_refit_map = self.ctx.weight_refit_map | ||
| constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} | ||
| constant_mapping = {k: v for k, v in weight_refit_map.items() if v.numel() == 1} |
Contributor
Author
There was a problem hiding this comment.
This path was exercised with
tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_dynamo_compile_with_custom_engine_cachetests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_torch_compile_with_default_disk_engine_cache(XFAILED)
in the RTX path only. In the earlier diffTensor.sizeis a method (and it was not called), so the comparison was always False andconstant_mappingis always empty. The current PR fixes this to allow CI to pass.
tp5uiuc
commented
May 6, 2026
Comment on lines
+180
to
+183
| weight_dtype = val.dtype | ||
| val_tensor = val.cuda() | ||
| trt_dtype = dtype._from(weight_dtype).to(trt.DataType) | ||
| torch_dtype = dtype._from(weight_dtype).to(torch.dtype) |
Contributor
Author
There was a problem hiding this comment.
As a result of fix #1 in https://github.com/pytorch/TensorRT/pull/4192/changes#r3192971184 and from previous commit (0273726) all constant_mapping items are torch tensors and not np arrays.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Two related cleanups that together let the previously-skipped RTX refit and engine cache tests run.
Test changes: #4181 removed the RTX-specific batch-norm workaround that bypassed constant folding, so the original refit bug (#3752) signature — "eps as a separate non-refittable layer" — is gone. Removing the
@unittest.skipIf(tensorrt_rtx, ...)decorators on those tests, however, exposed a different latent bug along the fast-refit path that prevented the tests from actually exercising refit (they fell back to GraphModule.forward).Library fix: in #3573 (June 2025)
ctx.weight_refit_mapswitched fromnp.ndarraytotorch.Tensor, but two consumer call sites kept the oldnp.ndarrayAPI. Both went unnoticed until refit on RTX surfaced them:_TRTInterpreter._construct_refit_mappingfiltered scalar constants withv.size == 1.Tensor.sizeis a method, so the comparison was always False andconstant_mappingwas always empty — scalar constants like batch-normepsnever reached the cachedweight_name_map["constant_mapping"]. Standard TRT happened to mask this becauserefitter.get_missing_weights()does not list these constants; on TRT-RTX, the stricterunset_weightscheck (fix: detect incomplete fast refit on TRT-RTX via unset weights check #4198) flagged all of them._refit_single_trt_engine_with_gmrehydrated those values viatorch.from_numpy(val).cuda(), which raisesTypeError: expected np.ndarray (got Tensor). This was hidden behind (1) — onceconstant_mappingactually had entries, the TypeError surfaced.Both fixes are minimal and consistent with the post-#3573
torch.Tensorstorage contract.Fixes #3752
Changes
Library
_TRTInterpreter.py:if v.size == 1→if v.numel() == 1so scalar constants are kept inconstant_mapping._refit.py: scalarconstant_mappingrehydration usesval.cuda()directly instead oftorch.from_numpy(val).cuda(). Local renamednp_weight_type→weight_dtypeto reflect the actual type.Tests
test_dynamo_compile_with_refittable_weight_stripped_engine(test_weight_stripped_engine.py)test_dynamo_compile_with_custom_engine_cacheandtest_dynamo_compile_change_input_shape(test_engine_cache.py)test_caching_small_model— this test fails a timing assertion (cached compilation is slower than uncached on TRT-RTX). Updated the skip message to reflect the actual reason rather than referencing 🐛 [Bug] TensorRT-RTX Refitter test failed when constant fold is disabled #3752.test_weight_stripped_engine.py:import tensorrt as trtmust come afterimport torch_tensorrtso thetensorrt_rtxmodule alias is resolved. Added# isort: skipto prevent automated reordering.Verification (TRT-RTX, A100, nightly torch_tensorrt_rtx 2.13.0.dev20260505+cu130)
test_torch_compile_with_default_disk_engine_cache(xfail)AssertionError: Fast refit failed on TensorRT-RTX: 20 of 20 engine weight(s) had no entry in weight_name_maptest_dynamo_compile_with_custom_engine_cacheTypeError: expected np.ndarray (got Tensor)at_refit.py:181test_dynamo_compile_change_input_shapetest_caching_small_modelType of change
Checklist: