@@ -85,7 +85,11 @@ def best_step(self) -> Optional[int]:
85
85
86
86
@abc .abstractmethod
87
87
def reload (self ):
88
- """Performs disk reads to ensure internal properties are up to date."""
88
+ """Reloads internal properties.
89
+
90
+ Resets internal cache of checkpoint steps, in case the directory managed
91
+ by this object has been updated externally.
92
+ """
89
93
90
94
@abc .abstractmethod
91
95
def reached_preemption (self , step : int ) -> bool :
@@ -112,186 +116,25 @@ def delete(self, step: int):
112
116
def save (
113
117
self ,
114
118
step : int ,
115
- items : Optional [Union [Any , Mapping [str , Any ]]] = None ,
116
- save_kwargs : Optional [Union [SaveParams , Mapping [str , SaveParams ]]] = None ,
117
- metrics : Optional [PyTree ] = None ,
118
- force : Optional [bool ] = False ,
119
- args : Optional [args_lib .CheckpointArgs ] = None ,
120
- custom_metadata : dict [str , Any ] | None = None ,
119
+ * args ,
120
+ ** kwargs ,
121
121
) -> bool :
122
- """Saves the provided items.
123
-
124
- This method should be called by all hosts - process synchronization and
125
- actions that need to be performed on only one host are managed internally.
126
-
127
- NOTE: The `items` and `save_kwargs` arguments are deprecated, use `args`
128
- instead. Make sure to configure `CheckpointManager` with `item_names`.
129
-
130
- `args` should be a subclass of
131
- `orbax.checkpoint.args.CheckpointArgs`, the specific type of which is used
132
- to indicate what logic is used to save the object. For a typical, PyTree of
133
- arrays, use `StandardSave`/`StandardRestore`.
134
-
135
- When constructing the `CheckpointManager`, if no `item_names` were provided,
136
- it is assumed that we are managing a single object. If `item_names` were
137
- provided, it is assumed that we are managing multiple objects, and `args`
138
- must be `orbax.checkpoint.args.CompositeArgs`. See below for details.
139
-
140
- Example::
141
-
142
- # Single item
143
- mngr = ocp.CheckpointManager(directory)
144
- mngr.save(step, args=ocp.args.StandardSave(my_train_state))
145
-
146
- # Multiple items
147
- mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
148
- mngr.save(step, args=ocp.args.Composite(
149
- state=ocp.args.StandardSave(my_train_state),
150
- meta=ocp.args.JsonSave(my_metadata)
151
- ))
152
-
153
- Args:
154
- step: current step, int
155
- items: a savable object, or a dictionary of object name to savable object.
156
- save_kwargs: save kwargs for a single Checkpointer, or a dictionary of
157
- object name to kwargs needed by the Checkpointer implementation to save
158
- the object.
159
- metrics: a dictionary of metric name (string) to numeric value to be
160
- tracked along with this checkpoint. Required if `options.best_fn` is
161
- set. Allows users to specify a metric value to determine which
162
- checkpoints are best and should be kept (in conjunction with
163
- `options.max_to_keep`).
164
- force: if `True`, this method will attempt to save a checkpoint regardless
165
- of the result of `AbstractCheckpointManager.should_save(step)`. By
166
- default, `save` will only write a checkpoint to disk when the options
167
- permit, e.g. when `step` is in `options.save_interval_steps` or
168
- `options.save_on_steps`. Setting `force=True` will not overwrite
169
- existing checkpoints.
170
- args: `CheckpointArgs` which is used to save checkpointable objects with
171
- the appropriate logic.
172
- custom_metadata: a dictionary of custom metadata to be written to the
173
- checkpoint directory via StepMetadata.
174
-
175
- Returns:
176
- bool indicating whether a save operation was performed.
177
- Raises:
178
- ValueError: if `track_best` was indicated but `metrics` is not provided.
179
- ValueError: directory creation failed.
180
- ValueError: if an item is provided for which no `Checkpointer` is
181
- found.
182
- ValueError: if the checkpoint already exists.
183
- """
122
+ """Saves the given step."""
184
123
185
124
@abc .abstractmethod
186
125
def restore (
187
126
self ,
188
127
step : Optional [int ],
189
- items : Optional [Union [Any , Mapping [str , Any ]]] = None ,
190
- restore_kwargs : Optional [
191
- Union [RestoreParams , Mapping [str , RestoreParams ]]
192
- ] = None ,
193
- directory : Optional [epath .PathLike ] = None ,
194
- args : Optional [args_lib .CheckpointArgs ] = None ,
128
+ * args ,
129
+ ** kwargs ,
195
130
) -> Union [Any , Mapping [str , Any ], args_lib .Composite ]:
196
- """Restores from the given step and provided items.
197
-
198
- This method should be called by all hosts - process synchronization and
199
- actions that need to be performed on only one host are managed internally.
200
-
201
- NOTE: The `items` and `restore_kwargs` arguments are deprecated, use `args`
202
- instead. Make sure to configure `CheckpointManager` with `item_names`.
203
- See `save` docstring for additional details.
204
-
205
- Example::
206
-
207
- # Single item
208
- mngr = ocp.CheckpointManager(directory)
209
- mngr.restore(step, args=ocp.args.StandardRestore(abstract_train_state))
210
-
211
- # Multiple items
212
- mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
213
- mngr.restore(step, args=ocp.args.Composite(
214
- state=ocp.args.StandardRestore(abstract_train_state),
215
- meta=ocp.args.JsonRestore(),
216
- ))
217
- # If it is acceptable to restore without providing additional arguments,
218
- # and if a save has already been performed, it is ok to do the following:
219
- mngr.restore(step, args=ocp.args.Composite(state=None, meta=None))
220
- # If a save has not already been performed, there is no way for Orbax to
221
- # know how to restore the objects. If a save has already been performed,
222
- # it remembers the logic used to save the objects.
223
-
224
- Args:
225
- step: current step, int
226
- items: a restoreable object, or a dictionary of object name to restorable
227
- object.
228
- restore_kwargs: restore kwargs for a single Checkpointer, or a dictionary
229
- of object name to kwargs needed by the Checkpointer implementation to
230
- restore the object.
231
- directory: if provided, uses the given directory rather than the
232
- `directory` property of this class. Can be used to restore checkpoints
233
- from an independent location.
234
- args: `CheckpointArgs` which is used to restore checkpointable objects
235
- with the appropriate logic.
236
-
237
- Returns:
238
- If managing a single item, returns a single checkpointable object.
239
- If managing multiple items, returns ocp.args.Composite, where the keys
240
- are item names, and values are checkpointable objects.
241
- """
131
+ """Restores the given step."""
242
132
243
133
@abc .abstractmethod
244
134
def item_metadata (
245
135
self , step : int
246
136
) -> Union [Any , Mapping [str , Any ], args_lib .Composite ]:
247
- """For all Checkpointers, returns any metadata associated with the item.
248
-
249
- Calls the `metadata` method for each Checkpointer and returns a
250
- mapping of each item name to the restored metadata. If the manager only
251
- manages a single item, a single metadata will be returned instead.
252
-
253
- To avoid errors due to missing CheckpointHandlers, concrete
254
- CheckpointManager constructor must allow mapping from item names to
255
- respective CheckpointHandlers to be input other than via save() and
256
- restore(). Please note that save() and restore() calls automatically
257
- map CheckpointHandlers to respective item names and retain it during the
258
- lifetime of the CheckpointManager instance.
259
-
260
- Example::
261
-
262
- # Single item
263
- mngr = ocp.CheckpointManager(directory)
264
- # No calls to save() or restore() before calling item_metadata().
265
- mngr.item_metadata(step) # Raises error.
266
-
267
- mngr = ocp.CheckpointManager(directory,
268
- item_handlers=ocp.StandardCheckpointHandler)
269
- # No calls to save() or restore() before calling item_metadata().
270
- metadata = mngr.item_metadata(step) # Successful.
271
-
272
- # Multiple items
273
- mngr = ocp.CheckpointManager(directory, item_names=('state', 'extra'))
274
- # No calls to save() or restore() before calling item_metadata().
275
- mngr.item_metadata(step) # Raises error.
276
-
277
- mngr = ocp.CheckpointManager(directory,
278
- item_names=('state', 'extra'),
279
- item_handlers={
280
- 'state': ocp.StandardCheckpointHandler,
281
- 'extra': ocp.PytreeCheckpointHandler,
282
- }
283
- )
284
- # No calls to save() or restore() before calling item_metadata().
285
- metadata = mngr.item_metadata(step) # Successful.
286
-
287
- Metadata may be None for an individual item.
288
-
289
- Args:
290
- step: Step for which to retrieve metadata.
291
-
292
- Returns:
293
- A dictionary mapping name to item metadata, or a single item metadata.
294
- """
137
+ """Returns metadata for all known items."""
295
138
296
139
@abc .abstractmethod
297
140
def metadata (
0 commit comments