Skip to content

fix(test): enable TRT-RTX refit and engine cache tests#4192

Open
tp5uiuc wants to merge 4 commits intopytorch:mainfrom
tp5uiuc:fix/enable-rtx-refit-tests
Open

fix(test): enable TRT-RTX refit and engine cache tests#4192
tp5uiuc wants to merge 4 commits intopytorch:mainfrom
tp5uiuc:fix/enable-rtx-refit-tests

Conversation

@tp5uiuc
Copy link
Copy Markdown
Contributor

@tp5uiuc tp5uiuc commented Apr 16, 2026

Description

Two related cleanups that together let the previously-skipped RTX refit and engine cache tests run.

Test changes: #4181 removed the RTX-specific batch-norm workaround that bypassed constant folding, so the original refit bug (#3752) signature — "eps as a separate non-refittable layer" — is gone. Removing the @unittest.skipIf(tensorrt_rtx, ...) decorators on those tests, however, exposed a different latent bug along the fast-refit path that prevented the tests from actually exercising refit (they fell back to GraphModule.forward).

Library fix: in #3573 (June 2025) ctx.weight_refit_map switched from np.ndarray to torch.Tensor, but two consumer call sites kept the old np.ndarray API. Both went unnoticed until refit on RTX surfaced them:

  1. _TRTInterpreter._construct_refit_mapping filtered scalar constants with v.size == 1. Tensor.size is a method, so the comparison was always False and constant_mapping was always empty — scalar constants like batch-norm eps never reached the cached weight_name_map["constant_mapping"]. Standard TRT happened to mask this because refitter.get_missing_weights() does not list these constants; on TRT-RTX, the stricter unset_weights check (fix: detect incomplete fast refit on TRT-RTX via unset weights check #4198) flagged all of them.
  2. _refit_single_trt_engine_with_gm rehydrated those values via torch.from_numpy(val).cuda(), which raises TypeError: expected np.ndarray (got Tensor). This was hidden behind (1) — once constant_mapping actually had entries, the TypeError surfaced.

Both fixes are minimal and consistent with the post-#3573 torch.Tensor storage contract.

Fixes #3752

Changes

Library

  • _TRTInterpreter.py: if v.size == 1if v.numel() == 1 so scalar constants are kept in constant_mapping.
  • _refit.py: scalar constant_mapping rehydration uses val.cuda() directly instead of torch.from_numpy(val).cuda(). Local renamed np_weight_typeweight_dtype to reflect the actual type.

Tests

  • Removed RTX skip from test_dynamo_compile_with_refittable_weight_stripped_engine (test_weight_stripped_engine.py)
  • Removed RTX skip from test_dynamo_compile_with_custom_engine_cache and test_dynamo_compile_change_input_shape (test_engine_cache.py)
  • Kept RTX skip on test_caching_small_model — this test fails a timing assertion (cached compilation is slower than uncached on TRT-RTX). Updated the skip message to reflect the actual reason rather than referencing 🐛 [Bug] TensorRT-RTX Refitter test failed when constant fold is disabled #3752.
  • Fixed import ordering in test_weight_stripped_engine.py: import tensorrt as trt must come after import torch_tensorrt so the tensorrt_rtx module alias is resolved. Added # isort: skip to prevent automated reordering.

Verification (TRT-RTX, A100, nightly torch_tensorrt_rtx 2.13.0.dev20260505+cu130)

Test Before fix After fix
test_torch_compile_with_default_disk_engine_cache (xfail) XPASS, but captured log shows AssertionError: Fast refit failed on TensorRT-RTX: 20 of 20 engine weight(s) had no entry in weight_name_map XPASS, no fast-refit assertion, no "is not found in weight mapping" warnings
test_dynamo_compile_with_custom_engine_cache FAILED — TypeError: expected np.ndarray (got Tensor) at _refit.py:181 PASSED
test_dynamo_compile_change_input_shape PASSED PASSED
test_caching_small_model SKIPPED (timing assertion) SKIPPED (timing assertion)

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla Bot added the cla signed label Apr 16, 2026
@github-actions github-actions Bot added the component: tests Issues re: Tests label Apr 16, 2026
@github-actions github-actions Bot requested a review from lanluo-nvidia April 16, 2026 07:03
Comment thread tests/py/dynamo/models/test_engine_cache.py
@tp5uiuc tp5uiuc force-pushed the fix/enable-rtx-refit-tests branch from d39be11 to fe174e7 Compare April 16, 2026 17:45
@tp5uiuc tp5uiuc marked this pull request as ready for review April 16, 2026 17:45
@tp5uiuc tp5uiuc force-pushed the fix/enable-rtx-refit-tests branch from f3e7ccd to 048d538 Compare April 23, 2026 18:11
Copy link
Copy Markdown
Collaborator

@cehongwang cehongwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread tests/py/dynamo/models/test_weight_stripped_engine.py Outdated
@tp5uiuc tp5uiuc force-pushed the fix/enable-rtx-refit-tests branch from a5fd05c to c6497b4 Compare April 29, 2026 06:36
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed this is fixed in 1.5. TRT-RTX doesn't cache refit graphs/kernels, so upon refitting, we are essentially recompiling kernels. Refit kernels are now generated AoT with v1.5, so caching behavior can be restored.

Copy link
Copy Markdown
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. waiting for CI

tp5uiuc added 3 commits May 5, 2026 12:38
Now that pytorch#4181 removed the RTX-specific batch norm workaround that
bypassed constant folding, the refit bug (pytorch#3752) is resolved — eps
constants are no longer created as separate CONSTANT layers on RTX.

Remove the RTX skip decorators from:
- test_dynamo_compile_with_refittable_weight_stripped_engine
- test_dynamo_compile_with_custom_engine_cache
- test_dynamo_compile_change_input_shape

Keep the RTX skip on test_caching_small_model, which fails a timing
assertion (cached compilation is slower than uncached on RTX). Update
the skip message to reflect the actual reason.

Fix import ordering in test_weight_stripped_engine.py: tensorrt must
be imported after torch_tensorrt so the tensorrt_rtx module alias is
resolved correctly.

Fixes pytorch#3752
Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
@tp5uiuc tp5uiuc force-pushed the fix/enable-rtx-refit-tests branch from c6497b4 to ce60a5d Compare May 5, 2026 19:38
The fast-refit path on TensorRT-RTX was failing with
"Fast refit failed on TensorRT-RTX: N of N engine weight(s) had no
entry in weight_name_map" for any model containing scalar constants
(e.g. batch-norm eps), because `weight_refit_map` values are
torch.Tensor (since pytorch#3573) but two consumer call sites still used
the old np.ndarray API:

* _TRTInterpreter._construct_refit_mapping filtered scalars with
  `v.size == 1`. `Tensor.size` is a method, so the comparison was
  always False and `constant_mapping` was always empty -- scalar
  constants never reached the cached `weight_name_map["constant_mapping"]`.
  Fixed by switching to `v.numel() == 1`.

* _refit_single_trt_engine_with_gm rehydrated those values via
  `torch.from_numpy(val).cuda()`, which raises TypeError on a
  Tensor. Fixed by using `val.cuda()` directly and renaming the
  local from `np_weight_type` to `weight_dtype` to reflect the
  actual type.

With both fixes, the engine-cache hit + fast-refit path now
covers scalar constants on TRT-RTX without falling back to
GraphModule.forward; the formerly-skipped refit tests pass.
@github-actions github-actions Bot added component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 6, 2026
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.numel() == 1}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path was exercised with

  • tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_dynamo_compile_with_custom_engine_cache
  • tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_torch_compile_with_default_disk_engine_cache (XFAILED)
    in the RTX path only. In the earlier diff Tensor.size is a method (and it was not called), so the comparison was always False and constant_mapping is always empty. The current PR fixes this to allow CI to pass.

Comment on lines +180 to +183
weight_dtype = val.dtype
val_tensor = val.cuda()
trt_dtype = dtype._from(weight_dtype).to(trt.DataType)
torch_dtype = dtype._from(weight_dtype).to(torch.dtype)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a result of fix #1 in https://github.com/pytorch/TensorRT/pull/4192/changes#r3192971184 and from previous commit (0273726) all constant_mapping items are torch tensors and not np arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug] TensorRT-RTX Refitter test failed when constant fold is disabled

3 participants