Skip to content

Commit 8d7cd50

Browse files
author
Wei
authored
cherry-pick newest change from master branch (#1502)
1 parent 9333ae7 commit 8d7cd50

19 files changed

+451
-203
lines changed

docs/_sources/tutorials/ptq.rst.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f
136136
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
137137
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
138138
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on
139-
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq
139+
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq
140140

141141
.. _writing_ptq_python:
142142

@@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode.
194194
calibrator=calibrator)
195195
196196
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
197-
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
198-
and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py
197+
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py
198+
and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py
199199

200200
Citations
201201
^^^^^^^^^^^

examples/fx/hugging_face_torchdynamo_example.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
)
1616
from transformers import BertConfig, ReformerConfig, XLNetModel, XLNetConfig
1717

18-
import torchdynamo
19-
from torchdynamo.optimizations import backends
20-
from torchdynamo.optimizations.training import aot_autograd_debug_strategy1
21-
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
22-
from torchdynamo.testing import collect_results
23-
from torchdynamo.testing import same
18+
import torch._dynamo as torchdynamo
19+
from torch._dynamo.optimizations import backends
20+
from torch._dynamo.optimizations.training import aot_autograd_debug_strategy1
21+
from torch._dynamo.optimizations.training import aot_autograd_speedup_strategy
22+
from torch._dynamo.testing import collect_results
23+
from torch._dynamo.testing import same
2424

2525
torch.backends.cuda.matmul.allow_tf32 = True
2626

examples/fx/torchdynamo_example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from dataclasses import dataclass, field, replace
44

55
import torch
6-
import torchdynamo
6+
import torch._dynamo as torchdynamo
77
import torchvision
88
from torch_tensorrt.fx.lower import compile
99
from torch_tensorrt.fx.utils import LowerPrecision
10-
from torchdynamo.optimizations import backends
10+
from torch._dynamo.optimizations import backends
1111

1212
"""
1313
The purpose of this example is to demostrate the lowering flow to TRT and Torchdynamo

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
trt_transposed_linear,
2626
trt_transposed_matmul,
2727
)
28+
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
2829

2930
_LOGGER: logging.Logger = logging.getLogger(__name__)
3031

@@ -3371,6 +3372,9 @@ def acc_ops_gelu(
33713372
name: str,
33723373
) -> Union[TRTTensor, Sequence[TRTTensor]]:
33733374
input_val = kwargs["input"]
3375+
approximate = kwargs["approximate"]
3376+
if approximate is not "none":
3377+
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
33743378
if not isinstance(input_val, TRTTensor):
33753379
raise RuntimeError(
33763380
f"GELU received input {input_val} that is not part "

py/torch_tensorrt/fx/lower.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def compile(
3232
module: nn.Module,
3333
input,
34+
min_acc_module_size: int = 10,
3435
max_batch_size: int = 2048,
3536
max_workspace_size=1 << 25,
3637
explicit_batch_dimension=False,
@@ -51,6 +52,7 @@ def compile(
5152
module: Original module for lowering.
5253
input: Input for module.
5354
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
55+
min_acc_module_size: Minimal number of nodes for an accelerated submodule
5456
max_workspace_size: Maximum size of workspace given to TensorRT.
5557
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
5658
lower_precision: lower_precision config given to TRTModule.
@@ -70,6 +72,7 @@ def compile(
7072

7173
lower_setting = LowerSetting(
7274
max_batch_size=max_batch_size,
75+
min_acc_module_size=min_acc_module_size,
7376
max_workspace_size=max_workspace_size,
7477
explicit_batch_dimension=explicit_batch_dimension,
7578
lower_precision=lower_precision,
@@ -268,6 +271,7 @@ def __call__(
268271
module: nn.Module,
269272
inputs: Input,
270273
additional_inputs: Optional[Input] = None,
274+
fp16_conversion_fn: Optional[Callable[[Input], Input]] = None,
271275
) -> nn.Module:
272276
lower_setting = self.lower_pass_manager_builder.lower_setting
273277
atol = lower_setting.correctness_atol
@@ -284,9 +288,26 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
284288
== LowerPrecision.FP16
285289
):
286290
module.half()
287-
inputs = tuple(
288-
x.half() if x is not None and x.dtype == torch.float32 else x
289-
for x in inputs
291+
# A custom conversion function can be passed to the lowerer to
292+
# handle inputs with custom types. By default, just handle
293+
# tensors and NoneType.
294+
if fp16_conversion_fn is None:
295+
conversion_fn = (
296+
lambda x: x.half()
297+
if x is not None and x.dtype == torch.float32
298+
else x
299+
)
300+
else:
301+
conversion_fn = fp16_conversion_fn
302+
303+
inputs = tuple(conversion_fn(x) for x in inputs)
304+
if lower_setting.is_aten:
305+
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
306+
inputs, additional_inputs
307+
)
308+
else:
309+
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
310+
inputs, additional_inputs
290311
)
291312
if lower_setting.is_aten:
292313
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(

py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,51 @@
11
import torch
22
import torch.nn as nn
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
4+
from parameterized import param, parameterized
45
from torch.testing._internal.common_utils import run_tests
56
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
67

78

89
class TestCatConverter(AccTestCase):
9-
def test_cat(self):
10+
@parameterized.expand(
11+
[
12+
param("cat", torch.cat),
13+
param("concat", torch.concat),
14+
]
15+
)
16+
def test_cat(self, _, op):
1017
class Cat(nn.Module):
1118
def forward(self, x, y, z):
12-
return torch.cat((x, y, z), 1)
19+
return op((x, y, z), 1)
1320

1421
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
1522
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
1623

17-
def test_cat_neg(self):
24+
@parameterized.expand(
25+
[
26+
param("cat", torch.cat),
27+
param("concat", torch.concat),
28+
]
29+
)
30+
def test_cat_neg(self, _, op):
1831
class Cat(nn.Module):
1932
def forward(self, x, y, z):
20-
return torch.cat((x, y, z), -1)
33+
return op((x, y, z), -1)
2134

2235
inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)]
2336
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
2437

25-
def test_cat_with_dynamic_shape(self):
38+
@parameterized.expand(
39+
[
40+
param("cat", torch.cat),
41+
param("concat", torch.concat),
42+
]
43+
)
44+
def test_cat_with_dynamic_shape(self, _, op):
2645
class Cat(nn.Module):
2746
def forward(self, x, y):
2847
x = x + y
29-
return torch.cat((x, y), 0)
48+
return op((x, y), 0)
3049

3150
input_specs = [
3251
InputTensorSpec(
@@ -42,11 +61,17 @@ def forward(self, x, y):
4261
]
4362
self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat})
4463

45-
def test_cat_with_dynamic_shape_four_dimensions(self):
64+
@parameterized.expand(
65+
[
66+
param("cat", torch.cat),
67+
param("concat", torch.concat),
68+
]
69+
)
70+
def test_cat_with_dynamic_shape_four_dimensions(self, _, op):
4671
class Cat(nn.Module):
4772
def forward(self, x, y):
4873
x = x + y
49-
return torch.cat((x, y), 0)
74+
return op((x, y), 0)
5075

5176
input_specs = [
5277
InputTensorSpec(
@@ -63,6 +88,14 @@ def forward(self, x, y):
6388

6489
self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat})
6590

91+
def test_concat(self):
92+
class Cat(nn.Module):
93+
def forward(self, x, y, z):
94+
return torch.concat((x, y, z), 1)
95+
96+
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
97+
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
98+
6699

67100
if __name__ == "__main__":
68101
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py

+33
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,39 @@ def forward(self, x):
5757
TestModule(), input_specs, expected_ops={acc_ops.gelu}
5858
)
5959

60+
def test_gelu_module(self):
61+
class TestModule(nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
self.gelu = torch.nn.GELU()
65+
66+
def forward(self, x):
67+
return self.gelu(x)
68+
69+
inputs = [torch.randn(3, 10, 20)]
70+
self.run_test(
71+
TestModule(),
72+
inputs,
73+
expected_ops={acc_ops.gelu},
74+
test_implicit_batch_dim=False,
75+
)
76+
77+
def test_gelu_module_throw(self):
78+
class TestModule(nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
self.gelu = torch.nn.GELU(approximate="tanh")
82+
83+
def forward(self, x):
84+
return self.gelu(x)
85+
86+
inputs = [torch.randn(3, 10, 20)]
87+
self.run_test_with_assert_error(
88+
TestModule(),
89+
inputs,
90+
expect_error=RuntimeError,
91+
)
92+
6093

6194
if __name__ == "__main__":
6295
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py

-30
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,6 @@
66

77

88
class TestNewOnesConverter(AccTestCase):
9-
def test_newone(self):
10-
class TestModule(nn.Module):
11-
def forward(self, x):
12-
return x.new_ones((3, 5), dtype=torch.float16)
13-
14-
inputs = [torch.randn(1, 10)]
15-
self.run_test(
16-
TestModule(),
17-
inputs,
18-
expected_ops={acc_ops.new_ones},
19-
test_implicit_batch_dim=False,
20-
)
21-
229
def test_newone_no_dtype(self):
2310
class TestModule(nn.Module):
2411
def forward(self, x):
@@ -47,23 +34,6 @@ def forward(self, x):
4734

4835

4936
class TestNewOnesConverterWithDynamicShape(AccTestCase):
50-
def test_newone(self):
51-
class TestModule(nn.Module):
52-
def forward(self, x):
53-
return x.new_ones((3, 5), dtype=torch.float16)
54-
55-
input_specs = [
56-
InputTensorSpec(
57-
shape=(-1, -1, -1, -1),
58-
dtype=torch.float32,
59-
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
60-
),
61-
]
62-
63-
self.run_test_with_dynamic_shape(
64-
TestModule(), input_specs, expected_ops={acc_ops.new_ones}
65-
)
66-
6737
def test_newone_no_dtype(self):
6838
class TestModule(nn.Module):
6939
def forward(self, x):

py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py

+42-41
Original file line numberDiff line numberDiff line change
@@ -271,47 +271,48 @@ def forward(self, x):
271271
precision=LowerPrecision.FP16,
272272
)
273273

274-
# tensor.int()
275-
def test_int(self):
276-
class To(torch.nn.Module):
277-
def forward(self, x):
278-
x = x.int()
279-
# we do not expect int to be output type, so add an extra layer
280-
x = x.float()
281-
return x
282-
283-
input = torch.randn(2, 2)
284-
inputs = [
285-
input,
286-
]
287-
self.run_test(
288-
To(),
289-
inputs,
290-
expected_ops={acc_ops.to_dtype},
291-
test_implicit_batch_dim=False,
292-
precision=LowerPrecision.FP32,
293-
)
294-
295-
# tensor.int()
296-
def test_int_with_dynamic_shape_four_dimensions(self):
297-
class To(torch.nn.Module):
298-
def forward(self, x):
299-
x = x.int()
300-
# we do not expect int to be output type, so add an extra layer
301-
x = x.float()
302-
return x
303-
304-
input_specs = [
305-
InputTensorSpec(
306-
shape=(-1, -1, -1, -1),
307-
dtype=torch.int,
308-
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
309-
),
310-
]
311-
312-
self.run_test_with_dynamic_shape(
313-
To(), input_specs, expected_ops={acc_ops.to_dtype}
314-
)
274+
# TODO Open in future. TRT 8.5 does not work for this test
275+
# The test is a rare case. We need to remove it in graph maybe.
276+
# def test_int(self):
277+
# class To(torch.nn.Module):
278+
# def forward(self, x):
279+
# x = x.int()
280+
# # we do not expect int to be output type, so add an extra layer
281+
# x = x.float()
282+
# return x
283+
284+
# input = torch.randn(2, 2)
285+
# inputs = [
286+
# input,
287+
# ]
288+
# self.run_test(
289+
# To(),
290+
# inputs,
291+
# expected_ops={acc_ops.to_dtype},
292+
# test_implicit_batch_dim=False,
293+
# precision=LowerPrecision.FP32,
294+
# )
295+
296+
# # tensor.int()
297+
# def test_int_with_dynamic_shape_four_dimensions(self):
298+
# class To(torch.nn.Module):
299+
# def forward(self, x):
300+
# x = x.int()
301+
# # we do not expect int to be output type, so add an extra layer
302+
# x = x.float()
303+
# return x
304+
305+
# input_specs = [
306+
# InputTensorSpec(
307+
# shape=(-1, -1, -1, -1),
308+
# dtype=torch.int,
309+
# shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
310+
# ),
311+
# ]
312+
313+
# self.run_test_with_dynamic_shape(
314+
# To(), input_specs, expected_ops={acc_ops.to_dtype}
315+
# )
315316

316317

317318
if __name__ == "__main__":

0 commit comments

Comments
 (0)