Skip to content

Commit d3ddf1a

Browse files
authored
Add missing configs for gamma3 (#37775)
### Problem description Currently this model is failing on stable branch ### What's changed Add ccl configs that are missing from this model ### Checklist - [x] [(T3K) T3000 demo tests](https://github.com/tenstorrent/tt-metal/actions/runs/21956362359) - CI passes
1 parent f371ab7 commit d3ddf1a

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

models/demos/gemma3/tt/model_config.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,60 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len):
863863
) # TODO: try out 3 for short axis and 4 for long axis (TG only) <- should work but untested in model
864864
self.ccl_dtype = ttnn.bfloat8_b
865865

866+
# model specific CCL configs
867+
default_ln_ag = {"num_links": 1, "chunks_per_sync": 10, "num_workers_per_link": 2}
868+
default_agmm = {"num_links": 1, "chunks_per_sync": 10, "num_workers_per_link": 2}
869+
default_mlp_rs = {
870+
"num_links": self.num_reduce_scatter_links,
871+
"chunks_per_sync": 10,
872+
"num_workers_per_link": 2,
873+
"rs_memory_config": ttnn.DRAM_MEMORY_CONFIG,
874+
}
875+
default_sampling_force_argmax = {
876+
"allow_force_argmax": False,
877+
"num_links": 1,
878+
"chunks_per_sync": 10,
879+
"num_workers_per_link": 2,
880+
"topology": ttnn.Topology.Linear,
881+
}
882+
model_specific_ccl_configs = {
883+
"Llama-3.1-8B": {
884+
"attn_ln_ag": {"num_links": 4, "chunks_per_sync": 10, "num_workers_per_link": 1},
885+
"ffn_ln_ag": {"num_links": 4, "chunks_per_sync": 25, "num_workers_per_link": 1},
886+
"attn_agmm": {"num_links": 4, "chunks_per_sync": 1, "num_workers_per_link": 1},
887+
"mlp_rs": {
888+
"num_links": 4,
889+
"chunks_per_sync": 1,
890+
"num_workers_per_link": 1,
891+
"rs_memory_config": ttnn.L1_MEMORY_CONFIG,
892+
},
893+
"sampling_force_argmax": {
894+
"allow_force_argmax": True,
895+
"num_links": 4,
896+
"chunks_per_sync": 10,
897+
"num_workers_per_link": 2,
898+
"topology": ttnn.Topology.Ring,
899+
},
900+
}
901+
}
902+
# Model-specific CCL configs are tuned for Galaxy (TG) with 4 links
903+
# Only apply them on Galaxy, otherwise use defaults
904+
executed_on_galaxy = ttnn.cluster.get_cluster_type() == ttnn.cluster.ClusterType.GALAXY
905+
if executed_on_galaxy and self.base_model_name in model_specific_ccl_configs:
906+
self.model_config["ATTN_LN_AG_CONFIG"] = model_specific_ccl_configs[self.base_model_name]["attn_ln_ag"]
907+
self.model_config["FFN_LN_AG_CONFIG"] = model_specific_ccl_configs[self.base_model_name]["ffn_ln_ag"]
908+
self.model_config["ATTN_AGMM_CONFIG"] = model_specific_ccl_configs[self.base_model_name]["attn_agmm"]
909+
self.model_config["MLP_RS_CONFIG"] = model_specific_ccl_configs[self.base_model_name]["mlp_rs"]
910+
self.model_config["SAMPLING_AG_CONFIG"] = model_specific_ccl_configs[self.base_model_name][
911+
"sampling_force_argmax"
912+
]
913+
else:
914+
self.model_config["ATTN_LN_AG_CONFIG"] = default_ln_ag
915+
self.model_config["FFN_LN_AG_CONFIG"] = default_ln_ag
916+
self.model_config["ATTN_AGMM_CONFIG"] = default_agmm
917+
self.model_config["MLP_RS_CONFIG"] = default_mlp_rs
918+
self.model_config["SAMPLING_AG_CONFIG"] = default_sampling_force_argmax
919+
866920
logger.info(f"Attention grid: {attn_input_grid}")
867921
logger.info(f"MLP grid: {mlp_core_grid}")
868922
logger.info(f"MLP prefill grids @ 32: w1/w3: {mlp1_3_grid(32)}, w2: {mlp2_grid(32)}")

0 commit comments

Comments
 (0)