Skip to content

Commit 7c065e0

Browse files
authored
#22781: Fixes to support uneven ND sharding (#22782)
### Ticket #22781 ### Problem description Currently ND sharding has issues with uneven shards ### What's changed Added more tests to check uneven shards Fixed row major tensor alignment for nd sharding Used padded tensor shape for DistributionSpec calculation Avoid converting ND sharding to single bank, because it doesn't have much support in OPs ### Checklist - [x] [All post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/15343571295) - [x] New/Existing tests provide coverage for changes
1 parent 214fed5 commit 7c065e0

File tree

6 files changed

+124
-52
lines changed

6 files changed

+124
-52
lines changed

tests/ttnn/unit_tests/gtests/tensor/test_tensor_nd_sharding.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,41 @@ INSTANTIATE_TEST_SUITE_P(
200200
.shard_shape = Shape({32, 32, 32}),
201201
.layout = Layout::TILE,
202202
},
203+
NDShardingParams{
204+
.shape = Shape({3 * 32 + 5, 4 * 32, 5 * 32}),
205+
.shard_shape = Shape({32, 4 * 32, 5 * 32}),
206+
.layout = Layout::TILE,
207+
},
208+
NDShardingParams{
209+
.shape = Shape({3 * 32, 4 * 32 + 5, 5 * 32}),
210+
.shard_shape = Shape({3 * 32, 32, 5 * 32}),
211+
.layout = Layout::TILE,
212+
},
213+
NDShardingParams{
214+
.shape = Shape({3 * 32, 4 * 32, 5 * 32 + 5}),
215+
.shard_shape = Shape({3 * 32, 4 * 32, 32}),
216+
.layout = Layout::TILE,
217+
},
218+
NDShardingParams{
219+
.shape = Shape({3 * 32, 4 * 32 + 5, 5 * 32 + 5}),
220+
.shard_shape = Shape({3 * 32, 32, 32}),
221+
.layout = Layout::TILE,
222+
},
223+
NDShardingParams{
224+
.shape = Shape({3 * 32 + 5, 4 * 32, 5 * 32 + 5}),
225+
.shard_shape = Shape({32, 4 * 32, 32}),
226+
.layout = Layout::TILE,
227+
},
228+
NDShardingParams{
229+
.shape = Shape({3 * 32 + 5, 4 * 32 + 5, 5 * 32}),
230+
.shard_shape = Shape({32, 32, 5 * 32}),
231+
.layout = Layout::TILE,
232+
},
233+
NDShardingParams{
234+
.shape = Shape({3 * 32 + 5, 4 * 32 + 5, 5 * 32 + 5}),
235+
.shard_shape = Shape({32, 32, 32}),
236+
.layout = Layout::TILE,
237+
},
203238
NDShardingParams{
204239
.shape = Shape({30, 40, 50}),
205240
.shard_shape = Shape({30, 40, 50}),
@@ -239,6 +274,51 @@ INSTANTIATE_TEST_SUITE_P(
239274
.shape = Shape({30, 40, 50}),
240275
.shard_shape = Shape({10, 10, 10}),
241276
.layout = Layout::ROW_MAJOR,
277+
},
278+
NDShardingParams{
279+
.shape = Shape({3, 4, 5}),
280+
.shard_shape = Shape({1, 1, 1}),
281+
.layout = Layout::ROW_MAJOR,
282+
},
283+
NDShardingParams{
284+
.shape = Shape({35, 40, 50}),
285+
.shard_shape = Shape({10, 40, 50}),
286+
.layout = Layout::ROW_MAJOR,
287+
},
288+
NDShardingParams{
289+
.shape = Shape({30, 45, 50}),
290+
.shard_shape = Shape({30, 10, 50}),
291+
.layout = Layout::ROW_MAJOR,
292+
},
293+
NDShardingParams{
294+
.shape = Shape({30, 40, 55}),
295+
.shard_shape = Shape({30, 40, 10}),
296+
.layout = Layout::ROW_MAJOR,
297+
},
298+
NDShardingParams{
299+
.shape = Shape({35, 45, 50}),
300+
.shard_shape = Shape({10, 10, 50}),
301+
.layout = Layout::ROW_MAJOR,
302+
},
303+
NDShardingParams{
304+
.shape = Shape({35, 40, 55}),
305+
.shard_shape = Shape({10, 40, 10}),
306+
.layout = Layout::ROW_MAJOR,
307+
},
308+
NDShardingParams{
309+
.shape = Shape({30, 45, 55}),
310+
.shard_shape = Shape({30, 10, 10}),
311+
.layout = Layout::ROW_MAJOR,
312+
},
313+
NDShardingParams{
314+
.shape = Shape({35, 45, 55}),
315+
.shard_shape = Shape({10, 10, 10}),
316+
.layout = Layout::ROW_MAJOR,
317+
},
318+
NDShardingParams{
319+
.shape = Shape({3, 5, 7}),
320+
.shard_shape = Shape({2, 2, 2}),
321+
.layout = Layout::ROW_MAJOR,
242322
}),
243323
::testing::Values(BufferType::L1, BufferType::DRAM),
244324
::testing::Values(ShardOrientation::ROW_MAJOR, ShardOrientation::COL_MAJOR)));

tests/ttnn/unit_tests/tensor/test_tensor_nd_sharding.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,35 @@
1515
([3, 4, 5], [3, 4, 5], ttnn.ROW_MAJOR_LAYOUT), # All data on a single core
1616
([3, 4, 5], [3, 4, 1], ttnn.ROW_MAJOR_LAYOUT), # Each core gets full batch and height dimension
1717
([3, 4, 5], [3, 1, 5], ttnn.ROW_MAJOR_LAYOUT), # Each core gets full batch and width dimension
18-
(
19-
[3, 4, 5],
20-
[1, 4, 5],
21-
ttnn.ROW_MAJOR_LAYOUT,
22-
), # Each core gets full height and width dimension, aka 1 batch per core
18+
([3, 4, 5], [1, 4, 5], ttnn.ROW_MAJOR_LAYOUT), # Each core gets full height and width dimension
2319
([3, 4, 5], [3, 1, 1], ttnn.ROW_MAJOR_LAYOUT), # Each core gets full batch dimension
2420
([3, 4, 5], [1, 4, 1], ttnn.ROW_MAJOR_LAYOUT), # Each core gets full height dimension
2521
([3, 4, 5], [1, 1, 5], ttnn.ROW_MAJOR_LAYOUT), # Each core gets full width dimension
26-
(
27-
[3, 4, 5],
28-
[1, 1, 1],
29-
ttnn.ROW_MAJOR_LAYOUT,
30-
), # Data is distributed equally across all cores, no dimenions preserved
22+
([3, 4, 5], [1, 1, 1], ttnn.ROW_MAJOR_LAYOUT), # Data is distributed equally across all cores
3123
# Tile Layout
32-
([3, 4 * 32, 5 * 32], [3, 4 * 32, 5 * 32], ttnn.TILE_LAYOUT), # All data on a single core
33-
([3, 4 * 32, 5 * 32], [3, 4 * 32, 32], ttnn.TILE_LAYOUT), # Each core gets full batch and height dimension
34-
([3, 4 * 32, 5 * 32], [3, 32, 5 * 32], ttnn.TILE_LAYOUT), # Each core gets full batch and width dimension
35-
(
36-
[3, 4 * 32, 5 * 32],
37-
[1, 4 * 32, 5 * 32],
38-
ttnn.TILE_LAYOUT,
39-
), # Each core gets full height and width dimension, aka 1 batch per core
40-
([3, 4 * 32, 5 * 32], [3, 32, 32], ttnn.TILE_LAYOUT), # Each core gets full batch dimension
41-
([3, 4 * 32, 5 * 32], [1, 4 * 32, 32], ttnn.TILE_LAYOUT), # Each core gets full height dimension
42-
([3, 4 * 32, 5 * 32], [1, 32, 5 * 32], ttnn.TILE_LAYOUT), # Each core gets full width dimension
43-
(
44-
[3, 4 * 32, 5 * 32],
45-
[1, 32, 32],
46-
ttnn.TILE_LAYOUT,
47-
), # Data is distributed equally across all cores, no dimenions preserved
24+
([3, 128, 160], [3, 128, 160], ttnn.TILE_LAYOUT), # All data on a single core
25+
([3, 128, 160], [3, 128, 32], ttnn.TILE_LAYOUT), # Each core gets full batch and height dimension
26+
([3, 128, 160], [3, 32, 160], ttnn.TILE_LAYOUT), # Each core gets full batch and width dimension
27+
([3, 128, 160], [1, 128, 160], ttnn.TILE_LAYOUT), # Each core gets full height and width dimension
28+
([3, 128, 160], [3, 32, 32], ttnn.TILE_LAYOUT), # Each core gets full batch dimension
29+
([3, 128, 160], [1, 128, 32], ttnn.TILE_LAYOUT), # Each core gets full height dimension
30+
([3, 128, 160], [1, 32, 160], ttnn.TILE_LAYOUT), # Each core gets full width dimension
31+
([3, 128, 160], [1, 32, 32], ttnn.TILE_LAYOUT), # Data is distributed equally across all cores
32+
# Uneven shards
33+
([30, 40, 55], [30, 40, 10], ttnn.ROW_MAJOR_LAYOUT),
34+
([30, 45, 50], [30, 10, 50], ttnn.ROW_MAJOR_LAYOUT),
35+
([35, 40, 50], [10, 40, 50], ttnn.ROW_MAJOR_LAYOUT),
36+
([30, 45, 50], [30, 10, 50], ttnn.ROW_MAJOR_LAYOUT),
37+
([35, 40, 50], [10, 40, 50], ttnn.ROW_MAJOR_LAYOUT),
38+
([35, 45, 50], [10, 10, 50], ttnn.ROW_MAJOR_LAYOUT),
39+
([35, 45, 55], [10, 10, 10], ttnn.ROW_MAJOR_LAYOUT),
40+
([3, 128, 165], [3, 128, 32], ttnn.TILE_LAYOUT),
41+
([3, 130, 160], [3, 32, 160], ttnn.TILE_LAYOUT),
42+
([5, 128, 160], [2, 128, 160], ttnn.TILE_LAYOUT),
43+
([3, 130, 165], [3, 32, 32], ttnn.TILE_LAYOUT),
44+
([5, 128, 165], [2, 128, 32], ttnn.TILE_LAYOUT),
45+
([5, 130, 160], [2, 32, 160], ttnn.TILE_LAYOUT),
46+
([5, 130, 165], [2, 32, 32], ttnn.TILE_LAYOUT),
4847
],
4948
)
5049
@pytest.mark.parametrize("buffer_type", [ttnn.BufferType.L1, ttnn.BufferType.DRAM])

ttnn/core/tensor/layout/page_config.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,22 @@ const Tile& TilePageConfig::get_tile() const { return tile_; }
134134
RowMajorPageConfig::RowMajorPageConfig(const Tile& tile) : tile_(tile) {}
135135

136136
Alignment RowMajorPageConfig::create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const {
137-
{
138-
if (memory_config.shard_spec().has_value()) {
139-
const auto& shard_spec = memory_config.shard_spec().value();
140-
if (shard_spec.mode == ShardMode::LOGICAL) {
141-
return shard_spec.physical_shard_shape.has_value() ? Alignment(shard_spec.physical_shard_shape.value())
142-
: Alignment({shard_spec.shape[1]});
143-
}
144-
// TODO: Investigate why we need guard against HEIGHT_SHARDED and merge logic with LOGICAL sharding
145-
if (shard_spec.mode == ShardMode::PHYSICAL &&
146-
memory_config.memory_layout() != TensorMemoryLayout::HEIGHT_SHARDED) {
147-
return Alignment({shard_spec.shape[1]});
148-
}
137+
if (memory_config.shard_spec().has_value()) {
138+
const auto& shard_spec = memory_config.shard_spec().value();
139+
if (shard_spec.mode == ShardMode::LOGICAL) {
140+
return shard_spec.physical_shard_shape.has_value() ? Alignment(shard_spec.physical_shard_shape.value())
141+
: Alignment({shard_spec.shape[1]});
149142
}
150-
return Alignment({1});
143+
// TODO: Investigate why we need guard against HEIGHT_SHARDED and merge logic with LOGICAL sharding
144+
if (shard_spec.mode == ShardMode::PHYSICAL &&
145+
memory_config.memory_layout() != TensorMemoryLayout::HEIGHT_SHARDED) {
146+
return Alignment({shard_spec.shape[1]});
147+
}
148+
} else if (memory_config.nd_shard_spec().has_value()) {
149+
const auto& nd_shard_spec = *memory_config.nd_shard_spec();
150+
return Alignment({nd_shard_spec.shard_shape[-1]});
151151
}
152+
return Alignment({1});
152153
}
153154

154155
void RowMajorPageConfig::validate_alignment(

ttnn/core/tensor/layout/tensor_layout.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,9 @@ std::optional<std::variant<ShardSpecBuffer, BufferDistributionSpec>> TensorLayou
214214
}
215215

216216
auto& nd_shard_spec = memory_config_.nd_shard_spec().value();
217+
auto padded_shape = compute_padded_shape(shape);
217218
return BufferDistributionSpec::from_shard_spec(
218-
shape, nd_shard_spec.shard_shape, page_shape, nd_shard_spec.grid, nd_shard_spec.orientation);
219+
padded_shape, nd_shard_spec.shard_shape, page_shape, nd_shard_spec.grid, nd_shard_spec.orientation);
219220
}
220221

221222
size_t TensorLayout::compute_packed_buffer_size_bytes(const ttnn::Shape& shape) const {

ttnn/core/tensor/tensor_impl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,8 +940,9 @@ std::array<Shape2D, 2> get_logical_and_physical_shard_shapes(const TensorSpec& t
940940

941941
// TODO: get_logical_shard_shape always returns shard shape from shard spec, which is not correct in physical mode
942942
// if there is padding
943-
if (tensor_spec.memory_config().is_sharded() and
944-
(tensor_spec.memory_config().shard_spec().value().mode == ShardMode::LOGICAL or
943+
if (tensor_spec.memory_config().is_sharded() &&
944+
((tensor_spec.memory_config().shard_spec().has_value() &&
945+
tensor_spec.memory_config().shard_spec().value().mode == ShardMode::LOGICAL) ||
945946
logical_shape == padded_shape)) {
946947
return {
947948
tensor_spec.tensor_layout().get_logical_shard_shape(),

ttnn/core/tensor/tensor_spec.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,6 @@ std::optional<MemoryConfig> TensorSpec::populate_legacy_shard_spec_from_nd() con
246246
return std::nullopt;
247247
}
248248

249-
// Detect single bank case
250-
if (nd_shard_shape == padded_shape()) {
251-
return MemoryConfig::create_with_prepopulated_shard_specs(
252-
TensorMemoryLayout::SINGLE_BANK,
253-
mem_config.buffer_type(),
254-
ShardSpec(nd_shard_spec.grid, physical_shape(), nd_shard_spec.orientation),
255-
mem_config.nd_shard_spec(),
256-
mem_config.created_with_nd_shard_spec());
257-
}
258-
259249
ShardSpec shard_spec(
260250
nd_shard_spec.grid,
261251
{nd_shard_shape.volume() / nd_shard_shape[-1], nd_shard_shape[-1]},

0 commit comments

Comments
 (0)