Skip to content

attention model failed to convert, "error: unknown: unsupported by backend contract: module initializers" #3915

Open
@heshuju

Description

@heshuju

attention model can't convert to stablehlo IR.
Qustion:
How can I solve this problem.
python script:

import torch
import torch.nn as nn
import torch.nn.init as init
import torch_mlir
class AttentionModel(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(AttentionModel, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, num_heads)

    def forward(self, query, key, value):
        attn_output, attn_weights = self.attention(query, key, value)
        return attn_output, attn_weights

embed_size = 64
num_heads = 8

model = AttentionModel(embed_size, num_heads)

model.eval() 
seq_len = 10
batch_size = 1 

query = torch.randn(seq_len, batch_size, embed_size) 
key = torch.randn(seq_len, batch_size, embed_size)   
value = torch.randn(seq_len, batch_size, embed_size) 

module = torch_mlir.compile(model, (query, key, value), output_type=torch_mlir.OutputType.STABLEHLO)

with open("model.mlir", "w") as f:
    f.write(str(module))
Traceback (most recent call last):
  File "/home/x/work/kneron_mlir/torch_example/attention/attention.py", line 37, in <module>
    module = torch_mlir.compile(model, (query, key, value), output_type=torch_mlir.OutputType.STABLEHLO)
  File "/home/x/work/kneron_mlir/torch-build/python_packages/torch_mlir/torch_mlir/__init__.py", line 451, in compile
    run_pipeline_with_repro_report(
  File "/home/x/work/kneron_mlir/torch-build/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline:
error: unknown: unsupported by backend contract: module initializers
note: unknown: see current operation: "torch.initialize.global_slots"(%6, %7, %8, %9) <{slotSymNames = [@attention.in_proj_weight, @attention.in_proj_bias, @attention.out_proj.weight, @attention.out_proj.bias]}> : (!torch.tensor<[192,64],f32>, !torch.tensor<[192],f32>, !torch.tensor<[64,64],f32>, !torch.tensor<[64],f32>) -> ()
note: unknown: this is likely due to InlineGlobalSlots being unable to inline a global slot

If use_tracing=True, It can convert model to stable IR.
module = torch_mlir.compile(model, (query, key, value), output_type=torch_mlir.OutputType.STABLEHLO,use_tracing=True)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions