2020
2121from .slurm_install_strategy import NeMoLauncherSlurmInstallStrategy
2222
23- REQUIRE_ENV_VARS = [
24- "NCCL_SOCKET_IFNAME" ,
25- "NCCL_IB_GID_INDEX" ,
26- "NCCL_IB_TC" ,
27- "NCCL_IB_QPS_PER_CONNECTION" ,
28- "UCX_IB_GID_INDEX" ,
29- "NCCL_IB_ADAPTIVE_ROUTING" ,
30- "NCCL_IB_SPLIT_DATA_ON_QPS" ,
31- "NCCL_IBEXT_DISABLE" ,
32- ]
33-
3423
3524class NeMoLauncherSlurmCommandGenStrategy (SlurmCommandGenStrategy ):
3625 """
@@ -50,10 +39,8 @@ def gen_exec_command(
5039 num_nodes : int ,
5140 nodes : List [str ],
5241 ) -> str :
53- # Ensure required environment variables are included
54- for key in REQUIRE_ENV_VARS :
55- if key not in extra_env_vars :
56- extra_env_vars [key ] = self .slurm_system .global_env_vars [key ]
42+ final_env_vars = self ._override_env_vars (self .default_env_vars , env_vars )
43+ final_env_vars = self ._override_env_vars (final_env_vars , extra_env_vars )
5744
5845 launcher_path = os .path .join (
5946 self .install_path ,
@@ -67,7 +54,7 @@ def gen_exec_command(
6754 self .final_cmd_args ["base_results_dir" ] = output_path
6855 self .final_cmd_args ["training.model.data.index_mapping_dir" ] = output_path
6956 self .final_cmd_args ["launcher_scripts_path" ] = os .path .join (launcher_path , "launcher_scripts" )
70- for key , value in extra_env_vars .items ():
57+ for key , value in final_env_vars .items ():
7158 self .final_cmd_args [f"env_vars.{ key } " ] = value
7259 self .final_cmd_args ["cluster.partition" ] = self .slurm_system .default_partition
7360 nodes = self .slurm_system .parse_nodes (nodes )
@@ -96,7 +83,7 @@ def gen_exec_command(
9683 tokenizer_path = extra_cmd_args .split ("training.model.tokenizer.model=" )[1 ].split (" " )[0 ]
9784 full_cmd += " " + f"container_mounts=[{ tokenizer_path } :{ tokenizer_path } ]"
9885
99- env_vars_str = " " .join (f"{ key } ={ value } " for key , value in extra_env_vars .items ())
86+ env_vars_str = " " .join (f"{ key } ={ value } " for key , value in final_env_vars .items ())
10087 full_cmd = f"{ env_vars_str } { full_cmd } " if env_vars_str else full_cmd
10188
10289 return full_cmd .strip ()
@@ -130,13 +117,19 @@ def _generate_cmd_args_str(self, args: Dict[str, str], nodes: List[str]) -> str:
130117 Returns:
131118 str: A string of command-line arguments.
132119 """
133- arg_str_parts = []
120+ cmd_arg_str_parts = []
121+ env_var_str_parts = []
122+
134123 for key , value in args .items ():
135- formatted_key = f"+{ key } " if key .startswith ("env_vars." ) else key
136- arg_str_parts .append (f"{ formatted_key } ={ value } " )
124+ if key .startswith ("env_vars." ):
125+ if isinstance (value , str ) and "," in value :
126+ value = f"\\ '{ value } \\ '"
127+ env_var_str_parts .append (f"+{ key } ={ value } " )
128+ else :
129+ cmd_arg_str_parts .append (f"{ key } ={ value } " )
137130
138131 if nodes :
139132 nodes_str = "," .join (nodes )
140- arg_str_parts .append (f"+cluster.nodelist=\\ '{ nodes_str } \\ '" )
133+ cmd_arg_str_parts .append (f"+cluster.nodelist=\\ '{ nodes_str } \\ '" )
141134
142- return " " .join (arg_str_parts )
135+ return " " .join (cmd_arg_str_parts + env_var_str_parts )
0 commit comments