@@ -981,11 +981,6 @@ def update(
981981 return
982982 self ._update_per_bucket (checkpoint_name , req_func , ranks )
983983
984- logger .info (
985- f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
986- f"Current device allocated { self .device_manager .device_module .memory_allocated () / 1024 / 1024 } MB, "
987- f"reserved { self .device_manager .device_module .memory_reserved () / 1024 / 1024 } MB."
988- )
989984 except Exception as e :
990985 logger .exception (
991986 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
@@ -996,6 +991,11 @@ def update(
996991 dist .destroy_process_group ()
997992
998993 self .device_manager .device_module .empty_cache ()
994+ logger .info (
995+ f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
996+ f"Current device allocated { self .device_manager .device_module .memory_allocated () / 1024 / 1024 } MB, "
997+ f"reserved { self .device_manager .device_module .memory_reserved () / 1024 / 1024 } MB."
998+ )
999999
10001000 def _bind_zmq_socket (self ) -> tuple [zmq .Socket , list [tuple [str , str ]]]:
10011001 def zmq_handle (device_uuid : str ) -> str :
@@ -1225,7 +1225,7 @@ def _update_per_bucket(
12251225 socket .send_pyobj (handle )
12261226
12271227 gidx = 0
1228- ret_code = torch .tensor ( 0 , device = self .device_manager .device_type )
1228+ ret_code = torch .zeros (() , device = self .device_manager .device_type , dtype = torch . int64 )
12291229 bcast_rank_map = _get_bcast_rank_map (self ._world_size , ranks )
12301230 for i in range (max_len ):
12311231 if i < len (receiver_rank_buckets ) and not disable_h2d_buffer :
@@ -1268,8 +1268,9 @@ def _update_per_bucket(
12681268 self .device_manager .device_module .synchronize ()
12691269 if ret_code .item () != 0 :
12701270 # quit early if any rank failed
1271- socket .send_pyobj (RuntimeError ("Failed to update weights due to remote errors" ))
1272- raise RuntimeError ("Failed to update weights due to remote errors" )
1271+ exception = RuntimeError ("Failed to update weights due to remote errors" )
1272+ socket .send_pyobj (exception )
1273+ raise exception
12731274 socket .send_pyobj (_to_named_tensor (bucket .items , gidx % 2 * bucket_size ))
12741275 gidx += 1
12751276
0 commit comments