Skip to content

🐛 [Bug] Unsupported ops : torch.ops.aten.index_put.default - validator: False #3432

Open
@chohk88

Description

@chohk88

Bug Description

Attempting torch.compile (backend = torch_tensorrt) the google/paligemma2-3b-pt-224 model, there is an unsupported ops like below:

Supported Nodes:
- torch.ops.aten.unsqueeze.default + Operator Count: 107
- torch.ops.aten.slice.Tensor + Operator Count: 474
- torch.ops.aten._to_copy.default + Operator Count: 334
- torch.ops.aten.expand.default + Operator Count: 55
- torch.ops.aten.reshape.default + Operator Count: 577
- torch.ops.aten.bmm.default + Operator Count: 1
- torch.ops.aten.permute.default + Operator Count: 288
- torch.ops.aten.cat.default + Operator Count: 53
- torch.ops.aten.cos.default + Operator Count: 1
- torch.ops.aten.sin.default + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 344
- torch.ops.aten.pow.Tensor_Scalar + Operator Count: 105
- torch.ops.aten.mean.dim + Operator Count: 105
- torch.ops.aten.add.Tensor + Operator Count: 353
- torch.ops.aten.sqrt.default + Operator Count: 105
- torch.ops.aten.div.Tensor + Operator Count: 106
- torch.ops.aten.mm.default + Operator Count: 183
- torch.ops.aten.neg.default + Operator Count: 52
- torch.ops.aten.clamp.default + Operator Count: 13
- torch.ops.aten.ge.Scalar + Operator Count: 13
- torch.ops.aten.select.int + Operator Count: 13
- torch.ops.aten.sub.Tensor + Operator Count: 13
- torch.ops.aten.remainder.Scalar + Operator Count: 13
- torch.ops.aten.index.Tensor + Operator Count: 26
- torch.ops.aten.clone.default + Operator Count: 78
- torch._C._nn.scaled_dot_product_attention + Operator Count: 26
- torch.ops.aten.gelu.default + Operator Count: 26
- torch.ops.aten.tanh.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.index_put.default + Operator Count: 52

...

DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.index_put.default: 1
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions.
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Skipping option 0 for aten.index_put.default: (validator: False, supports dynamic shapes: False)

...

DEBUG:torch_tensorrt.dynamo._compiler:Submodule in PyTorch: _run_on_gpu_2
 graph():
    %unsqueeze_1 : [num_users=1] = placeholder[target=unsqueeze_1]
    %_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%unsqueeze_1, None, None, torch.float16), kwargs = {})
    %slice_3 : [num_users=1] = placeholder[target=slice_3]
    %_assert_tensor_metadata_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%slice_3, None, None, torch.int64), kwargs = {})
    %mul : [num_users=1] = placeholder[target=mul]
    %_assert_tensor_metadata_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul, None, None, torch.float32), kwargs = {})
    %mul_1 : [num_users=1] = placeholder[target=mul_1]
    %_assert_tensor_metadata_3 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul_1, None, None, torch.float32), kwargs = {})
    %mul_2 : [num_users=1] = placeholder[target=mul_2]
    %_assert_tensor_metadata_5 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul_2, None, None, torch.float16), kwargs = {})
    %mul_4 : [num_users=1] = placeholder[target=mul_4]
    %_assert_tensor_metadata_7 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%mul_4, None, None, torch.float32), kwargs = {})
    %select : [num_users=1] = placeholder[target=select]
    %_assert_tensor_metadata_8 : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%select, None, None, torch.bool), kwargs = {})
    %slice_17 : [num_users=1] = placeholder[target=slice_17]
    %clamp : [num_users=2] = placeholder[target=clamp]
    %reshape_default_12 : [num_users=1] = placeholder[target=reshape_default_12]
    %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%slice_17, [None, None, %clamp], %reshape_default_12), kwargs = {})
    %slice_22 : [num_users=1] = placeholder[target=slice_22]
    %reshape_default_13 : [num_users=1] = placeholder[target=reshape_default_13]
    %index_put_1 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%slice_22, [None, None, %clamp], %reshape_default_13), kwargs = {})
    %select_1 : [num_users=1] = placeholder[target=select_1]

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image

DEVICE = "cuda:0"

model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)


model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float16).eval()
model.to(DEVICE).to(torch.float16)
# model.forward = model.forward.to(torch.float16).eval()

processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# model.config.token_healing = False

with torch.inference_mode():
    pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    pyt_generation_out = pyt_generation[0][input_len:]
    pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
    print("=============================")
    print("pyt_generation whole text:")
    print(pyt_generation)
    print("=============================")
    print("=============================")
    print("PyTorch generated text:")
    print(pyt_decoded)
    print("=============================")

with torch_tensorrt.logging.debug():
    torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=1023)
    model.forward = torch.compile(
        model.forward,
        backend="tensorrt",
        dynamic=None,
        options={
            "enabled_precisions": {torch.float16},
            "disable_tf32": True,
            "min_block_size": 1,
            # "use_explicit_typing": True,
            # "use_fp32_acc": True,
            "debug": True,
            # "use_aot_joint_export":False,
        },
    )
    
    with torch.inference_mode():
        trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) 
        trt_generation_out = trt_generation[0][input_len:]
        trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
        print(trt_generation)
        print("TensorRT generated text:")
        print(trt_decoded)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions