Skip to content

Commit d21d345

Browse files
committed
fix format error
1 parent cca085c commit d21d345

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

checkpoint_engine/ps.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,9 @@ def _update_per_bucket(
778778
if p2p_update:
779779
# p2p store need to register buffer to let other ranks read
780780
p2p_ipc_buffer_name = "__ipc_buffer__"
781-
self._p2p_store.register_named_tensors({p2p_ipc_buffer_name: buffer if disable_h2d_buffer else h2d_buffer})
781+
self._p2p_store.register_named_tensors(
782+
{p2p_ipc_buffer_name: buffer if disable_h2d_buffer else h2d_buffer}
783+
)
782784
handle = reduce_tensor(buffer)
783785

784786
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
@@ -826,7 +828,12 @@ def _update_per_bucket(
826828
if disable_h2d_buffer:
827829
if p2p_update:
828830
assert bucket == receiver_rank_buckets[i][1]
829-
self._copy_to_buffer(checkpoint_name, bucket, buffer_b, receiver_rank_buckets[i][0] if p2p_update else None)
831+
self._copy_to_buffer(
832+
checkpoint_name,
833+
bucket,
834+
buffer_b,
835+
receiver_rank_buckets[i][0] if p2p_update else None,
836+
)
830837
else:
831838
buffer_b.data.copy_(h2d_buffer[: bucket.size])
832839
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)

0 commit comments

Comments
 (0)