Commit 6faa897
committed
fix(refit): use torch.Tensor APIs in scalar constant_mapping path
The fast-refit path on TensorRT-RTX was failing with
"Fast refit failed on TensorRT-RTX: N of N engine weight(s) had no
entry in weight_name_map" for any model containing scalar constants
(e.g. batch-norm eps), because `weight_refit_map` values are
torch.Tensor (since #3573) but two consumer call sites still used
the old np.ndarray API:
* _TRTInterpreter._construct_refit_mapping filtered scalars with
`v.size == 1`. `Tensor.size` is a method, so the comparison was
always False and `constant_mapping` was always empty -- scalar
constants never reached the cached `weight_name_map["constant_mapping"]`.
Fixed by switching to `v.numel() == 1`.
* _refit_single_trt_engine_with_gm rehydrated those values via
`torch.from_numpy(val).cuda()`, which raises TypeError on a
Tensor. Fixed by using `val.cuda()` directly and renaming the
local from `np_weight_type` to `weight_dtype` to reflect the
actual type.
With both fixes, the engine-cache hit + fast-refit path now
covers scalar constants on TRT-RTX without falling back to
GraphModule.forward; the formerly-skipped refit tests pass.1 parent ce60a5d commit 6faa897
2 files changed
Lines changed: 5 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
177 | 177 | | |
178 | 178 | | |
179 | 179 | | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
184 | 184 | | |
185 | 185 | | |
186 | 186 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
486 | 486 | | |
487 | 487 | | |
488 | 488 | | |
489 | | - | |
| 489 | + | |
490 | 490 | | |
491 | 491 | | |
492 | 492 | | |
| |||
0 commit comments