Skip to content

Commit

Permalink
CheckpointServer: start in disallowed state + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jan 30, 2025
1 parent 866873a commit 086de9b
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
30 changes: 28 additions & 2 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import threading
import urllib.request
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import timedelta
from http.server import BaseHTTPRequestHandler
from typing import Generic, List, Optional, TypeVar
from typing import Generator, Generic, List, Optional, TypeVar

import torch

Expand Down Expand Up @@ -87,6 +88,25 @@ def shutdown(self, wait: bool = True) -> None:
"""


@contextmanager
def _timed_acquire(
lock: threading.Lock, timeout: timedelta
) -> Generator[None, None, None]:
"""
Acquire a lock with a timeout.
Args:
lock: the lock to acquire
timeout: the timeout to acquire the lock
"""
if not lock.acquire(timeout=timeout.total_seconds()):
raise TimeoutError(f"timed out acquiring lock after {timeout}")
try:
yield
finally:
lock.release()


class CheckpointServer(CheckpointTransport[T]):
"""
This is an HTTP server that can be used to transfer checkpoints
Expand All @@ -106,6 +126,10 @@ def __init__(self, timeout: timedelta) -> None:
self._timeout = timeout
self._state_dict: Optional[T] = None

# We don't allow checkpoints until the first send_checkpoint to avoid
# serving the default step=-1 invalid checkpoint.
self.disallow_checkpoint()

ckpt_server = self

class RequestHandler(BaseHTTPRequestHandler):
Expand All @@ -117,7 +141,9 @@ def do_GET(self):
# validate socket timeout is actually set
assert self.connection.gettimeout() == self.timeout

with ckpt_server._checkpoint_lock:
with _timed_acquire(
ckpt_server._checkpoint_lock, ckpt_server._timeout
):
step = ckpt_server._step

if self.path != f"/checkpoint/{step}":
Expand Down
50 changes: 49 additions & 1 deletion torchft/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
import urllib.error
from datetime import timedelta
from unittest import TestCase
from unittest.mock import MagicMock

from torchft.checkpointing import CheckpointServer
from torchft.checkpointing import CheckpointServer, _timed_acquire


class TestCheckpointing(TestCase):
Expand Down Expand Up @@ -55,3 +56,50 @@ def test_checkpoint_server(self) -> None:
)

server.shutdown()

def test_checkpoint_server_locking(self) -> None:
server = CheckpointServer(
timeout=timedelta(seconds=10),
)

# server should start up in a disallowed state this will block incoming
# requests until allow_checkpoint is called
self.assertTrue(server._checkpoint_lock.locked())
self.assertTrue(server._disallowed)
self.assertEqual(server._step, -1)

# allow requests
server.allow_checkpoint(1)

self.assertFalse(server._checkpoint_lock.locked())
self.assertFalse(server._disallowed)
self.assertEqual(server._step, 1)

# duplicate allow/disallow is fine
server.allow_checkpoint(2)
self.assertEqual(server._step, 2)

server.disallow_checkpoint()
server.disallow_checkpoint()
self.assertTrue(server._checkpoint_lock.locked())
self.assertTrue(server._disallowed)

server.shutdown()

def test_timed_acquire(self) -> None:
lock = threading.Lock()

with _timed_acquire(lock, timedelta(seconds=10)):
self.assertTrue(lock.locked())

self.assertFalse(lock.locked())

lock.acquire()

with self.assertRaisesRegex(
TimeoutError, r"timed out acquiring lock after 0.0"
):
with _timed_acquire(lock, timedelta(seconds=0.0)):
pass

self.assertTrue(lock.locked())

0 comments on commit 086de9b

Please sign in to comment.