Skip to content

Commit c421a44

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Add a resharding benchmark. This benchmark only loads the checkpoint repeatedly. The source checkpoint is expected to have been generated on a different topology and/or different sharding. It relies on a sharding config file to dictate the new shardings for the loaded checkpoint.
PiperOrigin-RevId: 874406266
1 parent eb178fb commit c421a44

File tree

9 files changed

+368
-65
lines changed

9 files changed

+368
-65
lines changed

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/checkpoint_generation.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -127,34 +127,28 @@ def _partition_axis_name(offset: int) -> str:
127127

128128

129129

130-
def _get_abstract_state(
131-
config: configs.CheckpointConfig,
130+
def get_abstract_state_with_generated_shardings(pytree_metadata: Any) -> Any:
131+
abstract_state = jax.tree.map(
132+
abstract_arrays.to_shape_dtype_struct, pytree_metadata
133+
)
134+
shardings = sharding_utils.construct_maximal_shardings(abstract_state)
135+
return jax.tree.map(
136+
lambda sds, sharding: jax.ShapeDtypeStruct(
137+
sds.shape, sds.dtype, sharding=sharding
138+
),
139+
abstract_state,
140+
shardings,
141+
)
142+
143+
144+
def get_abstract_state_from_sharding_config(
145+
sharding_config_path: epath.Path,
146+
metadata: Any,
132147
*,
133-
use_ocdbt: bool,
134-
devices: list[jax.Device] | None = None,
148+
devices: list[jax.Device],
135149
) -> Any:
136-
"""Loads sharding configuration from a JSON file."""
137-
path = epath.Path(config.path)
138-
devices = devices or jax.devices()
139-
with checkpointer.Checkpointer(
140-
pytree_checkpoint_handler.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt)
141-
) as ckptr:
142-
metadata = ckptr.metadata(path).item_metadata
143-
144-
if config.sharding_config_path is None:
145-
abstract_state = jax.tree.map(
146-
abstract_arrays.to_shape_dtype_struct, metadata.tree
147-
)
148-
shardings = sharding_utils.construct_maximal_shardings(abstract_state)
149-
return jax.tree.map(
150-
lambda sds, sharding: jax.ShapeDtypeStruct(
151-
sds.shape, sds.dtype, sharding=sharding
152-
),
153-
abstract_state,
154-
shardings,
155-
)
156-
157-
path = epath.Path(config.sharding_config_path)
150+
"""Loads abstract state from a JSON file."""
151+
path = epath.Path(sharding_config_path)
158152
parsed_config = json.loads(path.read_text())
159153
flat_abstract_state = {}
160154
for k, v in parsed_config.items():
@@ -169,9 +163,28 @@ def _get_abstract_state(
169163
spec=jax.sharding.PartitionSpec(*v['sharding']['spec']),
170164
),
171165
)
172-
return tree_utils.from_flat_dict(
173-
flat_abstract_state, metadata.tree, sep='.'
174-
)
166+
return tree_utils.from_flat_dict(flat_abstract_state, metadata, sep='.')
167+
168+
169+
def _get_abstract_state(
170+
config: configs.CheckpointConfig,
171+
*,
172+
use_ocdbt: bool,
173+
devices: list[jax.Device] | None = None,
174+
) -> Any:
175+
"""Creates abstract state for a provided CheckpointConfig."""
176+
path = epath.Path(config.path)
177+
devices = devices or jax.devices()
178+
with checkpointer.Checkpointer(
179+
pytree_checkpoint_handler.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt)
180+
) as ckptr:
181+
metadata = ckptr.metadata(path).item_metadata
182+
183+
if config.sharding_config_path is None:
184+
return get_abstract_state_with_generated_shardings(metadata.tree)
185+
return get_abstract_state_from_sharding_config(
186+
epath.Path(config.sharding_config_path), metadata, devices=devices
187+
)
175188

176189

177190
def load_checkpoint(config: configs.CheckpointConfig) -> Any:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/config_parsing.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,6 @@ def _validate_config(config: Dict[str, Any]) -> None:
8484
if key not in config:
8585
raise ValueError(f'Missing required key in YAML config: {key}')
8686

87-
if 'checkpoint_config' not in config and 'checkpoint_configs' not in config:
88-
raise ValueError(
89-
'Missing required key in YAML config: checkpoint_config or'
90-
' checkpoint_configs'
91-
)
92-
9387
if not isinstance(config['benchmarks'], list):
9488
raise ValueError("'benchmarks' must be a list.")
9589

@@ -137,10 +131,12 @@ def create_test_suite_from_config(
137131
checkpoint_configs = [
138132
config_lib.CheckpointConfig(**cc) for cc in config['checkpoint_configs']
139133
]
140-
else:
134+
elif 'checkpoint_config' in config:
141135
checkpoint_configs = [
142136
config_lib.CheckpointConfig(**config['checkpoint_config'])
143137
]
138+
else:
139+
checkpoint_configs = [config_lib.CheckpointConfig()]
144140

145141
if 'mesh_configs' in config:
146142
mesh_configs = [

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/config_parsing_test.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,6 @@ def test_missing_required_keys(self, key_to_remove):
109109
):
110110
config_parsing._validate_config(config)
111111

112-
def test_missing_checkpoint_config_and_configs(self):
113-
config = self._get_valid_config()
114-
del config['checkpoint_config']
115-
with self.assertRaisesRegex(
116-
ValueError,
117-
'Missing required key in YAML config: checkpoint_config or'
118-
' checkpoint_configs',
119-
):
120-
config_parsing._validate_config(config)
121-
122112
def test_benchmarks_not_list(self):
123113
config = self._get_valid_config()
124114
config['benchmarks'] = {}
@@ -366,6 +356,29 @@ def test_valid_creation_with_checkpoint_configs(self, mock_import, mock_load):
366356
],
367357
)
368358

359+
@mock.patch.object(config_parsing, '_load_yaml_config', autospec=True)
360+
@mock.patch.object(config_parsing, '_import_class', autospec=True)
361+
def test_valid_creation_no_checkpoint_config(self, mock_import, mock_load):
362+
yaml_content = """
363+
suite_name: No Checkpoint Config
364+
benchmarks:
365+
-
366+
generator: MockGenerator
367+
options:
368+
param1: 10
369+
"""
370+
mock_load.return_value = yaml.safe_load(yaml_content)
371+
mock_import.return_value = MockGenerator
372+
373+
test_suite = config_parsing.create_test_suite_from_config('fake.yaml')
374+
375+
self.assertLen(test_suite._benchmarks_generators, 1)
376+
# Defaults to a single empty CheckpointConfig.
377+
self.assertEqual(
378+
test_suite._benchmarks_generators[0]._checkpoint_configs,
379+
[config_lib.CheckpointConfig()],
380+
)
381+
369382
@mock.patch.object(config_parsing, '_load_yaml_config', autospec=True)
370383
@mock.patch.object(config_parsing, '_import_class', autospec=True)
371384
def test_generator_import_fail(self, mock_import, mock_load):

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class CheckpointConfig:
8989
sharding_config_path: str | None = None
9090

9191
def __post_init__(self):
92-
if self.path is None and self.spec is None:
93-
raise ValueError('Either path or spec must be provided.')
9492
if self.path is not None and self.spec is not None:
9593
raise ValueError('Only one of path or spec can be provided.')
9694
if self.sharding_config_path is not None and self.path is None:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class TestContext:
7373
"""Input object passed to each test function, providing pre-configured components for the test run.
7474
7575
Attributes:
76-
pytree: The generated or loaded checkpoint data.
76+
pytree: The generated or loaded checkpoint data. May be None.
7777
path: The test directory path.
7878
options: The specific BenchmarkOptions for this test variant.
7979
mesh: The mesh used for sharding the checkpoint data.
@@ -82,7 +82,7 @@ class TestContext:
8282
local_path: The local path to store the checkpoint data.
8383
"""
8484

85-
pytree: Any
85+
pytree: Any | None
8686
path: epath.Path
8787
options: BenchmarkOptions # The specific options for this test variant.
8888
mesh: jax.sharding.Mesh | None = None
@@ -165,20 +165,22 @@ def run(self, repeat_index: int | None = None) -> TestResult:
165165
):
166166
multihost.sync_global_processes("benchmark:setup_test_directory")
167167

168-
if self.checkpoint_config.path is None:
169-
data = checkpoint_generation.generate_checkpoint(
168+
if self.checkpoint_config.path is not None:
169+
pytree = checkpoint_generation.load_checkpoint(self.checkpoint_config)
170+
elif self.checkpoint_config.spec is not None:
171+
pytree = checkpoint_generation.generate_checkpoint(
170172
self.checkpoint_config, mesh=self.mesh
171173
)
172174
else:
173-
data = checkpoint_generation.load_checkpoint(self.checkpoint_config)
175+
pytree = None
174176

175177
with benchmark_metrics.measure(
176178
"sync_global_processes:benchmark:setup_pytree"
177179
):
178180
multihost.sync_global_processes("benchmark:setup_pytree")
179181

180182
context = TestContext(
181-
pytree=data,
183+
pytree=pytree,
182184
path=path,
183185
options=self.options,
184186
mesh=self.mesh,

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,45 @@ def test_fn(context):
169169
mock_create_mesh.assert_called_once_with(mesh_config)
170170
self.assertEqual(mock_metrics_report.call_count, 2)
171171

172+
@mock.patch.object(directory_setup, 'setup_test_directory')
173+
@mock.patch.object(checkpoint_generation, 'generate_checkpoint')
174+
@mock.patch.object(checkpoint_generation, 'load_checkpoint')
175+
@mock.patch.object(metric_lib.Metrics, 'report')
176+
def test_run_with_empty_checkpoint_config(
177+
self,
178+
mock_metrics_report,
179+
mock_load_checkpoint,
180+
mock_generate_checkpoint,
181+
mock_setup_test_directory,
182+
):
183+
path = epath.Path(self.create_tempdir().full_path)
184+
mock_setup_test_directory.return_value = path
185+
options = MyBenchmarkOptions()
186+
187+
def test_fn(context):
188+
self.assertIsNone(context.pytree)
189+
self.assertEqual(context.path, path)
190+
self.assertEqual(context.options, options)
191+
self.assertIsNone(context.mesh)
192+
return core.TestResult(metrics=metric_lib.Metrics())
193+
194+
ckpt_config = configs.CheckpointConfig()
195+
benchmark = core.Benchmark(
196+
test_fn=test_fn,
197+
checkpoint_config=ckpt_config,
198+
options=options,
199+
name='test_benchmark',
200+
)
201+
202+
benchmark.run()
203+
204+
mock_setup_test_directory.assert_called_once_with(
205+
'test_benchmark', None, None
206+
)
207+
mock_generate_checkpoint.assert_not_called()
208+
mock_load_checkpoint.assert_not_called()
209+
self.assertEqual(mock_metrics_report.call_count, 2)
210+
172211
@mock.patch.object(directory_setup, 'setup_test_directory')
173212
@mock.patch.object(checkpoint_generation, 'generate_checkpoint')
174213
@mock.patch.object(device_mesh, 'create_mesh')

checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1_benchmark.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
2929

3030

31-
def _metrics_to_measure(options: V1BenchmarkOptions) -> list[str]:
31+
def get_metrics_to_measure(options: V1BenchmarkOptions) -> list[str]:
3232
"""Returns the list of metrics to measure."""
3333
metrics = ["time", "rss", "io"]
3434
if options.metric_tracemalloc_enabled:
@@ -73,9 +73,10 @@ class V1BenchmarkOptions(benchmarks_core.BenchmarkOptions):
7373
metric_tensorstore_enabled: bool = False
7474
use_replica_parallel: bool | Sequence[bool] = False
7575
enable_replica_parallel_separate_folder: bool | Sequence[bool] = False
76+
chunk_byte_size: int | None | Sequence[int | None] = None
7677
enable_trace: bool = False
7778

78-
def is_valid(self):
79+
def is_valid(self) -> bool:
7980
assert isinstance(self.use_replica_parallel, bool)
8081
assert isinstance(self.enable_replica_parallel_separate_folder, bool)
8182
if self.enable_replica_parallel_separate_folder and (
@@ -89,6 +90,9 @@ def context(self) -> ocp.Context:
8990
return ocp.Context(
9091
array_options=ocp.options.ArrayOptions(
9192
saving=ocp.options.ArrayOptions.Saving(
93+
storage_options=ocp.options.ArrayOptions.Saving.StorageOptions(
94+
chunk_byte_size=self.chunk_byte_size,
95+
),
9296
use_ocdbt=self.use_ocdbt,
9397
use_zarr3=self.use_zarr3,
9498
use_replica_parallel=self.use_replica_parallel,
@@ -107,6 +111,13 @@ def context(self) -> ocp.Context:
107111
)
108112

109113

114+
def clear_pytree(pytree: Any) -> Any:
115+
"""Clears the pytree to free up memory."""
116+
return jax.tree.map(
117+
lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree
118+
)
119+
120+
110121
# ==============================================================================
111122
# 2. Implement the Benchmark Generator
112123
# ==============================================================================
@@ -118,12 +129,6 @@ class V1Benchmark(benchmarks_core.BenchmarksGenerator):
118129
V1BenchmarkHandler with various configurations.
119130
"""
120131

121-
def _clear_pytree(self, pytree: Any) -> Any:
122-
"""Clears the pytree to free up memory."""
123-
return jax.tree.map(
124-
lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree
125-
)
126-
127132
def test_fn(
128133
self, context: benchmarks_core.TestContext
129134
) -> benchmarks_core.TestResult:
@@ -147,7 +152,7 @@ def test_fn(
147152
assert isinstance(options, V1BenchmarkOptions)
148153

149154
logging.info("Benchmark options: %s", pprint.pformat(options))
150-
metrics_to_measure = _metrics_to_measure(options)
155+
metrics_to_measure = get_metrics_to_measure(options)
151156

152157
with ocp.Context(context=options.context):
153158
if options.enable_trace:
@@ -162,15 +167,15 @@ def test_fn(
162167
ocp.save_pytree(save_path, pytree)
163168
with metrics.measure("save_background", metrics_to_measure):
164169
pass
165-
context.pytree = self._clear_pytree(context.pytree)
170+
context.pytree = clear_pytree(context.pytree)
166171
if options.enable_trace:
167172
jax.profiler.stop_trace()
168173

169174
if options.enable_trace:
170175
jax.profiler.start_trace(context.path / "trace_load")
171176
with metrics.measure("load", metrics_to_measure):
172177
restored_pytree = ocp.load_pytree(save_path, abstract_pytree)
173-
self._clear_pytree(restored_pytree)
178+
clear_pytree(restored_pytree)
174179
if options.enable_trace:
175180
jax.profiler.stop_trace()
176181

0 commit comments

Comments
 (0)