Skip to content

Commit 4f45c06

Browse files
Varchocopybara-github
authored andcommitted
[SDY][Cleanup] remove apply_sharding_constraints.h file by moving debug args td file options so that constructor is auto-generated via tblgen.
PiperOrigin-RevId: 810599911
1 parent 17e1e18 commit 4f45c06

5 files changed

Lines changed: 18 additions & 73 deletions

File tree

shardy/dialect/sdy/transforms/import/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ cc_library(
4444
"sharding_group_import.cc",
4545
],
4646
hdrs = [
47-
"apply_sharding_constraints.h",
4847
"passes.h",
4948
],
5049
deps = [

shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,19 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "shardy/dialect/sdy/transforms/import/apply_sharding_constraints.h"
17-
1816
#include <cassert>
1917
#include <functional>
20-
#include <memory>
2118

2219
#include "llvm/ADT/STLExtras.h"
23-
#include "llvm/Support/CommandLine.h"
2420
#include "mlir/IR/Builders.h"
2521
#include "mlir/IR/BuiltinAttributes.h"
2622
#include "mlir/IR/BuiltinOps.h"
2723
#include "mlir/IR/MLIRContext.h"
2824
#include "mlir/IR/Value.h"
2925
#include "mlir/IR/ValueRange.h"
30-
#include "mlir/Pass/Pass.h"
3126
#include "mlir/Support/LLVM.h"
3227
#include "shardy/dialect/sdy/ir/dialect.h"
3328
#include "shardy/dialect/sdy/ir/utils.h"
34-
#include "shardy/dialect/sdy/transforms/common/propagation_options.h"
3529
#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep
3630
#include "shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.h"
3731
#include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h"
@@ -239,14 +233,7 @@ struct ApplyShardingConstraintsPass
239233
ApplyShardingConstraintsPass(const ApplyShardingConstraintsPass& other)
240234
: ApplyShardingConstraintsPassBase<ApplyShardingConstraintsPass>(other) {
241235
debugShardingOrigins = other.debugShardingOrigins;
242-
debugPropagationEdgeSharding = other.debugPropagationEdgeSharding;
243-
}
244-
245-
// Constructor to be used when creating the pass programmatically with
246-
// options.
247-
explicit ApplyShardingConstraintsPass(const PropagationOptions& options) {
248-
debugShardingOrigins = options.debugShardingOrigins;
249-
debugPropagationEdgeSharding = options.debugPropagationEdgeSharding;
236+
debugPropagationEdges = other.debugPropagationEdges;
250237
}
251238

252239
void runOnOperation() final {
@@ -255,7 +242,7 @@ struct ApplyShardingConstraintsPass
255242

256243
// Prepare debugging handler for sharding origins and edge sources.
257244
ShardingDebugMappings mappings(debugShardingOrigins,
258-
debugPropagationEdgeSharding);
245+
debugPropagationEdges);
259246
SourceShardingHandler handler(&mappings);
260247
// Prepare the handler and register it to the context.
261248
handler.prepareHandler(moduleOp);
@@ -315,26 +302,9 @@ struct ApplyShardingConstraintsPass
315302
context.registerActionHandler(nullptr);
316303
handler.saveOnModule(moduleOp);
317304
}
318-
319-
Option<bool> debugShardingOrigins{
320-
*this, "debug-sharding-origins",
321-
llvm::cl::desc("whether to save sharding origin information"),
322-
llvm::cl::init(false)};
323-
324-
Option<bool> debugPropagationEdgeSharding{
325-
*this, "debug-propagation-edge-sharding",
326-
llvm::cl::desc("whether to save propagation edge information"),
327-
llvm::cl::init(false)};
328305
};
329306

330307
} // namespace
331308

332-
// This function can be used to create the pass with specific options
333-
// programmatically.
334-
std::unique_ptr<Pass> createApplyShardingConstraintsPass(
335-
const PropagationOptions& options) {
336-
return std::make_unique<ApplyShardingConstraintsPass>(options);
337-
}
338-
339309
} // namespace sdy
340310
} // namespace mlir

shardy/dialect/sdy/transforms/import/apply_sharding_constraints.h

Lines changed: 0 additions & 38 deletions
This file was deleted.

shardy/dialect/sdy/transforms/import/import_pipeline.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ limitations under the License.
2121
#include "mlir/Transforms/Passes.h"
2222
#include "shardy/common/file_utils.h"
2323
#include "shardy/dialect/sdy/transforms/common/propagation_options.h"
24-
#include "shardy/dialect/sdy/transforms/import/apply_sharding_constraints.h"
2524
#include "shardy/dialect/sdy/transforms/import/passes.h"
2625

2726
namespace mlir {
@@ -66,7 +65,9 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
6665
options.dumpDirectory, "before_propagation", dumpIndex++));
6766

6867
pm.addNestedPass<func::FuncOp>(createAddDataFlowEdgesPass());
69-
pm.addPass(createApplyShardingConstraintsPass(options));
68+
pm.addPass(
69+
createApplyShardingConstraintsPass(ApplyShardingConstraintsPassOptions{
70+
options.debugShardingOrigins, options.debugPropagationEdgeSharding}));
7071
// The sharding group import pass must run after applying sharding
7172
// constraints. This ensures we can detect sharding conflicts between group
7273
// members which have pre-propagation shardings due to sharding constraints.

shardy/dialect/sdy/transforms/import/passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "Modul
105105
dangling case).
106106
}];
107107
let dependentDialects = ["mlir::sdy::SdyDialect"];
108+
109+
let options = [
110+
Option<"debugShardingOrigins", "debug-sharding-origins", "bool",
111+
/*default=*/"false",
112+
"Whether to compute the debug origin shardings for constraints. See "
113+
"`debug-sharding-origins` option in propagation for more info.">,
114+
Option<"debugPropagationEdges",
115+
"debug-propagation-edge-sharding", "bool",
116+
/*default=*/"false",
117+
"Whether to sink the debug propagation edge sharding info. See "
118+
"`debug-propagation-edge-sharding` option in propagation for more "
119+
"info.">
120+
];
108121
}
109122

110123
def ConstantOrScalarSplitterPass : Pass<"sdy-constant-or-scalar-splitter", "func::FuncOp"> {

0 commit comments

Comments
 (0)