Skip to content

Commit 1bfca4a

Browse files
committed
misc
1 parent 5a26fbf commit 1bfca4a

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

checkpoint_engine/ps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ def update(
10971097
timeout=timeout,
10981098
is_master=self._rank == 0,
10991099
)
1100-
# if both ranks is None or [], it will use fully broadcast to update to all ranks
1100+
# if ranks is None or [], it will use fully broadcast to update to all ranks
11011101
ranks_group = dist.new_group(ranks if ranks else None)
11021102
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
11031103
self.store_based_barrier(manager_store)

tests/test_update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
8282
try:
8383
trigger_error(socket_paths)
8484
except RuntimeError as e:
85-
assert str(e) == "Failed to update weights due to remote errors"
85+
assert str(e) == "Some workers failed to update weights"
8686

8787

8888
def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
@@ -177,7 +177,7 @@ def run(
177177
],
178178
),
179179
("test_with_remote_error", [[]]),
180-
# ("long_test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
180+
("test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
181181
],
182182
)
183183
def test_update(test_name: str, rank_list: list[list[int]] | None):

0 commit comments

Comments
 (0)