Skip to content

Commit 2914f22

Browse files
authored
#23224: Rework sharded buffer dispatch, added BufferRegion support for ND sharding (#23462)
### Ticket #23224 ### Problem description We need to achieve parity between 2D and ND sharding by adding support for specifying BufferRegion for ND sharded reads / writes. Currently buffer dispatch internally handles BufferRegion logic, which relies on getting 2D ShardSpec and utilizing the fact that BufferPageMapping isn't arbitrary. Moreover buffer dispatch logic utilizes ShardSpec to perform some optimizations for sharded buffer dispatch instead of relying on BufferPageMapping. This PR drastically changes sharded buffer dispatch by making it rely only on the provided page mapping (not on the kind of sharding or shard specs), and making BufferRegion handling completely universal and separate from the dispatch. ### What's changed BufferPageMapping used to contain a lot of duplicated and redundant fields with different kinds of mappings, after this PR it only has a single mapping, and it was renamed to UncompressedBufferPageMapping. This PR introduces new BufferPageMapping which can be created from UncompressedBufferPageMapping, and it basically identifies consecutive lists of pages into ranges. This allows to perform the same kind of optimizations in buffer dispatch in generic manner. Now BufferPageMapping handles the logic to apply a BufferRegion and do the appropriate filtering. Buffer now holds and returns BufferPageMapping instead of UncompressedBufferPageMapping. Added `Buffer::view(BufferRegion)` method which returns a Buffer corresponding to a region of the underlying root buffer while keeping the root buffer alive. `bank_local_page_address` and `sharded_page_address` were removed as obscure and not needed. Updated the buffer dispatch logic, drastically simplifying it and removing lots of edge cases Updated all of the related usages to accommodate the new BufferPageMapping API Added new tests for BufferRegion read/writes for ND sharding ### Checklist - [x] [All post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/15697895877) - [x] [Model regression CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/15661742169) - [x] [Device performance regression CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/15661743686) - [x] [Single-card demo tests CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/15661745385) - [x] New/Existing tests provide coverage for changes
1 parent d61c4d2 commit 2914f22

File tree

21 files changed

+790
-631
lines changed

21 files changed

+790
-631
lines changed

tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_non_blocking.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ TEST_F(DeviceFixture, TensixTestCircularBufferNonBlockingAPIs) {
118118

119119
std::vector<uint32_t> out_buf(data_buffer_size);
120120
for (size_t i = 0; i < n_cbs; i++) {
121-
tt::tt_metal::detail::ReadFromBuffer(master_data_buffers[i], out_buf, false);
121+
tt::tt_metal::detail::ReadFromBuffer(master_data_buffers[i], out_buf);
122122

123123
uint8_t const* raw_data = reinterpret_cast<uint8_t*>(out_buf.data());
124124
for (size_t pages_pushed = 0; pages_pushed < cb_n_pages; pages_pushed++) {
@@ -129,17 +129,4 @@ TEST_F(DeviceFixture, TensixTestCircularBufferNonBlockingAPIs) {
129129
}
130130
}
131131
}
132-
133-
for (size_t i = 0; i < n_cbs; i++) {
134-
tt::tt_metal::detail::ReadFromBuffer(subordinate_data_buffers[i], out_buf, true);
135-
136-
uint8_t const* raw_data = reinterpret_cast<uint8_t*>(out_buf.data());
137-
for (size_t pages_pushed = 0; pages_pushed < cb_n_pages; pages_pushed++) {
138-
for (size_t filled_pages_requested = 0; filled_pages_requested < cb_n_pages; filled_pages_requested++) {
139-
ASSERT_EQ(
140-
static_cast<bool>(raw_data[pages_pushed * cb_n_pages + filled_pages_requested]),
141-
filled_pages_requested <= pages_pushed);
142-
}
143-
}
144-
}
145132
}

tests/tt_metal/tt_metal/api/distribution_spec/test_buffer_distribution_spec.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
namespace distribution_spec_tests {
1616
using tt::tt_metal::BufferDistributionSpec;
17+
constexpr uint32_t PADDING = tt::tt_metal::UncompressedBufferPageMapping::PADDING;
1718

1819
struct BufferDistributionSpecInputs {
1920
tt::tt_metal::Shape physical_tensor_shape;
@@ -299,7 +300,9 @@ TEST_P(MeshBufferReadWriteTests, WriteReadLoopback) {
299300
}
300301
for (size_t empty_core_idx = expected_page_mapping.size(); empty_core_idx < page_mapping.size();
301302
empty_core_idx++) {
302-
EXPECT_EQ(page_mapping[empty_core_idx], std::vector<uint32_t>(page_mapping[empty_core_idx].size()));
303+
EXPECT_EQ(
304+
page_mapping[empty_core_idx],
305+
std::vector<uint32_t>(page_mapping[empty_core_idx].size(), UncompressedBufferPageMapping::PADDING));
303306
}
304307

305308
for (size_t i = 0; i < cores.size(); i++) {
@@ -313,7 +316,7 @@ TEST_P(MeshBufferReadWriteTests, WriteReadLoopback) {
313316

314317
const auto* result_per_core_ptr = reinterpret_cast<const uint8_t*>(result_per_core.data());
315318
for (size_t core_page = 0; core_page < page_mapping[i].size(); core_page++) {
316-
if (!page_mapping[i][core_page]) {
319+
if (page_mapping[i][core_page] == UncompressedBufferPageMapping::PADDING) {
317320
continue;
318321
}
319322
const auto host_page = page_mapping[i][core_page];
@@ -370,13 +373,13 @@ INSTANTIATE_TEST_SUITE_P(
370373
MeshBufferReadWriteExpected{
371374
.explicit_core_page_mapping = {
372375
{0, 1},
373-
{2, 0},
376+
{2, PADDING},
374377
{3, 4},
375-
{5, 0},
378+
{5, PADDING},
376379
{6, 7},
377-
{8, 0},
380+
{8, PADDING},
378381
{9, 10},
379-
{11, 0},
382+
{11, PADDING},
380383
},
381384
},
382385
},
@@ -395,10 +398,10 @@ INSTANTIATE_TEST_SUITE_P(
395398
},
396399
MeshBufferReadWriteExpected{
397400
.explicit_core_page_mapping = {
398-
{0, 1, 0, 2, 3, 0},
399-
{4, 5, 0, 6, 7, 0},
400-
{8, 9, 0, 10, 11, 0},
401-
{12, 13, 0, 14, 15, 0},
401+
{0, 1, PADDING, 2, 3, PADDING},
402+
{4, 5, PADDING, 6, 7, PADDING},
403+
{8, 9, PADDING, 10, 11, PADDING},
404+
{12, 13, PADDING, 14, 15, PADDING},
402405
},
403406
},
404407
},
@@ -416,8 +419,8 @@ INSTANTIATE_TEST_SUITE_P(
416419
},
417420
MeshBufferReadWriteExpected{
418421
.explicit_core_page_mapping = {
419-
{0, 2, 4, 0, 6, 8, 10, 0},
420-
{1, 3, 5, 0, 7, 9, 11, 0},
422+
{0, 2, 4, PADDING, 6, 8, 10, PADDING},
423+
{1, 3, 5, PADDING, 7, 9, 11, PADDING},
421424
},
422425
},
423426
},
@@ -437,15 +440,15 @@ INSTANTIATE_TEST_SUITE_P(
437440
MeshBufferReadWriteExpected{
438441
.explicit_core_page_mapping = {
439442
{0, 1, 3, 4, 30, 31, 33, 34},
440-
{2, 0, 5, 0, 32, 0, 35, 0},
441-
{6, 7, 9, 10, 0, 0, 0, 0},
442-
{8, 0, 11, 0, 0, 0, 0, 0},
443-
{12, 13, 15, 16, 0, 0, 0, 0},
444-
{14, 0, 17, 0, 0, 0, 0, 0},
445-
{18, 19, 21, 22, 0, 0, 0, 0},
446-
{20, 0, 23, 0, 0, 0, 0, 0},
447-
{24, 25, 27, 28, 0, 0, 0, 0},
448-
{26, 0, 29, 0, 0, 0, 0, 0}
443+
{2, PADDING, 5, PADDING, 32, PADDING, 35, PADDING},
444+
{6, 7, 9, 10, PADDING, PADDING, PADDING, PADDING},
445+
{8, PADDING, 11, PADDING, PADDING, PADDING, PADDING, PADDING},
446+
{12, 13, 15, 16, PADDING, PADDING, PADDING, PADDING},
447+
{14, PADDING, 17, PADDING, PADDING, PADDING, PADDING, PADDING},
448+
{18, 19, 21, 22, PADDING, PADDING, PADDING, PADDING},
449+
{20, PADDING, 23, PADDING, PADDING, PADDING, PADDING, PADDING},
450+
{24, 25, 27, 28, PADDING, PADDING, PADDING, PADDING},
451+
{26, PADDING, 29, PADDING, PADDING, PADDING, PADDING, PADDING}
449452
},
450453
},
451454
},
@@ -464,11 +467,11 @@ INSTANTIATE_TEST_SUITE_P(
464467
},
465468
MeshBufferReadWriteExpected{
466469
.explicit_core_page_mapping = {
467-
{0, 1, 3, 4, 12, 13, 15, 16, 26, 0, 29, 0, 38, 0, 41, 0, 54, 55, 57, 58, 0, 0, 0, 0},
468-
{2, 0, 5, 0, 14, 0, 17, 0, 30, 31, 33, 34, 42, 43, 45, 46, 56, 0, 59, 0, 0, 0, 0, 0},
469-
{6, 7, 9, 10, 18, 19, 21, 22, 32, 0, 35, 0, 44, 0, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0},
470-
{8, 0, 11, 0, 20, 0, 23, 0, 48, 49, 51, 52, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
471-
{24, 25, 27, 28, 36, 37, 39, 40, 50, 0, 53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
470+
{0, 1, 3, 4, 12, 13, 15, 16, 26, PADDING, 29, PADDING, 38, PADDING, 41, PADDING, 54, 55, 57, 58, PADDING, PADDING, PADDING, PADDING},
471+
{2, PADDING, 5, PADDING, 14, PADDING, 17, PADDING, 30, 31, 33, 34, 42, 43, 45, 46, 56, PADDING, 59, PADDING, PADDING, PADDING, PADDING, PADDING},
472+
{6, 7, 9, 10, 18, 19, 21, 22, 32, PADDING, 35, PADDING, 44, PADDING, 47, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING},
473+
{8, PADDING, 11, PADDING, 20, PADDING, 23, PADDING, 48, 49, 51, 52, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING},
474+
{24, 25, 27, 28, 36, 37, 39, 40, 50, PADDING, 53, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING, PADDING}
472475
},
473476
},
474477
}) // Values

tests/ttnn/unit_tests/gtests/tensor/test_tensor_nd_sharding.cpp

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,29 @@ struct NDShardingBufferSizeParams {
4545
size_t expected_num_dev_pages = 0;
4646
size_t expected_aligned_size_per_bank = 0;
4747
};
48-
} // namespace
49-
50-
class NDShardingTests
51-
: public ttnn::TTNNFixtureWithDevice,
52-
public ::testing::WithParamInterface<std::tuple<NDShardingParams, BufferType, ShardOrientation>> {};
53-
54-
TEST_P(NDShardingTests, LoopbackTest) {
55-
const auto& [params, buffer_type, orientation] = GetParam();
5648

49+
TensorSpec get_nd_sharding_tensor_spec(
50+
const NDShardingParams& params, BufferType buffer_type, ShardOrientation orientation, IDevice* device) {
5751
CoreRangeSet cores;
5852
if (buffer_type == BufferType::L1) {
5953
cores = CoreRangeSet(CoreRange(CoreCoord{0, 0}, CoreCoord{6, 6}));
6054
} else {
61-
auto dram_grid_size = device_->dram_grid_size();
55+
auto dram_grid_size = device->dram_grid_size();
6256
cores = CoreRangeSet(CoreRange(CoreCoord{0, 0}, CoreCoord{dram_grid_size.x - 1, dram_grid_size.y - 1}));
6357
}
6458
MemoryConfig memory_config{buffer_type, NdShardSpec{params.shard_shape, cores, orientation}};
6559
TensorLayout tensor_layout(DataType::UINT16, PageConfig(params.layout), memory_config);
66-
TensorSpec tensor_spec(params.shape, tensor_layout);
60+
return TensorSpec(params.shape, tensor_layout);
61+
}
62+
} // namespace
63+
64+
class NDShardingTests
65+
: public ttnn::TTNNFixtureWithDevice,
66+
public ::testing::WithParamInterface<std::tuple<NDShardingParams, BufferType, ShardOrientation>> {};
67+
68+
TEST_P(NDShardingTests, LoopbackTest) {
69+
const auto& [params, buffer_type, orientation] = GetParam();
70+
auto tensor_spec = get_nd_sharding_tensor_spec(params, buffer_type, orientation, device_);
6771

6872
size_t volume = params.shape.volume();
6973
std::vector<uint16_t> data(volume);
@@ -79,6 +83,55 @@ TEST_P(NDShardingTests, LoopbackTest) {
7983
}
8084
}
8185

86+
TEST_P(NDShardingTests, RegionWriteReadTest) {
87+
const auto& [params, buffer_type, orientation] = GetParam();
88+
auto tensor_spec = get_nd_sharding_tensor_spec(params, buffer_type, orientation, device_);
89+
90+
size_t volume = params.shape.volume();
91+
std::vector<uint16_t> data(volume);
92+
for (size_t i = 0; i < data.size(); i++) {
93+
data[i] = static_cast<uint16_t>(i);
94+
}
95+
auto data_tensor = Tensor::from_vector(data, tensor_spec);
96+
auto tensor_data_span = host_buffer::get_as<uint16_t>(data_tensor);
97+
auto tensor_data = std::vector<uint16_t>(tensor_data_span.begin(), tensor_data_span.end());
98+
99+
std::vector<uint16_t> empty_data(volume);
100+
auto tensor = Tensor::from_vector(empty_data, tensor_spec, device_);
101+
102+
auto& storage = std::get<DeviceStorage>(tensor.storage());
103+
auto buffer = storage.get_buffer();
104+
auto page_size = buffer->page_size();
105+
auto device = buffer->device();
106+
107+
size_t region_size = buffer->page_size();
108+
while (buffer->size() % (region_size * 2) == 0) {
109+
region_size *= 2;
110+
}
111+
112+
std::vector<uint16_t> partial_readback_data(tensor_data.size());
113+
std::vector<uint16_t> full_readback_data(tensor_data.size());
114+
115+
for (size_t region = 0; region < buffer->size() / region_size; region++) {
116+
size_t region_offset = region * region_size;
117+
auto buffer_view = buffer->view(BufferRegion{region_offset, region_size});
118+
EnqueueWriteBuffer(
119+
device->command_queue(),
120+
buffer_view,
121+
reinterpret_cast<const std::byte*>(tensor_data.data()) + region_offset,
122+
true);
123+
EnqueueReadBuffer(
124+
device->command_queue(),
125+
buffer_view,
126+
reinterpret_cast<std::byte*>(partial_readback_data.data()) + region_offset,
127+
true);
128+
}
129+
EXPECT_EQ(tensor_data, partial_readback_data);
130+
131+
EnqueueReadBuffer(device->command_queue(), *buffer, full_readback_data.data(), true);
132+
EXPECT_EQ(tensor_data, full_readback_data);
133+
}
134+
82135
class LegacyToNdShardingTests : public ::testing::TestWithParam<LegacyToNdShardingParams> {};
83136

84137
TEST_P(LegacyToNdShardingTests, LegacyToNdSharding) {

tt_metal/api/tt-metalium/buffer.hpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ struct BufferRegion {
171171
BufferRegion(DeviceAddr offset, DeviceAddr size) : offset(offset), size(size) {}
172172
};
173173

174-
class Buffer final {
174+
class Buffer final : public std::enable_shared_from_this<Buffer> {
175175
// Used in public Buffer constructors so they are only callable within Buffer
176176
// Buffer constructors are public so we can call std::make_shared on Buffer
177177
struct Private {
@@ -208,6 +208,11 @@ class Buffer final {
208208
std::optional<bool> bottom_up = std::nullopt,
209209
std::optional<SubDeviceId> sub_device_id = std::nullopt);
210210

211+
// Creates a view of the region of the buffer.
212+
// The view is a new buffer (unless the region is the entire buffer) that shares the same underlying device memory.
213+
// The view keeps the underlying buffer alive as long as the view is alive.
214+
std::shared_ptr<Buffer> view(const BufferRegion& region);
215+
211216
Buffer(const Buffer& other) = delete;
212217
Buffer& operator=(const Buffer& other) = delete;
213218
Buffer(Buffer&& other) = delete;
@@ -245,29 +250,23 @@ class Buffer final {
245250

246251
DeviceAddr page_address(uint32_t bank_id, uint32_t page_index) const;
247252

248-
DeviceAddr bank_local_page_address(uint32_t bank_id, uint32_t page_index) const;
249253
uint32_t alignment() const;
250254
DeviceAddr aligned_page_size() const;
251255
DeviceAddr aligned_size() const;
252256
DeviceAddr aligned_size_per_bank() const;
253257

254258
// SHARDED API STARTS HERE
255-
// If buffer contains BufferDistributionSpec, it is considered ND sharded
256-
bool is_nd_sharded() const;
257259
const std::optional<BufferDistributionSpec>& buffer_distribution_spec() const;
258-
259-
// TODO: WILL SEPARATE INTO SHARDED BUFFER CLASS
260-
261-
DeviceAddr sharded_page_address(uint32_t bank_id, uint32_t page_index) const;
262-
263260
ShardSpecBuffer shard_spec() const;
264261
void set_shard_spec(const ShardSpecBuffer& shard_spec);
265-
266-
// TODO: Consolidate with interleaved and delete this (maybe get from BufferDistributionSpec)
267262
std::optional<uint32_t> num_cores() const;
268-
269263
const std::shared_ptr<const BufferPageMapping>& get_buffer_page_mapping();
270264

265+
// Returns the buffer that owns the underlying device memory.
266+
// Typically returns itself unless the buffer was created with a view method.
267+
std::shared_ptr<Buffer> root_buffer();
268+
BufferRegion root_buffer_region() const { return BufferRegion(root_buffer_offset_, size_); }
269+
271270
std::optional<SubDeviceId> sub_device_id() const { return sub_device_id_; }
272271

273272
size_t unique_id() const { return unique_id_; }
@@ -325,11 +324,17 @@ class Buffer final {
325324

326325
std::optional<BufferDistributionSpec> buffer_distribution_spec_;
327326

327+
// The root buffer is the buffer that owns the underlying device memory.
328+
// The root buffer is populated only when the buffer was created with a view method.
329+
std::shared_ptr<Buffer> root_buffer_;
330+
// Offset of the current view buffer in the root buffer
331+
DeviceAddr root_buffer_offset_ = 0;
332+
328333
size_t unique_id_ = 0;
329334
static std::atomic<size_t> next_unique_id;
330335
};
331336

332-
BufferPageMapping generate_buffer_page_mapping(const Buffer& buffer);
337+
UncompressedBufferPageMapping generate_buffer_page_mapping(const Buffer& buffer);
333338

334339
using HostDataType = std::variant<
335340
const std::shared_ptr<std::vector<uint8_t>>,

tt_metal/api/tt-metalium/buffer_distribution_spec.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class BufferDistributionSpec {
3636
size_t num_cores() const { return cores_.size(); }
3737
const std::vector<CoreCoord>& get_cores() const { return cores_; }
3838

39-
BufferPageMapping compute_page_mapping() const;
39+
UncompressedBufferPageMapping compute_page_mapping() const;
4040

4141
private:
4242
tt::tt_metal::Shape tensor_shape_in_pages_;

0 commit comments

Comments
 (0)