Skip to content

Commit 2146934

Browse files
abhishek002002Orbax Authors
authored andcommitted
Improve Safetensors loading performance with file pooling and parallelization.
PiperOrigin-RevId: 870960223
1 parent 6f4d675 commit 2146934

File tree

9 files changed

+1375
-26
lines changed

9 files changed

+1375
-26
lines changed

checkpoint/orbax/checkpoint/_src/path/async_path.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,17 @@ async def open_file(
159159
path: epath.Path, mode: str = 'rb'
160160
) -> AsyncIterator[AsyncFile]:
161161
"""Async context manager for opening files."""
162-
f = await asyncio.to_thread(path.open, mode=mode)
163-
try:
164-
yield AsyncFile(f)
165-
finally:
166-
await asyncio.to_thread(f.close)
162+
f_or_cm = await asyncio.to_thread(path.open, mode=mode)
163+
if hasattr(f_or_cm, 'read'):
164+
f = f_or_cm
165+
try:
166+
yield AsyncFile(f)
167+
finally:
168+
await asyncio.to_thread(f.close)
169+
else: # It is a context manager
170+
cm = f_or_cm
171+
f = await asyncio.to_thread(cm.__enter__)
172+
try:
173+
yield AsyncFile(f)
174+
finally:
175+
await asyncio.to_thread(cm.__exit__, None, None, None)

checkpoint/orbax/checkpoint/_src/path/async_path_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import contextlib
17+
import io
18+
from unittest import mock
1619

1720
from absl.testing import absltest
1821
from absl.testing import parameterized
@@ -135,6 +138,24 @@ async def read_chunk(offset, size):
135138

136139
asyncio.run(_test())
137140

141+
def test_open_returns_context_manager_handled(self):
142+
test_file = self.test_dir / 'test.txt'
143+
test_file.write_text('hello world')
144+
145+
@contextlib.contextmanager
146+
def open_mock(mode):
147+
del mode
148+
yield io.BytesIO(b'hello world')
149+
150+
async def _test():
151+
with mock.patch.object(
152+
test_file, 'open', return_value=open_mock('rb')
153+
):
154+
async with async_path.open_file(test_file, 'rb') as f:
155+
self.assertEqual(await f.read(), b'hello world')
156+
157+
asyncio.run(_test())
158+
138159

139160
if __name__ == '__main__':
140161
absltest.main()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
suite_name: "Safetensors Load Benchmark"
2+
3+
mesh_configs:
4+
# Case 1: v5litepod-8, num_slices=2 (4 processes, 4 chips/process)
5+
- mesh_axes: ["data", "model"]
6+
ici_parallelism: {"data": 4, "model": 1}
7+
dcn_parallelism: {"data": 4, "model": 1}
8+
# Case 2: v5litepod-8, num_slices=1 (2 processes, 4 chips/process)
9+
- mesh_axes: ["data", "model"]
10+
ici_parallelism: {"data": 4, "model": 1}
11+
dcn_parallelism: {"data": 2, "model": 1}
12+
# Case 5: v5litepod-16, num_slices=1 (4 processes, 4 chips/process), ICI-only
13+
- mesh_axes: ["data", "model"]
14+
ici_parallelism: {"data": 8, "model": 16}
15+
dcn_parallelism: null
16+
- mesh_axes: ["data", "model"]
17+
ici_parallelism: {"data": 1, "model": 16}
18+
dcn_parallelism: null
19+
- mesh_axes: ["data", "model"]
20+
ici_parallelism: {"data": 1, "model": 32}
21+
dcn_parallelism: null
22+
- mesh_axes: ["data", "model"]
23+
ici_parallelism: {"data": 1, "model": 64}
24+
dcn_parallelism: null
25+
26+
checkpoint_config:
27+
spec:
28+
array: {dtype: "float32", shape: [1024, 2048], sharding: ["data", "model"]}
29+
30+
benchmarks:
31+
- generator: "orbax.checkpoint._src.testing.benchmarks.safetensors_benchmark.SafetensorsBenchmark"
32+
options:
33+
checkpoint_path: "gs://safetensor-kimi-central/test-model-kimi"
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmarks for SafetensorsLayout (V1)."""
16+
17+
import asyncio
18+
import dataclasses
19+
20+
from absl import logging
21+
from etils import epath
22+
import jax
23+
from orbax.checkpoint._src.arrays import sharding as sharding_utils
24+
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
25+
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
26+
from orbax.checkpoint.experimental import v1 as ocp_v1
27+
28+
29+
# ==============================================================================
30+
# Define the Options Dataclass
31+
# ==============================================================================
32+
@dataclasses.dataclass(frozen=True)
33+
class SafetensorsBenchmarkOptions(benchmarks_core.BenchmarkOptions):
34+
"""Configuration options for benchmarks targeting SafetensorsLayout.
35+
36+
Attributes:
37+
checkpoint_config_path: The path to the checkpoint config file.
38+
"""
39+
40+
checkpoint_path: str | None = None
41+
42+
43+
# ==============================================================================
44+
# 2. Implement the Benchmark Generator
45+
# ==============================================================================
46+
@benchmarks_core.benchmark_options(SafetensorsBenchmarkOptions)
47+
class SafetensorsBenchmark(benchmarks_core.BenchmarksGenerator):
48+
"""A generator for benchmarking SafetensorsLayout."""
49+
50+
def test_fn(
51+
self, context: benchmarks_core.TestContext
52+
) -> benchmarks_core.TestResult:
53+
"""The core test logic for a single save/restore cycle using V1 API."""
54+
metrics = metric_lib.Metrics()
55+
options = context.options
56+
assert isinstance(options, SafetensorsBenchmarkOptions)
57+
58+
load_path = epath.Path(options.checkpoint_path)
59+
logging.info('Benchmarking Load from: %s', load_path)
60+
mesh = context.mesh
61+
62+
async def _load_gcs():
63+
octx = ocp_v1.Context(
64+
checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS
65+
)
66+
with octx:
67+
# METRIC 1: Header/Index parsing (Metadata)
68+
with metrics.measure('metadata_load'):
69+
logging.info('Step 1: Parsing Safetensors metadata...')
70+
metadata = ocp_v1.pytree_metadata(load_path)
71+
abstract_state = metadata.metadata
72+
73+
# METRIC 2: The actual data transfer (The sharded load)
74+
with metrics.measure('data_load_sharded'):
75+
logging.info('Step 2: Starting sharded data transfer...')
76+
77+
shardings = sharding_utils.construct_maximal_shardings(
78+
abstract_state, list(mesh.devices.flatten())
79+
)
80+
sharded_abstract_state = jax.tree.map(
81+
lambda sds, sharding: jax.ShapeDtypeStruct(
82+
sds.shape, sds.dtype, sharding=sharding
83+
),
84+
abstract_state,
85+
shardings,
86+
)
87+
88+
restored_pytree = ocp_v1.load_pytree(
89+
load_path, sharded_abstract_state
90+
)
91+
92+
# Verify the result landed on TPU
93+
first_leaf = jax.tree_util.tree_leaves(restored_pytree)[0]
94+
logging.info(
95+
'SUCCESS: Load complete. First leaf shape: %s, on devices: %s',
96+
first_leaf.shape,
97+
first_leaf.devices(),
98+
)
99+
return restored_pytree
100+
101+
# Safe execution for benchmark environments
102+
loop = asyncio.get_event_loop()
103+
loop.run_until_complete(_load_gcs())
104+
105+
return benchmarks_core.TestResult(metrics=metrics)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
from etils import epath
18+
import jax
19+
import jax.numpy as jnp
20+
import numpy as np
21+
from orbax.checkpoint._src.testing.benchmarks import safetensors_benchmark
22+
from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs
23+
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
24+
from orbax.checkpoint.experimental import v1 as ocp_v1
25+
import safetensors.numpy as safe_np
26+
27+
SafetensorsBenchmarkOptions = safetensors_benchmark.SafetensorsBenchmarkOptions
28+
SafetensorsBenchmark = safetensors_benchmark.SafetensorsBenchmark
29+
30+
31+
class SafetensorsBenchmarkTest(parameterized.TestCase):
32+
33+
def setUp(self):
34+
super().setUp()
35+
self.test_dir = epath.Path(self.create_tempdir().full_path)
36+
self.checkpoint_path = self.test_dir / 'fake_checkpoint.safetensors'
37+
38+
self.dummy_pytree = {
39+
'tensor_a': jnp.ones((32, 1024), dtype=jnp.float32),
40+
'scalar': jnp.ones((), dtype=jnp.float32),
41+
'vector': jnp.ones((1024,), dtype=jnp.float32),
42+
}
43+
44+
save_pytree = jax.tree.map(np.array, self.dummy_pytree)
45+
safe_np.save_file(save_pytree, str(self.checkpoint_path))
46+
47+
def test_benchmark_test_fn_sharded_load(self):
48+
# 1. Setup Benchmark Generator
49+
generator = SafetensorsBenchmark(
50+
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
51+
options=SafetensorsBenchmarkOptions(),
52+
)
53+
54+
# 2. Create Test Context
55+
devices = np.array(jax.devices())
56+
if devices.size == 1:
57+
devices = devices.reshape(1, 1)
58+
else:
59+
devices = devices.reshape(1, devices.size) # Keep it simple for this test
60+
mesh = jax.sharding.Mesh(devices, ('data', 'model'))
61+
options = SafetensorsBenchmarkOptions(
62+
checkpoint_path=str(self.checkpoint_path)
63+
)
64+
65+
context = benchmarks_core.TestContext(
66+
pytree={}, # Unused in this test_fn implementation
67+
path=self.checkpoint_path,
68+
options=options,
69+
mesh=mesh,
70+
)
71+
72+
# 3. Run the Benchmark Test Function
73+
result = generator.test_fn(context)
74+
75+
# 4. Verify Benchmark Metrics
76+
self.assertIsInstance(result, benchmarks_core.TestResult)
77+
self.assertIn('metadata_load_time_duration', result.metrics.results)
78+
self.assertIn('data_load_sharded_time_duration', result.metrics.results)
79+
80+
# 5. Verify Loaded Content by Reloading
81+
octx = ocp_v1.Context(
82+
checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS
83+
)
84+
with octx:
85+
metadata = ocp_v1.pytree_metadata(self.checkpoint_path)
86+
abstract_state = metadata.metadata
87+
restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state)
88+
89+
self.assertEqual(
90+
jax.tree_util.tree_structure(restored_pytree),
91+
jax.tree_util.tree_structure(self.dummy_pytree),
92+
)
93+
jax.tree.map(
94+
self.assertTrue,
95+
jax.tree.map(
96+
lambda a, b: np.array_equal(np.array(a), np.array(b)),
97+
restored_pytree,
98+
self.dummy_pytree,
99+
),
100+
)
101+
jax.tree.map(
102+
self.assertEqual,
103+
jax.tree.map(lambda a: a.shape, restored_pytree),
104+
jax.tree.map(lambda a: a.shape, self.dummy_pytree),
105+
)
106+
jax.tree.map(
107+
self.assertEqual,
108+
jax.tree.map(lambda a: a.dtype, restored_pytree),
109+
jax.tree.map(lambda a: a.dtype, self.dummy_pytree),
110+
)
111+
112+
def test_benchmark_test_fn_rank_aware_sharding(self):
113+
# 1. Setup Benchmark Generator
114+
generator = SafetensorsBenchmark(
115+
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
116+
options=SafetensorsBenchmarkOptions(),
117+
)
118+
119+
# 2. Create Test Context
120+
devices = np.array(jax.devices())
121+
# Reshape devices to be 2D for the mesh axis names ('data', 'model')
122+
num_devices = devices.size
123+
if num_devices == 1:
124+
devices = devices.reshape(1, 1)
125+
elif num_devices == 2:
126+
devices = devices.reshape(1, 2)
127+
elif num_devices % 2 == 0:
128+
devices = devices.reshape(2, num_devices // 2)
129+
else: # Fallback for odd numbers, should not happen in typical test envs
130+
devices = devices.reshape(1, num_devices)
131+
mesh = jax.sharding.Mesh(devices, ('data', 'model'))
132+
options = SafetensorsBenchmarkOptions(
133+
checkpoint_path=str(self.checkpoint_path)
134+
)
135+
136+
context = benchmarks_core.TestContext(
137+
pytree={}, # Unused
138+
path=self.checkpoint_path,
139+
options=options,
140+
mesh=mesh,
141+
)
142+
143+
# 3. Run the Benchmark Test Function
144+
result = generator.test_fn(context)
145+
146+
# 4. Verify Benchmark Metrics
147+
self.assertIsInstance(result, benchmarks_core.TestResult)
148+
self.assertIn('metadata_load_time_duration', result.metrics.results)
149+
self.assertIn('data_load_sharded_time_duration', result.metrics.results)
150+
151+
# 5. Verify Loaded Content by Reloading
152+
octx = ocp_v1.Context(
153+
checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS
154+
)
155+
with octx:
156+
metadata = ocp_v1.pytree_metadata(self.checkpoint_path)
157+
abstract_state = metadata.metadata
158+
# Note: Sharding is not applied here, loading as is from the file.
159+
restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state)
160+
161+
self.assertEqual(
162+
jax.tree_util.tree_structure(restored_pytree),
163+
jax.tree_util.tree_structure(self.dummy_pytree),
164+
)
165+
jax.tree.map(
166+
self.assertTrue,
167+
jax.tree.map(
168+
lambda a, b: np.array_equal(np.array(a), np.array(b)),
169+
restored_pytree,
170+
self.dummy_pytree,
171+
),
172+
)
173+
jax.tree.map(
174+
self.assertEqual,
175+
jax.tree.map(lambda a: a.shape, restored_pytree),
176+
jax.tree.map(lambda a: a.shape, self.dummy_pytree),
177+
)
178+
jax.tree.map(
179+
self.assertEqual,
180+
jax.tree.map(lambda a: a.dtype, restored_pytree),
181+
jax.tree.map(lambda a: a.dtype, self.dummy_pytree),
182+
)
183+
184+
185+
if __name__ == '__main__':
186+
absltest.main()

0 commit comments

Comments
 (0)