Skip to content

Commit 56477a9

Browse files
Add simplification for XorOp. Simplification covers canonicalization of const to rhs, fold to zero when both constant operands are equal, and noop when one of the operands is constant zero
PiperOrigin-RevId: 899855076
1 parent 05a3944 commit 56477a9

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

third_party/stablehlo/temporary.patch

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,88 @@
1+
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
2+
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
3+
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
4+
@@ -2100,12 +2100,78 @@
5+
/////////
6+
// XorOp
7+
8+
-// CHECK-LABEL: @xor_cst_on_rhs
9+
-func.func @xor_cst_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> {
10+
+// CHECK-LABEL: @xor_cst_false_on_rhs
11+
+// CHECK-SAME: [[ARG:%.+]]: tensor<2xi1>
12+
+func.func @xor_cst_false_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> {
13+
+ // CHECK-NOT: stablehlo-constant
14+
%cst = stablehlo.constant dense<false> : tensor<2xi1>
15+
+ // CHECK-NOT: stablehlo-xor
16+
%0 = stablehlo.xor %cst, %arg0 : tensor<2xi1>
17+
- // CHECK: stablehlo.xor %arg0, %c : tensor<2xi1>
18+
+ // CHECK: return [[ARG]] : tensor<2xi1>
19+
return %0 : tensor<2xi1>
20+
+}
21+
+
22+
+// -----
23+
+
24+
+// CHECK-LABEL: @xor_cst_true_on_rhs
25+
+// CHECK-SAME: [[ARG:%.+]]: tensor<2xi1>
26+
+func.func @xor_cst_true_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> {
27+
+ // CHECK: [[CST:%.+]] = stablehlo.constant dense<true> : tensor<2xi1>
28+
+ %cst = stablehlo.constant dense<true> : tensor<2xi1>
29+
+ // CHECK: [[RES:%.+]] = stablehlo.xor [[ARG]], [[CST]] : tensor<2xi1>
30+
+ %0 = stablehlo.xor %cst, %arg0 : tensor<2xi1>
31+
+ // CHECK: return [[RES]]
32+
+ return %0 : tensor<2xi1>
33+
+}
34+
+
35+
+// -----
36+
+
37+
+// CHECK-LABEL: @xor_cst_int_on_rhs
38+
+// CHECK-SAME: [[ARG:%.+]]: tensor<2xi32>
39+
+func.func @xor_cst_int_on_rhs(%arg0: tensor<2xi32>) -> tensor<2xi32> {
40+
+ // CHECK: [[CST:%.+]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
41+
+ %cst = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
42+
+ // CHECK: [[RES:%.+]] = stablehlo.xor [[ARG]], [[CST]] : tensor<2xi32>
43+
+ %0 = stablehlo.xor %cst, %arg0 : tensor<2xi32>
44+
+ // CHECK: return [[RES]]
45+
+ return %0 : tensor<2xi32>
46+
+}
47+
+
48+
+// -----
49+
+
50+
+// CHECK-LABEL: @xor_same_lhs_rhs
51+
+// CHECK-SAME: [[ARG:%.+]]: tensor<2xi32>
52+
+func.func @xor_same_lhs_rhs(%arg0: tensor<2xi32>) -> tensor<2xi32> {
53+
+ // CHECK: [[RES:%.+]] = stablehlo.constant dense<0> : tensor<2xi32>
54+
+ %0 = stablehlo.xor %arg0, %arg0 : tensor<2xi32>
55+
+ // CHECK: return [[RES]]
56+
+ return %0 : tensor<2xi32>
57+
+}
58+
+
59+
+// -----
60+
+
61+
+// CHECK-LABEL: @xor_zero_like_lhs
62+
+// CHECK-SAME: [[ARG:%.+]]: tensor<3xi32>
63+
+func.func @xor_zero_like_lhs(%arg0: tensor<3xi32>) -> tensor<3xi32> {
64+
+ // CHECK-NOT: stablehlo.constant
65+
+ %0 = stablehlo.constant dense<0> : tensor<3xi32>
66+
+ // CHECK-NOT: stablehlo.xor
67+
+ %1 = stablehlo.xor %0, %arg0 : tensor<3xi32>
68+
+ // CHECK: return [[ARG]]
69+
+ return %1 : tensor<3xi32>
70+
+}
71+
+
72+
+// -----
73+
+
74+
+// CHECK-LABEL: @xor_zero_like_rhs
75+
+// CHECK-SAME: [[ARG:%.+]]: tensor<3xi32>
76+
+func.func @xor_zero_like_rhs(%arg0: tensor<3xi32>) -> tensor<3xi32> {
77+
+ // CHECK-NOT: stablehlo.constant
78+
+ %0 = stablehlo.constant dense<0> : tensor<3xi32>
79+
+ // CHECK-NOT: stablehlo.xor
80+
+ %1 = stablehlo.xor %arg0, %0 : tensor<3xi32>
81+
+ // CHECK: return [[ARG]]
82+
+ return %1 : tensor<3xi32>
83+
}
84+
85+
// -----
186
diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp
287
--- stablehlo/stablehlo/transforms/VhloToVersion.cpp
388
+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp
@@ -21,4 +106,32 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
21106

22107
return true;
23108
}
109+
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
110+
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
111+
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
112+
@@ -682,10 +682,17 @@
113+
def XorOp_CanonicalizeConstantToRhs
114+
: CanonicalizeConstantToRhs<StableHLO_XorOp>;
115+
116+
-// To consider: xor(X, X) -> 0
117+
-//
118+
-// It's unclear if this is beneficial on hardware vs. adding another constant.
119+
-//
120+
-// def XorOp_FoldToZero
121+
-// : Pat<(StableHLO_XorOp AnyStaticShapeTensor:$operand, $operand),
122+
-// (StableHLO_ConstantLike<"0"> $operand)>;
123+
+// Pattern: xor(X, X) -> 0
124+
+def XorOp_FoldToZero
125+
+ : Pat<(StableHLO_XorOp AnyStaticShapeTensor:$operand, $operand),
126+
+ (StableHLO_ConstantLike<"0"> $operand)>;
127+
+
128+
+// Pattern: xor(X, 0) -> X
129+
+def XorOp_RemoveNoop
130+
+ : Pat<(StableHLO_XorOp:$op $lhs, (ConstantLikeMatcher IntZero:$value)),
131+
+ (replaceWithValue $lhs)>;
132+
+
133+
+
134+
+
135+
+
136+
+
24137

0 commit comments

Comments
 (0)