Skip to content

Commit 0bef55b

Browse files
tests: Adding additional test cases for the unowned tensor feature (#3993)
Co-authored-by: cehongwang <wangcehong@gmail.com>
1 parent 99660e6 commit 0bef55b

File tree

2 files changed

+243
-7
lines changed

2 files changed

+243
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,10 +1084,11 @@ def preserve_module_specs(
10841084

10851085
output_node = list(partitioned_module.graph.nodes)[-1]
10861086
for arg in output_node.args:
1087-
target = arg[0].target
1088-
if "_run_on_acc" not in str(target):
1089-
continue
1090-
getattr(partitioned_module, target).set_output_tensors_as_unowned(True)
1087+
for output in arg:
1088+
target = output.target
1089+
if "_run_on_acc" not in str(target):
1090+
continue
1091+
getattr(partitioned_module, target).set_output_tensors_as_unowned(True)
10911092

10921093
# Reset settings object to user specification after fallback to global partitioning mode
10931094
if fast_partitioner_failed:

tests/py/dynamo/runtime/test_pre_allocated_outputs.py

Lines changed: 238 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def forward(self, x):
125125
)
126126
torch._dynamo.reset()
127127

128-
def test_pre_allocated_outputs_unowned_outputs(self):
128+
def test_pre_allocated_outputs_unowned_outputs_py_api_check_no_realloc(self):
129129
class SampleModel(torch.nn.Module):
130130
def forward(self, x):
131131
return torch.softmax(x * 7 + 2, dim=0)
@@ -146,21 +146,256 @@ def forward(self, x):
146146
)
147147

148148
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
149-
optimized_model(inputs[0])
149+
_ = optimized_model(inputs[0])
150150
output_tensors = [
151151
trt_mod.pre_allocated_outputs
152152
for name, trt_mod in optimized_model.named_children()
153153
if "_run_on_acc" in name
154154
]
155-
optimized_model(inputs[0])
155+
_ = optimized_model(inputs[0])
156156
new_output_tensors = [
157157
trt_mod.pre_allocated_outputs
158158
for name, trt_mod in optimized_model.named_children()
159159
if "_run_on_acc" in name
160160
]
161+
162+
# Run to run, output of intermediate engine is not reallocated
161163
self.assertTrue(output_tensors[0] is new_output_tensors[0])
164+
# Run to run, output of output engine is reallocated
162165
self.assertTrue(output_tensors[1] is not new_output_tensors[1])
163166

167+
@parameterized.expand(
168+
[
169+
("python_runtime", True),
170+
("cpp_runtime", False),
171+
]
172+
)
173+
def test_pre_allocated_outputs_unowned_outputs_api_check(
174+
self, _, use_python_runtime
175+
):
176+
class SampleModel(torch.nn.Module):
177+
def forward(self, x):
178+
return torch.softmax(x * 7 + 2, dim=0)
179+
180+
model = SampleModel().eval().cuda()
181+
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)]
182+
fx_graph = torch.fx.symbolic_trace(model)
183+
184+
# Validate that the results between Torch and Torch-TRT are similar
185+
optimized_model = torchtrt.compile(
186+
fx_graph,
187+
"dynamo",
188+
inputs[0],
189+
min_block_size=1,
190+
pass_through_build_failures=True,
191+
use_python_runtime=use_python_runtime,
192+
torch_executed_ops={torch.ops.aten.add.Tensor},
193+
)
194+
195+
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
196+
_ = optimized_model(inputs[0])
197+
self.assertTrue(
198+
all(
199+
seen == expected
200+
for seen, expected in zip(
201+
[
202+
optimized_model._run_on_acc_0.are_output_tensors_unowned(),
203+
optimized_model._run_on_acc_2.are_output_tensors_unowned(),
204+
],
205+
[False, True],
206+
)
207+
)
208+
)
209+
210+
@parameterized.expand(
211+
[
212+
("python_runtime", True),
213+
("cpp_runtime", False),
214+
]
215+
)
216+
def test_pre_allocated_outputs_unowned_outputs(self, _, use_python_runtime):
217+
class SampleModel(torch.nn.Module):
218+
def forward(self, x):
219+
return torch.softmax(x * 7 + 2, dim=0)
220+
221+
model = SampleModel().eval().cuda()
222+
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)]
223+
fx_graph = torch.fx.symbolic_trace(model)
224+
225+
# Validate that the results between Torch and Torch-TRT are similar
226+
optimized_model = torchtrt.compile(
227+
fx_graph,
228+
"dynamo",
229+
inputs[0],
230+
min_block_size=1,
231+
pass_through_build_failures=True,
232+
use_python_runtime=use_python_runtime,
233+
torch_executed_ops={torch.ops.aten.add.Tensor},
234+
)
235+
236+
torch_res = model(inputs[0])
237+
238+
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
239+
res_1 = optimized_model(inputs[0])
240+
res_2 = optimized_model(inputs[0])
241+
242+
# Results are correct
243+
torch.testing.assert_close(
244+
torch_res,
245+
res_1,
246+
rtol=5e-03,
247+
atol=5e-03,
248+
equal_nan=True,
249+
check_dtype=True,
250+
)
251+
252+
# Results between runs are identical
253+
torch.testing.assert_close(
254+
res_1,
255+
res_2,
256+
rtol=5e-03,
257+
atol=5e-03,
258+
equal_nan=True,
259+
check_dtype=True,
260+
)
261+
262+
torch._dynamo.reset()
263+
264+
def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_py_api_check_no_realloc(
265+
self,
266+
):
267+
class SampleModel(torch.nn.Module):
268+
def forward(self, x):
269+
y = torch.ops.aten.mul(x, 7)
270+
z = torch.ops.aten.add(y, 2)
271+
a = torch.ops.aten.softmax(z, dim=0)
272+
return y, z, a
273+
274+
model = SampleModel().eval().cuda()
275+
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)]
276+
fx_graph = torch.fx.symbolic_trace(model)
277+
278+
# Validate that the results between Torch and Torch-TRT are similar
279+
optimized_model = torchtrt.compile(
280+
fx_graph,
281+
"dynamo",
282+
inputs[0],
283+
min_block_size=1,
284+
pass_through_build_failures=True,
285+
use_python_runtime=True,
286+
torch_executed_ops={torch.ops.aten.add.Tensor},
287+
)
288+
289+
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
290+
res1 = optimized_model(inputs[0])
291+
output_tensors = [
292+
[t.data_ptr() for t in trt_mod.pre_allocated_outputs]
293+
for name, trt_mod in optimized_model.named_children()
294+
if "_run_on_acc" in name
295+
]
296+
297+
_ = optimized_model(inputs[0])
298+
new_output_tensors = [
299+
[t.data_ptr() for t in trt_mod.pre_allocated_outputs]
300+
for name, trt_mod in optimized_model.named_children()
301+
if "_run_on_acc" in name
302+
]
303+
304+
# Run to run, output of intermediate engine is reallocated
305+
self.assertTrue(output_tensors[0] != new_output_tensors[0])
306+
# Run to run, output of output engine is reallocated
307+
self.assertTrue(output_tensors[1] != new_output_tensors[1])
308+
309+
@parameterized.expand(
310+
[
311+
("python_runtime", True),
312+
("cpp_runtime", False),
313+
]
314+
)
315+
def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_api_check(
316+
self, _, use_python_runtime
317+
):
318+
class SampleModel(torch.nn.Module):
319+
def forward(self, x):
320+
y = torch.ops.aten.mul(x, 7)
321+
z = torch.ops.aten.add(y, 2)
322+
a = torch.ops.aten.softmax(z, dim=0)
323+
return y, z, a
324+
325+
model = SampleModel().eval().cuda()
326+
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)]
327+
fx_graph = torch.fx.symbolic_trace(model)
328+
329+
# Validate that the results between Torch and Torch-TRT are similar
330+
optimized_model = torchtrt.compile(
331+
fx_graph,
332+
"dynamo",
333+
inputs[0],
334+
min_block_size=1,
335+
pass_through_build_failures=True,
336+
use_python_runtime=use_python_runtime,
337+
torch_executed_ops={torch.ops.aten.add.Tensor},
338+
)
339+
340+
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
341+
_ = optimized_model(inputs[0])
342+
self.assertTrue(
343+
all(
344+
seen == expected
345+
for seen, expected in zip(
346+
[
347+
optimized_model._run_on_acc_0.are_output_tensors_unowned(),
348+
optimized_model._run_on_acc_2.are_output_tensors_unowned(),
349+
],
350+
[True, True],
351+
)
352+
)
353+
)
354+
355+
@parameterized.expand(
356+
[
357+
("python_runtime", True),
358+
("cpp_runtime", False),
359+
]
360+
)
361+
def test_pre_allocated_outputs_unowned_outputs_multi_outputs(
362+
self, _, use_python_runtime
363+
):
364+
class SampleModel(torch.nn.Module):
365+
def forward(self, x):
366+
y = torch.ops.aten.mul(x, 7)
367+
z = torch.ops.aten.add(y, 2)
368+
a = torch.ops.aten.softmax(z, dim=0)
369+
return y, z, a
370+
371+
model = SampleModel().eval().cuda()
372+
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)]
373+
fx_graph = torch.fx.symbolic_trace(model)
374+
375+
# Validate that the results between Torch and Torch-TRT are similar
376+
optimized_model = torchtrt.compile(
377+
fx_graph,
378+
"dynamo",
379+
inputs[0],
380+
min_block_size=1,
381+
pass_through_build_failures=True,
382+
use_python_runtime=use_python_runtime,
383+
torch_executed_ops={torch.ops.aten.add.Tensor},
384+
)
385+
386+
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
387+
res_1 = optimized_model(inputs[0])
388+
res_2 = optimized_model(inputs[0])
389+
390+
torch.testing.assert_close(
391+
res_1,
392+
res_2,
393+
rtol=5e-03,
394+
atol=5e-03,
395+
equal_nan=True,
396+
check_dtype=True,
397+
)
398+
164399
torch._dynamo.reset()
165400

166401

0 commit comments

Comments
 (0)