Skip to content

Commit 55c08a8

Browse files
authored
Merge pull request #555 from Xilinx/bump_to_12250739
[AutoBump] Merge with fixes of 1225073 (Jan 28, needs LLVM bump) (161)
2 parents 93a08c5 + 77b7da2 commit 55c08a8

32 files changed

Lines changed: 391 additions & 195 deletions

.github/workflows/ci.yml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,11 @@ jobs:
4545
restore-keys: |
4646
build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-
4747
48-
- name: "Setting up Python"
48+
- name: "Setting up Python" # AMD: python 3.10 and not 3.11
4949
run: |
5050
sudo apt update
51-
sudo apt install software-properties-common -y
52-
sudo add-apt-repository ppa:deadsnakes/ppa -y
53-
sudo apt install python3.11 python3-pip -y
54-
sudo apt-get install python3.11-dev python3.11-venv build-essential -y
51+
sudo apt install python3.10 python3-pip -y
52+
sudo apt-get install python3.10-dev python3.10-venv build-essential -y
5553
5654
- name: Install python deps (torch-${{ matrix.torch-version }})
5755
run: |
@@ -77,10 +75,18 @@ jobs:
7775
key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }}
7876

7977
- name: Integration tests (torch-${{ matrix.torch-version }})
78+
if: ${{ matrix.torch-version == 'nightly' }}
79+
continue-on-error: true
80+
run: |
81+
bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }}
82+
83+
- name: Integration tests (torch-${{ matrix.torch-version }})
84+
if: ${{ matrix.torch-version != 'nightly' }}
8085
run: |
8186
bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }}
8287
8388
- name: Check generated sources (torch-nightly only)
8489
if: ${{ matrix.torch-version == 'nightly' }}
90+
continue-on-error: true
8591
run: |
8692
bash build_tools/ci/check_generated_sources.sh

externals/llvm-project

externals/stablehlo

Submodule stablehlo updated 143 files

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ TypedValue<RankedTensorType> transposeBy(Location loc,
131131
// Get accumulator type for AvgPool2dOp.
132132
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
133133
TypeAttr &accType);
134-
135134
} // namespace tosa
136135
} // namespace mlir
137136

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
549549
}
550550
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
551551
MLIRContext *context = op->getContext();
552-
Type floatDtype = mlir::FloatType::getF64(context);
552+
Type floatDtype = mlir::Float64Type::get(context);
553553
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
554554
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype);
555555
Value zero =
@@ -569,7 +569,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
569569
}
570570
if (isa<AtenLogicalNotOp>(op)) {
571571
MLIRContext *context = op->getContext();
572-
Type floatDtype = mlir::FloatType::getF64(context);
572+
Type floatDtype = mlir::Float64Type::get(context);
573573
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
574574
Value zero =
575575
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
@@ -1028,7 +1028,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10281028
Type powType = dtype;
10291029
if (payloadArgs[0].getType().isInteger() ||
10301030
payloadArgs[1].getType().isInteger())
1031-
powType = mlir::FloatType::getF64(op->getContext());
1031+
powType = mlir::Float64Type::get(op->getContext());
10321032
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType);
10331033
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType);
10341034
auto powOp = b.create<math::PowFOp>(loc, lhs, rhs);

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1414
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1516
#include "mlir/IR/Matchers.h"
1617
#include "mlir/Transforms/DialectConversion.h"
1718
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
@@ -2280,9 +2281,9 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
22802281
Type &resultType,
22812282
const llvm::ArrayRef<int64_t> weightShape,
22822283
Value &input, Value &weights, Value &bias,
2283-
const int64_t groups, DenseI64ArrayAttr &pads,
2284-
DenseI64ArrayAttr &strides,
2285-
DenseI64ArrayAttr &dilations) {
2284+
const int64_t groups, DenseI64ArrayAttr pads,
2285+
DenseI64ArrayAttr strides, DenseI64ArrayAttr dilations,
2286+
TypeAttr accType) {
22862287
// Set up constants outside of loop
22872288
const int64_t sizeOfSliceInput = weightShape[1];
22882289
const int64_t sizeOfSliceKernel = weightShape[0] / groups;
@@ -2312,7 +2313,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
23122313
// Create conv
23132314
Value tempConv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(
23142315
rewriter, input.getLoc(), outputType, sliceInput, sliceWeight,
2315-
sliceBias, pads, strides, dilations);
2316+
sliceBias, pads, strides, dilations, accType);
23162317
// Add value to vector
23172318
sliceValues.push_back(tempConv2D);
23182319
}
@@ -2420,6 +2421,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24202421
return rewriter.notifyMatchFailure(op,
24212422
"non-const dilation list unsupported");
24222423

2424+
TypeAttr accType;
2425+
if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy,
2426+
accType)))
2427+
return rewriter.notifyMatchFailure(
2428+
op, "failed to get accumulator type for convolution ops");
2429+
24232430
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
24242431
// Perform the necessary transformations.
24252432
std::optional<Value> nchwToNhwcTransposeConst =
@@ -2523,22 +2530,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25232530
// quantized input is i32, which gets rescaled down to quantized output range.
25242531
SmallVector<int64_t> outputShape = {transposedInputShape[0], outputHDim,
25252532
outputWDim, outputCDim};
2526-
2527-
DenseI64ArrayAttr paddingAttr = rewriter.getDenseI64ArrayAttr(padding);
2528-
DenseI64ArrayAttr strideAttr = rewriter.getDenseI64ArrayAttr(stride);
2529-
DenseI64ArrayAttr dilationAttr = rewriter.getDenseI64ArrayAttr(dilation);
2530-
25312533
Value convOpResult;
25322534
if (groups == 1) {
25332535
// full convolution
25342536
auto convOpTy =
25352537
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy);
25362538
convOpResult =
25372539
rewriter
2538-
.create<tosa::Conv2DOp>(op->getLoc(),
2539-
getTypeConverter()->convertType(convOpTy),
2540-
transposedInput, transformedWeight, bias,
2541-
paddingAttr, strideAttr, dilationAttr)
2540+
.create<tosa::Conv2DOp>(
2541+
op->getLoc(), getTypeConverter()->convertType(convOpTy),
2542+
transposedInput, transformedWeight, bias,
2543+
rewriter.getDenseI64ArrayAttr(padding),
2544+
rewriter.getDenseI64ArrayAttr(stride),
2545+
rewriter.getDenseI64ArrayAttr(dilation), accType)
25422546
.getResult();
25432547
} else if (weightShape[1] == 1) {
25442548
// depthwise convolution
@@ -2548,14 +2552,18 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25482552
rewriter
25492553
.create<tosa::DepthwiseConv2DOp>(
25502554
op->getLoc(), getTypeConverter()->convertType(convOpTy),
2551-
transposedInput, transformedWeight, bias, paddingAttr,
2552-
strideAttr, dilationAttr)
2555+
transposedInput, transformedWeight, bias,
2556+
rewriter.getDenseI64ArrayAttr(padding),
2557+
rewriter.getDenseI64ArrayAttr(stride),
2558+
rewriter.getDenseI64ArrayAttr(dilation), accType)
25532559
.getResult();
25542560
} else {
25552561
// general group convolution
25562562
convOpResult = createConvInGroups(
25572563
rewriter, op, outputTy, weightShape, transposedInput, transformedWeight,
2558-
bias, groups, paddingAttr, strideAttr, dilationAttr);
2564+
bias, groups, rewriter.getDenseI64ArrayAttr(padding),
2565+
rewriter.getDenseI64ArrayAttr(stride),
2566+
rewriter.getDenseI64ArrayAttr(dilation), accType);
25592567
}
25602568

25612569
std::optional<Value> nhwcToNchwTransposeConst =
@@ -4103,9 +4111,11 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
41034111
}
41044112
}
41054113

4106-
auto result = rewriter.create<tosa::TileOp>(
4107-
op->getLoc(), resultType, reshapedInput,
4108-
rewriter.getDenseI64ArrayAttr(tileOpShape));
4114+
auto tileOpMultiples =
4115+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape);
4116+
4117+
auto result = rewriter.create<tosa::TileOp>(op->getLoc(), resultType,
4118+
reshapedInput, tileOpMultiples);
41094119

41104120
rewriter.replaceOp(op, {result.getResult()});
41114121
}
@@ -4298,9 +4308,11 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42984308
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
42994309
rewriter.getIntegerType(32));
43004310

4311+
auto tileOpMultiples =
4312+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape);
4313+
43014314
auto expandedIndices = rewriter.create<tosa::TileOp>(
4302-
op->getLoc(), tileType, reshapedIndices.getResult(),
4303-
rewriter.getDenseI64ArrayAttr(tileShape));
4315+
op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples);
43044316

43054317
// convert torch style index and dim into tf style indices
43064318
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
@@ -4639,17 +4651,23 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
46394651
if (needsTiling) {
46404652
auto idxType =
46414653
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
4654+
46424655
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
46434656
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
46444657
maxRankMaxDimShapeTf.push_back(1);
4658+
46454659
auto tileOpShapeTf(tileOpShape);
46464660
tileOpShapeTf.push_back(1);
4661+
46474662
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
46484663
idxType.getElementType());
46494664
auto reshapedIdxTensor = indicesTfConcatTensors[i];
4665+
4666+
auto tileOpMultiples =
4667+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf);
4668+
46504669
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
4651-
op->getLoc(), tileOutputTy, reshapedIdxTensor,
4652-
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
4670+
op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples);
46534671
}
46544672

46554673
// Every index tensor now has the same rank and shape
@@ -6220,12 +6238,14 @@ class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
62206238
op->getLoc(), fillValueMatchedInputRankType, fillValue,
62216239
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
62226240

6241+
auto tileOpMultiples =
6242+
tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape());
6243+
62236244
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
62246245
op->getLoc(),
62256246
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
62266247
fillValueElemTy),
6227-
fillValueMatchedInputRankTensor.getResult(),
6228-
makeShapeTorchCompatible(outType.getShape()));
6248+
fillValueMatchedInputRankTensor.getResult(), tileOpMultiples);
62296249
} else {
62306250
if (failed(torchScalarToTosaTensor(
62316251
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
@@ -6376,7 +6396,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
63766396
}
63776397

63786398
DenseElementsAttr paddingAttr = DenseIntElementsAttr::get(
6379-
RankedTensorType::get({rank, 2}, rewriter.getI64Type()),
6399+
RankedTensorType::get({2 * rank}, rewriter.getI64Type()),
63806400
translatePadsList);
63816401

63826402
Value padsList1 = rewriter.create<mlir::tosa::ConstOp>(
@@ -8033,9 +8053,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
80338053
resultType.getElementType()),
80348054
self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
80358055

8056+
auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(),
8057+
resultShapeIndex0Replaced);
8058+
80368059
auto selfTiled = rewriter.create<tosa::TileOp>(
8037-
op->getLoc(), resultType, selfReshaped.getResult(),
8038-
rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
8060+
op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples);
80398061

80408062
// Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]}
80418063
auto vec2Reshaped = rewriter.create<tosa::ReshapeOp>(
@@ -8044,9 +8066,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
80448066
resultType.getElementType()),
80458067
vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
80468068

8069+
auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(),
8070+
resultShapeIndex1Replaced);
8071+
80478072
auto vec2Tiled = rewriter.create<tosa::TileOp>(
8048-
op->getLoc(), resultType, vec2Reshaped.getResult(),
8049-
rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
8073+
op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples);
80508074

80518075
auto result =
80528076
tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(),

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
11+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1112
#include "torch-mlir/Conversion/Utils/Utils.h"
1213
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1314

@@ -566,11 +567,12 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
566567

567568
// [0] -> [0,0,0]
568569
SmallVector<int64_t, 1> tileShape({W}); // {3}
570+
auto tileOpMultiples =
571+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape);
569572
auto tosaFillValuesTileOp = tosa::CreateOpAndInfer<tosa::TileOp>(
570573
rewriter, op->getLoc(),
571574
GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()),
572-
tosaFillValuesOneReshapeOp.getResult(),
573-
rewriter.getDenseI64ArrayAttr(tileShape));
575+
tosaFillValuesOneReshapeOp.getResult(), tileOpMultiples);
574576

575577
// [0,0,0] -> [[0,0,0]]
576578
SmallVector<int64_t, 2> newTosaFillValuesShape({N, W}); // {1,3}

lib/Dialect/TMTensor/Transforms/Bufferize.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern<TMTensorOp> {
121121
};
122122

123123
namespace {
124+
125+
static Value materializeToTensor(OpBuilder &builder, TensorType type,
126+
ValueRange inputs, Location loc) {
127+
assert(inputs.size() == 1);
128+
assert(isa<BaseMemRefType>(inputs[0].getType()));
129+
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
130+
}
131+
124132
/// Converts TMTensor operations that work on tensor-type operands or results to
125133
/// work on buffers.
126134
struct TMTensorBufferizePass
@@ -133,7 +141,47 @@ struct TMTensorBufferizePass
133141
void runOnOperation() override {
134142
MLIRContext &context = getContext();
135143
ConversionTarget target(context);
136-
bufferization::BufferizeTypeConverter typeConverter;
144+
// Since the `BufferizeTypeConverter` has been removed here
145+
// https://github.com/llvm/llvm-project/commit/2ff2e871f5e632ea493efaf4f2192f8b18a54ab1,
146+
// hence we have inlined the converter here.
147+
TypeConverter typeConverter;
148+
typeConverter.addConversion([](Type type) { return type; });
149+
// Convert RankedTensorType to MemRefType.
150+
typeConverter.addConversion([](RankedTensorType type) -> Type {
151+
return MemRefType::get(type.getShape(), type.getElementType());
152+
});
153+
// Convert UnrankedTensorType to UnrankedMemRefType.
154+
typeConverter.addConversion([](UnrankedTensorType type) -> Type {
155+
return UnrankedMemRefType::get(type.getElementType(), 0);
156+
});
157+
typeConverter.addArgumentMaterialization(materializeToTensor);
158+
typeConverter.addSourceMaterialization(materializeToTensor);
159+
typeConverter.addTargetMaterialization([](OpBuilder &builder,
160+
BaseMemRefType type,
161+
ValueRange inputs,
162+
Location loc) -> Value {
163+
assert(inputs.size() == 1 && "expected exactly one input");
164+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
165+
// MemRef to MemRef cast.
166+
assert(inputType != type && "expected different types");
167+
// Ranked to unranked casts must be explicit.
168+
auto rankedDestType = dyn_cast<MemRefType>(type);
169+
if (!rankedDestType)
170+
return nullptr;
171+
bufferization::BufferizationOptions options;
172+
options.bufferAlignment = 0;
173+
FailureOr<Value> replacement = castOrReallocMemRefValue(
174+
builder, inputs[0], rankedDestType, options);
175+
if (failed(replacement))
176+
return nullptr;
177+
return *replacement;
178+
}
179+
if (isa<TensorType>(inputs[0].getType())) {
180+
// Tensor to MemRef cast.
181+
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
182+
}
183+
llvm_unreachable("only tensor/memref input types supported");
184+
});
137185

138186
// Mark all Standard operations legal.
139187
target.addLegalDialect<arith::ArithDialect, func::FuncDialect,

lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
110110

111111
RewritePatternSet patterns(context);
112112
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
113-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
114-
std::move(patterns)))) {
113+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
115114
return signalPassFailure();
116115
}
117116
}

0 commit comments

Comments
 (0)