Skip to content

[Bug] #18573

@DaTouJun

Description

@DaTouJun

Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first 😸

Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed.

Expected behavior

Works proper with the loaded export model

Actual behavior

/home/guan/miniconda3/envs/tvm/bin/python /home/guan/dev/pycharm/TVM/tvm2/helloworld.py
/home/guan/miniconda3/envs/tvm/lib/python3.11/site-packages/torch/export/pt2_archive/_package.py:682: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1581.)
tensor = torch.frombuffer(
Traceback (most recent call last):
File "/home/guan/dev/pycharm/TVM/tvm2/helloworld.py", line 8, in
mod = from_exported_program(exported_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/guan/miniconda3/envs/tvm/lib/python3.11/site-packages/tvm/relax/frontend/torch/exported_program_translator.py", line 1261, in from_exported_program
return ExportedProgramImporter().from_exported_program(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/guan/miniconda3/envs/tvm/lib/python3.11/site-packages/tvm/relax/frontend/torch/exported_program_translator.py", line 1156, in from_exported_program
self.env[node] = self.convert_mapfunc_name
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/guan/miniconda3/envs/tvm/lib/python3.11/site-packages/tvm/relax/frontend/torch/base_fx_graph_translator.py", line 1109, in _linear
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/guan/miniconda3/envs/tvm/lib/python3.11/site-packages/tvm/relax/block_builder.py", line 328, in emit
return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.call
File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 1068, in operator()
return builder->Emit(expr, name_hint);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 243, in tvm::relax::BlockBuilderImpl::Emit(tvm::RelaxExpr, tvm::ffi::String)
return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 395, in tvm::relax::BlockBuilderImpl::Emit(tvm::RelaxExpr, bool, tvm::ffi::String)
expr = this->Normalize(expr);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 532, in tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)
Expr normalized = this->VisitExpr(expr);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 615, in tvm::relax::Normalizer::VisitExpr(tvm::RelaxExpr const&)
return ExprFunctor::VisitExpr(expr);

File "/home/guan/dev/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
return vtable(n, this, std::forward(args)...);

File "/home/guan/dev/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>) const
return (*func_[n->type_index() - begin_type_index_])(n, std::forward(args)...);

File "/home/guan/dev/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);

File "/home/guan/dev/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)#9}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>) const
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 654, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); });

File "/home/guan/dev/tvm/3rdparty/tvm-ffi/include/tvm/ffi/container/array.h", line 799, in tvm::ffi::Array<tvm::RelaxExpr, std::enable_if<storage_enabled_vtvm::RelaxExpr, void>::type> tvm::ffi::Array<tvm::RelaxExpr, void>::Map<tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)::{lambda(tvm::RelaxExpr const&)#1}, tvm::RelaxExpr>(tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)::{lambda(tvm::RelaxExpr const&)#1}) const
return Array(MapHelper(data_, fmap));

File "/home/guan/dev/tvm/3rdparty/tvm-ffi/include/tvm/ffi/container/array.h", line 975, in tvm::ffi::ObjectPtrtvm::ffi::Object tvm::ffi::Array<tvm::RelaxExpr, void>::MapHelper<tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)::{lambda(tvm::RelaxExpr const&)#1}, tvm::RelaxExpr>(tvm::ffi::ObjectPtrtvm::ffi::Object, tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)::{lambda(tvm::RelaxExpr const&)#1})
U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it));

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 654, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)::{lambda(tvm::RelaxExpr const&)#1}::operator()(tvm::RelaxExpr const&) const
op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); });

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 563, in tvm::relax::Normalizer::NormalizeArgument(tvm::RelaxExpr const&)
Expr post = ExprFunctor::VisitExpr(arg);

File "/home/guan/dev/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
return vtable(n, this, std::forward(args)...);

File "/home/guan/dev/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>) const
return (*func_[n->type_index() - begin_type_index_])(n, std::forward(args)...);

File "/home/guan/dev/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);

File "/home/guan/dev/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>)#9}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>) const
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 664, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
auto inferred_sinfo = InferStructInfo(call);

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 847, in tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
return op_map_infer_struct_info_[op](call, ffi::GetRef(this));

File "/home/guan/dev/tvm/src/relax/op/tensor/linear_algebra.cc", line 141, in tvm::relax::InferStructInfoMatmul(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
ctx->ReportFatal(Diagnostic::Error(call)

File "/home/guan/dev/tvm/src/relax/ir/block_builder.cc", line 157, in tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
LOG(FATAL) << diagnostic->message;

File "/home/guan/dev/tvm/include/tvm/runtime/logging.h", line 321, in tvm::runtime::detail::LogFatal::~LogFatal()
GetEntry().Finalize();

File "/home/guan/dev/tvm/include/tvm/runtime/logging.h", line 337, in tvm::runtime::detail::LogFatal::Entry::Finalize()
InternalError error(file_, lineno_, stream_.str());

tvm.error.InternalError: Matmul requires the reduction length of the operands to be equal. However, the LHS lv has shape R.shape([1, 10]), while the RHS lv1 has shape R.shape([784, 128]). The reduction dimensions of T.int64(10) and T.int64(784) are not equal.
[16:08:40] /home/guan/dev/tvm/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!

进程已结束,退出代码为 1

Environment

python 3.11
tvm v0.22.0

Steps to reproduce

import torch
import os
os.environ['TVM_LIBRARY_PATH'] = '/home/guan/dev/tvm/build'
import tvm as t
from tvm.relax.frontend.torch import from_exported_program

exported_program = torch.export.load("model.pt2")
mod = from_exported_program(exported_program)

Model from:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
os.environ['TVM_LIBRARY_PATH'] = '/home/guan/dev/tvm/build'
from tvm.relax.frontend.torch import from_exported_program

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

input_size = 28 * 28
num_classes = 10

class SimpleNet(nn.Module):
def init(self):
super(SimpleNet, self).init()
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, num_classes)

def forward(self, x):
    x = x.view(x.size(0), -1)
    x = torch.relu(self.fc1(x))
    x = self.fc2(x)
    return x

model = SimpleNet()
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("--- 开始训练 (1 Epoch) ---")

def train(model, device, train_loader, optimizer, criterion, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

    if batch_idx % 100 == 0:
        print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
              f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 1
for epoch in range(1, num_epochs + 1):
train(model, device, train_loader, optimizer, criterion, epoch)

model.cpu()
model.eval()
example_args = (torch.randn(1, 1, 28, 28).to(torch.device("cpu")),)

exported_program = torch.export.export(model, example_args)
output_path = "model.pt2"
torch.export.save(exported_program, output_path)

mod = from_exported_program(exported_program)
print(mod)

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions