diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/p2p_checkpoint_manager_benchmark.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/p2p_checkpoint_manager_benchmark.yaml new file mode 100644 index 000000000..1dc28bfcb --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/p2p_checkpoint_manager_benchmark.yaml @@ -0,0 +1,52 @@ +# The name for the entire test suite run. +suite_name: "P2P CheckpointManager Benchmark" + +mesh_configs: + - mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"] + # ICI: Within a slice. Assuming 8 devices per slice. + # DCN: Across slices. + ici_parallelism: {"fsdp": 1, "tensor": 1, "data": 1} + dcn_parallelism: {"data": 1} # num_slices on the axis at replica_axis_index + process_is_granule: true + - mesh_axes: ["data", "model", "tensor", "fsdp"] + ici_parallelism: {"data": 1, "model": 1} + dcn_parallelism: {"data": 4, "model": 1} + - mesh_axes: ["data", "model", "tensor", "fsdp"] + ici_parallelism: {"data": 1, "model": 16} + dcn_parallelism: {"data": 2, "model": 1} + allow_split_physical_axes: true + - mesh_axes: ["data", "model", "tensor", "fsdp"] + ici_parallelism: {"data": 2, "model": 8} + dcn_parallelism: {"data": 2, "model": 1} + allow_split_physical_axes: true + - mesh_axes: ["data", "model", "tensor", "fsdp"] + ici_parallelism: {"data": 2, "model": 4} + dcn_parallelism: {"data": 2, "model": 1} + allow_split_physical_axes: true + +checkpoint_config: + spec: + a_1d: {dtype: "float32", shape: [32], sharding: [null]} + b_1d: {dtype: "float32", shape: [32], sharding: ["tensor"]} + c_2d: {dtype: "float32", shape: [32, 32], sharding: [null, "tensor"]} + d_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", null]} + e_2d: {dtype: "float32", shape: [32, 32], sharding: ["tensor", "fsdp"]} + f_2d: {dtype: "float32", shape: [32, 32], sharding: ["fsdp", "tensor"]} + g_2d: {dtype: "float32", shape: [32, 32], sharding: [null, null]} + h_3d: {dtype: "float32", shape: [32, 32, 32], sharding: ["tensor", null, "fsdp"]} + i_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "tensor"]} + j_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, "fsdp"]} + k_3d: {dtype: "float32", shape: [32, 32, 32], sharding: [null, null, null]} + custom_array: {dtype: "float32", shape: [8192, 64], sharding: ["tensor", null]} + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark" + options: + persistent_save_interval_steps: [2] + persistent_max_to_keep: [5] + local_save_interval_steps: [2] + local_max_to_keep: 2 + replica_axis_index: 0 + train_steps: 5 + experimental_orbax_use_distributed_process_id: true + experimental_use_distributed_id_for_mesh_consistency: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py index fb8a6c465..dd35236fa 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py @@ -159,6 +159,11 @@ def run(self, repeat_index: int | None = None) -> TestResult: path = directory_setup.setup_test_directory( self.name, self.output_dir, repeat_index ) + local_path = None + if self.local_directory is not None: + local_path = epath.Path(self.local_directory) / name + if repeat_index is not None: + local_path = local_path / f"repeat_{repeat_index}" with benchmark_metrics.measure( "sync_global_processes:benchmark:setup_test_directory" @@ -185,7 +190,7 @@ def run(self, repeat_index: int | None = None) -> TestResult: options=self.options, mesh=self.mesh, repeat_index=repeat_index, - local_path=self.local_directory, + local_path=local_path, ) test_context_summary = self._build_test_context_summary(context) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py index f02c6406b..62fbdebbc 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py @@ -40,8 +40,8 @@ def setup_test_directory( path = path / f"repeat_{repeat_index}" logging.info("Setting up test directory at: %s", path) if jax.process_index() == 0: - if path.exists(): + if path.exists() and not base_path.startswith("gs://"): logging.warning("Test directory %s already exists. Deleting it.", path) path.rmtree() - path.mkdir(parents=True, exist_ok=False) + path.mkdir(parents=True, exist_ok=True) return path diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/p2p_checkpoint_manager_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/p2p_checkpoint_manager_benchmark.py new file mode 100644 index 000000000..1e4748723 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/p2p_checkpoint_manager_benchmark.py @@ -0,0 +1,345 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for P2P CheckpointManager. + +This module contains benchmarks for +orbax.checkpoint.experimental.emergency.p2p.checkpoint_manager.CheckpointManager. +""" + +from collections.abc import Sequence +import dataclasses +import inspect +from typing import Any + +from absl import logging +from etils import epath +import jax +from orbax.checkpoint import checkpoint_utils +from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.core import mesh_utils +from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib +from orbax.checkpoint._src.testing.benchmarks.core import pytree_utils +from orbax.checkpoint._src.tree import utils +from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib +from orbax.checkpoint.experimental.emergency.p2p import checkpoint_manager as p2p_checkpoint_manager +from orbax.checkpoint.experimental.emergency.p2p import options as p2p_options + + +# ============================================================================== +# 1. Define the Options Dataclass +# ============================================================================== +@dataclasses.dataclass(frozen=True) +class P2pBenchmarkOptions(benchmarks_core.BenchmarkOptions): + """Configuration options for benchmarks targeting P2P CheckpointManager. + + Attributes: + persistent_save_interval_steps: The interval at which persistent checkpoints + should be saved. + persistent_max_to_keep: The maximum number of persistent checkpoints to + keep. + local_save_interval_steps: The interval at which local checkpoints should be + saved. + local_max_to_keep: The maximum number of local checkpoints to keep. + replica_axis_index: The index of the replica axis in the global mesh. + train_steps: The number of training steps to run. + """ + + persistent_save_interval_steps: int | Sequence[int] = 5 + persistent_max_to_keep: int | Sequence[int] = 5 + local_save_interval_steps: int | Sequence[int] = 2 + local_max_to_keep: int | Sequence[int] = 2 + replica_axis_index: int | Sequence[int] = 0 + train_steps: int | Sequence[int] = 10 + experimental_use_distributed_id_for_mesh_consistency: ( + bool | Sequence[bool] + ) = True + experimental_orbax_use_distributed_process_id: bool | Sequence[bool] = True + + +# ============================================================================== +# 2. Implement the Benchmark Generator +# ============================================================================== +def _create_checkpoint_manager( + local_directory: epath.Path, + persistent_directory: epath.Path, + global_mesh: jax.sharding.Mesh, + abstract_state: Any, + options: P2pBenchmarkOptions, +) -> p2p_checkpoint_manager.CheckpointManager: + """Creates a P2P CheckpointManager.""" + return p2p_checkpoint_manager.CheckpointManager( + local_directory=local_directory, + persistent_directory=persistent_directory, + global_mesh=global_mesh, + abstract_state=abstract_state, + options=p2p_options.CheckpointManagerOptions( + local=p2p_options.LocalCheckpointOptions( + save_interval_steps=options.local_save_interval_steps, + max_to_keep=options.local_max_to_keep, + ), + persistent=p2p_options.PersistentCheckpointOptions( + save_interval_steps=options.persistent_save_interval_steps, + max_to_keep=options.persistent_max_to_keep, + ), + replica_axis_index=options.replica_axis_index, + ), + ) + + +def _delete_checkpoints( + manager: p2p_checkpoint_manager.CheckpointManager, + step: int, + local_directory: epath.Path, + delete_before_restore: str = 'local_p0', +): + """Deletes checkpoints from the CheckpointManager.""" + step_dir = local_directory / str(step) + if delete_before_restore == 'local_p0': + if multihost.process_index() == 0 and step_dir.exists(): + logging.info( + 'Process 0: removing local checkpoint to trigger P2P restore.' + ) + step_dir.rmtree() + manager.reload() + elif delete_before_restore == 'local_all': + if step_dir.exists(): + logging.info( + 'All processes: removing local checkpoint to trigger GCS restore.' + ) + step_dir.rmtree() + manager.reload() + elif delete_before_restore == 'none': + logging.info('Skipping deletion of local checkpoint for local restore.') + else: + raise ValueError( + f'Invalid delete_before_restore: {delete_before_restore}' + ) + + +def _restore_and_validate( + manager: p2p_checkpoint_manager.CheckpointManager, + metrics: metric_lib.Metrics, + pytree: Any, + abstract_pytree: Any, + step: int, + restore_args: Any, + test_name: str = '', +): + """Restores a checkpoint and validates it.""" + prefix = f'{test_name}_' if test_name else '' + # Wait for save to complete on all hosts. + with metrics.measure(f'{prefix}sync_global_processes_{step}'): + multihost.sync_global_processes(f'{prefix}save_completed_{step}') + + with metrics.measure(f'{prefix}restore_{step}'): + restored = manager.restore( + step, + args=p2p_args_lib.Composite( + state=pytree_checkpoint_handler.PyTreeRestoreArgs( + restore_args=restore_args, + item=abstract_pytree, + ) + ), + )['state'] + logging.info('Assert Restored Pytree') + pytree_utils.assert_pytree_equal(pytree, restored) + with metrics.measure(f'{prefix}reload_after_restore_{step}'): + manager.reload() + + +@benchmarks_core.benchmark_options(P2pBenchmarkOptions) +class P2pCheckpointManagerBenchmark(benchmarks_core.BenchmarksGenerator): + """A generator for benchmarking P2P CheckpointManager.""" + + def _run_test( + self, + test_name: str, + context: benchmarks_core.TestContext, + metrics: metric_lib.Metrics, + abstract_pytree: Any, + restore_args: Any, + delete_before_restore: str = 'local_p0', + ): + """Runs a single test case.""" + logging.info('Running test: %s', test_name) + pytree = context.pytree + persistent_directory = context.path / test_name / 'persistent_p2p_ckpt' + if context.local_path is not None: + local_path = epath.Path(context.local_path) / test_name / 'local_p2p_ckpt' + local_directory = epath.Path(local_path) + else: + local_directory = ( + context.path + / test_name + / 'local_p2p_ckpt' + / f'process_{multihost.process_index()}' + ) + options = context.options + mesh = context.mesh + assert isinstance(options, P2pBenchmarkOptions) + + with metrics.measure(f'{test_name}_create_directories'): + if jax.process_index() == 0: + persistent_directory.mkdir(parents=True, exist_ok=True) + local_directory.mkdir(parents=True, exist_ok=True) + multihost.sync_global_processes(f'{test_name}_create_directories') + + with metrics.measure(f'{test_name}_create_checkpoint_manager'): + manager = _create_checkpoint_manager( + local_directory=local_directory, + persistent_directory=persistent_directory, + global_mesh=mesh, + abstract_state=abstract_pytree, + options=options, + ) + + step = manager.latest_step() + if step is not None: + logging.info('Latest step in test %s: %d', test_name, step) + + with metrics.measure(f'{test_name}_restore_and_validate_{step}'): + _restore_and_validate( + manager, + metrics, + pytree, + abstract_pytree, + step, + restore_args, + test_name=test_name, + ) + + start_step = step + 1 if step is not None else 0 + with metrics.measure(f'{test_name}_train_loop'): + for step in range(start_step, options.train_steps): + logging.info('Test %s: Training step %d', test_name, step) + with metrics.measure(f'{test_name}_save_{step}'): + manager.save( + step, + args=p2p_args_lib.Composite( + state=pytree_checkpoint_handler.PyTreeSaveArgs(pytree) + ), + ) + with metrics.measure(f'{test_name}_wait_until_finished_{step}'): + manager.wait_until_finished() + + if step % options.local_save_interval_steps == 0 and step != 0: + with metrics.measure(f'{test_name}_restore_and_validate_{step}'): + _delete_checkpoints( + manager, + step, + local_directory, + delete_before_restore=delete_before_restore, + ) + _restore_and_validate( + manager, + metrics, + pytree, + abstract_pytree, + step, + restore_args, + test_name=test_name, + ) + + manager.close() + + def test_local_restore( + self, + context: benchmarks_core.TestContext, + metrics: metric_lib.Metrics, + abstract_pytree: Any, + restore_args: Any, + ): + self._run_test( + 'test_local_restore', + context, + metrics, + abstract_pytree, + restore_args, + delete_before_restore='none', + ) + + def test_p2p_restore( + self, + context: benchmarks_core.TestContext, + metrics: metric_lib.Metrics, + abstract_pytree: Any, + restore_args: Any, + ): + self._run_test( + 'test_p2p_restore', + context, + metrics, + abstract_pytree, + restore_args, + delete_before_restore='local_p0', + ) + + def test_gcs_restore( + self, + context: benchmarks_core.TestContext, + metrics: metric_lib.Metrics, + abstract_pytree: Any, + restore_args: Any, + ): + self._run_test( + 'test_gcs_restore', + context, + metrics, + abstract_pytree, + restore_args, + delete_before_restore='local_all', + ) + + def test_fn( + self, context: benchmarks_core.TestContext + ) -> benchmarks_core.TestResult: + """The core test logic for a single save/restore cycle.""" + metrics = metric_lib.Metrics() + pytree = context.pytree + options = context.options + mesh = context.mesh + assert isinstance(options, P2pBenchmarkOptions) + + if mesh is None: + raise ValueError( + 'Mesh must be provided for P2pCheckpointManagerBenchmark' + ) + if not multihost.is_runtime_to_distributed_ids_initialized(): + multihost.initialize_runtime_to_distributed_ids() + + if not multihost.is_distributed_to_device_ids_initialized(): + multihost.initialize_distributed_to_device_ids() + + mesh_utils.pretty_log_mesh('Global Mesh: ', mesh) + + with metrics.measure('create_abstract_pytree'): + abstract_pytree = jax.tree.map(utils.to_shape_dtype_struct, pytree) + logging.info('abstract_pytree: %r', abstract_pytree) + + with metrics.measure('create_restore_args'): + restore_args = checkpoint_utils.construct_restore_args(abstract_pytree) + logging.info('restore_args: %r', restore_args) + + tests_to_run = [] + for name, method in inspect.getmembers(self, predicate=inspect.ismethod): + if name.startswith('test_') and name != 'test_fn': + tests_to_run.append(method) + + for test in tests_to_run: + test(context, metrics, abstract_pytree, restore_args) + + return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/p2p_checkpoint_manager_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/p2p_checkpoint_manager_benchmark_test.py new file mode 100644 index 000000000..0274932cf --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/p2p_checkpoint_manager_benchmark_test.py @@ -0,0 +1,466 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.testing.benchmarks import p2p_checkpoint_manager_benchmark +from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.core import mesh_utils +from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib +from orbax.checkpoint.experimental.emergency.p2p import checkpoint_manager as p2p_checkpoint_manager +from orbax.checkpoint.experimental.emergency.p2p import options as p2p_options + + +P2pBenchmarkOptions = p2p_checkpoint_manager_benchmark.P2pBenchmarkOptions +P2pCheckpointManagerBenchmark = ( + p2p_checkpoint_manager_benchmark.P2pCheckpointManagerBenchmark +) + + +class P2pCheckpointManagerBenchmarkTest(parameterized.TestCase): + + def test_test_fn_runs_benchmark_and_saves_metrics(self): + mock_checkpoint_manager_cls = self.enter_context( + mock.patch.object( + p2p_checkpoint_manager, 'CheckpointManager', autospec=True + ) + ) + mock_sync_global_processes = self.enter_context( + mock.patch.object(multihost, 'sync_global_processes', autospec=True) + ) + mock_is_runtime_to_distributed_ids_initialized = self.enter_context( + mock.patch.object( + multihost, + 'is_runtime_to_distributed_ids_initialized', + autospec=True, + ) + ) + mock_initialize_runtime_to_distributed_ids = self.enter_context( + mock.patch.object( + multihost, 'initialize_runtime_to_distributed_ids', autospec=True + ) + ) + mock_is_distributed_to_device_ids_initialized = self.enter_context( + mock.patch.object( + multihost, 'is_distributed_to_device_ids_initialized', autospec=True + ) + ) + mock_initialize_distributed_to_device_ids = self.enter_context( + mock.patch.object( + multihost, 'initialize_distributed_to_device_ids', autospec=True + ) + ) + mock_pretty_log_mesh = self.enter_context( + mock.patch.object(mesh_utils, 'pretty_log_mesh', autospec=True) + ) + self.enter_context( + mock.patch.object(multihost, 'process_index', return_value=0) + ) + benchmark = P2pCheckpointManagerBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=P2pBenchmarkOptions(), + ) + mesh_shape = (jax.device_count(), 1) + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(mesh_shape), ('data', 'model') + ) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('data', 'model') + ) + pytree = { + 'a': jax.device_put(np.arange(16).reshape((4, 4)), sharding), + } + mock_checkpoint_manager = mock_checkpoint_manager_cls.return_value + mock_checkpoint_manager.restore.return_value = {'state': pytree} + mock_checkpoint_manager.latest_step.return_value = None + test_dir = os.path.join(self.create_tempdir().full_path, 'test_test_fn') + os.makedirs(test_dir, exist_ok=True) + context = benchmarks_core.TestContext( + pytree=pytree, + path=epath.Path(test_dir), + options=P2pBenchmarkOptions(train_steps=2, local_save_interval_steps=1), + mesh=mesh, + ) + mock_is_runtime_to_distributed_ids_initialized.return_value = False + mock_is_distributed_to_device_ids_initialized.return_value = False + + td = epath.Path(test_dir) + (td / 'test_local_restore' / 'local_p2p_ckpt' / 'process_0' / '1').mkdir( + parents=True, exist_ok=True + ) + (td / 'test_p2p_restore' / 'local_p2p_ckpt' / 'process_0' / '1').mkdir( + parents=True, exist_ok=True + ) + (td / 'test_gcs_restore' / 'local_p2p_ckpt' / 'process_0' / '1').mkdir( + parents=True, exist_ok=True + ) + + result = benchmark.test_fn(context) + + with self.subTest('multihost initialization'): + mock_is_runtime_to_distributed_ids_initialized.assert_called_once() + mock_initialize_runtime_to_distributed_ids.assert_called_once() + mock_is_distributed_to_device_ids_initialized.assert_called_once() + mock_initialize_distributed_to_device_ids.assert_called_once() + with self.subTest('sync and mesh calls'): + mock_sync_global_processes.assert_called() + mock_pretty_log_mesh.assert_called_once() + with self.subTest('benchmark result type'): + self.assertIsInstance(result, benchmarks_core.TestResult) + with self.subTest('metrics timings'): + self.assertIn( + 'create_abstract_pytree_time_duration', result.metrics.results + ) + self.assertIn('create_restore_args_time_duration', result.metrics.results) + self.assertIn( + 'test_local_restore_create_directories_time_duration', + result.metrics.results, + ) + self.assertIn( + 'test_local_restore_create_checkpoint_manager_time_duration', + result.metrics.results, + ) + self.assertIn( + 'test_local_restore_train_loop_time_duration', result.metrics.results + ) + self.assertIn( + 'test_local_restore_save_0_time_duration', result.metrics.results + ) + self.assertIn( + 'test_local_restore_wait_until_finished_0_time_duration', + result.metrics.results, + ) + self.assertIn( + 'test_local_restore_save_1_time_duration', result.metrics.results + ) + self.assertIn( + 'test_local_restore_wait_until_finished_1_time_duration', + result.metrics.results, + ) + self.assertIn( + 'test_local_restore_restore_and_validate_1_time_duration', + result.metrics.results, + ) + with self.subTest('checkpoint manager calls'): + num_tests = 3 # local, p2p, gcs + self.assertEqual( + mock_checkpoint_manager_cls.call_count, + num_tests, + ) + self.assertEqual(mock_checkpoint_manager.save.call_count, num_tests * 2) + self.assertEqual( + mock_checkpoint_manager.wait_until_finished.call_count, num_tests * 2 + ) + self.assertEqual(mock_checkpoint_manager.restore.call_count, num_tests) + self.assertEqual(mock_checkpoint_manager.close.call_count, num_tests) + # p2p and gcs restore cause dir deletion, so +1 reload each. + # 1 reload after restore for each of 3 tests. + # so 1*3 + 2 = 5 reloads. + self.assertEqual(mock_checkpoint_manager.reload.call_count, 5) + + def test_generate_benchmarks_creates_multiple_benchmark_configs(self): + options = P2pBenchmarkOptions( + persistent_save_interval_steps=[5, 10], + local_save_interval_steps=[2, 4], + replica_axis_index=[0, 1], + ) + benchmark = P2pCheckpointManagerBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=options, + ) + benchmarks = benchmark.generate() + + self.assertLen(benchmarks, 8) + for b in benchmarks: + self.assertIsInstance(b.options, P2pBenchmarkOptions) + + @parameterized.parameters( + dict( + options=P2pBenchmarkOptions( + persistent_save_interval_steps=10, + persistent_max_to_keep=2, + local_save_interval_steps=3, + local_max_to_keep=3, + replica_axis_index=1, + train_steps=5, + ) + ), + dict( + options=P2pBenchmarkOptions( + train_steps=1, local_save_interval_steps=1 + ) + ), + ) + def test_test_fn_applies_benchmark_options_correctly(self, options): + mock_checkpoint_manager_cls = self.enter_context( + mock.patch.object( + p2p_checkpoint_manager, 'CheckpointManager', autospec=True + ) + ) + mock_checkpoint_manager_options_cls = self.enter_context( + mock.patch.object( + p2p_options, + 'CheckpointManagerOptions', + autospec=True, + ) + ) + mock_sync_global_processes = self.enter_context( + mock.patch.object(multihost, 'sync_global_processes', autospec=True) + ) + mock_is_runtime_to_distributed_ids_initialized = self.enter_context( + mock.patch.object( + multihost, + 'is_runtime_to_distributed_ids_initialized', + autospec=True, + ) + ) + mock_initialize_runtime_to_distributed_ids = self.enter_context( + mock.patch.object( + multihost, 'initialize_runtime_to_distributed_ids', autospec=True + ) + ) + mock_is_distributed_to_device_ids_initialized = self.enter_context( + mock.patch.object( + multihost, 'is_distributed_to_device_ids_initialized', autospec=True + ) + ) + mock_initialize_distributed_to_device_ids = self.enter_context( + mock.patch.object( + multihost, 'initialize_distributed_to_device_ids', autospec=True + ) + ) + mock_pretty_log_mesh = self.enter_context( + mock.patch.object(mesh_utils, 'pretty_log_mesh', autospec=True) + ) + self.benchmark = P2pCheckpointManagerBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=options, + ) + mesh_shape = (jax.device_count(), 1) + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(mesh_shape), ('data', 'model') + ) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('data', 'model') + ) + pytree = { + 'a': jax.device_put(np.arange(16).reshape((4, 4)), sharding), + } + mock_checkpoint_manager = mock_checkpoint_manager_cls.return_value + mock_checkpoint_manager.restore.return_value = {'state': pytree} + mock_checkpoint_manager.latest_step.return_value = None + test_dir = os.path.join(self.create_tempdir().full_path, 'test_test_fn') + os.makedirs(test_dir, exist_ok=True) + context = benchmarks_core.TestContext( + pytree=pytree, + path=epath.Path(test_dir), + options=options, + mesh=mesh, + ) + mock_is_runtime_to_distributed_ids_initialized.return_value = False + mock_is_distributed_to_device_ids_initialized.return_value = False + + self.benchmark.test_fn(context) + + with self.subTest('multihost initialization'): + mock_is_runtime_to_distributed_ids_initialized.assert_called_once() + mock_initialize_runtime_to_distributed_ids.assert_called_once() + mock_is_distributed_to_device_ids_initialized.assert_called_once() + mock_initialize_distributed_to_device_ids.assert_called_once() + with self.subTest('options propagation'): + self.assertEqual(mock_checkpoint_manager_cls.call_count, 3) + mock_checkpoint_manager_options_cls.assert_called_with( + local=mock.ANY, + persistent=mock.ANY, + replica_axis_index=options.replica_axis_index, + ) + local_options = mock_checkpoint_manager_options_cls.call_args[1]['local'] + self.assertEqual( + local_options.save_interval_steps, options.local_save_interval_steps + ) + self.assertEqual(local_options.max_to_keep, options.local_max_to_keep) + persistent_options = mock_checkpoint_manager_options_cls.call_args[1][ + 'persistent' + ] + self.assertEqual( + persistent_options.save_interval_steps, + options.persistent_save_interval_steps, + ) + self.assertEqual( + persistent_options.max_to_keep, options.persistent_max_to_keep + ) + with self.subTest('mesh setup calls'): + mock_sync_global_processes.assert_called() + mock_pretty_log_mesh.assert_called_once() + + +class HelperFunctionsTest(parameterized.TestCase): + + @mock.patch.object(p2p_checkpoint_manager, 'CheckpointManager', autospec=True) + @mock.patch.object(p2p_options, 'CheckpointManagerOptions', autospec=True) + def test_create_checkpoint_manager( + self, + mock_checkpoint_manager_options_cls, + mock_checkpoint_manager_cls, + ): + local_dir = epath.Path('/tmp/local') + persistent_dir = epath.Path('/tmp/persistent') + mesh = jax.sharding.Mesh(np.array(jax.devices()), ('data',)) + abstract_state = {'a': jax.ShapeDtypeStruct(shape=(4,), dtype=np.int32)} + options = P2pBenchmarkOptions( + local_save_interval_steps=2, + local_max_to_keep=3, + persistent_save_interval_steps=5, + persistent_max_to_keep=6, + replica_axis_index=1, + ) + + p2p_checkpoint_manager_benchmark._create_checkpoint_manager( + local_dir, persistent_dir, mesh, abstract_state, options + ) + + mock_checkpoint_manager_options_cls.assert_called_once_with( + local=p2p_options.LocalCheckpointOptions( + save_interval_steps=2, + max_to_keep=3, + ), + persistent=p2p_options.PersistentCheckpointOptions( + save_interval_steps=5, + max_to_keep=6, + ), + replica_axis_index=1, + ) + mock_checkpoint_manager_cls.assert_called_once_with( + local_directory=local_dir, + persistent_directory=persistent_dir, + global_mesh=mesh, + abstract_state=abstract_state, + options=mock_checkpoint_manager_options_cls.return_value, + ) + + +class DeleteCheckpointsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.manager = mock.create_autospec( + p2p_checkpoint_manager.CheckpointManager, instance=True + ) + self.local_directory = epath.Path(self.create_tempdir().full_path) + + @mock.patch.object(multihost, 'process_index', return_value=0) + def test_delete_checkpoints_local_p0(self, mock_process_index): + del mock_process_index + step = 0 + step_dir = self.local_directory / str(step) + step_dir.mkdir() + self.assertTrue(step_dir.exists()) + p2p_checkpoint_manager_benchmark._delete_checkpoints( + self.manager, step, self.local_directory, 'local_p0' + ) + self.assertFalse(step_dir.exists()) + self.manager.reload.assert_called_once() + + @mock.patch.object(multihost, 'process_index', return_value=1) + def test_delete_checkpoints_local_p0_non_p0(self, mock_process_index): + del mock_process_index + step = 0 + step_dir = self.local_directory / str(step) + step_dir.mkdir() + self.assertTrue(step_dir.exists()) + p2p_checkpoint_manager_benchmark._delete_checkpoints( + self.manager, step, self.local_directory, 'local_p0' + ) + self.assertTrue(step_dir.exists()) + self.manager.reload.assert_not_called() + + @mock.patch.object(multihost, 'process_index', return_value=0) + def test_delete_checkpoints_local_all(self, mock_process_index): + del mock_process_index + step = 0 + step_dir = self.local_directory / str(step) + step_dir.mkdir() + self.assertTrue(step_dir.exists()) + p2p_checkpoint_manager_benchmark._delete_checkpoints( + self.manager, step, self.local_directory, 'local_all' + ) + self.assertFalse(step_dir.exists()) + self.manager.reload.assert_called_once() + + @mock.patch.object(multihost, 'process_index', return_value=0) + def test_delete_checkpoints_none(self, mock_process_index): + del mock_process_index + step = 0 + step_dir = self.local_directory / str(step) + step_dir.mkdir() + self.assertTrue(step_dir.exists()) + p2p_checkpoint_manager_benchmark._delete_checkpoints( + self.manager, step, self.local_directory, 'none' + ) + self.assertTrue(step_dir.exists()) + self.manager.reload.assert_not_called() + + +class RestoreAndValidateTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.manager = mock.create_autospec( + p2p_checkpoint_manager.CheckpointManager, instance=True + ) + + @mock.patch.object(multihost, 'sync_global_processes', autospec=True) + @mock.patch( + 'orbax.checkpoint._src.testing.benchmarks.p2p_checkpoint_manager_benchmark.pytree_utils.assert_pytree_equal', + autospec=True, + ) + def test_restore_and_validate_succeeds( + self, + mock_assert_pytree_equal, + mock_sync_global_processes, + ): + metrics = metric_lib.Metrics() + pytree = {'a': np.array([1, 2, 3])} + abstract_pytree = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), pytree + ) + step = 0 + restore_args = {'a': mock.MagicMock()} + self.manager.restore.return_value = {'state': pytree} + + p2p_checkpoint_manager_benchmark._restore_and_validate( + self.manager, + metrics, + pytree, + abstract_pytree, + step, + restore_args=restore_args, + ) + + mock_sync_global_processes.assert_called_once_with('save_completed_0') + self.manager.reload.assert_called_once() + self.manager.restore.assert_called_once() + mock_assert_pytree_equal.assert_called_once_with(pytree, pytree) + + +if __name__ == '__main__': + absltest.main()