Skip to content

Commit c705303

Browse files
authored
#22205: Unify ND and 2D sharding, ND sharding performance improvements (#23152)
### Ticket #22205 ### Problem description We're working on improving ND sharding performance to be on par with regular 2D block sharding ### What's changed ND sharding was unified with 2D sharding, utilizing the same flow through BufferPageMapping, achieving roughly the same perfromance Removed DistributionSpec and nd sharding utils as no longer needed Logic for generation mapping from DistributionSpec was ported to BufferDistributionSpec and adjust to match BufferDistributionSpec format BufferPageMapping was moved to a separate header with minor adjustments to accommodate padding pages in the middle of a shard Removed specialized nd sharding code from FDMeshCommandQueue, HWCommandQueue, and tt_matal.cpp (slow dispatch) Minor updates to buffer dispatch and buffer code to accommodate nd sharding Upgraded tests, removed tests made outdated by this change, added a performance test ### Checklist - [x] [All post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/15501368821) - [x] New/Existing tests provide coverage for changes
1 parent 1700432 commit c705303

26 files changed

+514
-1523
lines changed

tests/tt_metal/distributed/test_mesh_socket.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,10 +1887,10 @@ TEST_F(MeshSocketTest, MultiConnectionSingleDeviceConfig) {
18871887
EXPECT_EQ(recv_configs.size(), recv_logical_coords.size());
18881888

18891889
const auto& sender_core_to_core_id =
1890-
send_socket.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id_;
1890+
send_socket.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id;
18911891

18921892
const auto& recv_core_to_core_id =
1893-
recv_socket.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id_;
1893+
recv_socket.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id;
18941894

18951895
for (const auto& connection : socket_connections) {
18961896
const auto& sender = connection.sender_core;
@@ -1987,10 +1987,10 @@ TEST_F(MeshSocketTest2DFabric, MultiConnectionMultiDeviceTest) {
19871987
auto [send_socket_dram, recv_socket_dram] = MeshSocket::create_sockets(md0, md1, socket_config_dram);
19881988

19891989
const auto& sender_core_to_core_id =
1990-
send_socket_l1.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id_;
1990+
send_socket_l1.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id;
19911991

19921992
const auto& recv_core_to_core_id =
1993-
recv_socket_l1.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id_;
1993+
recv_socket_l1.get_config_buffer()->get_backing_buffer()->get_buffer_page_mapping()->core_to_core_id;
19941994

19951995
std::unordered_map<MeshCoordinate, std::vector<sender_socket_md>> sender_configs_per_dev_coord;
19961996
std::unordered_map<MeshCoordinate, std::vector<receiver_socket_md>> recv_configs_per_dev_coord;

tests/tt_metal/tt_metal/api/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ set(UNIT_TESTS_API_SRC
1414
${CMAKE_CURRENT_SOURCE_DIR}/core_coord/test_CoreRangeSet_contains.cpp
1515
${CMAKE_CURRENT_SOURCE_DIR}/core_coord/test_CoreRangeSet_intersects.cpp
1616
${CMAKE_CURRENT_SOURCE_DIR}/core_coord/test_CoreRangeSet_merge.cpp
17-
${CMAKE_CURRENT_SOURCE_DIR}/distribution_spec/test_distribution_spec.cpp
1817
${CMAKE_CURRENT_SOURCE_DIR}/distribution_spec/test_buffer_distribution_spec.cpp
1918
${CMAKE_CURRENT_SOURCE_DIR}/test_banked.cpp
2019
${CMAKE_CURRENT_SOURCE_DIR}/test_bit_utils.cpp

tests/tt_metal/tt_metal/api/distribution_spec/test_buffer_distribution_spec.cpp

Lines changed: 81 additions & 125 deletions
Large diffs are not rendered by default.

tests/tt_metal/tt_metal/api/distribution_spec/test_distribution_spec.cpp

Lines changed: 0 additions & 360 deletions
This file was deleted.

tests/ttnn/unit_tests/gtests/tensor/test_tensor_nd_sharding.cpp

Lines changed: 48 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
#include "ttnn/tensor/tensor.hpp"
77
#include "ttnn/operations/eltwise/binary/binary.hpp"
8-
#include "ttnn/core/tensor/nd_sharding_utils.hpp"
9-
108
#include "ttnn_test_fixtures.hpp"
119

1210
namespace {
@@ -38,13 +36,6 @@ struct NDShardingOpCompatParams {
3836
Shape shard_shape;
3937
CoreCoord grid_size;
4038
};
41-
struct PrepareShardedDataParams {
42-
Shape shape;
43-
Shape shard_shape;
44-
uint32_t num_cores;
45-
46-
std::vector<uint8_t> expected_data;
47-
};
4839
} // namespace
4940

5041
class NDShardingTests
@@ -158,35 +149,59 @@ TEST_P(NdShardingOpCompatTests, TestAdd) {
158149
}
159150
}
160151

161-
class PrepareNdShardedDataTests : public ::testing::TestWithParam<PrepareShardedDataParams> {};
162-
163-
TEST_P(PrepareNdShardedDataTests, PrepareNdShardedData) {
164-
const auto& params = GetParam();
152+
class NDShardingPerfTests : public ttnn::TTNNFixtureWithDevice {};
165153

166-
CoreRangeSet cores(CoreRange(CoreCoord{0, 0}, CoreCoord{0, params.num_cores - 1}));
167-
NdShardSpec nd_shard_spec{params.shard_shape, cores, ShardOrientation::ROW_MAJOR};
168-
MemoryConfig memory_config{BufferType::L1, nd_shard_spec};
169-
TensorLayout tensor_layout(DataType::UINT8, PageConfig(Layout::ROW_MAJOR), memory_config);
170-
TensorSpec tensor_spec(params.shape, tensor_layout);
154+
TEST_F(NDShardingPerfTests, TestBatchShardingPerf) {
155+
CoreRangeSet cores(CoreRange(CoreCoord{0, 0}, CoreCoord{6, 6}));
171156

172-
std::vector<uint8_t> data(params.shape.volume());
173-
for (size_t i = 0; i < data.size(); i++) {
174-
data[i] = static_cast<uint8_t>(i);
175-
}
176-
auto tensor = Tensor::from_vector(data, tensor_spec);
177-
auto tensor_data = std::get<HostStorage>(tensor.get_storage()).buffer.view_as<uint8_t>();
157+
Shape tensor_shape{16, 1024, 1024};
158+
Shape shard_shape_nd_batch{16, 160, 160};
159+
Shape shard_shape_nd_small{1, 64, 64};
160+
Shape2D shard_shape_2d{2368, 160};
178161

179-
auto sharded_data = pack_nd_sharded_data<uint8_t>(tensor_data, tensor_spec);
180-
EXPECT_EQ(sharded_data.size(), params.expected_data.size());
181-
for (size_t i = 0; i < sharded_data.size(); i++) {
182-
EXPECT_EQ(sharded_data[i], static_cast<std::byte>(params.expected_data[i]));
162+
size_t volume = tensor_shape.volume();
163+
std::vector<uint16_t> data(volume);
164+
for (size_t i = 0; i < volume; i++) {
165+
data[i] = static_cast<uint16_t>(i);
183166
}
184167

185-
auto unpacked_data = unpack_nd_sharded_data<std::byte>(sharded_data, tensor_spec);
186-
EXPECT_EQ(unpacked_data.size(), tensor_data.size());
187-
for (size_t i = 0; i < unpacked_data.size(); i++) {
188-
EXPECT_EQ(unpacked_data[i], static_cast<std::byte>(tensor_data[i]));
189-
}
168+
auto measure_to_device_time_ns = [&](const TensorSpec& tensor_spec) -> double {
169+
auto tensor = Tensor::from_vector(data, tensor_spec);
170+
171+
auto start = std::chrono::high_resolution_clock::now();
172+
auto device_tensor = tensor.to_device(device_, tensor_spec.memory_config());
173+
auto end = std::chrono::high_resolution_clock::now();
174+
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
175+
return duration.count();
176+
};
177+
178+
double batch_nd_sharding_time_ns = [&]() {
179+
MemoryConfig memory_config{BufferType::L1, NdShardSpec{shard_shape_nd_batch, cores}};
180+
TensorLayout tensor_layout(DataType::UINT16, PageConfig(Layout::TILE), memory_config);
181+
TensorSpec tensor_spec(tensor_shape, tensor_layout);
182+
return measure_to_device_time_ns(tensor_spec);
183+
}();
184+
185+
double small_shards_nd_sharding_time_ns = [&]() {
186+
MemoryConfig memory_config{BufferType::L1, NdShardSpec{shard_shape_nd_small, cores}};
187+
TensorLayout tensor_layout(DataType::UINT16, PageConfig(Layout::TILE), memory_config);
188+
TensorSpec tensor_spec(tensor_shape, tensor_layout);
189+
return measure_to_device_time_ns(tensor_spec);
190+
}();
191+
192+
double block_2d_sharding_time_ns = [&]() {
193+
MemoryConfig memory_config{TensorMemoryLayout::BLOCK_SHARDED, BufferType::L1, ShardSpec{cores, shard_shape_2d}};
194+
TensorLayout tensor_layout(DataType::UINT16, PageConfig(Layout::TILE), memory_config);
195+
TensorSpec tensor_spec(tensor_shape, tensor_layout);
196+
return measure_to_device_time_ns(tensor_spec);
197+
}();
198+
199+
tt::log_info("Batch ND sharding time: {} ns", batch_nd_sharding_time_ns);
200+
tt::log_info("Small shards ND sharding time: {} ns", small_shards_nd_sharding_time_ns);
201+
tt::log_info("Block 2D sharding time: {} ns", block_2d_sharding_time_ns);
202+
203+
EXPECT_TRUE(batch_nd_sharding_time_ns < block_2d_sharding_time_ns * 4);
204+
EXPECT_TRUE(small_shards_nd_sharding_time_ns < block_2d_sharding_time_ns * 4);
190205
}
191206

192207
INSTANTIATE_TEST_SUITE_P(
@@ -739,111 +754,3 @@ INSTANTIATE_TEST_SUITE_P(
739754
.shard_shape = Shape({1, 1, 32 * 2, 32 * 2}),
740755
.grid_size = CoreCoord{3, 4},
741756
}));
742-
743-
INSTANTIATE_TEST_SUITE_P(
744-
TensorShardingTests,
745-
PrepareNdShardedDataTests,
746-
::testing::Values(
747-
PrepareShardedDataParams{
748-
.shape = Shape({2, 2, 2}),
749-
.shard_shape = Shape({2, 2, 2}),
750-
.num_cores = 1,
751-
.expected_data = {0, 1, 2, 3, 4, 5, 6, 7},
752-
},
753-
PrepareShardedDataParams{
754-
.shape = Shape({2, 2, 2}),
755-
.shard_shape = Shape({2, 2, 2}),
756-
.num_cores = 2,
757-
.expected_data = {0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0},
758-
},
759-
PrepareShardedDataParams{
760-
.shape = Shape({2, 2, 2}),
761-
.shard_shape = Shape({1, 2, 2}),
762-
.num_cores = 2,
763-
.expected_data = {0, 1, 2, 3, 4, 5, 6, 7},
764-
},
765-
PrepareShardedDataParams{
766-
.shape = Shape({2, 2, 2}),
767-
.shard_shape = Shape({2, 1, 2}),
768-
.num_cores = 2,
769-
.expected_data = {0, 1, 4, 5, 2, 3, 6, 7},
770-
},
771-
PrepareShardedDataParams{
772-
.shape = Shape({2, 2, 2}),
773-
.shard_shape = Shape({2, 2, 1}),
774-
.num_cores = 2,
775-
.expected_data = {0, 2, 4, 6, 1, 3, 5, 7},
776-
},
777-
PrepareShardedDataParams{
778-
.shape = Shape({2, 2, 2}),
779-
.shard_shape = Shape({2, 1, 1}),
780-
.num_cores = 2,
781-
.expected_data = {0, 4, 2, 6, 1, 5, 3, 7},
782-
},
783-
PrepareShardedDataParams{
784-
.shape = Shape({2, 2, 2}),
785-
.shard_shape = Shape({2, 1, 1}),
786-
.num_cores = 3,
787-
.expected_data = {0, 4, 3, 7, 1, 5, 0, 0, 2, 6, 0, 0},
788-
},
789-
PrepareShardedDataParams{
790-
.shape = Shape({2, 2, 2}),
791-
.shard_shape = Shape({1, 2, 1}),
792-
.num_cores = 2,
793-
.expected_data = {0, 2, 4, 6, 1, 3, 5, 7},
794-
},
795-
PrepareShardedDataParams{
796-
.shape = Shape({2, 2, 2}),
797-
.shard_shape = Shape({1, 2, 1}),
798-
.num_cores = 3,
799-
.expected_data = {0, 2, 5, 7, 1, 3, 0, 0, 4, 6, 0, 0},
800-
},
801-
PrepareShardedDataParams{
802-
.shape = Shape({2, 2, 2}),
803-
.shard_shape = Shape({1, 1, 2}),
804-
.num_cores = 2,
805-
.expected_data = {0, 1, 4, 5, 2, 3, 6, 7},
806-
},
807-
PrepareShardedDataParams{
808-
.shape = Shape({2, 2, 2}),
809-
.shard_shape = Shape({1, 1, 2}),
810-
.num_cores = 3,
811-
.expected_data = {0, 1, 6, 7, 2, 3, 0, 0, 4, 5, 0, 0},
812-
},
813-
PrepareShardedDataParams{
814-
.shape = Shape({2, 2, 2}),
815-
.shard_shape = Shape({1, 1, 1}),
816-
.num_cores = 2,
817-
.expected_data = {0, 2, 4, 6, 1, 3, 5, 7},
818-
},
819-
PrepareShardedDataParams{
820-
.shape = Shape({2, 2, 2}),
821-
.shard_shape = Shape({1, 1, 1}),
822-
.num_cores = 3,
823-
.expected_data = {0, 3, 6, 1, 4, 7, 2, 5, 0},
824-
},
825-
PrepareShardedDataParams{
826-
.shape = Shape({2, 2, 2}),
827-
.shard_shape = Shape({1, 1, 1}),
828-
.num_cores = 4,
829-
.expected_data = {0, 4, 1, 5, 2, 6, 3, 7},
830-
},
831-
PrepareShardedDataParams{
832-
.shape = Shape({2, 2, 2}),
833-
.shard_shape = Shape({1, 1, 1}),
834-
.num_cores = 5,
835-
.expected_data = {0, 5, 1, 6, 2, 7, 3, 0, 4, 0},
836-
},
837-
PrepareShardedDataParams{
838-
.shape = Shape({3, 3, 3}),
839-
.shard_shape = Shape({2, 2, 2}),
840-
.num_cores = 8,
841-
.expected_data = {/* core 0 */ 0, 1, 3, 4, 9, 10, 12, 13,
842-
/* core 1 */ 2, 0, 5, 0, 11, 0, 14, 0,
843-
/* core 2 */ 6, 7, 0, 0, 15, 16, 0, 0,
844-
/* core 3 */ 8, 0, 0, 0, 17, 0, 0, 0,
845-
/* core 4 */ 18, 19, 21, 22, 0, 0, 0, 0,
846-
/* core 5 */ 20, 0, 23, 0, 0, 0, 0, 0,
847-
/* core 6 */ 24, 25, 0, 0, 0, 0, 0, 0,
848-
/* core 7 */ 26, 0, 0, 0, 0, 0, 0, 0},
849-
}));

tt_metal/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ target_sources(
4040
api/tt-metalium/blockfloat_common.hpp
4141
api/tt-metalium/buffer.hpp
4242
api/tt-metalium/buffer_distribution_spec.hpp
43+
api/tt-metalium/buffer_page_mapping.hpp
4344
api/tt-metalium/buffer_types.hpp
4445
api/tt-metalium/circular_buffer.hpp
4546
api/tt-metalium/circular_buffer_constants.h
@@ -54,7 +55,6 @@ target_sources(
5455
api/tt-metalium/device.hpp
5556
api/tt-metalium/device_pool.hpp
5657
api/tt-metalium/dispatch_core_common.hpp
57-
api/tt-metalium/distribution_spec.hpp
5858
api/tt-metalium/event.hpp
5959
api/tt-metalium/fabric_host_interface.h
6060
api/tt-metalium/fabric_edm_packet_header.hpp

tt_metal/api/tt-metalium/buffer.hpp

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tt-metalium/core_coord.hpp>
3030
#include <tt-metalium/hal_types.hpp>
3131
#include <tt-metalium/sub_device_types.hpp>
32+
#include <tt-metalium/buffer_page_mapping.hpp>
3233
#include <umd/device/tt_core_coordinates.h>
3334
#include <umd/device/tt_soc_descriptor.h>
3435
#include <umd/device/types/xy_pair.h>
@@ -162,20 +163,6 @@ struct ShardedBufferConfig {
162163

163164
bool is_sharded(const TensorMemoryLayout& layout);
164165

165-
struct BufferPageMapping {
166-
std::vector<CoreCoord> all_cores_;
167-
std::vector<uint32_t> core_bank_indices_;
168-
std::vector<std::vector<uint32_t>> core_host_page_indices_;
169-
std::vector<uint32_t> dev_page_to_core_mapping_;
170-
171-
// some dev pages don't have mapping to host (in case of padding)
172-
std::vector<std::optional<uint32_t>> dev_page_to_host_page_mapping_;
173-
std::vector<uint32_t> host_page_to_dev_page_mapping_;
174-
std::unordered_map<CoreCoord, uint32_t> core_to_core_id_;
175-
std::vector<uint32_t> host_page_to_local_shard_page_mapping_;
176-
std::vector<std::array<uint32_t, 2>> core_shard_shape_;
177-
};
178-
179166
struct BufferRegion {
180167
DeviceAddr offset = 0;
181168
DeviceAddr size = 0;
@@ -267,24 +254,7 @@ class Buffer final {
267254
// SHARDED API STARTS HERE
268255
// If buffer contains BufferDistributionSpec, it is considered ND sharded
269256
bool is_nd_sharded() const;
270-
271-
/* BankDataMapping is a struct that provides an explicit mapping of data per bank:
272-
* - banks: Logical coordinates of banks to use
273-
* - bank_mapping_in_bytes: Mapping of data in bytes for each bank; it is a list of ChunkMapping which contains:
274-
* - src: host address offset in bytes
275-
* - dst: bank address offset in bytes
276-
* - size: size of data in bytes
277-
* Some notes:
278-
* - Size of banks and bank_mapping_in_bytes must be equal, with each bank having a corresponding mapping
279-
* - Each TargetData is a list of ChunkMapping which fully describes all data relevant to that bank
280-
* - In Buffer, all ChunkMapping are in bytes and takes into account page size and aligned page size
281-
* - Also see DistributionSpec class for more details about TargetData and ChunkMapping
282-
*/
283-
struct BankDataMapping {
284-
std::vector<CoreCoord> banks;
285-
std::vector<DistributionSpec::TargetData> bank_mapping_in_bytes;
286-
};
287-
BankDataMapping get_bank_data_mapping();
257+
const std::optional<BufferDistributionSpec>& buffer_distribution_spec() const;
288258

289259
// TODO: WILL SEPARATE INTO SHARDED BUFFER CLASS
290260

@@ -354,7 +324,6 @@ class Buffer final {
354324
std::shared_ptr<const BufferPageMapping> buffer_page_mapping_;
355325

356326
std::optional<BufferDistributionSpec> buffer_distribution_spec_;
357-
std::optional<std::vector<DistributionSpec::TargetData>> bank_mapping_in_bytes_ = std::nullopt;
358327

359328
size_t unique_id_ = 0;
360329
static std::atomic<size_t> next_unique_id;

tt_metal/api/tt-metalium/buffer_distribution_spec.hpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,43 @@
66

77
#include <tt-metalium/buffer_types.hpp>
88
#include <tt-metalium/core_coord.hpp>
9-
#include <tt-metalium/distribution_spec.hpp>
9+
#include <tt-metalium/shape.hpp>
1010
#include <tt-metalium/shape2d.hpp>
11+
#include <tt-metalium/buffer_page_mapping.hpp>
1112

1213
namespace tt::tt_metal {
1314

1415
class BufferDistributionSpec {
1516
public:
1617
static BufferDistributionSpec from_shard_spec(
17-
const tt::tt_metal::Shape& tensor_shape,
18-
const tt::tt_metal::Shape& physical_shard_shape,
19-
const Shape2D& page_shape,
20-
const CoreRangeSet& corerangeset,
21-
const ShardOrientation shard_orientation);
22-
23-
tt::tt_metal::Shape get_tensor_shape_in_pages() const { return page_distribution_spec_.get_tensor_shape(); }
24-
tt::tt_metal::Shape get_shard_shape_in_pages() const { return page_distribution_spec_.get_shard_shape(); }
25-
26-
size_t num_dev_pages_per_core() const {
27-
return page_distribution_spec_.get_shard_shape().volume() *
28-
page_distribution_spec_.get_max_num_shards_per_target();
29-
}
18+
tt::tt_metal::Shape tensor_shape,
19+
tt::tt_metal::Shape shard_shape,
20+
tt::tt_metal::Shape2D page_shape,
21+
CoreRangeSet core_range_set,
22+
ShardOrientation shard_orientation);
23+
24+
BufferDistributionSpec(
25+
tt::tt_metal::Shape tensor_shape_in_pages,
26+
tt::tt_metal::Shape shard_shape_in_pages,
27+
CoreRangeSet core_range_set,
28+
ShardOrientation shard_orientation);
29+
30+
tt::tt_metal::Shape get_tensor_shape_in_pages() const { return tensor_shape_in_pages_; }
31+
tt::tt_metal::Shape get_shard_shape_in_pages() const { return shard_shape_in_pages_; }
32+
33+
size_t num_shards() const;
34+
size_t num_shards_per_core() const;
35+
size_t num_dev_pages_per_core() const;
3036
size_t num_cores() const { return cores_.size(); }
3137
const std::vector<CoreCoord>& get_cores() const { return cores_; }
3238

33-
const std::vector<DistributionSpec::TargetData>& get_page_mapping(DistributionSpec::MappingMode mapping_mode);
39+
BufferPageMapping compute_page_mapping() const;
3440

3541
private:
36-
BufferDistributionSpec(const DistributionSpec& page_distribution_spec, const std::vector<CoreCoord>& cores);
42+
tt::tt_metal::Shape tensor_shape_in_pages_;
43+
tt::tt_metal::Shape shard_shape_in_pages_;
44+
ShardOrientation shard_orientation_ = ShardOrientation::ROW_MAJOR;
3745

38-
DistributionSpec page_distribution_spec_;
3946
std::vector<CoreCoord> cores_;
4047
};
4148

0 commit comments

Comments
 (0)