Skip to content

Commit 072217e

Browse files
authored
Fix Orbax error (#548)
This is addressing the following exception that can sometimes happen during checkpointing: ``` FileExistsError: [Errno 17] File exists: '..../10000.orbax-checkpoint-tmp-4' ```
2 parents 0992224 + 4fce415 commit 072217e

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

scripts/train_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_train(tmp_path: pathlib.Path, config_name: str):
1616
config = dataclasses.replace(
1717
_config._CONFIGS_DICT[config_name], # noqa: SLF001
1818
batch_size=2,
19-
checkpoint_base_dir=tmp_path / "checkpoint",
19+
checkpoint_base_dir=str(tmp_path / "checkpoint"),
2020
exp_name="test",
2121
overwrite=False,
2222
resume=False,

src/openpi/training/checkpoints.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
import asyncio
14
import concurrent.futures as futures
25
import dataclasses
36
import logging
@@ -6,6 +9,7 @@
69
from etils import epath
710
import jax
811
import orbax.checkpoint as ocp
12+
import orbax.checkpoint.future as future
913

1014
from openpi.shared import array_typing as at
1115
import openpi.shared.normalize as _normalize
@@ -117,18 +121,12 @@ def __call__(self, directory: epath.Path) -> None: ...
117121
class CallbackHandler(ocp.AsyncCheckpointHandler):
118122
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
119123

120-
def __init__(self):
121-
self._executor = futures.ThreadPoolExecutor(max_workers=1)
122-
123-
def close(self):
124-
self._executor.shutdown()
125-
126-
def save(self, directory: epath.Path, args: "CallbackSave"):
124+
def save(self, directory: epath.Path, args: CallbackSave):
127125
if jax.process_index() == 0:
128126
args.callback(directory)
129127

130-
async def async_save(self, directory: epath.Path, args: "CallbackSave") -> list[futures.Future]:
131-
return [self._executor.submit(self.save, directory, args)]
128+
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
129+
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
132130

133131
def restore(self, *args, **kwargs):
134132
raise NotImplementedError("CallbackHandler does not support restore")

0 commit comments

Comments
 (0)