|
| 1 | +"""Regression test for the C++ runtime use-after-free of contiguified |
| 2 | +input copies (core/runtime/execute_engine.cpp). |
| 3 | +
|
| 4 | +Bug: setup_input_tensors stashed `.contiguous()` copies in a function- |
| 5 | +local std::list. After return, those copies were freed; the CUDA caching |
| 6 | +allocator recycled their addresses for the engine's output buffer when |
| 7 | +shape+dtype matched. TRT's input bindings then aliased onto outputs, |
| 8 | +corrupting reads-after-writes inside the engine. |
| 9 | +
|
| 10 | +Fix: hoist `formatted_inputs` to a scope that outlives `enqueueV3`. |
| 11 | +
|
| 12 | +Trigger conditions reproduced here: |
| 13 | + - Input is non-contiguous → setup_input_tensors calls .contiguous() |
| 14 | + and allocates a fresh CUDA buffer. |
| 15 | + - Output shape+dtype matches the input → caching allocator can |
| 16 | + recycle the freed input buffer for the output. |
| 17 | + - Model has a residual that re-reads the input at the very end → |
| 18 | + intermediate scratch writes between the early and late reads of x |
| 19 | + can corrupt x via the aliased buffer. |
| 20 | +""" |
| 21 | + |
| 22 | +import torch |
| 23 | +import torch.nn as nn |
| 24 | +import torch_tensorrt as torchtrt |
| 25 | +from parameterized import parameterized |
| 26 | +from torch.testing._internal.common_utils import TestCase, run_tests |
| 27 | + |
| 28 | + |
| 29 | +B, S, H = 1, 4096, 128 |
| 30 | +NUM_BLOCKS = 8 |
| 31 | +NUM_TRIALS = 3 |
| 32 | + |
| 33 | + |
| 34 | +class _ResidualBlock(nn.Module): |
| 35 | + def __init__(self, d): |
| 36 | + super().__init__() |
| 37 | + self.norm = nn.LayerNorm(d) |
| 38 | + self.lin1 = nn.Linear(d, 4 * d, bias=False) |
| 39 | + self.lin2 = nn.Linear(4 * d, d, bias=False) |
| 40 | + |
| 41 | + def forward(self, x): |
| 42 | + return x + self.lin2(torch.nn.functional.silu(self.lin1(self.norm(x)))) |
| 43 | + |
| 44 | + |
| 45 | +class _Model(nn.Module): |
| 46 | + """Small residual stack. The trailing `+ x` forces TRT's engine to |
| 47 | + keep the input alive across the entire computation; the residual |
| 48 | + blocks add intermediate scratch writes between early and late reads. |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self): |
| 52 | + super().__init__() |
| 53 | + self.blocks = nn.ModuleList([_ResidualBlock(H) for _ in range(NUM_BLOCKS)]) |
| 54 | + |
| 55 | + def forward(self, x): |
| 56 | + h = x |
| 57 | + for b in self.blocks: |
| 58 | + h = b(h) |
| 59 | + return h + x |
| 60 | + |
| 61 | + |
| 62 | +def _make_noncontig_input(seed=0): |
| 63 | + g = torch.Generator(device="cuda").manual_seed(seed) |
| 64 | + base = torch.randn(B, H, S, device="cuda", dtype=torch.bfloat16, generator=g) |
| 65 | + x = base.transpose(1, 2) |
| 66 | + assert tuple(x.shape) == (B, S, H) |
| 67 | + assert not x.is_contiguous() |
| 68 | + return x |
| 69 | + |
| 70 | + |
| 71 | +class TestInputLifetime(TestCase): |
| 72 | + @parameterized.expand( |
| 73 | + [ |
| 74 | + ("cpp_runtime", False), |
| 75 | + ("python_runtime", True), |
| 76 | + ] |
| 77 | + ) |
| 78 | + def test_noncontig_input_matching_output_shape(self, _name, use_python_runtime): |
| 79 | + torch.manual_seed(0) |
| 80 | + model = _Model().to(device="cuda", dtype=torch.bfloat16).eval() |
| 81 | + x = _make_noncontig_input(seed=1) |
| 82 | + |
| 83 | + with torch.inference_mode(): |
| 84 | + eager_out = model(x) |
| 85 | + |
| 86 | + torch._dynamo.reset() |
| 87 | + compiled = torch.compile( |
| 88 | + model, |
| 89 | + backend="tensorrt", |
| 90 | + fullgraph=False, |
| 91 | + options={ |
| 92 | + "truncate_double": True, |
| 93 | + "enabled_precisions": {torch.float, torch.half, torch.bfloat16}, |
| 94 | + "min_block_size": 1, |
| 95 | + "optimization_level": 1, |
| 96 | + "enable_resource_partitioning": True, |
| 97 | + "use_python_runtime": use_python_runtime, |
| 98 | + }, |
| 99 | + ) |
| 100 | + |
| 101 | + with torch.inference_mode(): |
| 102 | + compiled(x) # compile / warmup |
| 103 | + for trial in range(NUM_TRIALS): |
| 104 | + trt_out = compiled(x) |
| 105 | + diff = (eager_out.float() - trt_out.float()).abs() |
| 106 | + mean = diff.mean().item() |
| 107 | + mx = diff.max().item() |
| 108 | + frac_gt_1 = (diff > 1.0).float().mean().item() |
| 109 | + # bf16 vs fp32 numerical drift on this small model is well |
| 110 | + # under these bounds; the pre-fix divergence was mean>1.5, |
| 111 | + # max>10, >1.0 over 60% of values. Tight bounds keep the |
| 112 | + # test sensitive while leaving room for legitimate bf16 noise. |
| 113 | + self.assertLess( |
| 114 | + mean, |
| 115 | + 0.05, |
| 116 | + f"trial {trial}: mean_abs_diff {mean:.4f} suggests input/output aliasing " |
| 117 | + f"(use-after-free of contiguified input copy in execute_engine.cpp)", |
| 118 | + ) |
| 119 | + self.assertLess(mx, 5.0, f"trial {trial}: max_abs_diff {mx:.4f}") |
| 120 | + self.assertLess( |
| 121 | + frac_gt_1, |
| 122 | + 0.005, |
| 123 | + f"trial {trial}: {100*frac_gt_1:.2f}% of outputs differ by >1.0", |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == "__main__": |
| 128 | + run_tests() |
0 commit comments