Skip to content

Commit a7b900c

Browse files
specture724Copilot
andcommitted
misc: fix pr issues
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: specture724 <149605198+specture724@users.noreply.github.com>
1 parent b323615 commit a7b900c

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

checkpoint_engine/ps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1265,11 +1265,12 @@ def _update_per_bucket(
12651265
brank = bcast_rank_map[receiver_rank]
12661266
dist.broadcast(buffer_b, src=brank)
12671267
resp = socket.recv()
1268+
ret_code = torch.tensor(0, device="cuda")
12681269
if resp != b"":
12691270
logger.error(
12701271
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
12711272
)
1272-
ret_code = torch.tensor(1, device="cuda")
1273+
ret_code.fill_(1)
12731274
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
12741275
torch.cuda.synchronize()
12751276
if ret_code.item() != 0:

tests/test_update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def checker_proc_with_error(
4646
):
4747
torch.cuda.set_device(rank)
4848
named_tensors = {name: tensor.cuda() for name, tensor in named_tensors.items()}
49+
_ = named_tensors
4950
_zmq_ctx = zmq.Context()
5051

5152
def trigger_error(socket_paths: list[tuple[str, str]]):
@@ -59,7 +60,7 @@ def trigger_error(socket_paths: list[tuple[str, str]]):
5960
)
6061

6162
def error_run(weights: list[tuple[str, torch.Tensor]]):
62-
weights = weights # Do some fake processing
63+
_ = weights # Do some fake processing
6364
time.sleep(random.uniform(0.1, 0.5))
6465
if rank == 0:
6566
raise RuntimeError("Intentional Error for testing.")
@@ -207,7 +208,6 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
207208
sys.exit(1)
208209
assert len(sys.argv) > 2
209210
test_type = sys.argv[1]
210-
world_size = get_world_size()
211211
rank_list = json.loads(sys.argv[2])
212212
if test_type == "test_no_error" or test_type == "long_test_no_error":
213213
run(checker_proc, rank_list, need_error=False)

0 commit comments

Comments
 (0)