Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18969,6 +18969,31 @@ def Torch_PrimsSumOp : Torch_Op<"prims.sum", [
}];
}

def Torch_PrimsXorSumOp : Torch_Op<"prims.xor_sum", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::xor_sum : (Tensor, int[]?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$inp,
AnyTorchOptionalListOfTorchIntType:$dims,
AnyTorchOptionalIntType:$output_dtype
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult PrimsXorSumOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void PrimsXorSumOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
AllowsTypeRefinement,
ReadOnly
Expand Down
64 changes: 63 additions & 1 deletion lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
auto constType = RankedTensorType::get({}, elementTy);
DenseElementsAttr constAttr = nullptr;
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
AtenLinalgVectorNormOp>(op)) {
AtenLinalgVectorNormOp, PrimsXorSumOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
Expand Down Expand Up @@ -157,6 +157,9 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
} else if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
result = stablehlo::MulOp::create(rewriter, op->getLoc(), blockArgumentTy,
*firstArgument, *secondArgument);
} else if (isa<PrimsXorSumOp>(op)) {
result = stablehlo::XorOp::create(rewriter, op->getLoc(), blockArgumentTy,
*firstArgument, *secondArgument);
} else {
op->emitError("unimplemented lowering in "
"createReduceOpWithSingleRegionOp");
Expand Down Expand Up @@ -756,6 +759,64 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
} // namespace

// PrimsXorSumOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<PrimsXorSumOp>::matchAndRewrite(
PrimsXorSumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getInp();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!inputTy || !outTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}

auto inputElemTy = inputTy.getElementType();
if (!isa<mlir::IntegerType>(inputElemTy)) {
return rewriter.notifyMatchFailure(
op, "XOR reduction requires integer element type");
}

SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (failed(checkNotNone(rewriter, op, op.getDims()))) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
} else {
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dims` is not supported");
}
if (inputDims.empty()) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
}
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());

SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputTy.getShape(), dims);

Value reduceResult = createReduceOpWithSingleRegionOp(
op, input,
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
rewriter, /*allowNonFinites=*/false);
if (!reduceResult) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}

rewriter.replaceOp(op, reduceResult);
return success();
}
} // namespace

// AtenProdDimIntOp
namespace {
template <>
Expand Down Expand Up @@ -1003,6 +1064,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdDimIntOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(PrimsXorSumOp);
#undef INSERT_ATEN_REDUCTION_OP_PATTERN

#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)")
emit("prims::xor_sum : (Tensor, int[]?, int?) -> (Tensor)")
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")

Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/TorchToStablehlo/reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,20 @@ func.func @torch.aten.prod.intdim_negative_dim(%arg0: !torch.vtensor<[?,?,?,?],f
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}

// -----

// CHECK-LABEL: @torch.prims.xor_sum(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4],si32>) -> !torch.vtensor<[],si32> {
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4],si32> -> tensor<4xi32>
// CHECK: %[[INIT:.*]] = stablehlo.constant dense<0> : tensor<i32>
// CHECK: %[[REDUCE:.*]] = stablehlo.reduce(%[[INPUT]] init: %[[INIT]]) applies stablehlo.xor across dimensions = [0] : (tensor<4xi32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[REDUCE]] : tensor<i32> -> !torch.vtensor<[],si32>
// CHECK: return %[[OUT]] : !torch.vtensor<[],si32>
func.func @torch.prims.xor_sum(%arg0: !torch.vtensor<[4],si32>) -> !torch.vtensor<[],si32> {
%int0 = torch.constant.int 0
%dims = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%none = torch.constant.none
%0 = torch.prims.xor_sum %arg0, %dims, %none : !torch.vtensor<[4],si32>, !torch.list<int>, !torch.none -> !torch.vtensor<[],si32>
return %0 : !torch.vtensor<[],si32>
}