@@ -223,32 +223,32 @@ def _common_all_gather_object(
223223 object_list [i ] = _tensor_to_object (tensor , tensor_size )
224224
225225
226+ def use_backend (backend : str | None ):
227+ global _BACKEND_INSTANCE
228+
229+ if not backend :
230+ return
231+
232+ mapping = {
233+ "vllm_nccl" : ".nccl.DistributedNccl" ,
234+ "vllm_hccl" : ".hccl.DistributedHccl" ,
235+ }
236+ if backend not in mapping :
237+ raise ValueError (f"Unsupported custom backend: { backend } " )
238+
239+ module_path , class_name = mapping [backend ].rsplit ("." , 1 )
240+ module = importlib .import_module (module_path , "checkpoint_engine.distributed" )
241+ backend_class = getattr (module , class_name )
242+ _BACKEND_INSTANCE = backend_class ()
243+
244+
226245def init_process_group (
227246 host : str ,
228247 port : int ,
229248 rank : int ,
230249 world_size : int ,
231- custom_dist : bool ,
232- backend : str ,
233250 timeout : timedelta = timedelta (seconds = 300 ),
234251):
235- global _BACKEND_INSTANCE
236-
237- if not custom_dist :
238- _BACKEND_INSTANCE = TorchBackend (backend_type = backend )
239- else :
240- mapping = {
241- "nccl" : ".nccl.DistributedNccl" ,
242- "hccl" : ".hccl.DistributedHccl" ,
243- }
244- if backend not in mapping :
245- raise ValueError (f"Unsupported custom backend: { backend } " )
246-
247- module_path , class_name = mapping [backend ].rsplit ("." , 1 )
248- module = importlib .import_module (module_path , "checkpoint_engine.distributed" )
249- backend_class = getattr (module , class_name )
250- _BACKEND_INSTANCE = backend_class ()
251-
252252 _BACKEND_INSTANCE .init_process_group (host , port , rank , world_size , timeout )
253253
254254
0 commit comments