Skip to content

_validate_params fails on zero-sized arrays #1309

Open
@hrbigelow

Description

@hrbigelow

Hi,

@niketkumar @cpgaffney1,

cc @dionhaefner

The following attempts to serialize a zero-sized array, but it fails validation in _validate_params.

I believe the problem is that _validate_params expects to find for every 'foo/.zarray' entry, a matching data entry foo/0. However, this code produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.

I'm actually not sure if tensorstore saves an entry z/0 or not, or what the intended behavior should be.

Any insight would be greatly appreciated!

import jax.numpy as jnp
import jax.tree_util as jtu
import tempfile
import orbax.checkpoint as ocp

target = {
    'a': jnp.array([1, 2, 3], jnp.int32),
    'z': jnp.zeros((0,)),
}

orbax_checkpointer = ocp.Checkpointer(
  ocp.PyTreeCheckpointHandler()
)

with tempfile.TemporaryDirectory() as ckpt_path:
  overwrite = True
  save_args = jtu.tree_map(lambda _: ocp.SaveArgs(), target)
  orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
(jax_env) henry@henry-gs65:orbax$ python flax4309.py 
Traceback (most recent call last):
  File "/home/henry/ai/projects/orbax/flax4309.py", line 18, in <module>
    orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/checkpointer.py", line 216, in save
    self._handler.finalize(tmpdir.get())
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 1004, in finalize
    self._handler_impl.finalize(directory)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 806, in finalize
    asyncio_utils.run_sync(
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 704, in merge_ocdbt_per_process_files
    await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 625, in _validate_params
    raise ValueError(
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/tmp/tmpbxi1zpec.orbax-checkpoint-tmp-0/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '3ef941407cca4f778414e9e92b15dedb',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions