Skip to content

Commit 88370e2

Browse files
authored
fix: make ranks_group None when ranks is None in PS engine (#70)
1 parent 4350e70 commit 88370e2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

checkpoint_engine/ps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def update(
12781278
is_master=self._rank == 0,
12791279
)
12801280
# if ranks is None or [], it will use fully broadcast to update to all ranks
1281-
ranks_group = dist.new_group(ranks if ranks else None)
1281+
ranks_group = dist.new_group(ranks) if ranks else None
12821282
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
12831283
self.store_based_barrier(manager_store)
12841284
except Exception as e:
@@ -1309,7 +1309,7 @@ def zmq_handle(device_uuid: str) -> str:
13091309
return socket, socket_paths
13101310

13111311
def _detect_bucket_size(
1312-
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1312+
self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
13131313
) -> tuple[int, bool]:
13141314
GiB = 1 << 30 # noqa: N806
13151315
# auto detect bucket size
@@ -1428,7 +1428,7 @@ def _update_per_bucket(
14281428
self,
14291429
checkpoint_name: str,
14301430
req_func: Callable[[list[tuple[str, str]]], None],
1431-
ranks_group: dist.ProcessGroup,
1431+
ranks_group: dist.ProcessGroup | None,
14321432
ranks: list[int] | None = None,
14331433
):
14341434
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"

0 commit comments

Comments
 (0)