|
| 1 | +# Copyright 2026 The Orbax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from absl.testing import absltest |
| 16 | +from absl.testing import parameterized |
| 17 | +from etils import epath |
| 18 | +import jax |
| 19 | +import jax.numpy as jnp |
| 20 | +import numpy as np |
| 21 | +from orbax.checkpoint._src.testing.benchmarks import safetensors_benchmark |
| 22 | +from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs |
| 23 | +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core |
| 24 | +from orbax.checkpoint.experimental import v1 as ocp_v1 |
| 25 | +import safetensors.numpy as safe_np |
| 26 | + |
| 27 | +SafetensorsBenchmarkOptions = safetensors_benchmark.SafetensorsBenchmarkOptions |
| 28 | +SafetensorsBenchmark = safetensors_benchmark.SafetensorsBenchmark |
| 29 | + |
| 30 | + |
| 31 | +class SafetensorsBenchmarkTest(parameterized.TestCase): |
| 32 | + |
| 33 | + def setUp(self): |
| 34 | + super().setUp() |
| 35 | + self.test_dir = epath.Path(self.create_tempdir().full_path) |
| 36 | + self.checkpoint_path = self.test_dir / 'fake_checkpoint.safetensors' |
| 37 | + |
| 38 | + self.dummy_pytree = { |
| 39 | + 'tensor_a': jnp.ones((32, 1024), dtype=jnp.float32), |
| 40 | + 'scalar': jnp.ones((), dtype=jnp.float32), |
| 41 | + 'vector': jnp.ones((1024,), dtype=jnp.float32), |
| 42 | + } |
| 43 | + |
| 44 | + save_pytree = jax.tree.map(np.array, self.dummy_pytree) |
| 45 | + safe_np.save_file(save_pytree, str(self.checkpoint_path)) |
| 46 | + |
| 47 | + def test_benchmark_test_fn_sharded_load(self): |
| 48 | + # 1. Setup Benchmark Generator |
| 49 | + generator = SafetensorsBenchmark( |
| 50 | + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], |
| 51 | + options=SafetensorsBenchmarkOptions(), |
| 52 | + ) |
| 53 | + |
| 54 | + # 2. Create Test Context |
| 55 | + devices = np.array(jax.devices()) |
| 56 | + if devices.size == 1: |
| 57 | + devices = devices.reshape(1, 1) |
| 58 | + else: |
| 59 | + devices = devices.reshape(1, devices.size) # Keep it simple for this test |
| 60 | + mesh = jax.sharding.Mesh(devices, ('data', 'model')) |
| 61 | + options = SafetensorsBenchmarkOptions( |
| 62 | + checkpoint_path=str(self.checkpoint_path) |
| 63 | + ) |
| 64 | + |
| 65 | + context = benchmarks_core.TestContext( |
| 66 | + pytree={}, # Unused in this test_fn implementation |
| 67 | + path=self.checkpoint_path, |
| 68 | + options=options, |
| 69 | + mesh=mesh, |
| 70 | + ) |
| 71 | + |
| 72 | + # 3. Run the Benchmark Test Function |
| 73 | + result = generator.test_fn(context) |
| 74 | + |
| 75 | + # 4. Verify Benchmark Metrics |
| 76 | + self.assertIsInstance(result, benchmarks_core.TestResult) |
| 77 | + self.assertIn('metadata_load_time_duration', result.metrics.results) |
| 78 | + self.assertIn('data_load_sharded_time_duration', result.metrics.results) |
| 79 | + |
| 80 | + # 5. Verify Loaded Content by Reloading |
| 81 | + octx = ocp_v1.Context( |
| 82 | + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS |
| 83 | + ) |
| 84 | + with octx: |
| 85 | + metadata = ocp_v1.pytree_metadata(self.checkpoint_path) |
| 86 | + abstract_state = metadata.metadata |
| 87 | + restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state) |
| 88 | + |
| 89 | + self.assertEqual( |
| 90 | + jax.tree_util.tree_structure(restored_pytree), |
| 91 | + jax.tree_util.tree_structure(self.dummy_pytree), |
| 92 | + ) |
| 93 | + jax.tree.map( |
| 94 | + self.assertTrue, |
| 95 | + jax.tree.map( |
| 96 | + lambda a, b: np.array_equal(np.array(a), np.array(b)), |
| 97 | + restored_pytree, |
| 98 | + self.dummy_pytree, |
| 99 | + ), |
| 100 | + ) |
| 101 | + jax.tree.map( |
| 102 | + self.assertEqual, |
| 103 | + jax.tree.map(lambda a: a.shape, restored_pytree), |
| 104 | + jax.tree.map(lambda a: a.shape, self.dummy_pytree), |
| 105 | + ) |
| 106 | + jax.tree.map( |
| 107 | + self.assertEqual, |
| 108 | + jax.tree.map(lambda a: a.dtype, restored_pytree), |
| 109 | + jax.tree.map(lambda a: a.dtype, self.dummy_pytree), |
| 110 | + ) |
| 111 | + |
| 112 | + def test_benchmark_test_fn_rank_aware_sharding(self): |
| 113 | + # 1. Setup Benchmark Generator |
| 114 | + generator = SafetensorsBenchmark( |
| 115 | + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], |
| 116 | + options=SafetensorsBenchmarkOptions(), |
| 117 | + ) |
| 118 | + |
| 119 | + # 2. Create Test Context |
| 120 | + devices = np.array(jax.devices()) |
| 121 | + # Reshape devices to be 2D for the mesh axis names ('data', 'model') |
| 122 | + num_devices = devices.size |
| 123 | + if num_devices == 1: |
| 124 | + devices = devices.reshape(1, 1) |
| 125 | + elif num_devices == 2: |
| 126 | + devices = devices.reshape(1, 2) |
| 127 | + elif num_devices % 2 == 0: |
| 128 | + devices = devices.reshape(2, num_devices // 2) |
| 129 | + else: # Fallback for odd numbers, should not happen in typical test envs |
| 130 | + devices = devices.reshape(1, num_devices) |
| 131 | + mesh = jax.sharding.Mesh(devices, ('data', 'model')) |
| 132 | + options = SafetensorsBenchmarkOptions( |
| 133 | + checkpoint_path=str(self.checkpoint_path) |
| 134 | + ) |
| 135 | + |
| 136 | + context = benchmarks_core.TestContext( |
| 137 | + pytree={}, # Unused |
| 138 | + path=self.checkpoint_path, |
| 139 | + options=options, |
| 140 | + mesh=mesh, |
| 141 | + ) |
| 142 | + |
| 143 | + # 3. Run the Benchmark Test Function |
| 144 | + result = generator.test_fn(context) |
| 145 | + |
| 146 | + # 4. Verify Benchmark Metrics |
| 147 | + self.assertIsInstance(result, benchmarks_core.TestResult) |
| 148 | + self.assertIn('metadata_load_time_duration', result.metrics.results) |
| 149 | + self.assertIn('data_load_sharded_time_duration', result.metrics.results) |
| 150 | + |
| 151 | + # 5. Verify Loaded Content by Reloading |
| 152 | + octx = ocp_v1.Context( |
| 153 | + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS |
| 154 | + ) |
| 155 | + with octx: |
| 156 | + metadata = ocp_v1.pytree_metadata(self.checkpoint_path) |
| 157 | + abstract_state = metadata.metadata |
| 158 | + # Note: Sharding is not applied here, loading as is from the file. |
| 159 | + restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state) |
| 160 | + |
| 161 | + self.assertEqual( |
| 162 | + jax.tree_util.tree_structure(restored_pytree), |
| 163 | + jax.tree_util.tree_structure(self.dummy_pytree), |
| 164 | + ) |
| 165 | + jax.tree.map( |
| 166 | + self.assertTrue, |
| 167 | + jax.tree.map( |
| 168 | + lambda a, b: np.array_equal(np.array(a), np.array(b)), |
| 169 | + restored_pytree, |
| 170 | + self.dummy_pytree, |
| 171 | + ), |
| 172 | + ) |
| 173 | + jax.tree.map( |
| 174 | + self.assertEqual, |
| 175 | + jax.tree.map(lambda a: a.shape, restored_pytree), |
| 176 | + jax.tree.map(lambda a: a.shape, self.dummy_pytree), |
| 177 | + ) |
| 178 | + jax.tree.map( |
| 179 | + self.assertEqual, |
| 180 | + jax.tree.map(lambda a: a.dtype, restored_pytree), |
| 181 | + jax.tree.map(lambda a: a.dtype, self.dummy_pytree), |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +if __name__ == '__main__': |
| 186 | + absltest.main() |
0 commit comments