Skip to content

Commit 39d8e44

Browse files
committed
[TOSA] Add Tosa_Shape type and ConstShapeOp
Adds: 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values) 2. const_shape operator 3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes 4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator) 5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks 5. changed TileOp's multiples from attribute to input, of !tosa.shape type. 6. add folder for tosa ConstShape operator Signed-off-by: Jerry Ge <[email protected]> Signed-off-by: Tai Ly <[email protected]> Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
1 parent 749bdc8 commit 39d8e44

File tree

18 files changed

+441
-32
lines changed

18 files changed

+441
-32
lines changed

mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
33
add_mlir_interface(TosaInterfaces)
44

55
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
6+
mlir_tablegen(TosaOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
7+
mlir_tablegen(TosaOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
68
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
79
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
810
add_public_tablegen_target(MLIRTosaAttributesIncGen)
911

1012
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
1113
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
1214
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
13-

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
4545
let cppNamespace = "mlir::tosa";
4646
let hasConstantMaterializer = 1;
4747
let useDefaultAttributePrinterParser = 1;
48+
let useDefaultTypePrinterParser = 1;
4849
}
4950

5051
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

+41
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,55 @@ template <typename ConcreteType>
9090
class TosaElementwiseOperator
9191
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
9292

93+
LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
94+
/// This class verifies that tosa shape operands are compile time resolvable
95+
template <typename ConcreteType>
96+
class TosaResolvableShapeOperands
97+
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
98+
public:
99+
static LogicalResult verifyTrait(Operation *op) {
100+
return verifyTosaResolvableShapeOperands(op);
101+
}
102+
};
103+
104+
LogicalResult verifyTosaShapeOperator(Operation *op);
105+
/// This class indicates that op operates on tosa shape types
106+
template <typename ConcreteType>
107+
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
108+
public:
109+
static LogicalResult verifyTrait(Operation *op) {
110+
return verifyTosaShapeOperator(op);
111+
}
112+
};
113+
114+
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
115+
/// This class indicates that op operates on tosa shape types
116+
template <typename ConcreteType>
117+
class TosaShapeOperatorWithSameRanks
118+
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
119+
public:
120+
static LogicalResult verifyTrait(Operation *op) {
121+
return verifyTosaShapeOperatorWithSameRanks(op);
122+
}
123+
};
124+
93125
} // namespace tosa
94126
} // namespace OpTrait
95127

128+
namespace tosa {
129+
130+
bool isa_tosa_shape_type(mlir::Type t);
131+
132+
} // namespace tosa
133+
96134
} // namespace mlir
97135

98136
#define GET_ATTRDEF_CLASSES
99137
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
100138

139+
#define GET_TYPEDEF_CLASSES
140+
#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc"
141+
101142
#define GET_OP_CLASSES
102143
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
103144

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
2323

2424
include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
2525
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
26+
include "mlir/Dialect/Tosa/IR/TosaTypes.td"
2627

2728
//===----------------------------------------------------------------------===//
2829
// TOSA Spec Section 2.2
@@ -1689,12 +1690,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
16891690

16901691
let arguments = (ins
16911692
Tosa_Tensor:$input1,
1692-
DenseI64ArrayAttr:$multiples);
1693+
Tosa_Shape:$multiples);
16931694

16941695
let results = (outs
16951696
Tosa_Tensor:$output
16961697
);
16971698

1699+
let extraClassDeclaration = [{
1700+
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
1701+
}];
1702+
16981703
let hasFolder = 1;
16991704
let hasVerifier = 1;
17001705
}
@@ -2106,4 +2111,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
21062111

21072112
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
21082113

2114+
include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
2115+
21092116
#endif // TOSA_OPS
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines shape operators for the TOSA dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef TOSA_SHAPE_OPS
14+
#define TOSA_SHAPE_OPS
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
include "mlir/Interfaces/SideEffectInterfaces.td"
19+
include "mlir/Interfaces/InferTypeOpInterface.td"
20+
include "mlir/Interfaces/LoopLikeInterface.td"
21+
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
22+
23+
include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
24+
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
25+
include "mlir/Dialect/Tosa/IR/TosaTypes.td"
26+
27+
// Op trait: operator has operands and results with TOSA shape type
28+
def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
29+
let cppNamespace = "mlir::OpTrait::tosa";
30+
}
31+
32+
class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
33+
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
34+
35+
let assemblyFormat =
36+
"operands attr-dict `:` functional-type(operands, results)";
37+
38+
let hasFolder = 1;
39+
}
40+
41+
// op trait: shape operator has same ranks for operands and results
42+
def TosaShapeOperatorWithSameRanks : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
43+
let cppNamespace = "mlir::OpTrait::tosa";
44+
}
45+
46+
class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
47+
: Tosa_ShapeOp<mnemonic, !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
48+
}
49+
50+
//===----------------------------------------------------------------------===//
51+
// Operator: ConstShape
52+
//===----------------------------------------------------------------------===//
53+
def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
54+
let summary = "Constant Shape op.";
55+
56+
let description = [{
57+
A node containing constant data for use as the input to an shape operation. May
58+
hold data only in index data type.
59+
60+
Example:
61+
62+
```mlir
63+
// Generic form
64+
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
65+
```
66+
}];
67+
68+
let arguments = (ins
69+
IndexElementsAttr:$value
70+
);
71+
72+
let results = (outs
73+
Tosa_Shape:$output
74+
);
75+
76+
let hasVerifier = 1;
77+
}
78+
79+
#endif // TOSA_SHAPE_OPS
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===-- TosaTypes.td - TOSA type definitions ---------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines the type definitions for the TOSA dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef TOSA_TYPES
14+
#define TOSA_TYPES
15+
16+
include "mlir/IR/AttrTypeBase.td"
17+
include "mlir/IR/OpBase.td"
18+
19+
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
20+
21+
//===----------------------------------------------------------------------===//
22+
// Tosa Type Definitions.
23+
//===----------------------------------------------------------------------===//
24+
25+
// The base class for Tosa dialect types.
26+
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
27+
: TypeDef<Tosa_Dialect, name, traits> {
28+
let mnemonic = typeMnemonic;
29+
}
30+
31+
//===----------------------------------------------------------------------===//
32+
// ShapeType
33+
//===----------------------------------------------------------------------===//
34+
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
35+
let summary = "Shape with static rank and Index element type";
36+
let description = [{
37+
Syntax:
38+
39+
```
40+
shape-type ::= `shape` `<` rank `>`
41+
```
42+
Values with shape type represents a shape with a fixed rank and a list of dimensions.
43+
Rank must be zero or a positive integer.
44+
Each dimension is represented by the builtin Index type.
45+
46+
Examples:
47+
48+
```mlir
49+
// Shape with rank of four, for example, [1, 1, 8, 16]:
50+
!tosa.shape<4>
51+
52+
// Shape with rank of one, for example, [16]:
53+
!tosa.shape<1>
54+
55+
// Shape with rank zero, for example, [] (i.e., shape of scalar values):
56+
!tosa.shape<0>
57+
```
58+
}];
59+
let parameters = (ins
60+
"int":$rank
61+
);
62+
let builders = [
63+
TypeBuilder<(ins "int":$rank)>
64+
];
65+
let assemblyFormat = "`<` $rank `>`";
66+
67+
let genVerifyDecl = 1;
68+
}
69+
70+
def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
71+
72+
// Whether a Tosa Shape type has a rank equal to the specified rank.
73+
class IsTosaShapeOfRankPred<int rank> : And<[
74+
IsTosaShapeType,
75+
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
76+
]>;
77+
78+
class TosaShapeOfRank<int rank> :
79+
Type<IsTosaShapeOfRankPred<rank>,
80+
"Tosa shape type of rank " # rank
81+
>;
82+
83+
def Rank1TosaShape : TosaShapeOfRank<1>;
84+
def Rank2TosaShape : TosaShapeOfRank<2>;
85+
def Rank4TosaShape : TosaShapeOfRank<4>;
86+
87+
#endif // TOSA_TYPES

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
18861886
auto elementTy = inputTy.getElementType();
18871887
int64_t rank = inputTy.getRank();
18881888

1889-
ArrayRef<int64_t> multiples = op.getMultiples();
1889+
SmallVector<int64_t> multiples;
1890+
if (failed(op.getConstantMultiples(multiples)))
1891+
return failure();
18901892

18911893
// Broadcast the newly added dimensions to their appropriate multiple.
18921894
SmallVector<int64_t, 2> genericShape;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
5555
target.addLegalOp<tosa::ApplyScaleOp>();
5656
target.addLegalOp<tosa::IfOp>();
5757
target.addLegalOp<tosa::ConstOp>();
58+
target.addLegalOp<tosa::ConstShapeOp>();
5859
target.addLegalOp<tosa::WhileOp>();
5960
target.addLegalOp<tosa::ConcatOp>();
6061
target.addLegalOp<tosa::SliceOp>();

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

+16-3
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
808808

809809
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
810810

811+
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
812+
811813
#define REDUCE_FOLDER(OP) \
812814
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
813815
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
985987
}
986988

987989
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
988-
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
989-
if (allOnes && getInput1().getType() == getType())
990-
return getInput1();
990+
if (getInput1().getType() == getType()) {
991+
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
992+
adaptor.getMultiples())) {
993+
if (multiples.isSplat() &&
994+
multiples.getSplatValue<APInt>().getSExtValue() == 1)
995+
return getInput1();
996+
if (auto int_array_attr =
997+
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
998+
if (llvm::all_of(int_array_attr.getValues<APInt>(),
999+
[](APInt v) { return v.getSExtValue() == 1; }))
1000+
return getInput1();
1001+
}
1002+
}
1003+
}
9911004
return {};
9921005
}
9931006

0 commit comments

Comments
 (0)