Skip to content

Commit 9b6b273

Browse files
hseok-ohzetwhite
andauthored
[onert] Support constant input BatchMatMul (#15033)
This commit adds support for constant input BatchMatMul. It includes BatchMatMul constant input tests. ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com> Co-authored-by: SeungHui Youn <61981457+zetwhite@users.noreply.github.com>
1 parent 659a05f commit 9b6b273

5 files changed

Lines changed: 53 additions & 9 deletions

File tree

runtime/compute/cker/include/cker/operation/BatchMatMul.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class BatchMatMul
4444
/**
4545
* @brief Prepare temporary area for calculation
4646
*/
47-
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y)
47+
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y,
48+
bool rhs_const)
4849
{
4950
if (adj_x)
5051
{
@@ -75,18 +76,19 @@ class BatchMatMul
7576

7677
_temp_rhs.resize(_temp_rhs_shape.FlatSize());
7778
}
79+
80+
_rhs_constant = rhs_const;
7881
}
7982

8083
void operator()(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape,
8184
const float *rhs_data, bool adj_x, bool adj_y, const Shape & /*output_shape*/,
8285
float *output_data)
8386
{
84-
// Assume lhs and rhs is not constant
85-
// TODO Handle constant input
86-
87-
if (!adj_y)
87+
// Don't need transpose if rhs is constant and already transposed
88+
if (!adj_y && !(_rhs_constant && _rhs_transposed))
8889
{
8990
transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
91+
_rhs_transposed = true;
9092
}
9193

9294
if (adj_x)
@@ -144,6 +146,8 @@ class BatchMatMul
144146
Shape _temp_lhs_shape;
145147
std::vector<float> _temp_rhs;
146148
Shape _temp_rhs_shape;
149+
bool _rhs_constant = false;
150+
bool _rhs_transposed = false;
147151
};
148152

149153
} // namespace cker

runtime/compute/cker/include/cker/operation/Einsum.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,8 @@ class Einsum
903903

904904
// LaunchBatchMatMul::Launch(lhs, rhs, adj_x, adj_y, bcast, &output_reshaped);
905905
BatchMatMul batchMatMul;
906-
batchMatMul.prepare(lhs.shape, rhs.shape, adj_x, adj_y);
906+
// Set rhs is not constant: don't use optimization
907+
batchMatMul.prepare(lhs.shape, rhs.shape, adj_x, adj_y, false);
907908
batchMatMul(lhs.shape, lhs.base<float>(), rhs.shape, rhs.base<float>(), adj_x, adj_y,
908909
output_reshaped.shape, output_reshaped.base<float>());
909910
}

runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void BatchMatMulLayer::batchMatMulFloat32()
3939

4040
// TODO implement for constant input
4141

42-
batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
42+
batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y, _rhs->is_constant());
4343
batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
4444
_adj_y, output_shape, getBuffer<float>(_output));
4545
}

runtime/onert/core/src/ir/OperationValidator.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,10 @@ void OperationValidator::visit(const operation::BatchMatMul &node)
117117
const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
118118
const auto output_index(node.getOutputs().at(0));
119119

120-
// Constant lhs and rhs is not implemented yet
121-
OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
120+
// RHS can be constant, but LHS is not constant
121+
// If one of inputs is constant, it must be RHS
122+
// If two inputs are constant, BatchMatMul is optimized into constant by compiler
123+
OP_REQUIRES(!isConstant(lhs_index));
122124

123125
// Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
124126
OP_REQUIRES(isValidType(

runtime/tests/nnfw_api/src/GenModelTests/one_op_tests/BatchMatMul.test.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,40 @@ TEST_F(GenModelTest, neg_OneOp_BatchMatMul_InvalidType)
4848

4949
SUCCEED();
5050
}
51+
52+
TEST_F(GenModelTest, OneOp_BatchMatMul_Const)
53+
{
54+
CircleGen cgen;
55+
int lhs = cgen.addTensor({{1, 2, 3}, circle::TensorType::TensorType_FLOAT32});
56+
std::vector<float> const_data{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
57+
uint32_t rhs_buf = cgen.addBuffer(const_data);
58+
int rhs = cgen.addTensor({{1, 3, 4}, circle::TensorType::TensorType_FLOAT32, rhs_buf});
59+
int out = cgen.addTensor({{1, 2, 4}, circle::TensorType::TensorType_FLOAT32});
60+
cgen.addOperatorBatchMatMul({{lhs, rhs}, {out}}, false, false);
61+
cgen.setInputsAndOutputs({lhs}, {out});
62+
_context = std::make_unique<GenModelTestContext>(cgen.finish());
63+
_context->addTestCase(TestCaseData{}
64+
.addInput<float>({1, 2, 3, 4, 5, 6})
65+
.addOutput<float>({74, 80, 86, 92, 173, 188, 203, 218}));
66+
_context->setBackends({"cpu"});
67+
68+
SUCCEED();
69+
}
70+
71+
TEST_F(GenModelTest, neg_OneOp_BatchMatMul_InvalidConst)
72+
{
73+
// LHS constant is not allowed
74+
CircleGen cgen;
75+
std::vector<float> const_data{1, 2, 3, 4, 5, 6};
76+
uint32_t lhs_buf = cgen.addBuffer(const_data);
77+
int lhs = cgen.addTensor({{1, 2, 3}, circle::TensorType::TensorType_FLOAT32, lhs_buf});
78+
int rhs = cgen.addTensor({{1, 3, 4}, circle::TensorType::TensorType_FLOAT32});
79+
int out = cgen.addTensor({{1, 2, 4}, circle::TensorType::TensorType_FLOAT32});
80+
cgen.addOperatorBatchMatMul({{lhs, rhs}, {out}}, false, false);
81+
cgen.setInputsAndOutputs({lhs}, {out});
82+
_context = std::make_unique<GenModelTestContext>(cgen.finish());
83+
_context->setBackends({"cpu"});
84+
_context->expectFailModelLoad();
85+
86+
SUCCEED();
87+
}

0 commit comments

Comments
 (0)