|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import threading |
7 | 8 | import urllib.error
|
8 | 9 | from datetime import timedelta
|
9 | 10 | from unittest import TestCase
|
10 | 11 | from unittest.mock import MagicMock
|
11 | 12 |
|
12 |
| -from torchft.checkpointing import CheckpointServer |
| 13 | +from torchft.checkpointing import CheckpointServer, _timed_acquire |
13 | 14 |
|
14 | 15 |
|
15 | 16 | class TestCheckpointing(TestCase):
|
@@ -55,3 +56,50 @@ def test_checkpoint_server(self) -> None:
|
55 | 56 | )
|
56 | 57 |
|
57 | 58 | server.shutdown()
|
| 59 | + |
| 60 | + def test_checkpoint_server_locking(self) -> None: |
| 61 | + server = CheckpointServer( |
| 62 | + timeout=timedelta(seconds=10), |
| 63 | + ) |
| 64 | + |
| 65 | + # server should start up in a disallowed state this will block incoming |
| 66 | + # requests until allow_checkpoint is called |
| 67 | + self.assertTrue(server._checkpoint_lock.locked()) |
| 68 | + self.assertTrue(server._disallowed) |
| 69 | + self.assertEqual(server._step, -1) |
| 70 | + |
| 71 | + # allow requests |
| 72 | + server.allow_checkpoint(1) |
| 73 | + |
| 74 | + self.assertFalse(server._checkpoint_lock.locked()) |
| 75 | + self.assertFalse(server._disallowed) |
| 76 | + self.assertEqual(server._step, 1) |
| 77 | + |
| 78 | + # duplicate allow/disallow is fine |
| 79 | + server.allow_checkpoint(2) |
| 80 | + self.assertEqual(server._step, 2) |
| 81 | + |
| 82 | + server.disallow_checkpoint() |
| 83 | + server.disallow_checkpoint() |
| 84 | + self.assertTrue(server._checkpoint_lock.locked()) |
| 85 | + self.assertTrue(server._disallowed) |
| 86 | + |
| 87 | + server.shutdown() |
| 88 | + |
| 89 | + def test_timed_acquire(self) -> None: |
| 90 | + lock = threading.Lock() |
| 91 | + |
| 92 | + with _timed_acquire(lock, timedelta(seconds=10)): |
| 93 | + self.assertTrue(lock.locked()) |
| 94 | + |
| 95 | + self.assertFalse(lock.locked()) |
| 96 | + |
| 97 | + lock.acquire() |
| 98 | + |
| 99 | + with self.assertRaisesRegex( |
| 100 | + TimeoutError, r"timed out acquiring lock after 0.0" |
| 101 | + ): |
| 102 | + with _timed_acquire(lock, timedelta(seconds=0.0)): |
| 103 | + pass |
| 104 | + |
| 105 | + self.assertTrue(lock.locked()) |
0 commit comments