Skip to content

Commit 83ad365

Browse files
[TTNN.JIT] Add Support for Width and Height Sharded Tensors (#5969)
### Ticket #5836 #5837 ### Problem description TTNN JIT needs to support width and height sharded tensors. #5437 added the required support for representing these layouts and generating DMA in D2M. This PR is focused on the JIT frontend and end to end execution. ### Changes in JIT Frontend - Emit height sharded and width sharded `TTNNLayoutAttr`s - Set `exact_grid` on all generated `TTNNLayoutAttr` to avoid grid collapsing when generating a TTNN flat buffer - Check for unsupported tensor layouts and raise exceptions - Centralize duplicate tensor layout translation code between the AST parser and graph tracer into `tensor_translator.py` ### Changes in Core D2M - Modified `TTIRToD2M` pass: - Match height and width sharded tensors and create virtual grids to represent the shard distribution - Modified `D2MToTTNN` pass: - Set the correct grid on `ttnn.generic` ops when the corresponding `d2m.generic` has a virtual grid - Include a fix for virtual grids from @bgrady-tt in 36148f7 ### Changes to TTNN Dialect and Flatbuffer Generation The problem: TTNN JIT needs to represent already fully specified tensors using the TTNN dialect, while the TTNN compiler gets to pick the properties of the tensor. This is problematic for height and width sharding since TTNN JIT needs to represent exact physical grids (e.g. `6x2` and **NOT** `3x4`), while the TTNN compiler can use a virtual grid (`12x1` or `1x12` in this case) and **choose** a desired physical collapsing The solution: Added an `exact_grid` optional parameter on `TTNNLayoutAttr`. If this parameter is set to true, the grid is not collapsed and recorded in the flatbuffer as is. ### Tests - IR tests for virtual grid and metal layout generation - Sweeping of all grids in `test_layouts.py` - Height and width sharded layouts added to all existing tests in `test_eltwise.py` and `test_eltwise_composite.py` - Closes #5836 - Closes #5837 ### Checklist - [X] New/Existing tests provide coverage for changes --------- Co-authored-by: Brett Grady <[email protected]>
1 parent e1a14da commit 83ad365

File tree

19 files changed

+1008
-249
lines changed

19 files changed

+1008
-249
lines changed

include/ttmlir/Dialect/D2M/Utils/VirtualGrid.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ createCoreVirtMaps(mlir::MLIRContext *context,
6363
if (is2DWidthSharded) {
6464
forwardMapExprs = {d1.floorDiv(gridColStride), d1 % gridColStride, d2, d3};
6565
// note: inverse map results are (deviceIndex, gridY, gridX)
66-
inverseMapExprs = {zero, zero, d0 * gridRowStride + d1};
66+
inverseMapExprs = {zero, zero, d0 * gridColStride + d1};
6767
} else if (is2DHeightSharded) {
6868
forwardMapExprs = {d0 % gridRowStride, d0.floorDiv(gridRowStride), d2, d3};
69-
inverseMapExprs = {zero, d1 * gridColStride + d0, zero};
69+
inverseMapExprs = {zero, d1 * gridRowStride + d0, zero};
7070
}
7171
auto forward = mlir::AffineMap::get(2 * rank, 0, forwardMapExprs, context);
7272
auto inverse = mlir::AffineMap::get(rank, 0, inverseMapExprs, context);

include/ttmlir/Dialect/TTCore/IR/TTCoreOpsTypes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ def TTCore_ShardLayoutAttr : TTCore_Attr<"ShardLayout", "shard", [TTCore_DeviceL
294294
}
295295

296296
unsigned getRank() const { return getStride().size(); }
297+
298+
// Computes the implied physical grid shape using the grid shape and index map (if present)
299+
llvm::SmallVector<int64_t> getPhysicalGridShape(ShapedType tensorType) const;
297300
}];
298301
}
299302

include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,9 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
644644
AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref,
645645
OptionalParameter<"TensorMemoryLayoutAttr", "TTNN tensor memory layout">:$mem_layout,
646646
OptionalParameter<"ttcore::TensorMeshAttr", "TT tensor mesh attr">:$tensor_mesh,
647-
OptionalParameter<"bool", "A status flag, asking the users to ignore the physical layout. This is used to model a sharded layout with unspecified shard shape.">:$ignorePhysicalLayout);
648-
let assemblyFormat = "`<` $linear`,` $grid`,` (`mesh` `=` $tensor_mesh^ `,`)? $memref (`,` $mem_layout^)? (`,` $ignorePhysicalLayout^)? `>`";
647+
OptionalParameter<"bool", "A status flag, asking the users to ignore the physical layout. This is used to model a sharded layout with unspecified shard shape.">:$ignorePhysicalLayout,
648+
OptionalParameter<"bool", "A status flag indicating that the grid should be treated as the exact physical grid and not a virtual grid to becollapsed">:$exactGrid);
649+
let assemblyFormat = "`<` $linear`,` $grid`,` (`mesh` `=` $tensor_mesh^ `,`)? $memref (`,` $mem_layout^)? (`,` `exactGrid` `=` $exactGrid^)? (`,` `ignorePhysicalLayout` `=` $ignorePhysicalLayout^)?`>`";
649650
let extraClassDeclaration = [{
650651
static TTNNLayoutAttr get(::mlir::MLIRContext *context,
651652
AffineMap linear,

include/ttmlir/Target/Utils/MLIRToFlatbuffer.h

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,8 @@ inline flatbuffers::Offset<::tt::target::ttnn::MemoryDesc>
868868
toFlatbuffer(FlatbufferObjectCache &cache, mlir::MemRefType memref,
869869
ttcore::TensorMeshAttr tensorMesh, ttnn::BufferType bufferType,
870870
ttnn::TensorMemoryLayoutAttr memLayoutAttr,
871-
ttcore::GridAttr shardGrid, ttcore::GridAttr deviceGrid) {
871+
ttcore::GridAttr shardGrid, ttcore::GridAttr deviceGrid,
872+
bool exactGrid) {
872873
auto shapeInt64 = memref.getShape();
873874
std::vector<int32_t> shape(shapeInt64.begin(), shapeInt64.end());
874875
ttcore::DataType dtype = ttcore::DataType::Float32;
@@ -905,21 +906,36 @@ toFlatbuffer(FlatbufferObjectCache &cache, mlir::MemRefType memref,
905906
shape[0] *= tileShape.y();
906907
shape[1] *= tileShape.x();
907908

908-
auto coreRangeSetAttr = ttnn::CoreRangeSetAttr::get(
909-
ctx, llvm::map_to_vector(
910-
ttcore::utils::toCoreRangeSet(
911-
shardGrid.getShape(),
912-
ttnn::optimizer_utils::
913-
createSingleDeviceVirtualToPhysicalAffineMap(
914-
ctx, memLayoutAttr.getValue(),
915-
deviceGrid.getShape())),
916-
[ctx](const auto &range) {
917-
const auto [loc, size] = range;
918-
return ttnn::CoreRangeAttr::get(
919-
ctx, ttnn::CoreCoordAttr::get(ctx, loc[0], loc[1]),
920-
ttnn::CoreCoordAttr::get(ctx, loc[0] + size[0] - 1,
921-
loc[1] + size[1] - 1));
922-
}));
909+
ttnn::CoreRangeSetAttr coreRangeSetAttr;
910+
// For TTNN JIT, the sharding core range has already been set by the user.
911+
// This means that for height and width sharding, the grid in the
912+
// TTNNLayoutAttr is not a virtual Mx1 or 1xN grid that would require
913+
// collapsing. This is required to distinguish between 3x4 and 2x6 as
914+
// sharding grids for height and width sharding.
915+
// Note that the core coord is (X,Y) but the grid shape is (Y,X).
916+
if (exactGrid) {
917+
coreRangeSetAttr = ttnn::CoreRangeSetAttr::get(
918+
ctx, ttnn::CoreRangeAttr::get(
919+
ctx, ttnn::CoreCoordAttr::get(ctx, 0, 0),
920+
ttnn::CoreCoordAttr::get(ctx, shardGrid.getShape()[1] - 1,
921+
shardGrid.getShape()[0] - 1)));
922+
} else {
923+
coreRangeSetAttr = ttnn::CoreRangeSetAttr::get(
924+
ctx, llvm::map_to_vector(
925+
ttcore::utils::toCoreRangeSet(
926+
shardGrid.getShape(),
927+
ttnn::optimizer_utils::
928+
createSingleDeviceVirtualToPhysicalAffineMap(
929+
ctx, memLayoutAttr.getValue(),
930+
deviceGrid.getShape())),
931+
[ctx](const auto &range) {
932+
const auto [loc, size] = range;
933+
return ttnn::CoreRangeAttr::get(
934+
ctx, ttnn::CoreCoordAttr::get(ctx, loc[0], loc[1]),
935+
ttnn::CoreCoordAttr::get(ctx, loc[0] + size[0] - 1,
936+
loc[1] + size[1] - 1));
937+
}));
938+
}
923939
shardSpecAttr = ttnn::ShardSpecAttr::get(
924940
ctx, coreRangeSetAttr, ttnn::ShapeAttr::get(ctx, shape),
925941
ttnn::ShardOrientationAttr::get(ctx,
@@ -951,7 +967,8 @@ ttnnLayoutAttrToFlatbuffer(FlatbufferObjectCache &cache,
951967
*cache.fbb, toFlatbuffer(cache, ttcore::OOBVal::Undef),
952968
toFlatbuffer(cache, layoutAttr.getMemref(), layoutAttr.getTensorMesh(),
953969
layoutAttr.getBufferType(), layoutAttr.getMemLayout(),
954-
layoutAttr.getGrid(), deviceAttr.getWorkerGrid()));
970+
layoutAttr.getGrid(), deviceAttr.getWorkerGrid(),
971+
layoutAttr.getExactGrid()));
955972
}
956973

957974
inline ::tt::target::ttnn::CoreType

lib/Conversion/D2MToTTNN/D2MToTTNN.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "ttmlir/Conversion/D2MToTTNN/D2MToTTNN.h"
66

7+
#include "ttmlir/AffineMapUtils.h"
78
#include "ttmlir/Asserts.h"
89
#include "ttmlir/Dialect/D2M/IR/D2MOps.h"
910
#include "ttmlir/Dialect/TTCore/IR/TTCoreOpsTypes.h"
@@ -117,13 +118,32 @@ class D2MGenericRewriter : public OpConversionPattern<d2m::GenericOp> {
117118
auto device = ttcore::lookupDevice(op->getParentOp());
118119
TT_assert(device);
119120

120-
// TTNN grids are (Width, Height), while D2M grids are (Height, Width).
121-
ttcore::GridAttr grid = op.getGrid();
121+
ttcore::GridAttr opGrid = op.getGrid();
122+
llvm::SmallVector<int64_t> endCoreRange;
123+
if (!opGrid.getMapping().isEmpty()) {
124+
// The genericOp has a virtual grid. We need to recover the original
125+
// physical grid.
126+
auto output = op.getOutputs()[0];
127+
mlir::ShapedType outputType =
128+
mlir::cast<mlir::ShapedType>(output.getType());
129+
auto shardLayout = mlir::dyn_cast<ttcore::ShardLayoutAttr>(
130+
ttcore::getDeviceLayout(outputType));
131+
TT_assertv(shardLayout, "Expected shardLayoutAttr for the output of a "
132+
"generic op with a virtual grid.");
133+
134+
auto physicalGridShape = shardLayout.getPhysicalGridShape(outputType);
135+
// TTNN grids are (Width, Height), while D2M grids are (Height, Width).
136+
endCoreRange = {physicalGridShape[1] - 1, physicalGridShape[0] - 1};
137+
} else {
138+
// TTNN grids are (Width, Height), while D2M grids are (Height, Width).
139+
endCoreRange = {opGrid.getShape()[1] - 1, opGrid.getShape()[0] - 1};
140+
}
141+
122142
ttnn::CoreRangeSetAttr coreRangeSet = ttnn::CoreRangeSetAttr::get(
123-
ctx, ttnn::CoreRangeAttr::get(
124-
ctx, ttnn::CoreCoordAttr::get(ctx, 0, 0),
125-
ttnn::CoreCoordAttr::get(ctx, grid.getShape()[1] - 1,
126-
grid.getShape()[0] - 1)));
143+
ctx,
144+
ttnn::CoreRangeAttr::get(
145+
ctx, ttnn::CoreCoordAttr::get(ctx, 0, 0),
146+
ttnn::CoreCoordAttr::get(ctx, endCoreRange[0], endCoreRange[1])));
127147

128148
llvm::SmallVector<Value> ios(size);
129149
llvm::SmallVector<Value> cbs(size);

lib/Conversion/TTIRToD2M/TTIRToD2M.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ttmlir/Dialect/D2M/IR/D2M.h"
1313
#include "ttmlir/Dialect/D2M/IR/D2MGenericRegionOps.h"
1414
#include "ttmlir/Dialect/D2M/Utils/Utils.h"
15+
#include "ttmlir/Dialect/D2M/Utils/VirtualGrid.h"
1516
#include "ttmlir/Utils.h"
1617

1718
#include "ttmlir/Dialect/D2M/IR/D2M.h"
@@ -64,13 +65,6 @@ class D2MNamedRewriterCommon {
6465
mlir::cast<ttcore::TileType>(ttnnLayout.getElementType()).getHeight() ==
6566
ttcore::TileType::getDefaultShape()[0] &&
6667
"Only default tile shape is supported");
67-
bool isBlockSharded = ttnnLayout.hasL1BufferType() &&
68-
ttnnLayout.getMemLayout().getValue() ==
69-
ttnn::TensorMemoryLayout::BlockSharded;
70-
bool isInterleaved = ttnnLayout.hasInterleavedDRAMTensorMemoryLayout();
71-
assert((isBlockSharded || isInterleaved) &&
72-
"Only block sharded L1 or interleaved DRAM tensor memory layouts "
73-
"are supported");
7468
}
7569

7670
RankedTensorType
@@ -102,15 +96,35 @@ class D2MNamedRewriterCommon {
10296
dimAlignments[dimAlignments.size() - 1] = 32;
10397
dimAlignments[dimAlignments.size() - 2] = 32;
10498

99+
bool needVirtualGrid = ttnnLayout.getMemLayout().getValue() ==
100+
ttnn::TensorMemoryLayout::HeightSharded ||
101+
ttnnLayout.getMemLayout().getValue() ==
102+
ttnn::TensorMemoryLayout::WidthSharded;
103+
AffineMap indexAffineMap = AffineMap::get(rewriter.getContext());
104+
llvm::SmallVector<int64_t> ttnnGridShape(ttnnLayout.getGrid().getShape());
105+
llvm::SmallVector<int64_t> optimalGrid = ttnnGridShape;
106+
if (needVirtualGrid) {
107+
if (ttnnLayout.getMemLayout().getValue() ==
108+
ttnn::TensorMemoryLayout::HeightSharded) {
109+
optimalGrid = {ttnnGridShape[0] * ttnnGridShape[1], 1};
110+
} else if (ttnnLayout.getMemLayout().getValue() ==
111+
ttnn::TensorMemoryLayout::WidthSharded) {
112+
optimalGrid = {1, ttnnGridShape[0] * ttnnGridShape[1]};
113+
}
114+
auto [fwdMap, _] = ttmlir::d2m::utils::grids::createCoreVirtMaps(
115+
rewriter.getContext(), optimalGrid, ttnnGridShape);
116+
indexAffineMap = fwdMap;
117+
}
118+
105119
auto metalLayout = ttcore::MetalLayoutAttr::get(
106120
rewriter.getContext(), tensorType.getShape(), ttcore::OOBVal::Undef,
107-
memSpace, memLayout, collapsedIntervals, dimAlignments);
121+
memSpace, memLayout, collapsedIntervals, dimAlignments, indexAffineMap);
108122

109123
llvm::SmallVector<int64_t> unshardedShape =
110124
metalLayout.getPhysicalShape(ttcore::TileType::getDefaultShape());
111125

112126
llvm::SmallVector<int64_t> shardedShape = metalLayout.getDeviceShape(
113-
ttnnLayout.getGrid().getShape(), ttcore::TileType::getDefaultShape());
127+
optimalGrid, ttcore::TileType::getDefaultShape());
114128

115129
Type elementType = ttnnLayout.getElementType();
116130
return mlir::RankedTensorType::get(shardedShape, elementType, metalLayout);

lib/Dialect/TTCore/IR/TTCoreOpsTypes.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,27 @@ ChipDescAttr::getDstLogicalSizeTiles(Type type, bool fullSyncEn,
547547
return nDstTiles;
548548
}
549549

550+
static llvm::SmallVector<int64_t>
551+
getPhysicalGridShapeFromShapeAndMap(ShapedType shapedType, AffineMap map) {
552+
auto shape = shapedType.getShape();
553+
554+
// find bounds of the physical grid by transforming the virtual grid using
555+
// index map
556+
std::pair<int64_t, int64_t> ybounds = {0, 0};
557+
std::pair<int64_t, int64_t> xbounds = {0, 0};
558+
ttmlir::utils::sample(shape, [&](SmallVector<int64_t, 8> point) {
559+
auto virtualPoint = map.compose(point);
560+
ybounds = {std::min(ybounds.first, virtualPoint[0]),
561+
std::max(ybounds.second, virtualPoint[0])};
562+
xbounds = {std::min(xbounds.first, virtualPoint[1]),
563+
std::max(xbounds.second, virtualPoint[1])};
564+
});
565+
566+
TT_assertv((ybounds.first == 0 && xbounds.first == 0),
567+
"Physical grid shape must start at y=0,x=0.");
568+
return {ybounds.second + 1, xbounds.second + 1};
569+
}
570+
550571
ShardLayoutAttr ShardLayoutAttr::get(mlir::MLIRContext *context,
551572
ArrayRef<int64_t> shape,
552573
uint64_t elementSize, uint32_t buffers) {
@@ -583,6 +604,12 @@ mlir::AffineMap ShardLayoutAttr::getAffineMap() const {
583604
getContext());
584605
}
585606

607+
llvm::SmallVector<int64_t>
608+
ShardLayoutAttr::getPhysicalGridShape(ShapedType tensorType) const {
609+
return getPhysicalGridShapeFromShapeAndMap(tensorType,
610+
this->getCoreVirtualizationMap());
611+
}
612+
586613
InterleavedLayoutAttr InterleavedLayoutAttr::get(mlir::MLIRContext *context,
587614
ArrayRef<int64_t> shape,
588615
uint64_t elementSize) {
@@ -882,23 +909,7 @@ MetalLayoutAttr::getPhysicalShape(ArrayRef<int64_t> tileShape) const {
882909

883910
llvm::SmallVector<int64_t>
884911
MetalLayoutAttr::getPhysicalGridShape(ShapedType tensorType) const {
885-
auto shape = tensorType.getShape();
886-
887-
// find bounds of the physical grid by transforming the virtual grid using
888-
// index map
889-
std::pair<int64_t, int64_t> ybounds = {0, 0};
890-
std::pair<int64_t, int64_t> xbounds = {0, 0};
891-
ttmlir::utils::sample(shape, [&](SmallVector<int64_t, 8> point) {
892-
auto virtualPoint = getIndexAffineMap().compose(point);
893-
ybounds = {std::min(ybounds.first, virtualPoint[0]),
894-
std::max(ybounds.second, virtualPoint[0])};
895-
xbounds = {std::min(xbounds.first, virtualPoint[1]),
896-
std::max(xbounds.second, virtualPoint[1])};
897-
});
898-
899-
TT_assertv((ybounds.first == 0 && xbounds.first == 0),
900-
"Physical grid shape must start at y=0,x=0.");
901-
return {ybounds.second + 1, xbounds.second + 1};
912+
return getPhysicalGridShapeFromShapeAndMap(tensorType, getIndexAffineMap());
902913
}
903914

904915
// Takes various shape fields and returns the expected physical shape, which

lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,8 @@ TTNNLayoutAttr TTNNLayoutAttr::withDataType(ttcore::DataType dataType) {
449449
getContext(), getLinear(), getGrid(),
450450
ttcore::buildMemRef<BufferType, BufferTypeAttr>(
451451
getContext(), getScalarShardShape(), elementType, getBufferType()),
452-
getMemLayout(), getTensorMesh(), getIgnorePhysicalLayout());
452+
getMemLayout(), getTensorMesh(), getIgnorePhysicalLayout(),
453+
getExactGrid());
453454
}
454455

455456
// Construct a new TTNNLayoutAttr
@@ -492,7 +493,8 @@ TTNNLayoutAttr TTNNLayoutAttr::withBufferType(BufferType memorySpace) {
492493
getContext(), getLinear(), grid,
493494
mlir::tt::ttcore::buildMemRef<BufferType, BufferTypeAttr>(
494495
getContext(), getScalarShardShape(), getElementType(), memorySpace),
495-
memLayoutAttr, getTensorMesh(), getIgnorePhysicalLayout());
496+
memLayoutAttr, getTensorMesh(), getIgnorePhysicalLayout(),
497+
getExactGrid());
496498
}
497499

498500
// Construct a new TTNNLayoutAttr
@@ -510,7 +512,7 @@ TTNNLayoutAttr::withMemoryLayout(TensorMemoryLayoutAttr memLayoutAttr) {
510512
getContext(), getScalarShardShape(),
511513
getElementType(), getBufferType()),
512514
memLayoutAttr, getTensorMesh(),
513-
getIgnorePhysicalLayout());
515+
getIgnorePhysicalLayout(), getExactGrid());
514516
}
515517

516518
// Construct a new TTNNLayoutAttr
@@ -542,7 +544,8 @@ TTNNLayoutAttr::withShardShape(llvm::SmallVector<int64_t> shardShape) {
542544
getContext(), getLinear(), getGrid(),
543545
mlir::tt::ttcore::buildMemRef<BufferType, BufferTypeAttr>(
544546
getContext(), shardShape, getElementType(), getBufferType()),
545-
getMemLayout(), getTensorMesh(), getIgnorePhysicalLayout());
547+
getMemLayout(), getTensorMesh(), getIgnorePhysicalLayout(),
548+
getExactGrid());
546549
}
547550

548551
// Construct a new TTNNLayoutAttr
@@ -578,7 +581,7 @@ TTNNLayoutAttr
578581
TTNNLayoutAttr::withIgnorePhysicalLayout(bool ignorePhysicalLayout) {
579582
return TTNNLayoutAttr::get(getContext(), getLinear(), getGrid(), getMemref(),
580583
getMemLayout(), getTensorMesh(),
581-
ignorePhysicalLayout);
584+
ignorePhysicalLayout, getExactGrid());
582585
};
583586

584587
TTNNLayoutAttr TTNNLayoutAttr::get(::mlir::MLIRContext *context,
@@ -588,7 +591,8 @@ TTNNLayoutAttr TTNNLayoutAttr::get(::mlir::MLIRContext *context,
588591
ttcore::TensorMeshAttr tensor_mesh) {
589592
return TTNNLayoutAttr::get(context, linear, grid, memref, mem_layout,
590593
tensor_mesh,
591-
/*ignorePhysicalLayout=*/false);
594+
/*ignorePhysicalLayout=*/false,
595+
/*exactGrid=*/false);
592596
}
593597

594598
// Construct a new TTNNLayoutAttr
@@ -641,14 +645,15 @@ TTNNLayoutAttr TTNNLayoutAttr::get(
641645
mlir::tt::ttcore::buildMemRef<BufferType, BufferTypeAttr>(
642646
context, shardShape, elementType, bufferType);
643647
return get(context, linear, grid, memRefType, memLayoutAttr, tensorMesh,
644-
ignorePhysicalLayout);
648+
ignorePhysicalLayout, /*exactGrid=*/false);
645649
}
646650

647651
llvm::LogicalResult TTNNLayoutAttr::verify(
648652
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, AffineMap,
649653
mlir::tt::ttcore::GridAttr grid, MemRefType memref,
650654
TensorMemoryLayoutAttr memLayout,
651-
mlir::tt::ttcore::TensorMeshAttr tensorMesh, bool ignorePhysicalLayout) {
655+
mlir::tt::ttcore::TensorMeshAttr tensorMesh, bool ignorePhysicalLayout,
656+
bool exactGrid) {
652657
BufferType bufferType =
653658
mlir::cast<BufferTypeAttr>(memref.getMemorySpace()).getValue();
654659

python/TTNNModule.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ void populateTTNNModule(nb::module_ &m) {
110110
[](MlirContext ctx, MlirAffineMap linear, MlirAttribute grid,
111111
MlirType memref, std::optional<unsigned> memLayout = std::nullopt,
112112
std::optional<tt::ttcore::TensorMeshAttr> tensorMesh =
113-
std::nullopt) {
113+
std::nullopt,
114+
std::optional<bool> exactGrid = std::nullopt) {
114115
tt::ttnn::TensorMemoryLayoutAttr memLayoutAttr;
115116
if (memLayout.has_value()) {
116117
memLayoutAttr = tt::ttnn::TensorMemoryLayoutAttr::get(
@@ -125,10 +126,12 @@ void populateTTNNModule(nb::module_ &m) {
125126
unwrap(ctx), mlir::cast<AffineMap>(unwrap(linear)),
126127
mlir::cast<tt::ttcore::GridAttr>(unwrap(grid)),
127128
mlir::cast<MemRefType>(unwrap(memref)), memLayoutAttr,
128-
tensorMeshAttr));
129+
tensorMeshAttr, /*ignorePhysicalLayout=*/false,
130+
exactGrid.value_or(false)));
129131
},
130132
nb::arg("ctx"), nb::arg("linear"), nb::arg("grid"), nb::arg("memref"),
131-
nb::arg("memLayout") = nb::none(), nb::arg("tensorMesh") = nb::none())
133+
nb::arg("memLayout") = nb::none(), nb::arg("tensorMesh") = nb::none(),
134+
nb::arg("exactGrid") = nb::none())
132135
.def_static(
133136
"get",
134137
[](MlirContext ctx, std::vector<std::int64_t> shape, MlirType type,

0 commit comments

Comments
 (0)