@@ -1939,7 +1939,11 @@ DenseElementsAttr tile(DenseElementsAttr inputValues, ShapedType outputType) {
19391939}
19401940
19411941struct TosaFoldConstantTile : public TosaFoldConstantBase <tosa::TileOp> {
1942- using TosaFoldConstantBase::TosaFoldConstantBase;
1942+
1943+ TosaFoldConstantTile (MLIRContext *ctxt, bool foldSplatOrSingleUseOnly,
1944+ int maxSizeToFold)
1945+ : TosaFoldConstantBase<tosa::TileOp>(ctxt, foldSplatOrSingleUseOnly),
1946+ maxSizeToFold (maxSizeToFold) {}
19431947
19441948 LogicalResult matchAndRewrite (tosa::TileOp op,
19451949 PatternRewriter &rewriter) const override {
@@ -1958,10 +1962,22 @@ struct TosaFoldConstantTile : public TosaFoldConstantBase<tosa::TileOp> {
19581962 foldSplatOrSingleUseOnly)
19591963 return failure ();
19601964
1965+ assert (maxSizeToFold >= 0 && " maxSizeToFold should be non-negative" );
1966+ if (maxSizeToFold > 0 ) {
1967+ if (!outputType.hasStaticShape ())
1968+ return failure ();
1969+ const int64_t numOfElements = outputType.getNumElements ();
1970+ if (numOfElements * (outputType.getElementTypeBitWidth () / 8 ) >
1971+ static_cast <int64_t >(maxSizeToFold))
1972+ return failure ();
1973+ }
1974+
19611975 rewriter.replaceOpWithNewOp <tosa::ConstOp>(op, outputType,
19621976 tile (inputValues, outputType));
19631977 return success ();
19641978 }
1979+
1980+ const int maxSizeToFold;
19651981};
19661982
19671983// / Getting the axes position of the element which is located
@@ -2275,7 +2291,8 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
22752291 patterns.add <TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly );
22762292 patterns.add <TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly );
22772293 if (options.enableTileFolding )
2278- patterns.add <TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly );
2294+ patterns.add <TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly ,
2295+ options.maxTileFoldSize );
22792296}
22802297
22812298void mlir::tosa::populateTosaConstantReduction (MLIRContext *ctx,
0 commit comments