Open
Description
Bug Description
Attempting torch.compile (backend = torch_tensorrt) the google/paligemma2-3b-pt-224 model, there is an unsupported ops like below:
Supported Nodes:
- torch.ops.aten.unsqueeze.default + Operator Count: 107
- torch.ops.aten.slice.Tensor + Operator Count: 474
- torch.ops.aten._to_copy.default + Operator Count: 334
- torch.ops.aten.expand.default + Operator Count: 55
- torch.ops.aten.reshape.default + Operator Count: 577
- torch.ops.aten.bmm.default + Operator Count: 1
- torch.ops.aten.permute.default + Operator Count: 288
- torch.ops.aten.cat.default + Operator Count: 53
- torch.ops.aten.cos.default + Operator Count: 1
- torch.ops.aten.sin.default + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 344
- torch.ops.aten.pow.Tensor_Scalar + Operator Count: 105
- torch.ops.aten.mean.dim + Operator Count: 105
- torch.ops.aten.add.Tensor + Operator Count: 353
- torch.ops.aten.sqrt.default + Operator Count: 105
- torch.ops.aten.div.Tensor + Operator Count: 106
- torch.ops.aten.mm.default + Operator Count: 183
- torch.ops.aten.neg.default + Operator Count: 52
- torch.ops.aten.clamp.default + Operator Count: 13
- torch.ops.aten.ge.Scalar + Operator Count: 13
- torch.ops.aten.select.int + Operator Count: 13
- torch.ops.aten.sub.Tensor + Operator Count: 13
- torch.ops.aten.remainder.Scalar + Operator Count: 13
- torch.ops.aten.index.Tensor + Operator Count: 26
- torch.ops.aten.clone.default + Operator Count: 78
- torch._C._nn.scaled_dot_product_attention + Operator Count: 26
- torch.ops.aten.gelu.default + Operator Count: 26
- torch.ops.aten.tanh.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.index_put.default + Operator Count: 52
...
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.index_put.default: 1
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions.
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Skipping option 0 for aten.index_put.default: (validator: False, supports dynamic shapes: False)
...
DEBUG:torch_tensorrt.dynamo._compiler:Submodule in PyTorch: _run_on_gpu_2
graph():
%unsqueeze_1 : [num_users=1] = placeholder[target=unsqueeze_1]
%_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%unsqueeze_1, None, None, torch.float16), kwargs = {})
%slice_3 : [num_users=1] = placeholder[target=slice_3]
%_assert_tensor_metadata_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%slice_3, None, None, torch.int64), kwargs = {})
%mul : [num_users=1] = placeholder[target=mul]
%_assert_tensor_metadata_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul, None, None, torch.float32), kwargs = {})
%mul_1 : [num_users=1] = placeholder[target=mul_1]
%_assert_tensor_metadata_3 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul_1, None, None, torch.float32), kwargs = {})
%mul_2 : [num_users=1] = placeholder[target=mul_2]
%_assert_tensor_metadata_5 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul_2, None, None, torch.float16), kwargs = {})
%mul_4 : [num_users=1] = placeholder[target=mul_4]
%_assert_tensor_metadata_7 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul_4, None, None, torch.float32), kwargs = {})
%select : [num_users=1] = placeholder[target=select]
%_assert_tensor_metadata_8 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%select, None, None, torch.bool), kwargs = {})
%slice_17 : [num_users=1] = placeholder[target=slice_17]
%clamp : [num_users=2] = placeholder[target=clamp]
%reshape_default_12 : [num_users=1] = placeholder[target=reshape_default_12]
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%slice_17, [None, None, %clamp], %reshape_default_12), kwargs = {})
%slice_22 : [num_users=1] = placeholder[target=slice_22]
%reshape_default_13 : [num_users=1] = placeholder[target=reshape_default_13]
%index_put_1 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%slice_22, [None, None, %clamp], %reshape_default_13), kwargs = {})
%select_1 : [num_users=1] = placeholder[target=select_1]
To Reproduce
Steps to reproduce the behavior:
import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image
DEVICE = "cuda:0"
model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16).eval()
model.to(DEVICE).to(torch.float16)
# model.forward = model.forward.to(torch.float16).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]
# model.config.token_healing = False
with torch.inference_mode():
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
pyt_generation_out = pyt_generation[0][input_len:]
pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
print("=============================")
print("pyt_generation whole text:")
print(pyt_generation)
print("=============================")
print("=============================")
print("PyTorch generated text:")
print(pyt_decoded)
print("=============================")
with torch_tensorrt.logging.debug():
torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float16},
"disable_tf32": True,
"min_block_size": 1,
# "use_explicit_typing": True,
# "use_fp32_acc": True,
"debug": True,
# "use_aot_joint_export":False,
},
)
with torch.inference_mode():
trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
trt_generation_out = trt_generation[0][input_len:]
trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
print(trt_generation)
print("TensorRT generated text:")
print(trt_decoded)