Skip to content

Commit ad2d2f6

Browse files
author
Orbax Authors
committed
Validate checkpoint before writing merged metadata
This avoids re-reading the metadata after writing it. PiperOrigin-RevId: 831073963
1 parent 5a0bd47 commit ad2d2f6

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

checkpoint/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Changed
11+
12+
- Validate checkpoints before writing merged OCDBT database using in-memory
13+
state, avoiding additional I/O to re-read metadata.
14+
1015
## [0.11.28] - 2025-11-06
1116

1217
### Added

checkpoint/orbax/checkpoint/_src/serialization/ocdbt_utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131

3232

3333
async def _validate_params(
34-
directory: epath.Path,
35-
ts_context: ts.Context,
34+
ts_kv_store: ts.KvStore,
3635
use_zarr3: bool,
3736
) -> None:
3837
"""Validates the params present in tensorstore KvStore.
@@ -42,15 +41,9 @@ async def _validate_params(
4241
NOTE: Support for zarr3 will be added later.
4342
4443
Args:
45-
directory: checkpoint location.
46-
ts_context: Tensorstore context.
44+
ts_kv_store: Open kvstore to validate, with transaction if applicable.
4745
use_zarr3: If True, use zarr3 driver, otherwise, use zarr driver.
4846
"""
49-
merged_kvstore_tspec = ts_utils.build_kvstore_tspec(
50-
directory.as_posix(), use_ocdbt=True
51-
)
52-
ts_kv_store = await ts_utils.open_kv_store(merged_kvstore_tspec, ts_context)
53-
5447
# TODO: b/362328389 - Add support for zarr3.
5548
if use_zarr3:
5649
logging.info(
@@ -196,11 +189,11 @@ async def merge_ocdbt_per_process_files(
196189
child.experimental_copy_range_to(parent.with_transaction(txn))
197190
)
198191
await asyncio.gather(*copy_ops)
199-
await txn.commit_async()
200192

201193
# Validate merged params.
202194
if enable_validation:
203-
await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
195+
_validate_params(parent.with_transaction(txn), use_zarr3=use_zarr3)
196+
await txn.commit_async()
204197

205198

206199
def get_process_index_for_subdir(

checkpoint/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
'jax >= 0.6.0',
3030
'numpy',
3131
'pyyaml',
32-
'tensorstore >= 0.1.71',
32+
'tensorstore >= 0.1.74',
3333
'nest_asyncio',
3434
'aiofiles',
3535
'protobuf',

0 commit comments

Comments
 (0)