Skip to content

[TorchToStablehlo] Add lowering for prims.xor_sum to stablehlo.reduce…#4501

Draft
BruceXinXin wants to merge 1 commit intollvm:mainfrom
BruceXinXin:bruce_add_xor_sum_lowering
Draft

[TorchToStablehlo] Add lowering for prims.xor_sum to stablehlo.reduce…#4501
BruceXinXin wants to merge 1 commit intollvm:mainfrom
BruceXinXin:bruce_add_xor_sum_lowering

Conversation

@BruceXinXin
Copy link

[TorchToStablehlo] Add prims.xor_sumstablehlo.reduce lowering

Summary

Adds end-to-end support for lowering prims::xor_sum to StableHLO.

Changes

  • torch_ods_gen.py: Register prims::xor_sum : (Tensor, int[]?, int?) -> (Tensor).
  • GeneratedTorchOps.td: Add Torch_PrimsXorSumOp definition (auto-generated).
  • Reduction.cpp:
    • Add PrimsXorSumOp to the zero-init branch in createInitialValueForReduceOp.
    • Add stablehlo::XorOp case in createReduceOpWithSingleRegionOp.
    • Implement ConvertAtenReductionOp<PrimsXorSumOp>::matchAndRewrite with
      dimension normalization and integer type validation.
    • Register the pattern via INSERT_ATEN_REDUCTION_OP_PATTERN.
  • reduction.mlir: Add FileCheck lit test for torch.prims.xor_sum on si32.

Test

… with XOR

Add support for lowering `prims::xor_sum` to StableHLO's `stablehlo.reduce`
using `stablehlo.xor` as the reduction body. The op reduces integer tensors
along specified dimensions using bitwise XOR, with an initial value of zero.

- Define `PrimsXorSumOp` in GeneratedTorchOps.td via torch_ods_gen.py
- Implement `ConvertAtenReductionOp<PrimsXorSumOp>` in Reduction.cpp
- Register the pattern in `populateReductionOpPatternsAndLegality`
- Add lit test for the new lowering
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant