Skip to content

attention model failed to convert, "error: unknown: unsupported by backend contract: module initializers" #3915

Open
@heshuju

Description

attention model can't convert to stablehlo IR.
Qustion:
How can I solve this problem.
python script:

import torch
import torch.nn as nn
import torch.nn.init as init
import torch_mlir
class AttentionModel(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(AttentionModel, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, num_heads)

    def forward(self, query, key, value):
        attn_output, attn_weights = self.attention(query, key, value)
        return attn_output, attn_weights

embed_size = 64
num_heads = 8

model = AttentionModel(embed_size, num_heads)

model.eval() 
seq_len = 10
batch_size = 1 

query = torch.randn(seq_len, batch_size, embed_size) 
key = torch.randn(seq_len, batch_size, embed_size)   
value = torch.randn(seq_len, batch_size, embed_size) 

module = torch_mlir.compile(model, (query, key, value), output_type=torch_mlir.OutputType.STABLEHLO)

with open("model.mlir", "w") as f:
    f.write(str(module))
Traceback (most recent call last):
  File "/home/x/work/kneron_mlir/torch_example/attention/attention.py", line 37, in <module>
    module = torch_mlir.compile(model, (query, key, value), output_type=torch_mlir.OutputType.STABLEHLO)
  File "/home/x/work/kneron_mlir/torch-build/python_packages/torch_mlir/torch_mlir/__init__.py", line 451, in compile
    run_pipeline_with_repro_report(
  File "/home/x/work/kneron_mlir/torch-build/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline:
error: unknown: unsupported by backend contract: module initializers
note: unknown: see current operation: "torch.initialize.global_slots"(%6, %7, %8, %9) <{slotSymNames = [@attention.in_proj_weight, @attention.in_proj_bias, @attention.out_proj.weight, @attention.out_proj.bias]}> : (!torch.tensor<[192,64],f32>, !torch.tensor<[192],f32>, !torch.tensor<[64,64],f32>, !torch.tensor<[64],f32>) -> ()
note: unknown: this is likely due to InlineGlobalSlots being unable to inline a global slot

If use_tracing=True, It can convert model to stable IR.
module = torch_mlir.compile(model, (query, key, value), output_type=torch_mlir.OutputType.STABLEHLO,use_tracing=True)

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions