Description
Bug Description
I am trying to quantize already trained FP16 models to INT8 precision using torch_tensorrt and accelerate inference with TensorRT engines. However, during this process, I encountered several different issues — either inside torch_tensorrt or TensorRT itself (not entirely sure).
In most cases, the models fail to pass the quantize and/or compile process.
To Reproduce
- Define several common models (MLP, CNN, Attention, LSTM, Transformer) in torch.
- Randomly initialize model weights.
- Convert the models to FP16 precision and move them to GPU.
- Compile models using torch_tensorrt:
- Compile to FP16 TensorRT engine.
- Compile and quantize to INT8 TensorRT engine.
- Compare inference performance and accuracy between:
- Original FP16 model
- FP16 TensorRT-compiled model
- INT8 TensorRT-compiled-quantized model
Here is the minimal reproducible code:
# (One can switch between different models and IRs by modifying the comments.)
import contextlib
import time
import torch
import torch.nn as nn
import torch_tensorrt
from torch.utils.data import DataLoader, Dataset
from torch_tensorrt.ts import ptq
class MLPModel(nn.Module):
def __init__(self, seq_len, in_dim, hidden_sizes):
super().__init__()
layers = []
for hidden_size in hidden_sizes:
layers.append(nn.Linear(in_dim, hidden_size))
in_dim = hidden_size
self.mlp = nn.Sequential(*layers)
self.end_layer = nn.Linear(in_dim, 48)
self.active = nn.ReLU()
def forward(self, x):
B, T, C = x.shape
x = x.transpose(0, 1) # (T, B, C)
x = self.mlp(x)
x = self.active(x)
x = self.end_layer(x)
x = x.mean(dim=0) # (B, 48)
return x
class CNNModel(nn.Module):
def __init__(self, seq_len, in_dim, num_layers):
super().__init__()
layers = []
input_channels = in_dim
for _ in range(num_layers):
layers.append(nn.Conv1d(input_channels, 256, kernel_size=3, padding=1))
layers.append(nn.ReLU())
input_channels = 256
self.conv = nn.Sequential(*layers)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(256, 48)
def forward(self, x):
B, T, C = x.shape
x = x.transpose(1, 2) # (B, C, T)
x = self.conv(x) # (B, 256, T)
x = self.pool(x) # (B, 256, 1)
x = x.squeeze(-1) # (B, 256)
x = self.fc(x) # (B, 48)
return x
class AttentionModel(nn.Module):
def __init__(self, seq_len, in_dim, num_layers):
super().__init__()
layers = []
for _ in range(num_layers):
layers.append(nn.MultiheadAttention(embed_dim=in_dim, num_heads=4, batch_first=True))
self.attention_layers = nn.ModuleList(layers)
self.fc = nn.Linear(in_dim, 48)
def forward(self, x):
B, T, C = x.shape
for attn in self.attention_layers:
x, _ = attn(x, x, x)
x = x.mean(dim=1) # (B, C)
x = self.fc(x) # (B, 48)
return x
class TransformerModel(nn.Module):
def __init__(self, seq_len, in_dim, num_layers):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=in_dim, nhead=4, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(in_dim, 48)
def forward(self, x):
B, T, C = x.shape
x = self.transformer(x) # (B, T, C)
x = x.mean(dim=1) # (B, C)
x = self.fc(x) # (B, 48)
return x
class LSTMModel(nn.Module):
def __init__(self, seq_len, in_dim, num_layers):
super().__init__()
self.lstm = nn.LSTM(
input_size=in_dim,
hidden_size=256,
num_layers=num_layers,
batch_first=True,
bidirectional=False,
)
self.fc = nn.Linear(256, 48)
def forward(self, x):
B, T, C = x.shape
output, (hn, cn) = self.lstm(x) # output: (B, T, 256)
x = output.mean(dim=1) # (B, 256)
x = self.fc(x) # (B, 48)
return x
@torch.no_grad()
def run_model_with_profiling(
model, example_inputs, num_warmup, num_runs, msg=""
):
for _ in range(num_warmup):
_ = model(*example_inputs)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_runs):
output = model(*example_inputs)
torch.cuda.synchronize()
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) * 1000 / num_runs
print(f"{msg} {avg_time=:.4f} ms, {output.dtype=}")
return output
class CalibrationDataset(Dataset):
def __init__(self, length, shape):
self.length = length
self.shape = shape
def __len__(self):
return self.length
def __getitem__(self, idx):
return torch.rand(self.shape).half().cuda()
sample_shape = (1024, 512)
calib_num = 256
batch_size = 256
warmup = 5
infer = 10
profile = False
# ir = "torch_compile" # "torchscript", "dynamo"
seq_len, item_len = 256, 512
model = MLPModel(seq_len, item_len, [1024] * 5)
# model = CNNModel(seq_len, item_len, num_layers=5)
# model = AttentionModel(seq_len, item_len, num_layers=2)
# model = TransformerModel(seq_len, item_len, num_layers=2)
# model = LSTMModel(seq_len, item_len, num_layers=2)
calib_dataset = CalibrationDataset(calib_num, sample_shape)
calib_dataloader = DataLoader(calib_dataset, batch_size=32, shuffle=False)
model = model.eval().half().cuda()
example_input = torch.randn(batch_size, *sample_shape).half().cuda()
print(f">>> model: {model.__class__}")
try:
model(example_input)
fp16_output = run_model_with_profiling(model, [example_input, ], warmup, infer, "fp16")
except Exception as e:
print(f">>> original model error: {e}")
else:
print(">>> original model passed")
try:
compiled_model = torch_tensorrt.compile(
model,
ir=ir,
inputs=[
torch_tensorrt.Input(
min_shape=[batch_size, *sample_shape],
opt_shape=[batch_size, *sample_shape],
max_shape=[batch_size, *sample_shape],
dtype=torch.half,
)
],
enabled_precisions={torch.float16},
truncate_long_and_double=True,
)
_ = run_model_with_profiling(compiled_model, [example_input, ], warmup, infer, "trt")
except Exception as e:
print(f">>> tensorrt compile error: {e}")
else:
print(">>> tensorrt compile passed")
try:
quantized_compiled_model = torch_tensorrt.compile(
model,
ir=ir,
inputs=[
torch_tensorrt.Input(
min_shape=[batch_size, *sample_shape],
opt_shape=[batch_size, *sample_shape],
max_shape=[batch_size, *sample_shape],
dtype=torch.half,
)
],
enabled_precisions={torch.int8},
calibrator=ptq.DataLoaderCalibrator(
calib_dataloader,
algo_type=ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device("cuda:0"),
),
truncate_long_and_double=True,
)
int8_output = run_model_with_profiling(quantized_compiled_model, [example_input, ], warmup, infer, "quant + trt")
except Exception as e:
print(f">>> quantized tensorrt compile error: {e}")
else:
print(">>> quantized tensorrt compile passed")
Expected Behavior
Successfully compile FP16 models to INT8 TensorRT engines, also maintain reasonable inference accuracy and performance.
Actual Behavior
In most cases, the compilation fails or the resulting models cannot run correctly. Below is a summary to the results that I tested:
Model | IR | FP16 | FP16 + TRT | INT8 + TRT | Error Log |
---|---|---|---|---|---|
MLP | torch_compile | pass | pass | failed | see [1] |
MLP | torchscript | pass | pass | pass | N/A |
MLP | dynamo | pass | pass | failed | see [2] |
CNN | torch_compile | pass | pass | failed | see [3] |
CNN | torchscript | pass | pass | failed | see [4] |
CNN | dynamo | pass | pass | failed | see [5] |
Attention | torch_compile | pass | pass | pass | N/A |
Attention | torchscript | pass | failed | failed | see [6] |
Attention | dynamo | pass | failed | failed | see [7] [8] |
Transformer | torch_compile | pass | pass | pass | N/A |
Transformer | torchscript | pass | failed | failed | see [9] |
Transformer | dynamo | pass | failed | failed | see [10] [11] |
LSTM | torch_compile | pass | pass | pass | N/A |
LSTM | torchscript | pass | failed | failed | see [12] |
LSTM | dynamo | pass | failed | failed | see [13] |
And the corresponding error log is (due to the length limitation I must upload a file) error_log.txt.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.6.0+cu124
- PyTorch Version (e.g. 1.0): 2.6.0+cu124
- CPU Architecture: x86_64
- OS (e.g., Linux): Rocky Linux 8.7
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source): N/A
- Are you using local sources or building from archives: N/A
- Python version: 3.10.16
- CUDA version: 12.5
- GPU models and configuration: NVIDIA GeForce RTX 4090
- Any other relevant information: N/A
Questions
Am I using torch_tensorrt incorrectly?
Are there any important documentation notes or best practices regarding compilation and quantization that I might have missed?
Whats the correct way (or official suggestion) to do this task, specifically given a fp16 model then build a int8 quantized version and inference the model with TensorRT backend?
Any help would be greatly appreciated! Thank you in advance!
Additional context
N/A