Skip to content

Commit e6acf2a

Browse files
committed
[circle-mlir/dialect] Enable TransposeOp IR
This will enable TransposeOp IR. ONE-DCO-1.0-Signed-off-by: SaeHie Park <saehie.park@gmail.com>
1 parent 2f0f776 commit e6acf2a

4 files changed

Lines changed: 294 additions & 0 deletions

File tree

circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,4 +418,41 @@ def CIR_NoValueOp : Op<CIR_Dialect, "no_value", [ConstantLike, Pure]> {
418418
}];
419419
}
420420

421+
def CIR_TransposeOp : CIR_Op<"transpose", [
422+
Pure,
423+
DeclareOpInterfaceMethods<CIR_ShapeInferenceOpInterface>,
424+
CIR_OperandHasRankAtMost<0, 5>,
425+
CIR_OperandHasRank<1, 1>,
426+
PredOpTrait<"input and output must have same element type", CIR_TCresVTEtIsSameAsOp<0, 0>>/*,
427+
SameOperandsAndResultsScale*/]> {
428+
let summary = "Transpose operator";
429+
430+
let description = [{
431+
Returns the Transpose of x
432+
}];
433+
434+
let arguments = (ins
435+
CIR_TensorOf<[I32, F32, I8, UI8, /*QI8, QUI8, CIR_Quint8,*/ I1, I64/*, QI16*/]>:$input,
436+
CIR_TensorOf<[I32]>:$perm
437+
);
438+
439+
let results = (outs
440+
CIR_TensorOf<[I32, F32, I8, UI8, /*QI8, QUI8, CIR_Quint8,*/ I1, I64/*, QI16*/]>:$output
441+
);
442+
443+
let hasVerifier = 1;
444+
445+
let hasFolder = 1;
446+
447+
let builders = [
448+
OpBuilder<(ins "Value":$input, "Value":$perm),
449+
[{ BuildTransposeOp(&$_builder, $_state, input, perm); }]>
450+
];
451+
452+
let extraClassDeclaration = [{
453+
// Quantized axes are verified in the Verify function.
454+
bool RequiredSameQuantizedAxes() { return false; }
455+
}];
456+
}
457+
421458
#endif // CIRCLE_OPS

circle-mlir/circle-mlir/lib/dialect/src/CircleDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ void ConstBytesAttr::print(mlir::AsmPrinter &printer) const
446446
#include "ops/ConstOp.h"
447447
#include "ops/CustomOp.h"
448448
#include "ops/NoValueOp.h"
449+
#include "ops/TransposeOp.h"
449450

450451
#include "mlir/CircleOpsDialect.cc.inc"
451452
#include "mlir/CircleOpsEnums.cc.inc"

circle-mlir/circle-mlir/lib/dialect/src/ShapeInference.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,57 @@ void CustomOp::inferShapes()
149149
}
150150
}
151151

152+
//===----------------------------------------------------------------------===//
153+
// TransposeOp
154+
//===----------------------------------------------------------------------===//
155+
156+
void TransposeOp::inferShapes()
157+
{
158+
TransposeOp op = *this;
159+
auto output_type = op.getOutput().getType().cast<ShapedType>();
160+
if (output_type.hasStaticShape())
161+
return;
162+
163+
auto input_type = op.getInput().getType().cast<ShapedType>();
164+
auto perm_type = op.getPerm().getType().cast<ShapedType>();
165+
166+
if (input_type.hasStaticShape() && perm_type.hasStaticShape())
167+
{
168+
if (perm_type.getNumElements() != input_type.getRank())
169+
{
170+
return;
171+
}
172+
}
173+
174+
mlir::DenseIntElementsAttr perm;
175+
if (!matchPattern(op.getPerm(), m_Constant(&perm)))
176+
{
177+
return;
178+
}
179+
180+
llvm::SmallVector<int64_t, 4> perm_list;
181+
for (const auto &perm_element : perm.getValues<APInt>())
182+
{
183+
const int64_t val = perm_element.getSExtValue();
184+
perm_list.push_back(val);
185+
}
186+
187+
// Get transposed shape and set it to the output type
188+
if (input_type.hasStaticShape() && !output_type.hasStaticShape())
189+
{
190+
llvm::SmallVector<int64_t, 4> transposed_shape;
191+
for (int64_t axis : perm_list)
192+
{
193+
transposed_shape.push_back(input_type.getDimSize(axis));
194+
}
195+
196+
dumpShape<TransposeOp>(op, transposed_shape);
197+
198+
auto inferred_type =
199+
mlir::Circle::GetTypeFromTensorShape(transposed_shape, input_type.getElementType());
200+
getResult().setType(inferred_type);
201+
}
202+
}
203+
152204
} // namespace Circle
153205
} // namespace mlir
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
// from tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
19+
20+
#ifndef __CIRCLE_MLIR_DIALECT_OPS_TRANSPOSE_OP_H__
21+
#define __CIRCLE_MLIR_DIALECT_OPS_TRANSPOSE_OP_H__
22+
23+
#include "circle-mlir/dialect/CircleDialect.h"
24+
25+
namespace mlir
26+
{
27+
namespace Circle
28+
{
29+
30+
//===----------------------------------------------------------------------===//
31+
// TransposeOp
32+
//===----------------------------------------------------------------------===//
33+
34+
namespace
35+
{
36+
37+
// Computes the permutation of a constant `input_tensor` according to `perm`.
38+
// The function recursively traverses the dimensions of the output tensor in
39+
// a row-major order and writes the value in the output tensor into
40+
// `new_values`.
41+
void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
42+
ArrayRef<int64_t> output_shape, int num_dimensions, int output_axis,
43+
std::vector<uint64_t> *input_indices, std::vector<Attribute> *new_values)
44+
{
45+
// Refer to the implementation of `Transpose` function in
46+
// tensorflow/lite/kernels/internal/reference/reference_ops.h
47+
assert(output_axis < num_dimensions);
48+
const int input_axis = perm[output_axis];
49+
for (int i = 0; i < output_shape[output_axis]; ++i)
50+
{
51+
// Update the input indices on `input_axis`.
52+
input_indices->at(input_axis) = i;
53+
// Write the value from `input_tensor` if it is the last axis or
54+
// recurse into the next axis.
55+
const bool is_last_axis = output_axis == num_dimensions - 1;
56+
if (is_last_axis)
57+
{
58+
new_values->push_back(input_tensor.getValues<Attribute>()[*input_indices]);
59+
}
60+
else
61+
{
62+
ComputePermutation(input_tensor, perm, output_shape, num_dimensions, output_axis + 1,
63+
input_indices, new_values);
64+
}
65+
}
66+
}
67+
68+
} // namespace
69+
70+
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor)
71+
{
72+
auto operands = adaptor.getOperands();
73+
assert(operands.size() == 2);
74+
auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
75+
auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
76+
if (!input_tensor || !perm_tensor)
77+
return nullptr;
78+
79+
// Do not try to fold elements attr of a quant type because
80+
// DenseElementsAttr does not support it.
81+
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
82+
return nullptr;
83+
84+
assert(perm_tensor.getShapedType().getRank() == 1);
85+
const int num_dimensions = input_tensor.getShapedType().getRank();
86+
assert(perm_tensor.getShapedType().getNumElements() == num_dimensions);
87+
88+
ArrayRef<int64_t> input_shape = input_tensor.getShapedType().getShape();
89+
auto output_type = getType().cast<ShapedType>();
90+
91+
SmallVector<int32_t, 4> perm;
92+
SmallVector<int64_t, 4> output_shape;
93+
for (int i = 0; i < num_dimensions; ++i)
94+
{
95+
perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
96+
output_shape.push_back(input_shape[perm[i]]);
97+
98+
// Check that the derived output shape matches the static shape.
99+
assert(!output_type.hasStaticShape() || output_type.getShape()[i] == output_shape[i]);
100+
}
101+
102+
std::vector<Attribute> new_values;
103+
new_values.reserve(input_tensor.getShapedType().getNumElements());
104+
std::vector<uint64_t> input_indices(num_dimensions);
105+
ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
106+
/*output_axis=*/0, &input_indices, &new_values);
107+
auto result_type =
108+
mlir::Circle::GetTypeFromTensorShape(output_shape, output_type.getElementType());
109+
return DenseElementsAttr::get(result_type, new_values);
110+
}
111+
112+
mlir::LogicalResult TransposeOp::verify()
113+
{
114+
TransposeOp op = *this;
115+
auto input_type = op.getInput().getType().cast<ShapedType>();
116+
auto perm_type = op.getPerm().getType().cast<ShapedType>();
117+
auto output_type = op.getOutput().getType().cast<ShapedType>();
118+
if (input_type.hasStaticShape() && perm_type.hasStaticShape())
119+
{
120+
if (perm_type.getNumElements() != input_type.getRank())
121+
{
122+
return op.emitOpError("perm tensor elements size is not equal to input tensor rank");
123+
}
124+
}
125+
126+
mlir::DenseIntElementsAttr perm;
127+
if (!matchPattern(op.getPerm(), m_Constant(&perm)))
128+
{
129+
return success();
130+
}
131+
132+
int index = 0;
133+
llvm::SmallVector<int64_t, 4> axes;
134+
for (const auto &axis_int : perm.getValues<APInt>())
135+
{
136+
const int64_t axis = axis_int.getSExtValue();
137+
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank()))
138+
{
139+
return op.emitOpError(llvm::formatv("perm[{0}] must be in [0, rank)", index));
140+
}
141+
if (std::count(axes.begin(), axes.end(), axis) > 0)
142+
{
143+
return op.emitOpError(llvm::formatv("perm[{0}] cannot have duplicated axis", index));
144+
}
145+
axes.push_back(axis);
146+
index++;
147+
}
148+
149+
if (input_type.hasStaticShape() && output_type.hasStaticShape())
150+
{
151+
llvm::SmallVector<int64_t, 4> transposed_shape;
152+
for (int64_t axis : axes)
153+
{
154+
transposed_shape.push_back(input_type.getDimSize(axis));
155+
}
156+
auto expected_output_type =
157+
mlir::Circle::GetTypeFromTensorShape(transposed_shape, input_type.getElementType());
158+
if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type)))
159+
{
160+
return op.emitOpError(
161+
llvm::formatv("expect output type {0}, got {1}", expected_output_type, output_type));
162+
}
163+
}
164+
165+
// TODO enable quantization
166+
167+
return success();
168+
}
169+
170+
static void BuildTransposeOp(OpBuilder *builder, OperationState &result, Value input, Value perm)
171+
{
172+
// Output size is only known if input is ranked and perm is a constant.
173+
auto input_type = input.getType().cast<TensorType>();
174+
mlir::DenseIntElementsAttr perm_const;
175+
if (!input_type.hasRank() || !matchPattern(perm, m_Constant(&perm_const)) || perm_const.empty())
176+
{
177+
TransposeOp::build(*builder, result, UnrankedTensorType::get(input_type.getElementType()),
178+
input, perm);
179+
return;
180+
}
181+
182+
const auto perm_value_it = perm_const.value_begin<APInt>();
183+
184+
const ArrayRef<int64_t> input_shape = input_type.getShape();
185+
SmallVector<int64_t, 4> output_shape(input_shape.size());
186+
187+
for (int i = 0; i < output_shape.size(); ++i)
188+
{
189+
const APInt perm_val = perm_value_it[i];
190+
output_shape[i] = input_shape[perm_val.getSExtValue()];
191+
}
192+
193+
auto element_type = input_type.getElementType();
194+
195+
// TODO enable quantization
196+
197+
TransposeOp::build(*builder, result,
198+
mlir::Circle::GetTypeFromTensorShape(output_shape, element_type), input, perm);
199+
}
200+
201+
} // namespace Circle
202+
} // namespace mlir
203+
204+
#endif // __CIRCLE_MLIR_DIALECT_OPS_TRANSPOSE_OP_H__

0 commit comments

Comments
 (0)