Skip to content

🐛 [Bug] Unsupported ops : torch.ops.aten.masked_scatter.default (Paligemma2) #3410

Open
@chohk88

Description

@chohk88

Bug Description

Attempting torch.compile (backend = torch_tensorrt) the google/paligemma2-3b-pt-224 model, there is an unsupported ops like below:

DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph:
graph():
    %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_]
    %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_]
    %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_]
    %image_features : [num_users=2] = call_method[target=to](args = (%l_image_features_, cuda:0, torch.float16), kwargs = {})
    %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%l_inputs_embeds_, %l_special_image_mask_, %image_features), kwargs = {})
    return (inputs_embeds, image_features)
DEBUG:torch_tensorrt.dynamo.lowering.passes.repair_input_aliasing:Inserted auxiliary clone nodes for placeholders:
graph():
    %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_]
    %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_]
    %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_]
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_special_image_mask_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_inputs_embeds_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_image_features_,), kwargs = {})
    %image_features : [num_users=2] = call_method[target=to](args = (%clone_default, cuda:0, torch.float16), kwargs = {})
    %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%clone_default_1, %clone_default_2, %image_features), kwargs = {})
    return (inputs_embeds, image_features)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_]
    %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_]
    %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_]
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_special_image_mask_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_inputs_embeds_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_image_features_,), kwargs = {})
    %image_features : [num_users=2] = call_method[target=to](args = (%clone_default, cuda:0, torch.float16), kwargs = {})
    %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%clone_default_1, %clone_default_2, %image_features), kwargs = {})
    return (inputs_embeds, image_features)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_]
    %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_]
    %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_]
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_special_image_mask_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_inputs_embeds_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_image_features_,), kwargs = {})
    %image_features : [num_users=2] = call_method[target=to](args = (%clone_default, cuda:0, torch.float16), kwargs = {})
    %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%clone_default_1, %clone_default_2, %image_features), kwargs = {})
    return (inputs_embeds, image_features)
DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg2_1,), kwargs = {})
    %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {})
    %clone_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%clone_2,), kwargs = {dtype: torch.float16, device: cuda:0})
    %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%clone_1, %clone, %_to_copy), kwargs = {})
    return (masked_scatter, _to_copy)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone_2 from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone_1 from graph, since it is a clone node which is the only user of placeholder arg1_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg2_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0})
    %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {})
    return (masked_scatter, _to_copy)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0})
    %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {})
    return (masked_scatter, _to_copy)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0})
    %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {})
    return (masked_scatter, _to_copy)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0})
    %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {})
    return (masked_scatter, _to_copy)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten._to_copy.default: 2
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Skipping option 0 for aten._to_copy.default: (validator: False, supports dynamic shapes: True)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 1 for converting aten._to_copy.default
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.masked_scatter.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten._to_copy.default: 2
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Skipping option 0 for aten._to_copy.default: (validator: False, supports dynamic shapes: True)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 1 for converting aten._to_copy.default
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.masked_scatter.default + Operator Count: 1

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image

DEVICE = "cuda:0"

model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)


model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float16).eval()
model.to(DEVICE).to(torch.float16)
# model.forward = model.forward.to(torch.float16).eval()

processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# model.config.token_healing = False

with torch.inference_mode():
    pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    pyt_generation_out = pyt_generation[0][input_len:]
    pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
    print("=============================")
    print("pyt_generation whole text:")
    print(pyt_generation)
    print("=============================")
    print("=============================")
    print("PyTorch generated text:")
    print(pyt_decoded)
    print("=============================")

with torch_tensorrt.logging.debug():
    torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=1023)
    model.forward = torch.compile(
        model.forward,
        backend="tensorrt",
        dynamic=None,
        options={
            "enabled_precisions": {torch.float16},
            "disable_tf32": True,
            "min_block_size": 1,
            # "use_explicit_typing": True,
            # "use_fp32_acc": True,
            "debug": True,
            # "use_aot_joint_export":False,
        },
    )
    
    with torch.inference_mode():
        trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) 
        trt_generation_out = trt_generation[0][input_len:]
        trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
        print(trt_generation)
        print("TensorRT generated text:")
        print(trt_decoded)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions