-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
TVM should compile the model correctly.
Actual behavior
For the following model,
TVM crashes:
Traceback (most recent call last):
File "/home/ubuntu/Documents/DLCompiler-test/tvm/1126/bugs/onnx_output4/test1.py", line 52, in <module>
test(onnx_model)
File "/home/ubuntu/Documents/DLCompiler-test/tvm/1126/bugs/onnx_output4/test1.py", line 45, in test
tvm_model = relax.transform.LegalizeOps()(tvm_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in std::_Function_handler<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext), tvm::relax::transform::LegalizeOps(tvm::ffi::Optional<tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Function, void>, void>, tvm::ffi::Optional<tvm::ffi::Array<tvm::ffi::String, void>, void>, bool)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&, tvm::IRModule&&, tvm::transform::PassContext&&)
File "<unknown>", line 0, in tvm::relax::transform::LegalizeOps(tvm::ffi::Optional<tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Function, void>, void>, tvm::ffi::Optional<tvm::ffi::Array<tvm::ffi::String, void>, void>, bool)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const
File "<unknown>", line 0, in tvm::relax::LegalizeMutator::Transform()
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelaxExpr const&, tvm::ffi::Optional<tvm::ffi::Array<tvm::relax::Var, void>, void>)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::IfNode const*)
File "<unknown>", line 0, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
File "<unknown>", line 0, in tvm::relax::LegalizeMutator::VisitExpr_(tvm::relax::CallNode const*)
File "python/tvm_ffi/cython/function.pxi", line 1058, in tvm_ffi.core.tvm_ffi_callback
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/transform/legalize_ops/nn.py", line 493, in _nn_prelu
return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis)
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/block_builder.py", line 361, in call_te
tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs)
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/utils.py", line 355, in gen_call_tir_inputs
te_out = func(*te_args, **te_kwargs)
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/te/tag.py", line 57, in tagged_fdecl
return fdecl(*args, **kwargs)
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/topi/nn/elemwise.py", line 130, in prelu
assert len(slope.shape) == 1
AssertionError
I am not sure that this is a bug for TVM. This issue is same as the question that has been fixed.
Environment
OS: Ubuntu 20.04
TVM: 0.23.dev0 (f4e28d3)
onnxruntime: 1.23.2
Steps to reproduce
This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime.
from typing import Dict, List, Literal, Optional
import sys
import os
import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import argparse
import pickle
def test(
model: ModelProto,
inputs: Optional[Dict[str, np.ndarray]] = None,
ir_version: int = 8,
opset: int = 14,
) -> None:
# Configure model format.
if ir_version is not None:
model.ir_version = ir_version
if opset is not None:
model.opset_import[0].version = opset
with open("inputs.pkl", 'rb') as fp:
inputs = pickle.load(fp)
# Run the model through onnx to get the expected result.
try:
ort_session = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
except Exception as e:
print(e)
print("This model cannot be executed by onnxruntime!")
sys.exit(1)
tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
tvm_model = relax.transform.LegalizeOps()(tvm_model)
if __name__ == "__main__":
onnx_model = onnx.load("11.onnx")
test(onnx_model)Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug