Skip to content

Commit e4980c6

Browse files
authored
Generic Pool Large Kernel Optimization (#23162)
Ticket N/A Problem description Generic Pool's performance is poor for large kernel sizes. What's changed - YoloV4's expected perf has been increased from 87.8 to 93.5 FPS. - Generic pool now supports 32 row reductions - A bug was fixed with the size of the intermediate / partials CB - A bug was fixed in the face dimension passed to unpack tilize - For Max Pool, the fill_with_val in the loop was eliminated. This is possible since the junk data left from previous iterations do not affect the max value. - in_cb initialization has been added for cases where there are not more intermediate reduction chunks than multibuffering chunks. This is necessary since the compute kernel always processes max_rows_per_reduction rows from the in_cb which may include uninitialized data if multibuffering is enabled. However when we have enough intermediate reduction chunks, the entire in_cb get's filled with valid data which cannot contain values larger than the max, thus initialization is not necessary. - Clear out tiles is now used for buffer initialization as well as for Avg Pool's fill_with_val called in the loop resulting in dramatically better performance in some cases. Note - Multi buffering does not require in-loop fill_with_val since one CB only processes a single top left index at a time, and if necessary the in_cb was initialized before the loop. - Junk data from previous top left indices is not an issue since all kernel positions have the same number of elements. - For both average pool and max pool we would not need to initialize the CB with the init value at all since we know we have kernel_HW > max_rows_per_reduction except that we are using multibuffering so there will usually be some dead space. It is possible that it is worth it to turn off multibuffering but more testing is required. Checklist ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/15646661523 - [x] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI passes: (same failure as main, unrelated to changes) https://github.com/tenstorrent/tt-metal/actions/runs/15646662524 - [x] [Model regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) CI passes: (same failure as main, unrelated to changes) https://github.com/tenstorrent/tt-metal/actions/runs/15646665080 - [x] [Device performance regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/15646663480 - [x] [Nightly L2](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml) CI passes: (wormhole) https://github.com/tenstorrent/tt-metal/actions/runs/15646668961 (blackhole) https://github.com/tenstorrent/tt-metal/actions/runs/15646670132 - [x] [Frequent model](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/15646666798 - [x] New/Existing tests provide coverage for changes
1 parent 613ef42 commit e4980c6

File tree

7 files changed

+139
-56
lines changed

7 files changed

+139
-56
lines changed

models/demos/yolov4/tests/perf/test_perf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_yolov4(
101101
@pytest.mark.parametrize(
102102
"batch_size, model_name, expected_perf",
103103
[
104-
(1, "yolov4", 87.8),
104+
(1, "yolov4", 93.5),
105105
],
106106
)
107107
@pytest.mark.models_device_performance_bare_metal

models/experimental/yolov8s_world/tests/test_perf_yolov8s_world.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_perf(device, use_pretrained_weight, use_program_cache):
132132
@pytest.mark.parametrize(
133133
"batch_size, expected_perf",
134134
[
135-
[1, 79.2],
135+
[1, 80.0],
136136
],
137137
)
138138
@pytest.mark.models_device_performance_bare_metal

ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/pool_2d_multi_core_large_kernel.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ template <
6464
uint32_t num_output_tiles,
6565
bool is_partial_tile,
6666
uint32_t max_rows_for_reduction,
67+
uint32_t unpA_face_r_dim,
6768
bool neginf_srca_maxpool,
6869
bool zero_srca_avgpool>
6970
inline void reduce_h_fused(const uint32_t interm_cb_id, const uint32_t in_scalar_cb_id, const uint32_t out_cb_id) {
@@ -80,7 +81,7 @@ inline void reduce_h_fused(const uint32_t interm_cb_id, const uint32_t in_scalar
8081
num_output_tiles,
8182
0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/,
8283
num_faces_in_input_tile /* unpack 1 or 2 faces ) */,
83-
max_rows_for_reduction);
84+
unpA_face_r_dim);
8485
for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) {
8586
reduce_tile_math(c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */);
8687
}
@@ -119,6 +120,8 @@ void MAIN {
119120
constexpr uint32_t interm_cb_id = get_compile_time_arg_val(15);
120121
constexpr uint32_t in_one_cb_id = get_compile_time_arg_val(16);
121122
constexpr bool one_scalar_per_core = get_compile_time_arg_val(17);
123+
constexpr uint32_t sync_cb_id1 = get_compile_time_arg_val(18);
124+
constexpr uint32_t sync_cb_id2 = get_compile_time_arg_val(19);
122125

123126
constexpr bool is_partial_tile = in_c < 32;
124127
static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16");
@@ -136,19 +139,18 @@ void MAIN {
136139
constexpr bool neginf_srca_maxpool = (REDUCE_OP == PoolType::MAX) ? true : false;
137140
constexpr bool zero_srca_avgpool = (REDUCE_OP == PoolType::SUM) ? true : false;
138141

142+
constexpr uint32_t face_r_dim = 16;
139143
tilizeA_B_reduce_init<neginf_srca_maxpool, zero_srca_avgpool>(
140-
in_cb_id_0,
141-
in_scalar_cb_id_0,
142-
max_tiles_per_iter,
143-
interm_cb_id,
144-
num_faces_in_input_tile,
145-
max_rows_for_reduction);
144+
in_cb_id_0, in_scalar_cb_id_0, max_tiles_per_iter, interm_cb_id, num_faces_in_input_tile, face_r_dim);
146145

147146
constexpr uint32_t remaining_elems = window_size_hw % max_rows_for_reduction;
148147
constexpr uint32_t interm_reduction_chunks =
149148
remaining_elems ? window_size_hw / max_rows_for_reduction + 1 : window_size_hw / max_rows_for_reduction;
150-
if constexpr (one_scalar_per_core) {
151-
cb_wait_front(in_scalar_cb_id_0, 1);
149+
150+
// wait for initialization to complete
151+
cb_wait_front(sync_cb_id1, 2);
152+
if constexpr (split_reader) {
153+
cb_wait_front(sync_cb_id2, 2);
152154
}
153155

154156
for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; ++i) {
@@ -171,7 +173,7 @@ void MAIN {
171173
is_partial_tile,
172174
max_rows_for_reduction,
173175
split_reader,
174-
max_rows_for_reduction,
176+
face_r_dim,
175177
neginf_srca_maxpool,
176178
zero_srca_avgpool>(in_cb_id_0, in_cb_id_1, curr_scalar_cb_id, i, h, interm_cb_id);
177179
}
@@ -184,6 +186,7 @@ void MAIN {
184186
max_tiles_per_iter,
185187
is_partial_tile,
186188
max_rows_for_reduction,
189+
face_r_dim,
187190
neginf_srca_maxpool,
188191
zero_srca_avgpool>(
189192
interm_cb_id, REDUCE_OP == PoolType::MAX ? in_scalar_cb_id_0 : in_one_cb_id, out_cb_id);
@@ -200,7 +203,7 @@ void MAIN {
200203
is_partial_tile,
201204
max_rows_for_reduction,
202205
split_reader,
203-
max_rows_for_reduction,
206+
face_r_dim,
204207
neginf_srca_maxpool,
205208
zero_srca_avgpool>(in_cb_id_0, in_cb_id_1, curr_scalar_cb_id, i, h, interm_cb_id);
206209
}
@@ -213,6 +216,7 @@ void MAIN {
213216
partial_iter_output_tiles,
214217
is_partial_tile,
215218
max_rows_for_reduction,
219+
face_r_dim,
216220
neginf_srca_maxpool,
217221
zero_srca_avgpool>(interm_cb_id, REDUCE_OP == PoolType::MAX ? in_scalar_cb_id_0 : in_one_cb_id, out_cb_id);
218222
if constexpr (!one_scalar_per_core) {

ttnn/cpp/ttnn/operations/pool/generic/device/kernels/dataflow/reader_pool2d_sharded_common.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,29 @@ ALWI bool fill_with_val(uint32_t begin_addr, uint32_t n, uint16_t val, bool unco
2020

2121
return true;
2222
}
23+
24+
template <uint32_t cb_id, uint32_t clear_value_cb_id>
25+
FORCE_INLINE void clear_out_tiles() {
26+
constexpr uint32_t tile_size = get_tile_size(cb_id);
27+
const uint32_t num_pages = get_local_cb_interface(cb_id).fifo_num_pages;
28+
const uint32_t num_tiles = get_local_cb_interface(cb_id).fifo_page_size / tile_size;
29+
const uint64_t clear_value_addr = get_noc_addr(get_read_ptr(clear_value_cb_id));
30+
uint64_t write_addr = get_noc_addr(get_write_ptr(cb_id));
31+
32+
for (uint32_t i = 0; i < num_tiles * num_pages; ++i) {
33+
noc_async_read(clear_value_addr, write_addr, tile_size);
34+
write_addr += tile_size;
35+
}
36+
noc_async_read_barrier();
37+
}
38+
39+
template <uint32_t clear_value_cb_id, uint32_t num_tiles>
40+
FORCE_INLINE void clear_out_tiles(uint64_t write_addr, uint64_t clear_value_addr) {
41+
constexpr uint32_t tile_size = get_tile_size(clear_value_cb_id);
42+
43+
for (uint32_t i = 0; i < num_tiles; ++i) {
44+
noc_async_read(clear_value_addr, write_addr, tile_size);
45+
write_addr += tile_size;
46+
}
47+
noc_async_read_barrier();
48+
}

ttnn/cpp/ttnn/operations/pool/generic/device/kernels/dataflow/reader_pool_2d_multi_core_sharded.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,6 @@
1313
#include "debug/dprint_pages.h"
1414
#endif
1515

16-
template <uint32_t cb_id, uint32_t clear_value_cb_id>
17-
FORCE_INLINE void clear_out_tiles() {
18-
constexpr uint32_t tile_size = get_tile_size(cb_id);
19-
const uint32_t num_pages = get_local_cb_interface(cb_id).fifo_num_pages;
20-
const uint32_t num_tiles = get_local_cb_interface(cb_id).fifo_page_size / tile_size;
21-
const uint64_t clear_value_addr = get_noc_addr(get_read_ptr(clear_value_cb_id));
22-
uint64_t write_addr = get_noc_addr(get_write_ptr(cb_id));
23-
24-
for (uint32_t i = 0; i < num_tiles * num_pages; ++i) {
25-
noc_async_read(clear_value_addr, write_addr, tile_size);
26-
write_addr += tile_size;
27-
}
28-
noc_async_read_barrier();
29-
}
30-
31-
template <uint32_t clear_value_cb_id, uint32_t num_tiles>
32-
FORCE_INLINE void clear_out_tiles(uint64_t write_addr, uint64_t clear_value_addr) {
33-
constexpr uint32_t tile_size = get_tile_size(clear_value_cb_id);
34-
35-
for (uint32_t i = 0; i < num_tiles; ++i) {
36-
noc_async_read(clear_value_addr, write_addr, tile_size);
37-
write_addr += tile_size;
38-
}
39-
noc_async_write_barrier();
40-
}
41-
4216
/**
4317
* Pool 2D (Max pool 2D and Avg pool 2D)
4418
*/

ttnn/cpp/ttnn/operations/pool/generic/device/kernels/dataflow/reader_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void kernel_main() {
4444
constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(14);
4545
constexpr uint32_t ceil_pad_w = get_compile_time_arg_val(15);
4646

47+
constexpr uint32_t TILE_HEIGHT = 32;
4748
constexpr uint32_t TILE_WIDTH = 32;
4849
constexpr uint32_t MAX_ELE_PER_REDUCTION = 512; // TILE_WIDTH * 8 * numbytes
4950

@@ -54,8 +55,14 @@ void kernel_main() {
5455
constexpr uint32_t in_scalar_cb_id_1 = get_compile_time_arg_val(21);
5556
constexpr uint32_t interm_reduction_cb_id = get_compile_time_arg_val(22);
5657
constexpr uint32_t in_one_cb_id = get_compile_time_arg_val(23);
58+
constexpr uint32_t clear_value_cb_id = get_compile_time_arg_val(24);
59+
constexpr bool is_avg_pool = (bool)get_compile_time_arg_val(25);
5760
constexpr bool one_scalar_per_core = get_compile_time_arg_val(26);
5861
constexpr uint32_t config_cb_id = get_compile_time_arg_val(27);
62+
constexpr uint32_t multi_buffering_factor = get_compile_time_arg_val(28);
63+
constexpr uint32_t sync_cb_id1 = get_compile_time_arg_val(29);
64+
constexpr uint32_t sync_cb_id2 = get_compile_time_arg_val(30);
65+
5966
constexpr uint32_t in_scalar_cb_id =
6067
split_reader && reader_id == 1 && !one_scalar_per_core ? in_scalar_cb_id_1 : in_scalar_cb_id_0;
6168

@@ -64,21 +71,68 @@ void kernel_main() {
6471
uint32_t scalar_end = 1;
6572
uint32_t scalar_value = 0;
6673

74+
constexpr uint32_t window_size_hw = window_h * window_w;
75+
constexpr uint32_t remaining_elems = window_size_hw % max_rows_for_reduction;
76+
constexpr uint32_t interm_reduction_chunks =
77+
remaining_elems ? window_size_hw / max_rows_for_reduction + 1 : window_size_hw / max_rows_for_reduction;
78+
// we only need to initialize the in_cb if we will not fill each multibuffering chunk with max_rows worth of data
79+
constexpr bool need_to_initialize_in_cb = remaining_elems && interm_reduction_chunks <= multi_buffering_factor;
80+
constexpr uint32_t in_cb_ntiles = in_cb_sz / (TILE_WIDTH * TILE_HEIGHT); // only use the non-multi buffering size
81+
82+
// fill the clear cb
83+
if constexpr (split_reader) {
84+
constexpr uint32_t half_tile = TILE_HEIGHT * TILE_WIDTH / 2;
85+
if constexpr (reader_id == 0) {
86+
fill_with_val(get_write_ptr(clear_value_cb_id), half_tile, bf16_init_value);
87+
} else {
88+
fill_with_val(get_write_ptr(clear_value_cb_id) + 2 * half_tile, half_tile, bf16_init_value); // 2 for bf16
89+
}
90+
} else {
91+
if constexpr (reader_id == 0) {
92+
fill_with_val(get_write_ptr(clear_value_cb_id), TILE_HEIGHT * TILE_WIDTH, bf16_init_value);
93+
}
94+
}
95+
96+
// ensure the clear CB is full before proceeding
97+
if constexpr (reader_id == 0) {
98+
cb_push_back(sync_cb_id1, 1);
99+
if constexpr (split_reader) {
100+
cb_wait_front(sync_cb_id2, 1);
101+
}
102+
} else {
103+
cb_push_back(sync_cb_id2, 1);
104+
cb_wait_front(sync_cb_id1, 1);
105+
}
106+
107+
if constexpr (need_to_initialize_in_cb && !is_avg_pool) { // for avg pool fill_with_val runs in loop, no need to
108+
// initialize
109+
clear_out_tiles<in_cb_id, clear_value_cb_id>();
110+
}
111+
67112
if constexpr (reader_id == 0) {
68113
constexpr uint32_t bf16_one_u16 = bf16_one_u32 >> 16;
69-
// fill interm buffer with init_value
70-
fill_with_val(get_write_ptr(interm_reduction_cb_id), in_cb_sz, bf16_init_value);
114+
// initialize buffers
115+
clear_out_tiles<interm_reduction_cb_id, clear_value_cb_id>();
71116
if constexpr (one_scalar_per_core) {
72-
cb_reserve_back(in_scalar_cb_id_0, 1);
73117
fill_with_val(get_write_ptr(in_scalar_cb_id_0), TILE_WIDTH, bf16_scalar >> 16);
74-
cb_push_back(in_scalar_cb_id_0, 1);
75118
}
76-
if (bf16_scalar != bf16_one_u32 || !one_scalar_per_core) {
77-
// Pool operation is not maxpool
119+
if constexpr (is_avg_pool) {
120+
// for avgpool, we use a one's CB to avoid double division by kernel size for large kernel case.
78121
fill_with_val(get_write_ptr(in_one_cb_id), TILE_WIDTH, bf16_one_u16);
79122
}
80123
}
81124

125+
// ensure initialization is done before proceeding
126+
if constexpr (reader_id == 0) {
127+
cb_push_back(sync_cb_id1, 1);
128+
if constexpr (split_reader) {
129+
cb_wait_front(sync_cb_id2, 2);
130+
}
131+
} else {
132+
cb_push_back(sync_cb_id2, 1);
133+
cb_wait_front(sync_cb_id1, 2);
134+
}
135+
82136
const uint32_t in_l1_read_base_addr = get_read_ptr(in_shard_cb_id);
83137
uint32_t reader_indices_l1_addr = get_read_ptr(in_reader_indices_cb_id);
84138
volatile tt_l1_ptr uint16_t* reader_indices_ptr =
@@ -90,7 +144,6 @@ void kernel_main() {
90144

91145
uint32_t counter = reader_id;
92146
constexpr uint32_t total_elems_to_reduce = window_h * window_w;
93-
constexpr uint32_t remaining_elems = total_elems_to_reduce % max_rows_for_reduction;
94147
constexpr bool wide_reduction = in_nblocks_c > 1;
95148
constexpr uint32_t read_bytes =
96149
wide_reduction ? MAX_ELE_PER_REDUCTION : in_nbytes_c; // in_cb is MAX_ELE_PER_REDUCTION for wide reductions
@@ -145,9 +198,17 @@ void kernel_main() {
145198
cb_push_back(in_cb_id, 1);
146199
cb_reserve_back(in_cb_id, 1);
147200
out_l1_write_addr = get_write_ptr(in_cb_id);
148-
// If next is last chunk, fill whole buffer with the init_value.
149-
if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) {
150-
fill_with_val(out_l1_write_addr, in_cb_sz, bf16_init_value);
201+
// If next is last chunk, fill whole buffer with the init_value. note for max pool we do
202+
// not need to fill the CB for the partial chunk since as long as we have N>1 chunks we
203+
// are guaranteed that the junk data remaining from chunk N-1 will fill the entire CB and
204+
// cannot contain values greater than the max value, and if we have N=1 chunks we already
205+
// initialized the entire CB with the init value, but for avg pool we need to fill the
206+
// entire CB with the init value since the junk data will contribute to the average.
207+
if constexpr (is_avg_pool) {
208+
if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) {
209+
clear_out_tiles<clear_value_cb_id, in_cb_ntiles>(
210+
get_noc_addr(out_l1_write_addr), get_noc_addr(get_read_ptr(clear_value_cb_id)));
211+
}
151212
}
152213
}
153214
}

ttnn/cpp/ttnn/operations/pool/generic/device/pool_multi_core_program_factory.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,14 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
296296
const bool is_large_kernel =
297297
is_partial_tile ? kernel_size_hw > tt::constants::TILE_HEIGHT / 2 : kernel_size_hw > tt::constants::TILE_HEIGHT;
298298

299-
// ToDo: enable 32 sticks per tile for reduction for all cases.
299+
// TODO: enable 32 sticks per tile for reduction for all cases, we can only support 16 row reductions for
300+
// partial tiles, and there is currently a bug forcing us to use 16 row reductions for avg pool when there
301+
// is 1 remainder C tile
300302
const uint32_t max_rows_for_reduction =
301-
(!is_partial_tile && !is_large_kernel) ? tt::constants::TILE_HEIGHT : tt::constants::TILE_HEIGHT / 2;
303+
!is_partial_tile && !(is_wide_reduction && pool_type == Pool2DType::AVG_POOL2D &&
304+
in_ntiles_c % MAX_TILES_PER_REDUCTION == 1)
305+
? tt::constants::TILE_HEIGHT
306+
: tt::constants::TILE_HEIGHT / 2;
302307
TT_FATAL(nblocks == 1, "Multiple blocks not yet supported");
303308

304309
if (input_shape[3] < tt::constants::TILE_WIDTH) {
@@ -360,14 +365,22 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
360365
}
361366

362367
uint32_t clear_value_cb_id = 32;
363-
if (max_rows_for_reduction == tt::constants::TILE_HEIGHT) {
368+
if (max_rows_for_reduction == tt::constants::TILE_HEIGHT || is_large_kernel ||
369+
(is_wide_reduction && in_ntiles_c % MAX_TILES_PER_REDUCTION != 0)) {
364370
// CB storing just "clear value" (-inf for maxpool, 0 for avgpool)
365-
// is needed only if we use more then 16 sticks per tile for reduction.
371+
// is needed only if we use more then 16 sticks per tile for reduction
372+
// or if we use large kernel size.
366373
clear_value_cb_id = next_cb_index++;
367374
tt::tt_metal::create_cb(clear_value_cb_id, program, all_cores, tile_size(in_df), 1, in_df);
368375
log_debug(tt::LogOp, "CB {} :: PS = {}, NP = {}", clear_value_cb_id, tile_size(in_df), 1);
369376
}
370377

378+
// CBs for NC/BR synchornization
379+
int32_t sync_cb_id1 = next_cb_index++;
380+
auto sync_cb1 = tt::tt_metal::create_cb(sync_cb_id1, program, all_cores, 2, 2, tt::DataFormat::UInt16);
381+
int32_t sync_cb_id2 = next_cb_index++;
382+
auto sync_cb2 = tt::tt_metal::create_cb(sync_cb_id2, program, all_cores, 2, 2, tt::DataFormat::UInt16);
383+
371384
// incoming data is the input cb instead of raw l1/dram addr
372385
// this input shard has halo and padding inserted.
373386
const uint32_t raw_in_cb_npages = input.shard_spec().value().shape[0];
@@ -441,7 +454,7 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
441454
uint32_t max_pool_partials_cb_id = 32;
442455
if (is_large_kernel) {
443456
max_pool_partials_cb_id = next_cb_index++; // max_pool partials
444-
const uint32_t max_pool_partials_cb_pagesize = out_cb_pagesize;
457+
const uint32_t max_pool_partials_cb_pagesize = in_cb_pagesize;
445458
const uint32_t max_pool_partials_cb_npages = nblocks;
446459

447460
tt::tt_metal::create_cb(
@@ -540,7 +553,10 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
540553
clear_value_cb_id,
541554
(uint32_t)pool_type,
542555
one_scalar_per_core,
543-
config_cb_id};
556+
config_cb_id,
557+
multi_buffering_factor,
558+
sync_cb_id1,
559+
sync_cb_id2};
544560
std::vector<uint32_t> reader1_ct_args = reader0_ct_args;
545561
reader1_ct_args[8] = 1; // split reader id for reader1
546562

@@ -589,7 +605,9 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
589605
out_cb_id,
590606
max_pool_partials_cb_id,
591607
in_one_cb_id,
592-
one_scalar_per_core};
608+
one_scalar_per_core,
609+
sync_cb_id1,
610+
sync_cb_id2};
593611

594612
auto compute_config = tt::tt_metal::ComputeConfig{
595613
.math_fidelity = MathFidelity::HiFi4,

0 commit comments

Comments
 (0)