File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -778,7 +778,9 @@ def _update_per_bucket(
778778 if p2p_update :
779779 # p2p store need to register buffer to let other ranks read
780780 p2p_ipc_buffer_name = "__ipc_buffer__"
781- self ._p2p_store .register_named_tensors ({p2p_ipc_buffer_name : buffer if disable_h2d_buffer else h2d_buffer })
781+ self ._p2p_store .register_named_tensors (
782+ {p2p_ipc_buffer_name : buffer if disable_h2d_buffer else h2d_buffer }
783+ )
782784 handle = reduce_tensor (buffer )
783785
784786 buckets_by_receiver_rank : dict [int , list [H2DBucket ]] = defaultdict (list )
@@ -826,7 +828,12 @@ def _update_per_bucket(
826828 if disable_h2d_buffer :
827829 if p2p_update :
828830 assert bucket == receiver_rank_buckets [i ][1 ]
829- self ._copy_to_buffer (checkpoint_name , bucket , buffer_b , receiver_rank_buckets [i ][0 ] if p2p_update else None )
831+ self ._copy_to_buffer (
832+ checkpoint_name ,
833+ bucket ,
834+ buffer_b ,
835+ receiver_rank_buckets [i ][0 ] if p2p_update else None ,
836+ )
830837 else :
831838 buffer_b .data .copy_ (h2d_buffer [: bucket .size ])
832839 dist .broadcast (buffer_b , src = receiver_rank , group = ranks_group )
You can’t perform that action at this time.
0 commit comments