Skip to content

failed to legalize operation 'torch.aten.cumsum' #3866

Open
@DavidGinten

Description

Hello,

I encountered the following error when trying to run this model with IREE Turbine: https://huggingface.co/facebook/detr-resnet-50
The error does not occur in IREE Turbine though but in the lowering from the torch aten cumsum operation to TMTensor.

Error:

Failure while executing pass pipeline (pm.run()): failed to legalize operation 'torch.aten.cumsum' that was explicitly marked illegal.

The problem seems to occur in here:

if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");

To reproduce:

import torch
import numpy as np
from transformers import AutoModelForObjectDetection
from iree.turbine import aot
from iree.compiler.ir import Context
from iree.compiler.passmanager import PassManager

model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")

model.eval()

args = ((np.random.randn(1, 3, 224, 224).astype(np.float32),))
args_torch = tuple([torch.from_numpy(x) for x in args])

exported_model: aot.ExportOutput = aot.export(
            model, args=args_torch, dynamic_shapes=None, strict_export=False
        )

context: Context = exported_model.mlir_module.context
with context:
    pm = PassManager.parse("builtin.module(torch-to-iree)")
    pm.run(exported_model.mlir_module.operation)

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions