@@ -959,8 +959,9 @@ ReduceScatterProgramArtifacts build_ring_reduce_scatter_minimal_async_program_ar
959959 src_node_id, dst_node_id, link, program, {mux_logical_core});
960960 }
961961 tt::tt_metal::SetRuntimeArgs (program, mux_kernel_id, {mux_logical_core}, mux_rt_args);
962+
963+ termination_master_logical_core = *((termination_master_core_iter++)->begin ());
962964 }
963- termination_master_logical_core = *((termination_master_core_iter++)->begin ());
964965
965966 for (uint32_t worker = 0 ; worker < num_workers_per_direction; worker++) {
966967 auto core = *((worker_core_iter++)->begin ());
@@ -1046,41 +1047,42 @@ ReduceScatterProgramArtifacts build_ring_reduce_scatter_minimal_async_program_ar
10461047 termination_master_virtual_core,
10471048 program,
10481049 writer_rt_args);
1049- if (intermediate_is_sharded) {
1050- shard_builder::extend_sharding_run_time_args (intermediate_tensor, writer_rt_args);
1051- }
1052- if (output_is_sharded) {
1053- shard_builder::extend_sharding_run_time_args (output_tensor, writer_rt_args);
1054- }
1055- if (!num_mux_cores_per_direction_per_link) {
1056- if (dir ) { // forward
1057- writer_rt_args. push_back (forward_coord. has_value ());
1058- if (forward_coord.has_value ()) {
1059- const auto src_fabric_node_id = mesh_device-> get_fabric_node_id (sender_device_coord);
1060- const auto dst_fabric_node_id = mesh_device->get_fabric_node_id (forward_coord. value () );
1061- tt::tt_fabric::append_fabric_connection_rt_args (
1062- src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
1063- }
1064- writer_rt_args. push_back ( false );
1065- } else {
1066- writer_rt_args. push_back ( false );
1067- writer_rt_args.push_back (backward_coord. has_value () );
1068- if (backward_coord.has_value ()) {
1069- const auto src_fabric_node_id = mesh_device-> get_fabric_node_id (sender_device_coord);
1070- const auto dst_fabric_node_id = mesh_device->get_fabric_node_id (backward_coord. value () );
1071- tt::tt_fabric::append_fabric_connection_rt_args (
1072- src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
1073- }
1050+ }
1051+ if (intermediate_is_sharded) {
1052+ shard_builder::extend_sharding_run_time_args (intermediate_tensor, writer_rt_args);
1053+ }
1054+ if (output_is_sharded) {
1055+ shard_builder::extend_sharding_run_time_args (output_tensor, writer_rt_args);
1056+ }
1057+ if (!num_mux_cores_per_direction_per_link ) {
1058+ if (dir) { // forward
1059+ writer_rt_args. push_back (forward_coord.has_value ());
1060+ if (forward_coord. has_value ()) {
1061+ const auto src_fabric_node_id = mesh_device->get_fabric_node_id (sender_device_coord );
1062+ const auto dst_fabric_node_id = mesh_device-> get_fabric_node_id (forward_coord. value ());
1063+ tt::tt_fabric::append_fabric_connection_rt_args (
1064+ src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
1065+ }
1066+ writer_rt_args. push_back ( false );
1067+ } else {
1068+ writer_rt_args.push_back (false );
1069+ writer_rt_args. push_back (backward_coord.has_value ());
1070+ if (backward_coord. has_value ()) {
1071+ const auto src_fabric_node_id = mesh_device->get_fabric_node_id (sender_device_coord );
1072+ const auto dst_fabric_node_id = mesh_device-> get_fabric_node_id (backward_coord. value ());
1073+ tt::tt_fabric::append_fabric_connection_rt_args (
1074+ src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
10741075 }
10751076 }
1076- tt::tt_metal::SetRuntimeArgs (program, writer_kernel_id, {core}, writer_rt_args);
1077-
1078- std::vector<uint32_t > reduce_rt_args = {
1079- start_tiles_read, // start_tiles_read
1080- start_tiles_to_read, // start_tiles_to_read
1081- dir}; // dir
1082- tt::tt_metal::SetRuntimeArgs (program, sender_reduce_kernel_id, {core}, reduce_rt_args);
10831077 }
1078+ tt::tt_metal::SetRuntimeArgs (program, writer_kernel_id, {core}, writer_rt_args);
1079+
1080+ std::vector<uint32_t > reduce_rt_args = {
1081+ start_tiles_read, // start_tiles_read
1082+ start_tiles_to_read, // start_tiles_to_read
1083+ dir}; // dir
1084+ tt::tt_metal::SetRuntimeArgs (program, sender_reduce_kernel_id, {core}, reduce_rt_args);
1085+ }
10841086 }
10851087 }
10861088
0 commit comments