diff --git a/tests/nightly/tg/ccl/test_minimal_all_gather_async.py b/tests/nightly/tg/ccl/test_minimal_all_gather_async.py index 185d6643b43b..1b65615bf0ae 100644 --- a/tests/nightly/tg/ccl/test_minimal_all_gather_async.py +++ b/tests/nightly/tg/ccl/test_minimal_all_gather_async.py @@ -7,18 +7,8 @@ import ttnn from tests.nightly.t3000.ccl.test_minimal_all_gather_async import run_all_gather_impl -from tests.ttnn.multidevice_perf_tests.sweep_all_gather_hyperparameters_t3000 import get_max_chunks_per_sync from models.common.utility_functions import skip_for_blackhole, skip_for_wormhole_b0 from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.perf.benchmarking_utils import BenchmarkProfiler -from tracy import signpost - - -def create_fabric_router_config(max_payload_size): - """Helper to create FabricRouterConfig with custom max payload size.""" - config = ttnn._ttnn.fabric.FabricRouterConfig() - config.max_packet_payload_size_bytes = max_payload_size - return config @skip_for_blackhole("This test is for wormhole") @@ -564,155 +554,3 @@ def test_all_gather_async_wan_galaxy_4x32( # torch_reference = torch_input.repeat([devices, 1, 1, 1]) # eq, output = comp_equal(torch_output, torch_reference) # assert eq, f"Output mismatch between torch and ttnn all-gather: {output}" - - -@pytest.mark.parametrize( - "ag_output_shape, dim, cluster_axis, ag_input_dtype, layout, mem_config_input, mem_config_ag", - [ - ([1, 1, 9472, 5120], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), - ([1, 1, 9472, 256], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), - ([1, 1, 9472, 128], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), - ([1, 1, 118, 128], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), - ], - ids=[ - "spatial_activation", - "layernorm_stats", - "rmsnorm_stats_spatial", - "rmsnorm_stats_prompt", - ], -) -@pytest.mark.parametrize( - "device_params, all_gather_topology", - [ - ( - { - "fabric_config": ttnn.FabricConfig.FABRIC_1D_RING, - "fabric_router_config": create_fabric_router_config(8192), - "trace_region_size": 1000000, - }, - ttnn.Topology.Ring, - ), - ], - indirect=["device_params"], - ids=["fabric_ring"], -) -@pytest.mark.parametrize("num_links", [2], ids=lambda v: f"{v}links") -@pytest.mark.parametrize("chunks_per_sync", [320], ids=lambda v: f"{v}chunks") -@pytest.mark.parametrize("num_workers_per_link", [3], ids=lambda v: f"{v}workers") -@pytest.mark.parametrize("num_buffers_per_channel", [4], ids=lambda v: f"{v}buffers") -@pytest.mark.parametrize("num_iters, warmup_iters", [(75, 10)]) -@pytest.mark.parametrize("mesh_device", [(4, 8)], indirect=True) -def test_all_gather_wan( - mesh_device, - ag_output_shape, - dim, - cluster_axis, - ag_input_dtype, - layout, - mem_config_input, - mem_config_ag, - num_links, - chunks_per_sync, - num_workers_per_link, - num_buffers_per_channel, - all_gather_topology, - num_iters, - warmup_iters, -): - from loguru import logger - - # Create input tensor - mesh_shape = tuple(mesh_device.shape) - input_shape = ag_output_shape - num_devices = mesh_shape[cluster_axis] - - torch.manual_seed(2005) - torch_input = torch.rand(input_shape, dtype=torch.bfloat16) - - shard_dims = (None, dim) if cluster_axis == 1 else (dim, None) - tt_input = ttnn.from_torch( - torch_input, - layout=layout, - dtype=ag_input_dtype, - memory_config=mem_config_input, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=shard_dims, mesh_shape=mesh_shape), - device=mesh_device, - ) - - # AllGather config - if chunks_per_sync == "MAX": - chunks_per_sync_val = get_max_chunks_per_sync(num_devices, ag_output_shape, num_links) - else: - chunks_per_sync_val = chunks_per_sync - - # Compile Run - logger.info("Compiling op") - tt_output = ttnn.all_gather( - tt_input, - dim=dim, - cluster_axis=cluster_axis, - topology=all_gather_topology, - num_links=num_links, - memory_config=mem_config_ag, - chunks_per_sync=chunks_per_sync_val, - num_workers_per_link=num_workers_per_link, - num_buffers_per_channel=num_buffers_per_channel, - ) - ttnn.synchronize_device(mesh_device) - - # Check output - errors = [] - for dev, tt_out in enumerate(ttnn.get_device_tensors(tt_output)): - eq, mess = comp_pcc(torch_input, ttnn.to_torch(tt_out)) - if not eq: - errors.append(f"Device {dev}: {mess}") - assert not errors, f"PCC check failed on {len(errors)} device(s):\n" + "\n".join(errors) - - ################## TRACE RUN ####################### - - # Capture trace - logger.info("Capturing trace") - - def capture_trace(n_iters): - trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - for i in range(n_iters): - _ = ttnn.all_gather( - tt_input, - dim=dim, - cluster_axis=cluster_axis, - topology=all_gather_topology, - num_links=num_links, - memory_config=mem_config_ag, - chunks_per_sync=chunks_per_sync_val, - num_workers_per_link=num_workers_per_link, - num_buffers_per_channel=num_buffers_per_channel, - ) - ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) - ttnn.synchronize_device(mesh_device) - return trace_id - - if warmup_iters > 0: - trace_id_warmup = capture_trace(warmup_iters) - trace_id = capture_trace(num_iters) - - # Run the op - logger.info("Starting Trace perf test...") - profiler = BenchmarkProfiler() - profiler.start("all-gather-async-trace-warmup") - if warmup_iters > 0: - ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) - ttnn.release_trace(mesh_device, trace_id_warmup) - ttnn.synchronize_device(mesh_device) - profiler.end("all-gather-async-trace-warmup") - - profiler.start("all-gather-async-trace") - signpost("start") - ttnn.execute_trace(mesh_device, trace_id, blocking=False) - ttnn.release_trace(mesh_device, trace_id) - ttnn.synchronize_device(mesh_device) - signpost("stop") - profiler.end("all-gather-async-trace") - time_taken = profiler.get_duration("all-gather-async-trace") - logger.info(f"Time taken e2e: {time_taken} s") - logger.info(f"Time per iter e2e: {time_taken / num_iters} s") - logger.info(f"Time per iter e2e: {time_taken / num_iters * 1e6} us") diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_default_program_factory.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_default_program_factory.cpp index 909a4865d94d..266c21b6e334 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_default_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_default_program_factory.cpp @@ -383,12 +383,6 @@ AllGatherProgramArtifacts build_all_gather_async_minimal_default_program_artifac // L1 Scratch CB Creation const size_t packet_size_bytes = tt::tt_fabric::get_tt_fabric_channel_buffer_size_bytes(); uint32_t l1_scratch_cb_page_size_bytes = page_size; - TT_FATAL( - packet_size_bytes >= l1_scratch_cb_page_size_bytes, - "Fabric packet size ({} bytes) must be >= tensor page size ({} bytes). " - "Increase max_packet_payload_size_bytes in FabricRouterConfig.", - packet_size_bytes, - l1_scratch_cb_page_size_bytes); // scatter-write currently supports 4 distinct noc addresses uint32_t max_target_noc_addresses_per_packet = 4; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_reader.cpp index 32809e0a358a..7cc45203ae9a 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_reader.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_reader.cpp @@ -111,30 +111,28 @@ void kernel_main() { uint32_t tiles_read = input_tile_id_start; uint32_t tiles_to_read = input_tile_id_end; uint32_t output_tile_id_start = 0; - { - for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { - while (tiles_read < tiles_to_read) { - uint32_t tiles_remaining_to_read = tiles_to_read - tiles_read; - uint32_t num_tiles_to_read = std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); - - cb_reserve_back(cb_output_id, num_tiles_to_write_per_packet); - size_t l1_write_addr = get_write_ptr(cb_output_id); - for (uint32_t j = 0; j < num_tiles_to_read; ++j) { - uint32_t tile_id = output_tile_id_start + tiles_read; - uint64_t noc_read_addr = get_noc_addr(tile_id, input_tensor_addrgen); - noc_async_read(noc_read_addr, l1_write_addr, page_size); - - l1_write_addr += page_size; - tiles_read++; - } - - noc_async_read_barrier(); - cb_push_back(cb_output_id, num_tiles_to_write_per_packet); + for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { + while (tiles_read < tiles_to_read) { + uint32_t tiles_remaining_to_read = tiles_to_read - tiles_read; + uint32_t num_tiles_to_read = std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); + + cb_reserve_back(cb_output_id, num_tiles_to_write_per_packet); + size_t l1_write_addr = get_write_ptr(cb_output_id); + for (uint32_t j = 0; j < num_tiles_to_read; ++j) { + uint32_t tile_id = output_tile_id_start + tiles_read; + uint64_t noc_read_addr = get_noc_addr(tile_id, input_tensor_addrgen); + noc_async_read(noc_read_addr, l1_write_addr, page_size); + + l1_write_addr += page_size; + tiles_read++; } - tiles_read = input_tile_id_start; - tiles_to_read = input_tile_id_end; - output_tile_id_start += input_tensor_Wt * input_tensor_Ht; + + noc_async_read_barrier(); + cb_push_back(cb_output_id, num_tiles_to_write_per_packet); } + tiles_read = input_tile_id_start; + tiles_to_read = input_tile_id_end; + output_tile_id_start += input_tensor_Wt * input_tensor_Ht; } uint32_t slices_received = 0; @@ -158,37 +156,9 @@ void kernel_main() { } } - // Split forwarding for even-sized ring devices: each direction handles only its half of the split slice - // This must match the writer's split-forwarding logic - // - // Take a 4-device ring as an example: num_targets_forward=2, num_targets_backward=1 - // Without split-forwarding: - // - Forward receives 2 slices: device at -1 hop, device at -2 hops (opposite) - // - Backward receives 1 slice: device at +1 hop - // With split-forwarding: - // - Forward receives 2 slices: device at -1 hop (full), device at -2 hops (FIRST HALF) - // - Backward receives 2 slices: device at +1 hop (full), device at +2 hops (SECOND HALF) - // Note: backward needs to receive an ADDITIONAL slice for the second half! - bool split_forwarding_enabled = false; - if constexpr (topology == Topology::Ring) { - if (ring_size % 2 == 0 && ring_size > 2) { // if ring size is even, we need to write the first half of the - // tiles, otherwise we write the entire packet - split_forwarding_enabled = true; - // Match writer's special case: backward worker forwards half slice when num_targets_backward_direction == 1 - if (direction == 1) { - // slices_expected = 2; // Receive additional slice for second half of split data - writes_expected++; // Forward the first slice (device at +1 hop) - } - } - } - uint32_t chunk_count = 0; uint32_t sem_target = 0; - uint32_t slices_forwarded = 0; // Track forwarded slices separately for split-forwarding while (slices_received < slices_expected) { - // Check if this is the last slice and split forwarding applies - bool is_last_slice = (slices_received == slices_expected - 1); - bool is_split_slice = split_forwarding_enabled && is_last_slice; // Do i expect more from the backward direction? // In the linear case, I expect num_targets_backward_direction slices from the left // In the ring case, I expect num_targets_backward_direction slices from the right, (keep in mind this differs @@ -216,21 +186,9 @@ void kernel_main() { // Direction == forward: Should I forward what I got from the right to my left? // In the linear case, if I have any targets to my left, always forward // In the ring case, if I have received on the right less than my targets on the left, forward - bool should_forward = false; - if constexpr (topology == Topology::Linear) { - should_forward = (writes_expected > 0); - } else if constexpr (topology == Topology::Ring) { - should_forward = ((slices_received + 1) < (writes_expected + 1)); - } - - // CRITICAL: For split-forwarding, the writer's split condition is based on FORWARDED slice count - // (slice_writes == writes_expected - 1), not RECEIVED slice count. - // We must match the writer's logic to avoid CB deadlock. - bool is_last_forwarded_slice = should_forward && (slices_forwarded == writes_expected - 1); - bool is_split_forwarded_slice = split_forwarding_enabled && is_last_forwarded_slice; - - if (should_forward) { - // read the next slice out of memory, and put it in CB for writer to forward + if ((topology == Topology::Linear && writes_expected > 0) || + (topology == Topology::Ring && ((slices_received + 1) < (writes_expected + 1)))) { + // read the next backward slice out of memory, and put it in CB tiles_read = input_tile_id_start; tiles_to_read = input_tile_id_end; @@ -239,36 +197,6 @@ void kernel_main() { uint32_t row_offset = start_row_offset; uint32_t slice_Wt = input_tensor_Wt; uint32_t stride_Wt = output_tensor_Wt; - - // For split-forwarding: each direction only handles its half - // Forward (direction==0): first half, Backward (direction==1): second half - // Use is_split_forwarded_slice (based on forwarded count) to match writer's logic - if (is_split_forwarded_slice) { - uint32_t total_tiles = input_tile_id_end - input_tile_id_start; - uint32_t first_half_tiles = total_tiles / 2; - - if (direction == 0) { - // Forward reader: only process first half - tiles_to_read = input_tile_id_start + first_half_tiles; - } else { - // Backward reader: only process second half - tiles_read = input_tile_id_start + first_half_tiles; - - // Adjust row/column position to skip first half - uint32_t tiles_to_skip = first_half_tiles; - while (tiles_to_skip > 0) { - if (tiles_to_skip < slice_Wt - pages_read_in_row) { - pages_read_in_row += tiles_to_skip; - tiles_to_skip = 0; - } else { - tiles_to_skip -= (slice_Wt - pages_read_in_row); - row_offset += stride_Wt; - pages_read_in_row = 0; - } - } - } - } - if constexpr (gather_dim == 3) { output_tile_id_start = actual_sender_chip_id * input_tensor_Wt; } else if constexpr (gather_dim == 2) { @@ -281,105 +209,61 @@ void kernel_main() { } uint32_t num_channels_processed_in_current_batch = 0; - { - for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { - chunk_count = 0; - while (tiles_read < tiles_to_read) { - if (chunk_count % chunks_per_sync == 0) { - noc_semaphore_wait_min( - reinterpret_cast(out_ready_sem), sem_target + 1); - sem_target++; - } - chunk_count++; - - uint32_t tiles_remaining_to_read = tiles_to_read - tiles_read; - uint32_t num_tiles_to_read = std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); - - cb_reserve_back(cb_output_id, num_tiles_to_write_per_packet); - size_t l1_write_addr = get_write_ptr(cb_output_id); - for (uint32_t j = 0; j < num_tiles_to_read; ++j) { - uint32_t tile_id = output_tile_id_start + row_offset + pages_read_in_row; - uint64_t noc_read_addr = get_noc_addr(tile_id, output_tensor_addrgen); - noc_async_read(noc_read_addr, l1_write_addr, page_size); - - l1_write_addr += page_size; - tiles_read++; - - pages_read_in_row++; - if (pages_read_in_row >= slice_Wt) { - row_offset += stride_Wt; - pages_read_in_row = 0; - } - } - - noc_async_read_barrier(); - cb_push_back(cb_output_id, num_tiles_to_write_per_packet); - } - num_channels_processed_in_current_batch++; - if (gather_dim == 1 && num_channels_processed_in_current_batch == input_tensor_C) { - output_tile_id_start += - output_tensor_Wt * output_tensor_Ht * (output_tensor_C - input_tensor_C + 1); - } else { - output_tile_id_start += output_tensor_Wt * output_tensor_Ht; + for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { + chunk_count = 0; + while (tiles_read < tiles_to_read) { + if (chunk_count % chunks_per_sync == 0) { + noc_semaphore_wait_min( + reinterpret_cast(out_ready_sem), sem_target + 1); + sem_target++; } + chunk_count++; - if (num_channels_processed_in_current_batch == input_tensor_C) { - num_channels_processed_in_current_batch = 0; - } + uint32_t tiles_remaining_to_read = tiles_to_read - tiles_read; + uint32_t num_tiles_to_read = std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); + + cb_reserve_back(cb_output_id, num_tiles_to_write_per_packet); + size_t l1_write_addr = get_write_ptr(cb_output_id); + for (uint32_t j = 0; j < num_tiles_to_read; ++j) { + uint32_t tile_id = output_tile_id_start + row_offset + pages_read_in_row; + uint64_t noc_read_addr = get_noc_addr(tile_id, output_tensor_addrgen); + noc_async_read(noc_read_addr, l1_write_addr, page_size); + + l1_write_addr += page_size; + tiles_read++; - // Reset for next batch, but respect split slice boundaries - // Use is_split_forwarded_slice (based on forwarded count) to match writer's logic - tiles_read = input_tile_id_start; - tiles_to_read = input_tile_id_end; - pages_read_in_row = start_pages_read_in_row; - row_offset = start_row_offset; - if (is_split_forwarded_slice) { - uint32_t total_tiles = input_tile_id_end - input_tile_id_start; - uint32_t first_half_tiles = total_tiles / 2; - if (direction == 0) { - tiles_read = input_tile_id_start; - tiles_to_read = input_tile_id_start + first_half_tiles; - } else { - tiles_read = input_tile_id_start + first_half_tiles; - tiles_to_read = input_tile_id_end; - // Re-adjust position for second half - uint32_t tiles_to_skip = first_half_tiles; - while (tiles_to_skip > 0) { - if (tiles_to_skip < slice_Wt - pages_read_in_row) { - pages_read_in_row += tiles_to_skip; - tiles_to_skip = 0; - } else { - tiles_to_skip -= (slice_Wt - pages_read_in_row); - row_offset += stride_Wt; - pages_read_in_row = 0; - } - } + pages_read_in_row++; + if (pages_read_in_row >= slice_Wt) { + row_offset += stride_Wt; + pages_read_in_row = 0; } - } else { - tiles_read = input_tile_id_start; - tiles_to_read = input_tile_id_end; } + + noc_async_read_barrier(); + cb_push_back(cb_output_id, num_tiles_to_write_per_packet); } + num_channels_processed_in_current_batch++; + if (gather_dim == 1 && num_channels_processed_in_current_batch == input_tensor_C) { + output_tile_id_start += + output_tensor_Wt * output_tensor_Ht * (output_tensor_C - input_tensor_C + 1); + } else { + output_tile_id_start += output_tensor_Wt * output_tensor_Ht; + } + + if (num_channels_processed_in_current_batch == input_tensor_C) { + num_channels_processed_in_current_batch = 0; + } + + pages_read_in_row = start_pages_read_in_row; + row_offset = start_row_offset; + tiles_read = input_tile_id_start; + tiles_to_read = input_tile_id_end; } - slices_forwarded++; // Track forwarded slices for split-forwarding logic } else { - // Not forwarding - just wait for semaphores indicating data has arrived for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { chunk_count = 0; tiles_read = input_tile_id_start; tiles_to_read = input_tile_id_end; - - // For split slices, each direction only waits for its half's semaphores - if (is_split_slice) { - uint32_t total_tiles = input_tile_id_end - input_tile_id_start; - uint32_t first_half_tiles = total_tiles / 2; - if (direction == 0) { - tiles_to_read = input_tile_id_start + first_half_tiles; - } else { - tiles_read = input_tile_id_start + first_half_tiles; - } - } - while (tiles_read < tiles_to_read) { if (chunk_count % chunks_per_sync == 0) { noc_semaphore_wait_min( @@ -391,6 +275,8 @@ void kernel_main() { uint32_t num_tiles_to_read = std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); tiles_read += num_tiles_to_read; } + tiles_read = input_tile_id_start; + tiles_to_read = input_tile_id_end; } } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_writer.cpp index 9eb31d807e88..f0a4661edb34 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_default_writer.cpp @@ -273,13 +273,6 @@ void kernel_main() { } uint32_t slice_writes = 0; - bool split_forwarding_enabled = false; - if constexpr (topology == Topology::Ring) { - if (ring_size % 2 == 0 && ring_size > 2) { // if ring size is even, we need to write the first half of the - // tiles, otherwise we write the entire packet - split_forwarding_enabled = true; - } - } // Write out the local slice to both DRAM and forward and backward uint32_t pages_read_in_row = start_pages_read_in_row; @@ -334,13 +327,11 @@ void kernel_main() { safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem, 0); uint32_t num_channels_processed_in_current_batch = 0; uint32_t chunk_count = 0; - - for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { - chunk_count = 0; - while (tiles_read < tiles_to_read) { - uint32_t tiles_remaining_to_read = tiles_to_read - tiles_read; - uint32_t tiles_to_put_in_current_packet = - std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); + for (uint32_t bh_idx = 0; bh_idx < input_batch_head_count; bh_idx++) { + chunk_count = 0; + while (tiles_read < tiles_to_read) { + uint32_t tiles_remaining_to_read = tiles_to_read - tiles_read; + uint32_t tiles_to_put_in_current_packet = std::min(tiles_remaining_to_read, num_tiles_to_write_per_packet); cb_wait_front(cb_output_id, num_tiles_to_write_per_packet); size_t l1_read_addr = get_read_ptr(cb_output_id); @@ -384,6 +375,7 @@ void kernel_main() { noc_async_write(l1_read_addr + i * page_size, local_noc_addrs[i], page_size); } noc_async_write_barrier(); + } else { if constexpr (num_targets_forward_direction) { if (tiles_to_put_in_current_packet > 1) { @@ -406,9 +398,9 @@ void kernel_main() { } tiles_read += tiles_to_put_in_current_packet; - noc_async_writes_flushed(); + noc_async_writes_flushed(); - cb_pop_front(cb_output_id, num_tiles_to_write_per_packet); + cb_pop_front(cb_output_id, num_tiles_to_write_per_packet); chunk_count++; if (chunk_count % chunks_per_sync == 0) { @@ -440,16 +432,16 @@ void kernel_main() { tile_id_start += output_tensor_Wt * output_tensor_Ht; } - if (num_channels_processed_in_current_batch == input_tensor_C) { - num_channels_processed_in_current_batch = 0; - } - - tiles_read = input_tile_id_start; - tiles_to_read = input_tile_id_end; - pages_read_in_row = start_pages_read_in_row; - row_offset = start_row_offset; + if (num_channels_processed_in_current_batch == input_tensor_C) { + num_channels_processed_in_current_batch = 0; } + tiles_read = input_tile_id_start; + tiles_to_read = input_tile_id_end; + pages_read_in_row = start_pages_read_in_row; + row_offset = start_row_offset; + } + // increment locally if constexpr (fuse_op) { if (direction == 1) { @@ -471,12 +463,8 @@ void kernel_main() { } else if constexpr (topology == Topology::Ring) { if (direction == 1) { writes_expected = num_targets_backward_direction - 1; - if (split_forwarding_enabled) { - writes_expected++; // Backward worker will also forward 1 slice (but only half of it) - } } else { writes_expected = num_targets_forward_direction - 1; - // For 4-device ring, forward worker will only send half of last slice } } @@ -491,10 +479,6 @@ void kernel_main() { // In the linear case, I expect num_targets_forward_direction slices from the right, and check if I have a // neighbor to the left // In the ring case, I expect to write to the left num_backward_target times - - // Check if this is the last slice for split-forwarding - bool is_last_slice = (slice_writes == writes_expected - 1); - int slice_chip_id; uint32_t actual_slice_chip_id; if (direction == 1) { @@ -514,32 +498,6 @@ void kernel_main() { uint32_t pages_read_in_row = start_pages_read_in_row; uint32_t slice_Wt = input_tensor_Wt; uint32_t stride_Wt = output_tensor_Wt; - - if (split_forwarding_enabled && is_last_slice) { - uint32_t total_tiles = input_tile_id_end - input_tile_id_start; - uint32_t first_half_tiles = total_tiles / 2; - - if (direction == 0) { - // Forward worker: only forward first half - tiles_to_read = input_tile_id_start + first_half_tiles; - } else { - // Backward worker: skip first half, forward second half - tiles_read = input_tile_id_start + first_half_tiles; - - // Adjust starting position for tiles - uint32_t tiles_to_skip = first_half_tiles; - while (tiles_to_skip > 0) { - if (tiles_to_skip < slice_Wt - pages_read_in_row) { - pages_read_in_row += tiles_to_skip; - tiles_to_skip = 0; - } else { - tiles_to_skip -= (slice_Wt - pages_read_in_row); - row_offset += stride_Wt; - pages_read_in_row = 0; - } - } - } - } if constexpr (gather_dim == 3) { tile_id_start = actual_slice_chip_id * input_tensor_Wt; } else if constexpr (gather_dim == 2) {