Skip to content

Commit 5a5de43

Browse files
author
Orbax Authors
committed
#p2p Cleanup P2P restore directories after use
PiperOrigin-RevId: 866984150
1 parent 9cd8fc6 commit 5a5de43

File tree

4 files changed

+72
-11
lines changed

4 files changed

+72
-11
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Composite Checkpoint Manager handling P2P syncing with optional Persistent Fallback."""
1616

17+
import shutil
1718
import threading
1819
import time
1920
from typing import Any, Iterable, Mapping, Optional, Sequence, Union, final
@@ -366,15 +367,27 @@ def _restore_from_local_or_p2p(
366367
return self._local_manager.restore(step, args=args)
367368
else:
368369
logging.info('Step %d not found locally, fetching from P2P.', step)
369-
if self._p2p.fetch(step):
370-
p2p_restore_dir = self._local_directory / constants.P2P_RESTORE_DIR_NAME
371-
return self._local_manager.restore(
372-
step, args=args, directory=p2p_restore_dir
373-
)
374-
else:
375-
raise FileNotFoundError(
376-
f'Failed to fetch step {step} from P2P network.'
377-
)
370+
p2p_restore_dir = self._local_directory / constants.P2P_RESTORE_DIR_NAME
371+
try:
372+
if self._p2p.fetch(step):
373+
return self._local_manager.restore(
374+
step, args=args, directory=p2p_restore_dir
375+
)
376+
else:
377+
raise FileNotFoundError(
378+
f'Failed to fetch step {step} from P2P network.'
379+
)
380+
finally:
381+
if p2p_restore_dir.exists():
382+
logging.info(
383+
'Removing P2P restore directory: %s after restoration is'
384+
' complete',
385+
p2p_restore_dir,
386+
)
387+
try:
388+
shutil.rmtree(str(p2p_restore_dir))
389+
except OSError as e:
390+
logging.exception('Failed to remove P2P restore directory: %s', e)
378391

379392
@override
380393
def restore(

checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,38 @@ def test_restore_no_step_in_p2p_but_in_persistent(self, _):
402402

403403
manager.close()
404404

405+
@mock.patch.object(p2p_cm.shutil, 'rmtree', autospec=True)
406+
@mock.patch.object(multihost, 'process_index', return_value=0)
407+
def test_restore_p2p_cleanup(self, unused_process_index, mock_rmtree):
408+
"""Tests that P2P restore directory is cleaned up after restore."""
409+
self.local_manager_instance.scan_stored_steps.return_value = (0, [])
410+
self.local_manager_instance.all_steps.return_value = []
411+
self.mock_sync_global_data.return_value = []
412+
413+
# P2P fetch succeeds
414+
self.peer_selector_instance.get_source_peer.return_value = (
415+
protocol.PeerDiscoveryInfo(
416+
ip='1.2.3.4', port=5678, process_index=1, steps=[1]
417+
)
418+
)
419+
self.p2p_node_instance.fetch_shard_from_peer.return_value = True
420+
self.local_manager_instance.restore.return_value = {'a': 1}
421+
422+
manager = p2p_cm.CheckpointManager(
423+
self.mesh,
424+
self.abstract_state,
425+
self.local_dir,
426+
)
427+
428+
# Make p2p_restore_dir exist so cleanup is triggered
429+
p2p_restore_dir = self.local_dir / service.constants.P2P_RESTORE_DIR_NAME
430+
p2p_restore_dir.mkdir()
431+
432+
manager.restore(1, args=self.restore_args)
433+
434+
mock_rmtree.assert_called_once_with(str(p2p_restore_dir))
435+
manager.close()
436+
405437

406438
if __name__ == '__main__':
407439
absltest.main()

checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ def fetch_shard_from_peer(
256256
)
257257
return False
258258

259-
# TODO(exlin): Remove this directory once the transfer is globally completed
260-
# to save memory space.
261259
stage_dir = self.directory / f'stage_{step}_{stored_process_index}'
262260
if stage_dir.exists():
263261
shutil.rmtree(str(stage_dir))

checkpoint/orbax/checkpoint/experimental/emergency/p2p/service_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,24 @@ def download_side_effect(unused_ip, unused_port, rel_path, dest_path):
333333
mock_move.assert_called_once_with(str(stage_dir / '1'), str(final_dir))
334334
mock_rmtree.assert_called_with(str(stage_dir), ignore_errors=True)
335335

336+
@mock.patch.object(service.shutil, 'rmtree', autospec=True)
337+
@mock.patch.object(service.protocol.TCPClient, 'request', autospec=True)
338+
@mock.patch.object(service.protocol.TCPClient, 'download', autospec=True)
339+
def test_fetch_shard_from_peer_exception_cleanup(
340+
self,
341+
mock_download,
342+
mock_request,
343+
mock_rmtree,
344+
):
345+
"""Tests that stage_dir is cleaned up if an exception occurs."""
346+
mock_request.return_value = [{'rel_path': '1/file1', 'size': 10}]
347+
mock_download.side_effect = OSError('Download failed')
348+
349+
self.assertFalse(self.node.fetch_shard_from_peer('peer', 123, 1, 10))
350+
351+
stage_dir = self.temp_dir / 'stage_1_10'
352+
mock_rmtree.assert_called_with(str(stage_dir), ignore_errors=True)
353+
336354

337355
if __name__ == '__main__':
338356
absltest.main()

0 commit comments

Comments
 (0)