Skip to content

Commit 6faa897

Browse files
committed
fix(refit): use torch.Tensor APIs in scalar constant_mapping path
The fast-refit path on TensorRT-RTX was failing with "Fast refit failed on TensorRT-RTX: N of N engine weight(s) had no entry in weight_name_map" for any model containing scalar constants (e.g. batch-norm eps), because `weight_refit_map` values are torch.Tensor (since #3573) but two consumer call sites still used the old np.ndarray API: * _TRTInterpreter._construct_refit_mapping filtered scalars with `v.size == 1`. `Tensor.size` is a method, so the comparison was always False and `constant_mapping` was always empty -- scalar constants never reached the cached `weight_name_map["constant_mapping"]`. Fixed by switching to `v.numel() == 1`. * _refit_single_trt_engine_with_gm rehydrated those values via `torch.from_numpy(val).cuda()`, which raises TypeError on a Tensor. Fixed by using `val.cuda()` directly and renaming the local from `np_weight_type` to `weight_dtype` to reflect the actual type. With both fixes, the engine-cache hit + fast-refit path now covers scalar constants on TRT-RTX without falling back to GraphModule.forward; the formerly-skipped refit tests pass.
1 parent ce60a5d commit 6faa897

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ def _refit_single_trt_engine_with_gm(
177177
constant_mapping_with_type = {}
178178

179179
for constant_name, val in constant_mapping.items():
180-
np_weight_type = val.dtype
181-
val_tensor = torch.from_numpy(val).cuda()
182-
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
183-
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
180+
weight_dtype = val.dtype
181+
val_tensor = val.cuda()
182+
trt_dtype = dtype._from(weight_dtype).to(trt.DataType)
183+
torch_dtype = dtype._from(weight_dtype).to(torch.dtype)
184184
constant_mapping_with_type[constant_name] = (
185185
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
186186
trt_dtype,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _save_weight_mapping(self) -> None:
486486
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
487487
weight_name_map: dict[str, Any] = {}
488488
weight_refit_map = self.ctx.weight_refit_map
489-
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
489+
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.numel() == 1}
490490
net = self.ctx.net
491491
for i in range(net.num_layers):
492492
layer = net[i]

0 commit comments

Comments
 (0)