Skip to content

Commit 5203272

Browse files
hzfanapivovarov
authored andcommitted
[Fix][Frontend][TOPI] minor bugs (apache#8622)
* fix * fix * lint
1 parent 0e4edef commit 5203272

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

include/tvm/topi/detail/ravel_unravel.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ using namespace tvm::te;
4444
*/
4545
inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
4646
ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
47-
ICHECK_GT(indices.size(), 0) << "indices must not be empty";
47+
if (indices.size() == 0U) {
48+
return 0;
49+
}
4850
PrimExpr idx;
4951
for (size_t i = 0; i < indices.size(); ++i) {
5052
if (i == 0) {

python/tvm/relay/frontend/pytorch.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,16 @@ def linear(self, inputs, input_types):
14441444
# 0 - input
14451445
# 1 - weight
14461446
bias = inputs[2]
1447-
mm_out = self.matmul(inputs[:2], input_types[:2])
1447+
a_shape = self.infer_shape_with_prelude(inputs[0])
1448+
b_shape = self.infer_shape_with_prelude(inputs[1])
1449+
if len(a_shape) == 2 and len(b_shape) == 2:
1450+
mm_out = _op.nn.dense(inputs[0], inputs[1])
1451+
elif len(b_shape) == 1:
1452+
mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2])
1453+
else:
1454+
mm_out = self.matmul(
1455+
[inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]
1456+
)
14481457
if isinstance(bias, _expr.Expr):
14491458
bias_ndims = len(self.infer_shape_with_prelude(bias))
14501459
if bias_ndims == 1:

tests/python/frontend/pytorch/test_forward.py

+6
Original file line numberDiff line numberDiff line change
@@ -1569,8 +1569,10 @@ def forward(self, input, weight):
15691569
return F.linear(input, weight)
15701570

15711571
input2d = torch.rand([2, 2]).float()
1572+
input3d = torch.rand([4, 3, 2]).float()
15721573
weight1d = torch.rand([2]).float()
15731574
weight2d = torch.rand([2, 2]).float()
1575+
weight3x2 = torch.rand([3, 2]).float()
15741576
bias1d = torch.rand([2]).float()
15751577
bias2d = torch.rand([2, 2]).float()
15761578
# 2D input, 2D weight, 1D bias
@@ -1579,9 +1581,12 @@ def forward(self, input, weight):
15791581
verify_model(Linear(), input_data=[input2d, weight2d, bias2d])
15801582
# 2D input, 2D weight, no bias
15811583
verify_model(LinearNoBias(), input_data=[input2d, weight2d])
1584+
verify_model(LinearNoBias(), input_data=[input2d, weight3x2])
15821585
# 2D input, 1D weight, 1D bias is not supported by torch.linear()
15831586
# 2D input, 1D weight, no bias
15841587
verify_model(LinearNoBias(), input_data=[input2d, weight1d])
1588+
# 3D input, 2D weight, no bias
1589+
verify_model(LinearNoBias(), input_data=[input3d, weight3x2])
15851590
# TODO: Add the following cases when matmul(1D, _) is supported by TVM
15861591
# 1D input, 2D weight, 1D bias
15871592
# 1D input, 2D weight, no bias
@@ -3939,6 +3944,7 @@ def test_fn(is_sorted, return_inverse, return_counts):
39393944
test_forward_logsoftmax()
39403945
test_forward_sigmoid()
39413946
test_forward_dense()
3947+
test_forward_linear()
39423948
test_forward_avgpool1d()
39433949
test_forward_avgpool2d()
39443950
test_forward_avgpool3d()

tests/python/relay/test_op_level3.py

+1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def verify_reshape(shape, newshape, oshape):
293293
verify_reshape((2, 3, 4), (-3, -2), (6, 4))
294294
verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4))
295295
verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))
296+
verify_reshape((1,), (), ())
296297

297298

298299
def test_reshape_fail():

0 commit comments

Comments
 (0)