Skip to content

Commit 986425f

Browse files
committed
Fix reduce_scatter_minimial_async merge conflict bugs
1 parent a0b141a commit 986425f

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

models/tt_transformers/tt/generator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def prefill_forward_text(
412412
local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx]
413413
if "image_grid_thw" in local_kwargs:
414414
local_kwargs["image_grid_thw"] = local_kwargs["image_grid_thw"][idx]
415-
local_kwargs["user_id"] = user_id
416415

417416
if sampling_enabled:
418417
sampling_executed = True
@@ -511,11 +510,11 @@ def prefill_forward_text(
511510
last_token_idx_relative = last_token_idx - num_cached_tokens
512511
ttnn.synchronize_device(self.model[model_id].mesh_device)
513512

514-
if res["hidden_states"] is not None:
513+
if "hidden_states" in res:
515514
output_tensor[idx] = self.model[model_id].process_output_prefill_hidden_states(
516515
res["hidden_states"], last_token_idx=(last_token_idx_relative % 32)
517516
)
518-
elif res["logits"] is not None:
517+
elif res["sampling"]:
519518
tt_tokens = res["logits"][0]
520519
tt_log_probs = res["logits"][1]
521520
tokens_host = ttnn.to_torch(ttnn.get_device_tensors(tt_tokens)[0]).reshape(-1)[

models/tt_transformers/tt/model_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,10 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len):
13381338
self.set_tg_attention_config()
13391339

13401340
self.is_multichip = self.num_devices > 1
1341+
self.num_reduce_scatter_links = 1
1342+
self.num_all_gather_links = (
1343+
2 if self.is_galaxy else 1
1344+
) # TODO: try out 3 for short axis and 4 for long axis (TG only) <- should work but untested in model
13411345
self.ccl_dtype = ttnn.bfloat8_b
13421346

13431347
# model specific CCL configs

ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_minimal_async/device/reduce_scatter_minimal_async_program.cpp

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -959,8 +959,9 @@ ReduceScatterProgramArtifacts build_ring_reduce_scatter_minimal_async_program_ar
959959
src_node_id, dst_node_id, link, program, {mux_logical_core});
960960
}
961961
tt::tt_metal::SetRuntimeArgs(program, mux_kernel_id, {mux_logical_core}, mux_rt_args);
962+
963+
termination_master_logical_core = *((termination_master_core_iter++)->begin());
962964
}
963-
termination_master_logical_core = *((termination_master_core_iter++)->begin());
964965

965966
for (uint32_t worker = 0; worker < num_workers_per_direction; worker++) {
966967
auto core = *((worker_core_iter++)->begin());
@@ -1046,41 +1047,42 @@ ReduceScatterProgramArtifacts build_ring_reduce_scatter_minimal_async_program_ar
10461047
termination_master_virtual_core,
10471048
program,
10481049
writer_rt_args);
1049-
if (intermediate_is_sharded) {
1050-
shard_builder::extend_sharding_run_time_args(intermediate_tensor, writer_rt_args);
1051-
}
1052-
if (output_is_sharded) {
1053-
shard_builder::extend_sharding_run_time_args(output_tensor, writer_rt_args);
1054-
}
1055-
if (!num_mux_cores_per_direction_per_link) {
1056-
if (dir) { // forward
1057-
writer_rt_args.push_back(forward_coord.has_value());
1058-
if (forward_coord.has_value()) {
1059-
const auto src_fabric_node_id = mesh_device->get_fabric_node_id(sender_device_coord);
1060-
const auto dst_fabric_node_id = mesh_device->get_fabric_node_id(forward_coord.value());
1061-
tt::tt_fabric::append_fabric_connection_rt_args(
1062-
src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
1063-
}
1064-
writer_rt_args.push_back(false);
1065-
} else {
1066-
writer_rt_args.push_back(false);
1067-
writer_rt_args.push_back(backward_coord.has_value());
1068-
if (backward_coord.has_value()) {
1069-
const auto src_fabric_node_id = mesh_device->get_fabric_node_id(sender_device_coord);
1070-
const auto dst_fabric_node_id = mesh_device->get_fabric_node_id(backward_coord.value());
1071-
tt::tt_fabric::append_fabric_connection_rt_args(
1072-
src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
1073-
}
1050+
}
1051+
if (intermediate_is_sharded) {
1052+
shard_builder::extend_sharding_run_time_args(intermediate_tensor, writer_rt_args);
1053+
}
1054+
if (output_is_sharded) {
1055+
shard_builder::extend_sharding_run_time_args(output_tensor, writer_rt_args);
1056+
}
1057+
if (!num_mux_cores_per_direction_per_link) {
1058+
if (dir) { // forward
1059+
writer_rt_args.push_back(forward_coord.has_value());
1060+
if (forward_coord.has_value()) {
1061+
const auto src_fabric_node_id = mesh_device->get_fabric_node_id(sender_device_coord);
1062+
const auto dst_fabric_node_id = mesh_device->get_fabric_node_id(forward_coord.value());
1063+
tt::tt_fabric::append_fabric_connection_rt_args(
1064+
src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
1065+
}
1066+
writer_rt_args.push_back(false);
1067+
} else {
1068+
writer_rt_args.push_back(false);
1069+
writer_rt_args.push_back(backward_coord.has_value());
1070+
if (backward_coord.has_value()) {
1071+
const auto src_fabric_node_id = mesh_device->get_fabric_node_id(sender_device_coord);
1072+
const auto dst_fabric_node_id = mesh_device->get_fabric_node_id(backward_coord.value());
1073+
tt::tt_fabric::append_fabric_connection_rt_args(
1074+
src_fabric_node_id, dst_fabric_node_id, link, program, {core}, writer_rt_args);
10741075
}
10751076
}
1076-
tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, {core}, writer_rt_args);
1077-
1078-
std::vector<uint32_t> reduce_rt_args = {
1079-
start_tiles_read, // start_tiles_read
1080-
start_tiles_to_read, // start_tiles_to_read
1081-
dir}; // dir
1082-
tt::tt_metal::SetRuntimeArgs(program, sender_reduce_kernel_id, {core}, reduce_rt_args);
10831077
}
1078+
tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, {core}, writer_rt_args);
1079+
1080+
std::vector<uint32_t> reduce_rt_args = {
1081+
start_tiles_read, // start_tiles_read
1082+
start_tiles_to_read, // start_tiles_to_read
1083+
dir}; // dir
1084+
tt::tt_metal::SetRuntimeArgs(program, sender_reduce_kernel_id, {core}, reduce_rt_args);
1085+
}
10841086
}
10851087
}
10861088

0 commit comments

Comments
 (0)