diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 82d11ec0737a..fbe7b31e2925 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -18969,6 +18969,31 @@ def Torch_PrimsSumOp : Torch_Op<"prims.sum", [ }]; } +def Torch_PrimsXorSumOp : Torch_Op<"prims.xor_sum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::xor_sum : (Tensor, int[]?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$inp, + AnyTorchOptionalListOfTorchIntType:$dims, + AnyTorchOptionalIntType:$output_dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsXorSumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void PrimsXorSumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 82c9d4659461..b9e1929c78da 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -47,7 +47,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, auto constType = RankedTensorType::get({}, elementTy); DenseElementsAttr constAttr = nullptr; if (isa(op)) { + AtenLinalgVectorNormOp, PrimsXorSumOp>(op)) { if (isa(elementTy)) { constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -157,6 +157,9 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, } else if (isa(op)) { result = stablehlo::MulOp::create(rewriter, op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = stablehlo::XorOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else { op->emitError("unimplemented lowering in " "createReduceOpWithSingleRegionOp"); @@ -756,6 +759,64 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// PrimsXorSumOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + PrimsXorSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getInp(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!isa(inputElemTy)) { + return rewriter.notifyMatchFailure( + op, "XOR reduction requires integer element type"); + } + + SmallVector inputDims; + SmallVector dims; + if (failed(checkNotNone(rewriter, op, op.getDims()))) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } else { + if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dims` is not supported"); + } + if (inputDims.empty()) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } + } + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } + } + llvm::sort(dims.begin(), dims.end()); + + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); + + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, + rewriter, /*allowNonFinites=*/false); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + rewriter.replaceOp(op, reduceResult); + return success(); +} +} // namespace + // AtenProdDimIntOp namespace { template <> @@ -1003,6 +1064,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdDimIntOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(PrimsXorSumOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN #define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \ diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7d5e65c21cef..2e00a01e9755 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1319,6 +1319,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)") + emit("prims::xor_sum : (Tensor, int[]?, int?) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)") diff --git a/test/Conversion/TorchToStablehlo/reduction.mlir b/test/Conversion/TorchToStablehlo/reduction.mlir index 924f17f4c9b1..730ce9b37bd9 100644 --- a/test/Conversion/TorchToStablehlo/reduction.mlir +++ b/test/Conversion/TorchToStablehlo/reduction.mlir @@ -31,3 +31,20 @@ func.func @torch.aten.prod.intdim_negative_dim(%arg0: !torch.vtensor<[?,?,?,?],f // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32> } + +// ----- + +// CHECK-LABEL: @torch.prims.xor_sum( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4],si32>) -> !torch.vtensor<[],si32> { +// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4],si32> -> tensor<4xi32> +// CHECK: %[[INIT:.*]] = stablehlo.constant dense<0> : tensor +// CHECK: %[[REDUCE:.*]] = stablehlo.reduce(%[[INPUT]] init: %[[INIT]]) applies stablehlo.xor across dimensions = [0] : (tensor<4xi32>, tensor) -> tensor +// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[REDUCE]] : tensor -> !torch.vtensor<[],si32> +// CHECK: return %[[OUT]] : !torch.vtensor<[],si32> +func.func @torch.prims.xor_sum(%arg0: !torch.vtensor<[4],si32>) -> !torch.vtensor<[],si32> { + %int0 = torch.constant.int 0 + %dims = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %none = torch.constant.none + %0 = torch.prims.xor_sum %arg0, %dims, %none : !torch.vtensor<[4],si32>, !torch.list, !torch.none -> !torch.vtensor<[],si32> + return %0 : !torch.vtensor<[],si32> +}