Skip to content

Commit 958320c

Browse files
Expand constant expressions on named computations with all constant like ops.
PiperOrigin-RevId: 804851791
1 parent b54fab5 commit 958320c

2 files changed

Lines changed: 117 additions & 45 deletions

File tree

shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include <cassert>
1717
#include <utility>
1818

19+
#include "llvm/ADT/DenseSet.h"
1920
#include "llvm/ADT/STLExtras.h"
2021
#include "mlir/Dialect/Func/IR/FuncOps.h"
2122
#include "mlir/IR/Builders.h"
@@ -66,28 +67,41 @@ void eraseShardingGroupUsers(Operation* op) {
6667

6768
// A constant preserving op is an op that is considered a constant expression if
6869
// it is pure and all its results can be considered as constant expressions
69-
// given all its operands are constant expressions.
70-
bool isConstantPreserving(Operation* op) {
70+
// given all its operands are constant expressions, for which it holds if the
71+
// given op is either:
72+
// - A broadcast, reshape or slice op.
73+
// - An elementwise op.
74+
// - A named computation all operations are constant preserving.
75+
// Assumes the op is not constant or iota.
76+
bool isConstantPreserving(
77+
Operation* op,
78+
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
7179
if (isa<stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
7280
stablehlo::SliceOp>(op)) {
7381
return isPure(op);
7482
}
7583
if (isElementwise(op)) {
7684
return isPure(op);
7785
}
86+
if (auto namedComputationOp = dyn_cast<NamedComputationOp>(op)) {
87+
return !nonConstantNamedComputationOps.contains(
88+
namedComputationOp.getName()) &&
89+
isPure(op);
90+
}
7891
return false;
7992
}
8093

8194
// Returns true if the given op is either:
8295
// - A constant or iota op.
83-
// - A constant preserving op. (see isConstantPreserving)
84-
// - All operands are constants, that is, exist in `constantOps`.
85-
bool isConstantExpression(Operation* op,
86-
const llvm::SetVector<Operation*>& constantOps) {
96+
// - A constant preserving op. (see isConstantPreserving) and all operands are
97+
// constants, that is, exist in `constantOps`.
98+
bool isConstantExpression(
99+
Operation* op, const llvm::SetVector<Operation*>& constantOps,
100+
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
87101
if (isa<ConstantOp, stablehlo::IotaOp>(op)) {
88102
return true;
89103
}
90-
return isConstantPreserving(op) &&
104+
return isConstantPreserving(op, nonConstantNamedComputationOps) &&
91105
llvm::all_of(op->getOperands(), [&](Value operand) {
92106
return operand.getDefiningOp() &&
93107
constantOps.contains(operand.getDefiningOp());
@@ -156,14 +170,27 @@ void cloneSubComputationOnOperands(
156170
}
157171

158172
void processOp(Operation* op, llvm::SetVector<Operation*>& constantOps,
159-
llvm::SetVector<Operation*>& scalarExpansionOps) {
173+
llvm::SetVector<Operation*>& scalarExpansionOps,
174+
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
160175
if (isa<ShardingGroupOp>(op)) {
161176
return;
162177
}
163-
if (isConstantExpression(op, constantOps)) {
178+
if (isConstantExpression(op, constantOps, nonConstantNamedComputationOps)) {
164179
constantOps.insert(op);
165180
return;
166181
}
182+
// NOTE: There are cases that op is an constant expression but may not pass
183+
// the following check such as constant and iota ops. That is fine because if
184+
// the op is a constant expression it is a stronger condition than being just
185+
// constant preserving and it does not make the parent named computation
186+
// non-const, and at this point, it is guaranteed that the op is not constant
187+
// expression.
188+
if (!isConstantPreserving(op, nonConstantNamedComputationOps) &&
189+
!op->hasTrait<OpTrait::IsTerminator>()) {
190+
if (auto namedCompuationOp = op->getParentOfType<NamedComputationOp>()) {
191+
nonConstantNamedComputationOps.insert(namedCompuationOp.getName());
192+
}
193+
}
167194
if (isScalarExpansion(op)) {
168195
scalarExpansionOps.insert(op);
169196
return;
@@ -213,22 +240,45 @@ struct ConstantOrScalarSplitterPass
213240
}
214241

215242
// Then we split constant sub-computations for each non-constant user.
216-
llvm::SetVector<Operation*> constantOps, scalarExpansionOps;
217-
funcOp.walk(
218-
[&](Operation* op) { processOp(op, constantOps, scalarExpansionOps); });
243+
llvm::SmallVector<llvm::SetVector<Operation*>> constantOps;
244+
llvm::SetVector<Operation*> scalarExpansionOps;
245+
llvm::SmallDenseSet<StringRef> nonConstantNamedComputationOps;
246+
constantOps.push_back(llvm::SetVector<Operation*>());
247+
funcOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
248+
if (isa<NamedComputationOp>(op)) {
249+
constantOps.push_back(llvm::SetVector<Operation*>());
250+
return;
251+
}
252+
processOp(op, constantOps.back(), scalarExpansionOps,
253+
nonConstantNamedComputationOps);
254+
if (op->hasTrait<OpTrait::IsTerminator>() &&
255+
isa<NamedComputationOp>(op->getParentOp())) {
256+
for (Operation* op : llvm::reverse(constantOps.back())) {
257+
if (hasOnlyUsersOfType<ShardingGroupOp>(op)) {
258+
eraseShardingGroupUsers(op);
259+
op->erase();
260+
}
261+
}
262+
constantOps.pop_back();
263+
processOp(op->getParentOp(), constantOps.back(), scalarExpansionOps,
264+
nonConstantNamedComputationOps);
265+
return;
266+
}
267+
});
219268

220269
// Since for every op in `constantOps` that has a use that isn't in
221270
// `constantOps`, we replaced the use with a clone of the entire
222271
// sub-computation, we can now erase all ops in `constantOps` as long as we
223272
// iterate in reverse order. Note that we did not clone scalars so we keep
224273
// the original.
225274
for (Operation* op : llvm::concat<Operation* const>(
226-
scalarExpansionOps, llvm::reverse(constantOps))) {
275+
scalarExpansionOps, llvm::reverse(constantOps.back()))) {
227276
if (hasOnlyUsersOfType<ShardingGroupOp>(op)) {
228277
eraseShardingGroupUsers(op);
229278
op->erase();
230279
}
231280
}
281+
constantOps.pop_back();
232282
}
233283

234284
private:

shardy/dialect/sdy/transforms/import/test/constant_or_scalar_splitter.mlir

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,19 @@ func.func @constant_to_named_computation_with_only_constant_ops(%arg0: tensor<8x
119119

120120
// CHECK-LABEL: func @constant_multiple_users_within_named_computation_with_no_arguments_and_with_only_constant_ops
121121
func.func @constant_multiple_users_within_named_computation_with_no_arguments_and_with_only_constant_ops() -> (tensor<8x16xf32>, tensor<8x16xf32>) {
122-
// CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">()
123-
// CHECK-NEXT: %[[CONST:.*]] = sdy.constant dense<1.000000e+00>
124-
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[CONST]]
125-
// CHECK-NEXT: sdy.return %[[NEGATE]]
122+
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">()
123+
// CHECK-NEXT: %[[CONST0:.*]] = sdy.constant dense<1.000000e+00>
124+
// CHECK-NEXT: %[[NEGATE0:.*]] = stablehlo.negate %[[CONST0]]
125+
// CHECK-NEXT: sdy.return %[[NEGATE0]]
126126
// CHECK-NEXT: } : () -> tensor<8x16xf32>
127-
// CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC]]
128-
// CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC]]
127+
// CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">()
128+
// CHECK-NEXT: %[[CONST1:.*]] = sdy.constant dense<1.000000e+00>
129+
// CHECK-NEXT: %[[NEGATE1:.*]] = stablehlo.negate %[[CONST1]]
130+
// CHECK-NEXT: sdy.return %[[NEGATE1]]
131+
// CHECK-NEXT: } : () -> tensor<8x16xf32>
132+
// CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC0]]
133+
// CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC1]]
129134
// CHECK-NEXT: return %[[ABS_0]], %[[ABS_1]]
130-
// TODO(enver): The named computation should be splitted.
131135
%0 = sdy.named_computation<"foo">() () {
132136
%1 = stablehlo.constant dense<1.000000e+00> : tensor<8x16xf32>
133137
%2 = stablehlo.negate %1 : tensor<8x16xf32>
@@ -140,16 +144,21 @@ func.func @constant_multiple_users_within_named_computation_with_no_arguments_an
140144

141145
// CHECK-LABEL: func @constant_to_named_computation_with_one_argument_and_with_only_constant_ops
142146
func.func @constant_to_named_computation_with_one_argument_and_with_only_constant_ops() -> (tensor<8x16xf32>, tensor<8x16xf32>) {
143-
// CHECK-NEXT: %[[CONST:.*]] = sdy.constant dense<1.000000e+00>
144-
// CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[CONST]]) (%arg0: tensor<8x16xf32>) {
145-
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
146-
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[NEGATE]]
147-
// CHECK-NEXT: sdy.return %[[ADD]]
147+
// CHECK-NEXT: %[[CONST0:.*]] = sdy.constant dense<1.000000e+00>
148+
// CHECK-NEXT: %[[CONST1:.*]] = sdy.constant dense<1.000000e+00>
149+
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[CONST0]]) (%arg0: tensor<8x16xf32>) {
150+
// CHECK-NEXT: %[[NEGATE0:.*]] = stablehlo.negate %arg0
151+
// CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[NEGATE0]]
152+
// CHECK-NEXT: sdy.return %[[ADD0]]
148153
// CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
149-
// CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC]]
150-
// CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC]]
154+
// CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[CONST1]]) (%arg0: tensor<8x16xf32>) {
155+
// CHECK-NEXT: %[[NEGATE1:.*]] = stablehlo.negate %arg0
156+
// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[NEGATE1]]
157+
// CHECK-NEXT: sdy.return %[[ADD1]]
158+
// CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
159+
// CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC0]]
160+
// CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC1]]
151161
// CHECK-NEXT: return %[[ABS_0]], %[[ABS_1]]
152-
// TODO(enver): The named computation should be splitted.
153162
%0 = stablehlo.constant dense<1.000000e+00> : tensor<8x16xf32>
154163
%1 = sdy.named_computation<"foo">(%0) (%arg0: tensor<8x16xf32>) {
155164
%2 = stablehlo.negate %arg0 : tensor<8x16xf32>
@@ -165,15 +174,20 @@ func.func @constant_to_named_computation_with_one_argument_and_with_only_constan
165174
func.func @constant_multiple_users_one_to_named_computation_with_one_argument_and_with_only_constant_ops() -> (tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>) {
166175
// CHECK-NEXT: %[[CONST_0:.*]] = sdy.constant dense<1.000000e+00>
167176
// CHECK-NEXT: %[[CONST_1:.*]] = sdy.constant dense<1.000000e+00>
168-
// CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[CONST_0]]) (%arg0: tensor<8x16xf32>) {
169-
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
170-
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[NEGATE]]
171-
// CHECK-NEXT: sdy.return %[[ADD]]
177+
// CHECK-NEXT: %[[CONST_2:.*]] = sdy.constant dense<1.000000e+00>
178+
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[CONST_1]]) (%arg0: tensor<8x16xf32>) {
179+
// CHECK-NEXT: %[[NEGATE0:.*]] = stablehlo.negate %arg0
180+
// CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[NEGATE0]]
181+
// CHECK-NEXT: sdy.return %[[ADD0]]
172182
// CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
173-
// CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC]]
174-
// CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC]]
175-
// CHECK-NEXT: return %[[CONST_1]], %[[ABS_0]], %[[ABS_1]]
176-
// TODO(enver): The named computation should be splitted.
183+
// CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[CONST_2]]) (%arg0: tensor<8x16xf32>) {
184+
// CHECK-NEXT: %[[NEGATE1:.*]] = stablehlo.negate %arg0
185+
// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[NEGATE1]]
186+
// CHECK-NEXT: sdy.return %[[ADD1]]
187+
// CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
188+
// CHECK-NEXT: %[[ABS_0:.*]] = stablehlo.abs %[[NC0]]
189+
// CHECK-NEXT: %[[ABS_1:.*]] = stablehlo.abs %[[NC1]]
190+
// CHECK-NEXT: return %[[CONST_0]], %[[ABS_0]], %[[ABS_1]]
177191
%0 = stablehlo.constant dense<1.000000e+00> : tensor<8x16xf32>
178192
%1 = sdy.named_computation<"foo">(%0) (%arg0: tensor<8x16xf32>) {
179193
%2 = stablehlo.negate %arg0 : tensor<8x16xf32>
@@ -687,17 +701,25 @@ func.func @constant_both_to_named_computation_and_inside_named_computation_and_n
687701
func.func @constant_both_to_named_computation_and_inside_named_computation_and_named_computation_is_constant() -> (tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>) {
688702
// CHECK-NEXT: %[[CONST0:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
689703
// CHECK-NEXT: %[[CONST1:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
690-
// CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[CONST0]]) (%arg0: tensor<8x16xf32>) {
691-
// CHECK-NEXT: %[[CONST2:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
704+
// CHECK-NEXT: %[[CONST2:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
705+
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[CONST1]]) (%arg0: tensor<8x16xf32>) {
692706
// CHECK-NEXT: %[[CONST3:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
693-
// CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[CONST2]] : tensor<8x16xf32>
694-
// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[CONST3]] : tensor<8x16xf32>
695-
// CHECK-NEXT: %[[MULTIPLY:.*]] = stablehlo.multiply %[[ADD0]], %[[ADD1]] : tensor<8x16xf32>
696-
// CHECK-NEXT: sdy.return %[[MULTIPLY]] : tensor<8x16xf32>
707+
// CHECK-NEXT: %[[CONST4:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
708+
// CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg0, %[[CONST3]] : tensor<8x16xf32>
709+
// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %arg0, %[[CONST4]] : tensor<8x16xf32>
710+
// CHECK-NEXT: %[[MULTIPLY0:.*]] = stablehlo.multiply %[[ADD0]], %[[ADD1]] : tensor<8x16xf32>
711+
// CHECK-NEXT: sdy.return %[[MULTIPLY0]] : tensor<8x16xf32>
697712
// CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
698-
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[NC]] : tensor<8x16xf32>
699-
// CHECK-NEXT: return %[[CONST1]], %[[NC]], %[[NEGATE]] : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>
700-
// TODO(enver): The named computation should be splitted as well.
713+
// CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[CONST2]]) (%arg0: tensor<8x16xf32>) {
714+
// CHECK-NEXT: %[[CONST5:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
715+
// CHECK-NEXT: %[[CONST6:.*]] = sdy.constant dense<1.000000e+00> : tensor<8x16xf32>
716+
// CHECK-NEXT: %[[ADD2:.*]] = stablehlo.add %arg0, %[[CONST5]] : tensor<8x16xf32>
717+
// CHECK-NEXT: %[[ADD3:.*]] = stablehlo.add %arg0, %[[CONST6]] : tensor<8x16xf32>
718+
// CHECK-NEXT: %[[MULTIPLY1:.*]] = stablehlo.multiply %[[ADD2]], %[[ADD3]] : tensor<8x16xf32>
719+
// CHECK-NEXT: sdy.return %[[MULTIPLY1]] : tensor<8x16xf32>
720+
// CHECK-NEXT: } : (tensor<8x16xf32>) -> tensor<8x16xf32>
721+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[NC1]] : tensor<8x16xf32>
722+
// CHECK-NEXT: return %[[CONST0]], %[[NC0]], %[[NEGATE]] : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>
701723
%0 = stablehlo.constant dense<1.000000e+00> : tensor<8x16xf32>
702724
%1 = sdy.named_computation<"foo">(%0) (%arg0: tensor<8x16xf32>) {
703725
%2 = stablehlo.constant dense<1.000000e+00> : tensor<8x16xf32>

0 commit comments

Comments
 (0)