Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ checkpoint_config:
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-405B-checkpoints/0/items"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ checkpoint_config:
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ checkpoint_config:
path: "gs://orbax-benchmarks/checkpoints/llama-3.1-8B-checkpoints/0/items"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ checkpoint_config:
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-4-fsdp-8-tensor-4/abstract_state.json"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkBenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ checkpoint_config:
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-256-data-4-fsdp-8-tensor-4/abstract_state.json"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkBenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ checkpoint_config:
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-32-data-8-fsdp-2-tensor-1/abstract_state.json"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkBenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ checkpoint_config:
sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-32-data-8-fsdp-2-tensor-1/abstract_state.json"

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_benchmark.V1Benchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkBenchmarkOptions` class
# associated with the `V1Benchmark` generator.
# These keys must match the attributes of the `BenchmarkOptions` class
# associated with the `Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ mesh_config:
# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1_resharding_benchmark.V1ReshardingBenchmark"
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1ReshardingBenchmarkOptions` class
# associated with the `V1ReshardingBenchmark` generator.
# These keys must match the attributes of the `ReshardingBenchmarkOptions` class
# associated with the `ReshardingBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib


def get_metrics_to_measure(options: V1BenchmarkOptions) -> list[str]:
def get_metrics_to_measure(options: BenchmarkOptions) -> list[str]:
"""Returns the list of metrics to measure."""
metrics = ["time", "rss", "io"]
if options.metric_tracemalloc_enabled:
Expand All @@ -42,8 +42,8 @@ def get_metrics_to_measure(options: V1BenchmarkOptions) -> list[str]:
# 1. Define the Options Dataclass for this specific benchmark
# ==============================================================================
@dataclasses.dataclass(frozen=True)
class V1BenchmarkOptions(benchmarks_core.BenchmarkOptions):
"""Configuration options for benchmarks targeting V1BenchmarkHandler.
class BenchmarkOptions(benchmarks_core.BenchmarkOptions):
"""Configuration options for benchmarks targeting BenchmarkHandler.

Each attribute can be a single value or a list of values to create
a parameter sweep.
Expand Down Expand Up @@ -121,12 +121,12 @@ def clear_pytree(pytree: Any) -> Any:
# ==============================================================================
# 2. Implement the Benchmark Generator
# ==============================================================================
@benchmarks_core.benchmark_options(V1BenchmarkOptions)
class V1Benchmark(benchmarks_core.BenchmarksGenerator):
"""A concrete generator for `orbax.checkpoint.V1BenchmarkHandler`.
@benchmarks_core.benchmark_options(BenchmarkOptions)
class Benchmark(benchmarks_core.BenchmarksGenerator):
"""A concrete generator for `orbax.checkpoint.BenchmarkHandler`.

This class provides the specific test logic for benchmarking the
V1BenchmarkHandler with various configurations.
BenchmarkHandler with various configurations.
"""

def test_fn(
Expand All @@ -149,7 +149,7 @@ def test_fn(
abstract_pytree = jax.tree.map(ocp.arrays.to_shape_dtype_struct, pytree)
save_path = context.path / "ckpt"
options = context.options
assert isinstance(options, V1BenchmarkOptions)
assert isinstance(options, BenchmarkOptions)

logging.info("Benchmark options: %s", pprint.pformat(options))
metrics_to_measure = get_metrics_to_measure(options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,46 @@
from absl.testing import parameterized
from etils import epath
import jax.numpy as jnp
from orbax.checkpoint._src.testing.benchmarks import v1_benchmark
from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
from orbax.checkpoint._src.testing.benchmarks.v1 import benchmark


V1BenchmarkOptions = v1_benchmark.V1BenchmarkOptions
V1Benchmark = v1_benchmark.V1Benchmark
BenchmarkOptions = benchmark.BenchmarkOptions
Benchmark = benchmark.Benchmark


class V1BenchmarkTest(parameterized.TestCase):
class BenchmarkTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = epath.Path(self.create_tempdir().full_path)

@parameterized.parameters(
dict(
options=V1BenchmarkOptions(use_ocdbt=False, use_zarr3=True),
options=BenchmarkOptions(use_ocdbt=False, use_zarr3=True),
expected_len=1,
),
dict(
options=V1BenchmarkOptions(use_ocdbt=[False, True], use_zarr3=True),
options=BenchmarkOptions(use_ocdbt=[False, True], use_zarr3=True),
expected_len=2,
),
dict(
options=V1BenchmarkOptions(
options=BenchmarkOptions(
use_ocdbt=[False, True], use_zarr3=[False, True]
),
expected_len=4,
),
)
def test_generate_benchmarks(self, options, expected_len):
generator = V1Benchmark(
generator = Benchmark(
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
options=options,
)
benchmarks = generator.generate()
self.assertLen(benchmarks, expected_len)
for benchmark in benchmarks:
self.assertIsInstance(benchmark.options, V1BenchmarkOptions)
for b in benchmarks:
self.assertIsInstance(b.options, BenchmarkOptions)

@parameterized.product(
use_ocdbt=(False, True),
Expand All @@ -82,17 +82,17 @@ def test_benchmark_test_fn(
):
return

generator = V1Benchmark(
generator = Benchmark(
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
options=V1BenchmarkOptions(),
options=BenchmarkOptions(),
)

pytree = {
'a': jnp.arange(10),
'b': {'c': jnp.ones((5, 5))},
}

test_options = V1BenchmarkOptions(
test_options = BenchmarkOptions(
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
use_compression=use_compression,
Expand All @@ -116,7 +116,7 @@ def test_benchmark_test_fn(

self.assertIsInstance(result, benchmarks_core.TestResult)
# Check for expected metrics keys based on _metrics_to_measure
# in v1_benchmark.py and the metric.measure calls.
# in benchmark.py and the metric.measure calls.
# The benchmark records "save_blocking", "save_background", "load".
# Metric "time" is always added.

Expand Down
Loading
Loading