@@ -130,24 +130,39 @@ bool ShardSolver::resolveStep() {
130130 // << "\n Consumer layout " <<
131131 // consumerLayouts[consumerId]
132132 // << "\n\n";
133+ if (reshardOnEdge) {
134+ // TODO(odjuricic): This should read from results of previous
135+ // resolve instead of accepting all.
136+ //
137+ assert (producerId <=
138+ std::numeric_limits<decltype (Path::producerId)>::max ());
139+ assert (consumerId <=
140+ std::numeric_limits<decltype (Path::consumerId)>::max ());
141+ paths.push_back (Path (producerId, consumerId));
142+ edgeProducerBitset.set (producerId);
143+ edgeConsumerBitset.set (consumerId);
144+ continue ;
145+ }
146+
133147 llvm::Expected<bool > shardCompatible = checkShardCompatible (
134148 producerOp->getResult (0 ), producerLayouts[producerId], consumerOp,
135149 consumerLayouts[consumerId]);
136150
137- if (!shardCompatible) {
138- std::string error = llvm::toString (shardCompatible.takeError ());
139- if (!errorCount.count (error)) {
140- errorCount.insert ({error, 0 });
141- }
142- errorCount[error]++;
143- } else if (reshardOnEdge) {
151+ if (shardCompatible && shardCompatible.get ()) {
144152 assert (producerId <=
145153 std::numeric_limits<decltype (Path::producerId)>::max ());
146154 assert (consumerId <=
147155 std::numeric_limits<decltype (Path::consumerId)>::max ());
148156 paths.push_back (Path (producerId, consumerId));
149157 edgeProducerBitset.set (producerId);
150158 edgeConsumerBitset.set (consumerId);
159+
160+ } else {
161+ std::string error = llvm::toString (shardCompatible.takeError ());
162+ if (!errorCount.count (error)) {
163+ errorCount.insert ({error, 0 });
164+ }
165+ errorCount[error]++;
151166 }
152167 }
153168 }
@@ -157,10 +172,6 @@ bool ShardSolver::resolveStep() {
157172
158173 // No valid paths found for this edge, mark it for resharding.
159174 //
160- if (reshardOnEdge) {
161- return false ;
162- }
163-
164175 if (!insertReshard (edge)) {
165176 return false ;
166177 }
@@ -236,14 +247,6 @@ bool ShardSolver::preprocessFirstOp() {
236247 // Add constraint check
237248 Operation *firstOp = shardSpecs->front ().op ;
238249
239- // if (llvm::isa<tt::ttnn::MatmulOp>(firstOp)) {
240- // auto loc = llvm::dyn_cast<NameLoc>(firstOp->getLoc());
241- // if (loc.getName() == "index_262.dc.matmul.3") {
242- // llvm::errs() << loc;
243- // }
244- // llvm::errs() << "MatmulOp\n";
245- // }
246-
247250 if (memReconfigEdges.count (
248251 Edge (firstOp->getOperand (0 ).getDefiningOp (), firstOp, 0 )) > 0 ) {
249252 return true ;
@@ -296,17 +299,7 @@ bool ShardSolver::preprocessFirstOp() {
296299 }
297300
298301 if (!hasValidLayout) {
299- // Print all consumer and producer layouts:
300- //
301- // firstOp->emitError() << "No valid output layout found for DRAM input!";
302-
303- // llvm::errs() << "First op layouts: " << firstOpLayouts.size() << "\n";
304- // for (auto layout : firstOpLayouts) {
305- // llvm::errs() << "\t" << layout << "\n";
306- // }
307-
308302 // Insert reshard edge for the first op to start the chain.
309- //
310303 Edge shardChainInputEdge = Edge (firstOp->getOperand (0 ).getDefiningOp (),
311304 firstOp, 0 /* operandIndex*/ );
312305
0 commit comments