|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
1 | 4 | import concurrent.futures as futures |
2 | 5 | import dataclasses |
3 | 6 | import logging |
|
6 | 9 | from etils import epath |
7 | 10 | import jax |
8 | 11 | import orbax.checkpoint as ocp |
| 12 | +import orbax.checkpoint.future as future |
9 | 13 |
|
10 | 14 | from openpi.shared import array_typing as at |
11 | 15 | import openpi.shared.normalize as _normalize |
@@ -117,18 +121,12 @@ def __call__(self, directory: epath.Path) -> None: ... |
117 | 121 | class CallbackHandler(ocp.AsyncCheckpointHandler): |
118 | 122 | """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.""" |
119 | 123 |
|
120 | | - def __init__(self): |
121 | | - self._executor = futures.ThreadPoolExecutor(max_workers=1) |
122 | | - |
123 | | - def close(self): |
124 | | - self._executor.shutdown() |
125 | | - |
126 | | - def save(self, directory: epath.Path, args: "CallbackSave"): |
| 124 | + def save(self, directory: epath.Path, args: CallbackSave): |
127 | 125 | if jax.process_index() == 0: |
128 | 126 | args.callback(directory) |
129 | 127 |
|
130 | | - async def async_save(self, directory: epath.Path, args: "CallbackSave") -> list[futures.Future]: |
131 | | - return [self._executor.submit(self.save, directory, args)] |
| 128 | + async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]: |
| 129 | + return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))] |
132 | 130 |
|
133 | 131 | def restore(self, *args, **kwargs): |
134 | 132 | raise NotImplementedError("CallbackHandler does not support restore") |
|
0 commit comments