Skip to content

Commit 3b6cdb3

Browse files
committed
Fix tensor lifetime issue
1 parent 2233edb commit 3b6cdb3

2 files changed

Lines changed: 149 additions & 3 deletions

File tree

core/runtime/execute_engine.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
9292
return false;
9393
}
9494

95-
void setup_input_tensors(
95+
// Returns the contiguified input tensors. The caller must keep the returned
96+
// list alive until enqueueV3 completes: the engine is bound to .data_ptr()s
97+
// inside it, and freshly-allocated contig copies (from .contiguous() on
98+
// non-contig inputs) would otherwise be freed and their CUDA addresses
99+
// recycled by the caching allocator for output tensors, aliasing inputs
100+
// onto outputs and corrupting reads after the first output write.
101+
std::list<at::Tensor> setup_input_tensors(
96102
std::vector<at::Tensor> inputs,
97103
c10::intrusive_ptr<TRTEngine> compiled_engine,
98104
bool cudagraphs_enabled,
@@ -169,6 +175,7 @@ void setup_input_tensors(
169175
"Failed to bind tensor address for " << name);
170176
}
171177
}
178+
return formatted_inputs;
172179
}
173180

174181
std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
@@ -255,6 +262,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
255262
// Shape tensor CPU buffers must outlive inferShapes() and enqueueV3()
256263
std::list<std::vector<int64_t>> inputShapeTensorValues;
257264

265+
// Contiguified input copies must outlive enqueueV3() to prevent input/output
266+
// buffer aliasing via CUDA caching-allocator address reuse (see
267+
// setup_input_tensors comment).
268+
std::list<at::Tensor> formatted_inputs;
269+
258270
// Intialize inputs and outputs to be available throughout the succeeding scopes
259271
{ // Input Setup
260272
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
@@ -263,7 +275,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
263275
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
264276
}
265277

266-
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues);
278+
formatted_inputs =
279+
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues);
267280
// Check if input shapes can be inferred.
268281
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
269282
std::vector<char const*> names(io_size);
@@ -389,14 +402,19 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
389402
// Shape tensor CPU buffers must outlive inferShapes() and enqueueV3()
390403
std::list<std::vector<int64_t>> inputShapeTensorValues;
391404

405+
// Contiguified input copies must outlive enqueueV3() to prevent input/output
406+
// buffer aliasing via CUDA caching-allocator address reuse (see
407+
// setup_input_tensors comment).
408+
std::list<at::Tensor> formatted_inputs;
409+
392410
{ // Input Setup
393411
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
394412
if (compiled_engine->profile_execution) {
395413
input_profiler_guard =
396414
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
397415
}
398416

399-
setup_input_tensors(inputs, compiled_engine, false, false, inputShapeTensorValues);
417+
formatted_inputs = setup_input_tensors(inputs, compiled_engine, false, false, inputShapeTensorValues);
400418
// Check if input shapes can be inferred.
401419
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
402420
std::vector<char const*> names(io_size);
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Regression test for the C++ runtime use-after-free of contiguified
2+
input copies (core/runtime/execute_engine.cpp).
3+
4+
Bug: setup_input_tensors stashed `.contiguous()` copies in a function-
5+
local std::list. After return, those copies were freed; the CUDA caching
6+
allocator recycled their addresses for the engine's output buffer when
7+
shape+dtype matched. TRT's input bindings then aliased onto outputs,
8+
corrupting reads-after-writes inside the engine.
9+
10+
Fix: hoist `formatted_inputs` to a scope that outlives `enqueueV3`.
11+
12+
Trigger conditions reproduced here:
13+
- Input is non-contiguous → setup_input_tensors calls .contiguous()
14+
and allocates a fresh CUDA buffer.
15+
- Output shape+dtype matches the input → caching allocator can
16+
recycle the freed input buffer for the output.
17+
- Model has a residual that re-reads the input at the very end →
18+
intermediate scratch writes between the early and late reads of x
19+
can corrupt x via the aliased buffer.
20+
"""
21+
22+
import torch
23+
import torch.nn as nn
24+
import torch_tensorrt as torchtrt
25+
from parameterized import parameterized
26+
from torch.testing._internal.common_utils import TestCase, run_tests
27+
28+
29+
B, S, H = 1, 4096, 128
30+
NUM_BLOCKS = 8
31+
NUM_TRIALS = 3
32+
33+
34+
class _ResidualBlock(nn.Module):
35+
def __init__(self, d):
36+
super().__init__()
37+
self.norm = nn.LayerNorm(d)
38+
self.lin1 = nn.Linear(d, 4 * d, bias=False)
39+
self.lin2 = nn.Linear(4 * d, d, bias=False)
40+
41+
def forward(self, x):
42+
return x + self.lin2(torch.nn.functional.silu(self.lin1(self.norm(x))))
43+
44+
45+
class _Model(nn.Module):
46+
"""Small residual stack. The trailing `+ x` forces TRT's engine to
47+
keep the input alive across the entire computation; the residual
48+
blocks add intermediate scratch writes between early and late reads.
49+
"""
50+
51+
def __init__(self):
52+
super().__init__()
53+
self.blocks = nn.ModuleList([_ResidualBlock(H) for _ in range(NUM_BLOCKS)])
54+
55+
def forward(self, x):
56+
h = x
57+
for b in self.blocks:
58+
h = b(h)
59+
return h + x
60+
61+
62+
def _make_noncontig_input(seed=0):
63+
g = torch.Generator(device="cuda").manual_seed(seed)
64+
base = torch.randn(B, H, S, device="cuda", dtype=torch.bfloat16, generator=g)
65+
x = base.transpose(1, 2)
66+
assert tuple(x.shape) == (B, S, H)
67+
assert not x.is_contiguous()
68+
return x
69+
70+
71+
class TestInputLifetime(TestCase):
72+
@parameterized.expand(
73+
[
74+
("cpp_runtime", False),
75+
("python_runtime", True),
76+
]
77+
)
78+
def test_noncontig_input_matching_output_shape(self, _name, use_python_runtime):
79+
torch.manual_seed(0)
80+
model = _Model().to(device="cuda", dtype=torch.bfloat16).eval()
81+
x = _make_noncontig_input(seed=1)
82+
83+
with torch.inference_mode():
84+
eager_out = model(x)
85+
86+
torch._dynamo.reset()
87+
compiled = torch.compile(
88+
model,
89+
backend="tensorrt",
90+
fullgraph=False,
91+
options={
92+
"truncate_double": True,
93+
"enabled_precisions": {torch.float, torch.half, torch.bfloat16},
94+
"min_block_size": 1,
95+
"optimization_level": 1,
96+
"enable_resource_partitioning": True,
97+
"use_python_runtime": use_python_runtime,
98+
},
99+
)
100+
101+
with torch.inference_mode():
102+
compiled(x) # compile / warmup
103+
for trial in range(NUM_TRIALS):
104+
trt_out = compiled(x)
105+
diff = (eager_out.float() - trt_out.float()).abs()
106+
mean = diff.mean().item()
107+
mx = diff.max().item()
108+
frac_gt_1 = (diff > 1.0).float().mean().item()
109+
# bf16 vs fp32 numerical drift on this small model is well
110+
# under these bounds; the pre-fix divergence was mean>1.5,
111+
# max>10, >1.0 over 60% of values. Tight bounds keep the
112+
# test sensitive while leaving room for legitimate bf16 noise.
113+
self.assertLess(
114+
mean,
115+
0.05,
116+
f"trial {trial}: mean_abs_diff {mean:.4f} suggests input/output aliasing "
117+
f"(use-after-free of contiguified input copy in execute_engine.cpp)",
118+
)
119+
self.assertLess(mx, 5.0, f"trial {trial}: max_abs_diff {mx:.4f}")
120+
self.assertLess(
121+
frac_gt_1,
122+
0.005,
123+
f"trial {trial}: {100*frac_gt_1:.2f}% of outputs differ by >1.0",
124+
)
125+
126+
127+
if __name__ == "__main__":
128+
run_tests()

0 commit comments

Comments
 (0)