Skip to content

Commit 086de9b

Browse files
committed
CheckpointServer: start in disallowed state + tests
1 parent 866873a commit 086de9b

File tree

2 files changed

+77
-3
lines changed

2 files changed

+77
-3
lines changed

torchft/checkpointing.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import threading
1818
import urllib.request
1919
from abc import ABC, abstractmethod
20+
from contextlib import contextmanager
2021
from datetime import timedelta
2122
from http.server import BaseHTTPRequestHandler
22-
from typing import Generic, List, Optional, TypeVar
23+
from typing import Generator, Generic, List, Optional, TypeVar
2324

2425
import torch
2526

@@ -87,6 +88,25 @@ def shutdown(self, wait: bool = True) -> None:
8788
"""
8889

8990

91+
@contextmanager
92+
def _timed_acquire(
93+
lock: threading.Lock, timeout: timedelta
94+
) -> Generator[None, None, None]:
95+
"""
96+
Acquire a lock with a timeout.
97+
98+
Args:
99+
lock: the lock to acquire
100+
timeout: the timeout to acquire the lock
101+
"""
102+
if not lock.acquire(timeout=timeout.total_seconds()):
103+
raise TimeoutError(f"timed out acquiring lock after {timeout}")
104+
try:
105+
yield
106+
finally:
107+
lock.release()
108+
109+
90110
class CheckpointServer(CheckpointTransport[T]):
91111
"""
92112
This is an HTTP server that can be used to transfer checkpoints
@@ -106,6 +126,10 @@ def __init__(self, timeout: timedelta) -> None:
106126
self._timeout = timeout
107127
self._state_dict: Optional[T] = None
108128

129+
# We don't allow checkpoints until the first send_checkpoint to avoid
130+
# serving the default step=-1 invalid checkpoint.
131+
self.disallow_checkpoint()
132+
109133
ckpt_server = self
110134

111135
class RequestHandler(BaseHTTPRequestHandler):
@@ -117,7 +141,9 @@ def do_GET(self):
117141
# validate socket timeout is actually set
118142
assert self.connection.gettimeout() == self.timeout
119143

120-
with ckpt_server._checkpoint_lock:
144+
with _timed_acquire(
145+
ckpt_server._checkpoint_lock, ckpt_server._timeout
146+
):
121147
step = ckpt_server._step
122148

123149
if self.path != f"/checkpoint/{step}":

torchft/checkpointing_test.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import threading
78
import urllib.error
89
from datetime import timedelta
910
from unittest import TestCase
1011
from unittest.mock import MagicMock
1112

12-
from torchft.checkpointing import CheckpointServer
13+
from torchft.checkpointing import CheckpointServer, _timed_acquire
1314

1415

1516
class TestCheckpointing(TestCase):
@@ -55,3 +56,50 @@ def test_checkpoint_server(self) -> None:
5556
)
5657

5758
server.shutdown()
59+
60+
def test_checkpoint_server_locking(self) -> None:
61+
server = CheckpointServer(
62+
timeout=timedelta(seconds=10),
63+
)
64+
65+
# server should start up in a disallowed state this will block incoming
66+
# requests until allow_checkpoint is called
67+
self.assertTrue(server._checkpoint_lock.locked())
68+
self.assertTrue(server._disallowed)
69+
self.assertEqual(server._step, -1)
70+
71+
# allow requests
72+
server.allow_checkpoint(1)
73+
74+
self.assertFalse(server._checkpoint_lock.locked())
75+
self.assertFalse(server._disallowed)
76+
self.assertEqual(server._step, 1)
77+
78+
# duplicate allow/disallow is fine
79+
server.allow_checkpoint(2)
80+
self.assertEqual(server._step, 2)
81+
82+
server.disallow_checkpoint()
83+
server.disallow_checkpoint()
84+
self.assertTrue(server._checkpoint_lock.locked())
85+
self.assertTrue(server._disallowed)
86+
87+
server.shutdown()
88+
89+
def test_timed_acquire(self) -> None:
90+
lock = threading.Lock()
91+
92+
with _timed_acquire(lock, timedelta(seconds=10)):
93+
self.assertTrue(lock.locked())
94+
95+
self.assertFalse(lock.locked())
96+
97+
lock.acquire()
98+
99+
with self.assertRaisesRegex(
100+
TimeoutError, r"timed out acquiring lock after 0.0"
101+
):
102+
with _timed_acquire(lock, timedelta(seconds=0.0)):
103+
pass
104+
105+
self.assertTrue(lock.locked())

0 commit comments

Comments
 (0)