@@ -1278,7 +1278,7 @@ def update(
12781278 is_master = self ._rank == 0 ,
12791279 )
12801280 # if ranks is None or [], it will use fully broadcast to update to all ranks
1281- ranks_group = dist .new_group (ranks if ranks else None )
1281+ ranks_group = dist .new_group (ranks ) if ranks else None
12821282 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
12831283 self .store_based_barrier (manager_store )
12841284 except Exception as e :
@@ -1309,7 +1309,7 @@ def zmq_handle(device_uuid: str) -> str:
13091309 return socket , socket_paths
13101310
13111311 def _detect_bucket_size (
1312- self , ranks_group : dist .ProcessGroup , * , disable_h2d_buffer : bool = False
1312+ self , ranks_group : dist .ProcessGroup | None , * , disable_h2d_buffer : bool = False
13131313 ) -> tuple [int , bool ]:
13141314 GiB = 1 << 30 # noqa: N806
13151315 # auto detect bucket size
@@ -1428,7 +1428,7 @@ def _update_per_bucket(
14281428 self ,
14291429 checkpoint_name : str ,
14301430 req_func : Callable [[list [tuple [str , str ]]], None ],
1431- ranks_group : dist .ProcessGroup ,
1431+ ranks_group : dist .ProcessGroup | None ,
14321432 ranks : list [int ] | None = None ,
14331433 ):
14341434 assert len (self ._current_global_parameter_metas ) != 0 , "parameter metas is empty"
0 commit comments