@@ -863,6 +863,60 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len):
863863 ) # TODO: try out 3 for short axis and 4 for long axis (TG only) <- should work but untested in model
864864 self .ccl_dtype = ttnn .bfloat8_b
865865
866+ # model specific CCL configs
867+ default_ln_ag = {"num_links" : 1 , "chunks_per_sync" : 10 , "num_workers_per_link" : 2 }
868+ default_agmm = {"num_links" : 1 , "chunks_per_sync" : 10 , "num_workers_per_link" : 2 }
869+ default_mlp_rs = {
870+ "num_links" : self .num_reduce_scatter_links ,
871+ "chunks_per_sync" : 10 ,
872+ "num_workers_per_link" : 2 ,
873+ "rs_memory_config" : ttnn .DRAM_MEMORY_CONFIG ,
874+ }
875+ default_sampling_force_argmax = {
876+ "allow_force_argmax" : False ,
877+ "num_links" : 1 ,
878+ "chunks_per_sync" : 10 ,
879+ "num_workers_per_link" : 2 ,
880+ "topology" : ttnn .Topology .Linear ,
881+ }
882+ model_specific_ccl_configs = {
883+ "Llama-3.1-8B" : {
884+ "attn_ln_ag" : {"num_links" : 4 , "chunks_per_sync" : 10 , "num_workers_per_link" : 1 },
885+ "ffn_ln_ag" : {"num_links" : 4 , "chunks_per_sync" : 25 , "num_workers_per_link" : 1 },
886+ "attn_agmm" : {"num_links" : 4 , "chunks_per_sync" : 1 , "num_workers_per_link" : 1 },
887+ "mlp_rs" : {
888+ "num_links" : 4 ,
889+ "chunks_per_sync" : 1 ,
890+ "num_workers_per_link" : 1 ,
891+ "rs_memory_config" : ttnn .L1_MEMORY_CONFIG ,
892+ },
893+ "sampling_force_argmax" : {
894+ "allow_force_argmax" : True ,
895+ "num_links" : 4 ,
896+ "chunks_per_sync" : 10 ,
897+ "num_workers_per_link" : 2 ,
898+ "topology" : ttnn .Topology .Ring ,
899+ },
900+ }
901+ }
902+ # Model-specific CCL configs are tuned for Galaxy (TG) with 4 links
903+ # Only apply them on Galaxy, otherwise use defaults
904+ executed_on_galaxy = ttnn .cluster .get_cluster_type () == ttnn .cluster .ClusterType .GALAXY
905+ if executed_on_galaxy and self .base_model_name in model_specific_ccl_configs :
906+ self .model_config ["ATTN_LN_AG_CONFIG" ] = model_specific_ccl_configs [self .base_model_name ]["attn_ln_ag" ]
907+ self .model_config ["FFN_LN_AG_CONFIG" ] = model_specific_ccl_configs [self .base_model_name ]["ffn_ln_ag" ]
908+ self .model_config ["ATTN_AGMM_CONFIG" ] = model_specific_ccl_configs [self .base_model_name ]["attn_agmm" ]
909+ self .model_config ["MLP_RS_CONFIG" ] = model_specific_ccl_configs [self .base_model_name ]["mlp_rs" ]
910+ self .model_config ["SAMPLING_AG_CONFIG" ] = model_specific_ccl_configs [self .base_model_name ][
911+ "sampling_force_argmax"
912+ ]
913+ else :
914+ self .model_config ["ATTN_LN_AG_CONFIG" ] = default_ln_ag
915+ self .model_config ["FFN_LN_AG_CONFIG" ] = default_ln_ag
916+ self .model_config ["ATTN_AGMM_CONFIG" ] = default_agmm
917+ self .model_config ["MLP_RS_CONFIG" ] = default_mlp_rs
918+ self .model_config ["SAMPLING_AG_CONFIG" ] = default_sampling_force_argmax
919+
866920 logger .info (f"Attention grid: { attn_input_grid } " )
867921 logger .info (f"MLP grid: { mlp_core_grid } " )
868922 logger .info (f"MLP prefill grids @ 32: w1/w3: { mlp1_3_grid (32 )} , w2: { mlp2_grid (32 )} " )
0 commit comments