Skip to content

Commit 95e1140

Browse files
committed
Fix tests
1 parent 83c8e8e commit 95e1140

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

lib/Dialect/TTNN/Analysis/ShardSolver.cpp

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,39 @@ bool ShardSolver::resolveStep() {
130130
// << "\n Consumer layout " <<
131131
// consumerLayouts[consumerId]
132132
// << "\n\n";
133+
if (reshardOnEdge) {
134+
// TODO(odjuricic): This should read from results of previous
135+
// resolve instead of accepting all.
136+
//
137+
assert(producerId <=
138+
std::numeric_limits<decltype(Path::producerId)>::max());
139+
assert(consumerId <=
140+
std::numeric_limits<decltype(Path::consumerId)>::max());
141+
paths.push_back(Path(producerId, consumerId));
142+
edgeProducerBitset.set(producerId);
143+
edgeConsumerBitset.set(consumerId);
144+
continue;
145+
}
146+
133147
llvm::Expected<bool> shardCompatible = checkShardCompatible(
134148
producerOp->getResult(0), producerLayouts[producerId], consumerOp,
135149
consumerLayouts[consumerId]);
136150

137-
if (!shardCompatible) {
138-
std::string error = llvm::toString(shardCompatible.takeError());
139-
if (!errorCount.count(error)) {
140-
errorCount.insert({error, 0});
141-
}
142-
errorCount[error]++;
143-
} else if (reshardOnEdge) {
151+
if (shardCompatible && shardCompatible.get()) {
144152
assert(producerId <=
145153
std::numeric_limits<decltype(Path::producerId)>::max());
146154
assert(consumerId <=
147155
std::numeric_limits<decltype(Path::consumerId)>::max());
148156
paths.push_back(Path(producerId, consumerId));
149157
edgeProducerBitset.set(producerId);
150158
edgeConsumerBitset.set(consumerId);
159+
160+
} else {
161+
std::string error = llvm::toString(shardCompatible.takeError());
162+
if (!errorCount.count(error)) {
163+
errorCount.insert({error, 0});
164+
}
165+
errorCount[error]++;
151166
}
152167
}
153168
}
@@ -157,10 +172,6 @@ bool ShardSolver::resolveStep() {
157172

158173
// No valid paths found for this edge, mark it for resharding.
159174
//
160-
if (reshardOnEdge) {
161-
return false;
162-
}
163-
164175
if (!insertReshard(edge)) {
165176
return false;
166177
}
@@ -236,14 +247,6 @@ bool ShardSolver::preprocessFirstOp() {
236247
// Add constraint check
237248
Operation *firstOp = shardSpecs->front().op;
238249

239-
// if (llvm::isa<tt::ttnn::MatmulOp>(firstOp)) {
240-
// auto loc = llvm::dyn_cast<NameLoc>(firstOp->getLoc());
241-
// if (loc.getName() == "index_262.dc.matmul.3") {
242-
// llvm::errs() << loc;
243-
// }
244-
// llvm::errs() << "MatmulOp\n";
245-
// }
246-
247250
if (memReconfigEdges.count(
248251
Edge(firstOp->getOperand(0).getDefiningOp(), firstOp, 0)) > 0) {
249252
return true;
@@ -296,17 +299,7 @@ bool ShardSolver::preprocessFirstOp() {
296299
}
297300

298301
if (!hasValidLayout) {
299-
// Print all consumer and producer layouts:
300-
//
301-
// firstOp->emitError() << "No valid output layout found for DRAM input!";
302-
303-
// llvm::errs() << "First op layouts: " << firstOpLayouts.size() << "\n";
304-
// for (auto layout : firstOpLayouts) {
305-
// llvm::errs() << "\t" << layout << "\n";
306-
// }
307-
308302
// Insert reshard edge for the first op to start the chain.
309-
//
310303
Edge shardChainInputEdge = Edge(firstOp->getOperand(0).getDefiningOp(),
311304
firstOp, 0 /*operandIndex*/);
312305

test/lit.cfg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646
# system_desc_path: The system desc that is to be used to generate the binary files.
4747
config.system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")
4848

49+
# This needs to be done optimizer subdirectories only.
50+
lit_config.parallelism_groups["optimizer"] = 1
51+
config.parallelism_group = "optimizer"
52+
4953
# set features based on system
5054
system_desc = None
5155
if config.system_desc_path:

test/unittests/Optimizer/TestShardSolver.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
1818

1919
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
20+
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
2021

2122
using namespace mlir::tt::ttnn;
2223

@@ -42,7 +43,13 @@ class ShardSolverBase : public ::testing::Test {
4243
}
4344

4445
mlir::RankedTensorType getTensorRankedType() {
45-
return mlir::RankedTensorType::get(getTensorShape(), builder.getF32Type());
46+
return mlir::RankedTensorType::get(
47+
getTensorShape(), builder.getF32Type(),
48+
TTNNLayoutAttr::get(&context, getTensorShape(), builder.getF32Type(),
49+
BufferType::DRAM,
50+
mlir::tt::GridAttr::get(&context, {1, 1}),
51+
mlir::tt::ttnn::TensorMemoryLayoutAttr::get(
52+
&context, TensorMemoryLayout::Interleaved)));
4653
}
4754

4855
mlir::Value createEmptyTensor() {
@@ -63,6 +70,10 @@ class ShardSolverBase : public ::testing::Test {
6370
mlir::TypeRange(input), mlir::TypeRange(output));
6471
func = builder.create<mlir::func::FuncOp>(builder.getUnknownLoc(), "test",
6572
funcType);
73+
func->setAttr(
74+
mlir::tt::DeviceAttr::name,
75+
mlir::tt::DeviceAttr::get(
76+
&context, mlir::tt::SystemDescAttr::getDefault(&context)));
6677

6778
mlir::Block *block = func.addEntryBlock();
6879
block->addArgument(getTensorRankedType(), builder.getUnknownLoc());
@@ -225,6 +236,12 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) {
225236
TTNNLayoutAttr const &producerLayout,
226237
mlir::Operation *consumerOp,
227238
TTNNLayoutAttr const &consumerLayout) {
239+
// Interleaved to sharded is always supported.
240+
//
241+
if (producerLayout.hasInterleavedDRAMTensorMemoryLayout()) {
242+
return true;
243+
}
244+
228245
// Simple shard compat assumption. Try to keep same shard layout.
229246
//
230247
if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) {

0 commit comments

Comments
 (0)