Skip to content

Commit 3ddcbb3

Browse files
authored
32 Bit Indexing Support for MPWI (#35491)
### Ticket #27845 ### Problem description MPWI currently only suppports HW < 2^16. ### What's changed - MPWI has been updated to support sizes up to HW < 2^32. - FP32 DST is used for these cases - A check has been added to sliding window to ensure the per-core HW does not exceed Uint16 limits since sliding window still uses Uint16 config tensors. ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884522066 - [x] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884517960 - [x] [Model regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884549209 same failure as main - [x] [Device performance regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884545637 same failure as main - [x] [Nightly L2](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884525760 - [x] [Nightly Blackhole](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-nightly-tests.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884532202 same failures as main - [x] [Frequent model](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/21884537542 same failures as main
1 parent bc6d286 commit 3ddcbb3

File tree

11 files changed

+299
-174
lines changed

11 files changed

+299
-174
lines changed

tests/sweep_framework/sweep_utils/max_pool2d_with_indices_common.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def validate_indices(input_tensor, torch_indices, ttnn_indices, kernel_size, str
1515
"""
1616
Validate indices using logic from test_mpwi.py
1717
Note input tensors should be in [N, H, W, C] format
18+
Supports both uint16 and uint32 index tensors (indices should be converted to int64 before calling)
1819
Returns (indices_valid, tie_breaking_differences, actual_errors, value_differences, window_violations)
1920
"""
2021
batch_size, input_h, input_w, channels = input_tensor.shape
@@ -215,10 +216,18 @@ def run_max_pool2d_with_indices(
215216
)
216217

217218
ttnn_output_torch = ttnn.to_torch(ttnn_output)
218-
# convert indexes to int64 for compatability with torch
219+
220+
# convert indexes to int64 for compatibility with torch
219221
ttnn_indices_torch = ttnn.to_torch(ttnn_indices, dtype=torch.int64)
220-
# manually fix the wrapping since TTNN uint16 tensors get converted to int16 torch tensors, even when data type is specified as int64
221-
ttnn_indices_torch = torch.where(ttnn_indices_torch < 0, ttnn_indices_torch + 65536, ttnn_indices_torch)
222+
223+
# manually fix the wrapping since TTNN uint16/uint32 tensors get converted to int16/int32 torch tensors
224+
# even when data type is specified as int64
225+
if ttnn_indices.dtype == ttnn.uint16:
226+
# uint16: wraps at 65536 (2^16)
227+
ttnn_indices_torch = torch.where(ttnn_indices_torch < 0, ttnn_indices_torch + 65536, ttnn_indices_torch)
228+
elif ttnn_indices.dtype == ttnn.uint32:
229+
# uint32: wraps at 4294967296 (2^32)
230+
ttnn_indices_torch = torch.where(ttnn_indices_torch < 0, ttnn_indices_torch + 4294967296, ttnn_indices_torch)
222231

223232
torch_output, torch_indices = torch.nn.functional.max_pool2d(
224233
torch_input,

tests/ttnn/unit_tests/operations/pool/test_mpwi.py

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def test_mpwi_20_core_C_dims(device, in_c):
8282
[4, 64, 30, 40, 4, 8, 1, 1, 2, 4, 1, 1, False],
8383
],
8484
)
85-
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat16])
85+
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
8686
@skip_with_watcher("Test is not passing with watcher enabled github issue #37195")
87-
def test_mpwi_kernel_sizes(device, ttnn_dtype, input_spec):
87+
def test_mpwi_small_kernel_sizes(device, ttnn_dtype, input_spec):
8888
(
8989
in_n,
9090
in_c,
@@ -129,11 +129,6 @@ def test_mpwi_kernel_sizes(device, ttnn_dtype, input_spec):
129129
[
130130
# Contains following parameters
131131
# [batch_size, input_channels, input_height, input_width, kernel_height, kernel_width, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, ceil_mode]
132-
# DILATION / MULTI-BATCH CASES
133-
[2, 40, 100, 100, 3, 3, 2, 2, 0, 1, 2, 2, True],
134-
[3, 56, 85, 85, 3, 3, 3, 3, 1, 0, 2, 2, False],
135-
[4, 24, 56, 64, 3, 3, 2, 1, 1, 1, 3, 2, True],
136-
# LARGE KERNEL CASES
137132
[2, 64, 159, 159, 13, 13, 2, 2, 6, 6, 2, 2, True],
138133
[2, 40, 100, 100, 9, 9, 2, 2, 0, 1, 2, 2, True],
139134
[3, 56, 85, 85, 8, 8, 3, 3, 1, 0, 2, 2, False],
@@ -146,7 +141,7 @@ def test_mpwi_kernel_sizes(device, ttnn_dtype, input_spec):
146141
)
147142
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
148143
@skip_with_watcher("Test is not passing with watcher enabled github issue #37195")
149-
def test_mpwi_general(device, ttnn_dtype, input_spec):
144+
def test_mpwi_large_kernel_sizes(device, ttnn_dtype, input_spec):
150145
(
151146
in_n,
152147
in_c,
@@ -198,21 +193,28 @@ def test_mpwi_general(device, ttnn_dtype, input_spec):
198193
)
199194

200195

201-
@pytest.mark.skip(reason="DRAM slicing with return_indices is not yet supported")
202196
@pytest.mark.parametrize(
203197
"input_spec",
204198
[
205199
# Contains following parameters
206-
# [batch_size, input_channels, input_height, input_width, kernel_height, kernel_width, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, ceil_mode, num_slices]
207-
# DILATION / MULTI-BATCH CASES
208-
[2, 40, 1024, 1024, 3, 3, 2, 2, 0, 1, 2, 2, True, 8],
209-
[3, 56, 512, 512, 3, 3, 3, 3, 1, 0, 2, 2, False, 8],
210-
[4, 24, 768, 768, 3, 3, 2, 1, 1, 1, 3, 2, True, 8],
200+
# [batch_size, input_channels, input_height, input_width, kernel_height, kernel_width, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, ceil_mode]
201+
[3, 16, 80, 80, 9, 9, 3, 3, 3, 1, 2, 2, True],
202+
[2, 48, 60, 60, 6, 6, 2, 2, 2, 0, 2, 2, False],
203+
[4, 56, 65, 55, 5, 5, 1, 2, 1, 1, 1, 2, False],
204+
[4, 24, 56, 64, 3, 3, 2, 1, 0, 1, 3, 2, True],
211205
],
212206
)
213207
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
208+
@pytest.mark.parametrize(
209+
"sharding_scheme",
210+
[
211+
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
212+
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
213+
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
214+
],
215+
)
214216
@skip_with_watcher("Test is not passing with watcher enabled github issue #37195")
215-
def test_mpwi_dram_slice(device, ttnn_dtype, input_spec):
217+
def test_mpwi_general(device, ttnn_dtype, sharding_scheme, input_spec):
216218
(
217219
in_n,
218220
in_c,
@@ -227,9 +229,11 @@ def test_mpwi_dram_slice(device, ttnn_dtype, input_spec):
227229
dilation_h,
228230
dilation_w,
229231
ceil_mode,
230-
num_slices,
231232
) = input_spec
232-
dram_slice_config = ttnn.Conv2dSliceConfig(num_slices=num_slices, slice_type=ttnn.Conv2dDRAMSliceWidth)
233+
234+
if sharding_scheme == ttnn.TensorMemoryLayout.WIDTH_SHARDED and ttnn_dtype == ttnn.bfloat8_b:
235+
pytest.skip("this case runs OOM")
236+
233237
run_max_pool2d_with_indices(
234238
in_n,
235239
in_c,
@@ -245,10 +249,61 @@ def test_mpwi_dram_slice(device, ttnn_dtype, input_spec):
245249
dilation_w,
246250
ttnn_dtype,
247251
device,
248-
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
252+
sharding=sharding_scheme,
253+
ceil_mode=ceil_mode,
254+
memory_config=None,
255+
run_twice=True,
256+
config_tensor_in_dram=True,
257+
)
258+
259+
260+
@pytest.mark.parametrize(
261+
"input_spec",
262+
[
263+
[1, 32, 384, 384, 3, 3, 1, 1, 1, 1, 1, 1, False],
264+
[1, 48, 350, 350, 5, 5, 1, 1, 2, 2, 1, 1, False],
265+
[1, 64, 350, 350, 6, 6, 1, 1, 3, 3, 1, 1, False],
266+
[3, 32, 300, 300, 7, 7, 1, 1, 3, 3, 1, 1, False],
267+
[2, 48, 300, 300, 9, 9, 2, 2, 4, 4, 1, 1, False],
268+
],
269+
)
270+
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat16])
271+
@skip_with_watcher("Test is not passing with watcher enabled github issue #37195")
272+
def test_mpwi_32_bit_index(device, ttnn_dtype, input_spec):
273+
(
274+
in_n,
275+
in_c,
276+
in_h,
277+
in_w,
278+
kernel_h,
279+
kernel_w,
280+
stride_h,
281+
stride_w,
282+
pad_h,
283+
pad_w,
284+
dilation_h,
285+
dilation_w,
249286
ceil_mode,
250-
None, # no memory_config
251-
False, # not in place
252-
dram_slice_config=dram_slice_config,
287+
) = input_spec
288+
289+
run_max_pool2d_with_indices(
290+
in_n,
291+
in_c,
292+
in_h,
293+
in_w,
294+
kernel_h,
295+
kernel_w,
296+
stride_h,
297+
stride_w,
298+
pad_h,
299+
pad_w,
300+
dilation_h,
301+
dilation_w,
302+
ttnn_dtype,
303+
device,
304+
sharding=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
305+
ceil_mode=ceil_mode,
306+
memory_config=None,
307+
run_twice=True,
253308
config_tensor_in_dram=True,
254309
)

tt_metal/hw/inc/api/compute/compute_kernel_api.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,14 +595,12 @@ ALWI void topk_tile_init() { MATH((llk_math_eltwise_unary_sfpu_topk_init<true>()
595595
* acquired state via *acquire_dst* call. This call is blocking and is only
596596
* available on the compute engine.
597597
*
598-
* Only a reduction of 9 rows is supported at this time.
599-
*
600598
* | Argument | Description | Type | Valid Range | Required |
601599
* |-----------------|-----------------------------------------------------------------------------|------------|-------------------------------------------------------|----------|
602600
* | idst | The index of the tile in DST register containing the data to be reduced | uint32_t | Must be less than the size of the DST register buffer | True |
603601
* | idst_idx | The index of the tile in DST register containing the indices of the data | uint32_t | Must be less than the size of the DST register buffer | True |
604602
* | chunk | The index of the intra-kernel "chunk" of data for large kernel accumulation | uint32_t | 0 to UINT_MAX | False |
605-
* | num_rows | The number of rows to use for the MaxPool operation | uint32_t | {9} | False |
603+
* | num_rows | The number of rows to use for the MaxPool operation | uint32_t | <= 32, but note either 9 or 32 rows will be reduced | False |
606604
* | layout | The data layout of the data in DST | DataLayout | TILE or ROW_MAJOR | False |
607605
* | accumulate | Whether to accumulate results for large kernels | bool | true, false | False |
608606
* | ITERATIONS | The number of iterations to perform (unused) | int | 1 to 8 | False |

ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/compute_mpwi.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ void kernel_main() {
7373
constexpr uint32_t kernel_h = get_compile_time_arg_val(34);
7474
constexpr uint32_t kernel_w = get_compile_time_arg_val(35);
7575
constexpr uint32_t clear_value_cb_id = get_compile_time_arg_val(36);
76+
constexpr uint32_t indexes_32_bit = get_compile_time_arg_val(37);
77+
78+
constexpr DataFormat copy_format = indexes_32_bit ? DataFormat::UInt32 : DataFormat::UInt16;
7679

7780
constexpr uint32_t mpwi_cb_tile_idx = 0;
7881
constexpr uint32_t data_dst_idx = 0;
@@ -111,8 +114,8 @@ void kernel_main() {
111114

112115
uint32_t current_idx_col;
113116
uint32_t current_idx_row;
114-
const uint16_t start_row = (uint16_t)get_arg_val<uint32_t>(2);
115-
const uint16_t start_col = (uint16_t)get_arg_val<uint32_t>(3);
117+
const uint32_t start_row = get_arg_val<uint32_t>(2);
118+
const uint32_t start_col = get_arg_val<uint32_t>(3);
116119
current_idx_col = start_col;
117120
current_idx_row = start_row;
118121

@@ -127,7 +130,6 @@ void kernel_main() {
127130
}
128131

129132
unary_op_init_common(in_cb_id_0, in_cb_id_0);
130-
copy_tile_to_dst_init_short(in_cb_id_0);
131133
max_reduce_with_indices_init<ckernel::DataLayout::ROW_MAJOR>();
132134

133135
// if max out sticks is non-zero then this will be used as the number of out sticks for every core
@@ -144,6 +146,7 @@ void kernel_main() {
144146
tile_regs_acquire();
145147
uint32_t intra_kernel_h = 0;
146148
uint32_t intra_kernel_w = 0;
149+
copy_tile_to_dst_init_short(compute_tmp_idx_cb_id);
147150
reconfig_data_format_srca(compute_tmp_idx_cb_id);
148151
if (first_iteration) { // move the initial indexes from the reader to DST
149152
cb_wait_front(in_idx_cb_id, 1);
@@ -159,24 +162,27 @@ void kernel_main() {
159162
// clear the accumulation tiles since they will contain garbage data which is partially loaded
160163
// since max SFPU offset if 62 DST rows, but 4 rows are loaded each time so we load 2 rows of
161164
// DST tiles 1 and 3 during the reduction of tiles 0 and 2
165+
copy_tile_to_dst_init_short(clear_value_cb_id);
162166
reconfig_data_format_srca(clear_value_cb_id);
163167
copy_tile(clear_value_cb_id, mpwi_cb_tile_idx, data_accum_dst_idx);
164168

165169
// make a copy of the initial indexes to be used for restoring between C blocks
166-
copy_dest_values<DataFormat::UInt16>(index_dst_idx, index_temp_dst_idx);
170+
copy_dest_values<copy_format>(index_dst_idx, index_temp_dst_idx);
167171
}
168172

169173
for (uint32_t chunk = 0; chunk < interm_reduction_chunks; chunk++) {
170174
bool last_chunk = chunk == interm_reduction_chunks - 1;
171175

172176
cb_wait_front(curr_in_cb_id, 1);
177+
copy_tile_to_dst_init_short(curr_in_cb_id);
173178
reconfig_data_format_srca(curr_in_cb_id);
174179
copy_tile(curr_in_cb_id, mpwi_cb_tile_idx, data_dst_idx);
175180

176181
// increments happen between every chunk within a C block, and between C blocks
177182
bool increment_needed = false;
178183
if (last_c_block && last_chunk) { // increment for the next kernel position
179184
increment_needed = true;
185+
copy_tile_to_dst_init_short(compute_tmp_idx_cb_id);
180186
reconfig_data_format_srca(compute_tmp_idx_cb_id);
181187
// update the current index column
182188
if (current_idx_col + stride_w + eff_kernel_w > in_w_padded) {
@@ -198,6 +204,7 @@ void kernel_main() {
198204
} else if (is_large_kernel) { // only need to increment within C block if multiple chunks
199205
if (!last_chunk) { // increment for the next chunk within the same C block
200206
increment_needed = true;
207+
copy_tile_to_dst_init_short(compute_tmp_idx_cb_id);
201208
reconfig_data_format_srca(compute_tmp_idx_cb_id);
202209
if (intra_kernel_w + sticks_per_chunk < kernel_w) { // move right in this row
203210
intra_kernel_w += sticks_per_chunk;
@@ -210,23 +217,22 @@ void kernel_main() {
210217
}
211218
}
212219
if (!increment_needed) {
213-
copy_dest_values<DataFormat::UInt16>(index_dst_idx, index_scratch_out_dst_idx);
220+
copy_dest_values<copy_format>(index_dst_idx, index_scratch_out_dst_idx);
214221
} else {
215222
// we allow overflow here for negative values as this only occurs in padding regions
216223
add_int_tile_init();
217-
add_int_tile<DataFormat::UInt16>(index_dst_idx, inc_dst_idx, index_scratch_out_dst_idx);
224+
add_int_tile<copy_format>(index_dst_idx, inc_dst_idx, index_scratch_out_dst_idx);
218225
max_reduce_with_indices_init<ckernel::DataLayout::ROW_MAJOR>();
219226
}
220227

221-
// TODO # 27845: implement accumulation for <=9 MPWI SFPU so we can use this version for large kernels
222-
// as well
228+
// TODO implement accumulation for <=9 MPWI SFPU so we can use this version for large kernels as well
223229
constexpr uint32_t max_mpwi_kernel_size = window_size_hw <= 9 ? 9 : 32;
224230
max_reduce_with_indices<max_mpwi_kernel_size, ckernel::DataLayout::ROW_MAJOR, is_large_kernel>(
225231
data_dst_idx, index_dst_idx, chunk);
226232

227233
if constexpr (is_large_kernel) {
228234
if (!last_chunk) {
229-
copy_dest_values<DataFormat::UInt16>(index_scratch_out_dst_idx, index_dst_idx);
235+
copy_dest_values<copy_format>(index_scratch_out_dst_idx, index_dst_idx);
230236
}
231237
}
232238

@@ -236,7 +242,7 @@ void kernel_main() {
236242
// After all chunks: if not last C block, restore base indices for next C block
237243
if constexpr (is_large_kernel) {
238244
if (!last_c_block) {
239-
copy_dest_values<DataFormat::UInt16>(index_temp_dst_idx, index_scratch_out_dst_idx);
245+
copy_dest_values<copy_format>(index_temp_dst_idx, index_scratch_out_dst_idx);
240246
}
241247
}
242248

0 commit comments

Comments
 (0)