Skip to content

Commit 2d1b360

Browse files
angel-coreThe tunix Authors
authored andcommitted
Refactor tests to common base file to support additional pathways testing.
PiperOrigin-RevId: 914897466
1 parent 9a28b3d commit 2d1b360

3 files changed

Lines changed: 334 additions & 298 deletions

File tree

tests/sft/checkpoint_manager_test.py

Lines changed: 5 additions & 298 deletions
Original file line numberDiff line numberDiff line change
@@ -14,308 +14,15 @@
1414

1515
"""Peft Checkpoint manager unittest."""
1616

17-
import os
18-
import tempfile
1917
from absl.testing import absltest
20-
from absl.testing import parameterized
21-
from etils import epath
22-
from flax import config as flax_config
23-
from flax import nnx
24-
import jax
25-
import jax.numpy as jnp
26-
import jax.sharding as shd
27-
import numpy as np
28-
import optax
29-
import qwix
30-
from tunix.sft import checkpoint_manager
18+
from tunix.sft import checkpoint_manager_test_base
3119

32-
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
3320

21+
class CheckpointManagerTest(
22+
checkpoint_manager_test_base.BaseCheckpointManagerTest
23+
):
3424

35-
if hasattr(flax_config, 'flax_always_shard_variable'):
36-
flax_config.update('flax_always_shard_variable', False)
37-
38-
39-
def assert_close(path, x, y, atol=1e-5, rtol=1e-5):
40-
np.testing.assert_allclose(
41-
x, y, atol, rtol, err_msg=f'Mismatch at path: {path}'
42-
)
43-
44-
45-
def assert_not_equal(path, x, y):
46-
np.testing.assert_(
47-
np.any(np.not_equal(x, y)), msg=f'Unexpected match at path: {path}'
48-
)
49-
50-
51-
class TestModel(nnx.Module):
52-
53-
def __init__(self, rngs: nnx.Rngs):
54-
kernel_init_fn = nnx.initializers.lecun_normal()
55-
self.w1 = nnx.Linear(
56-
in_features=2,
57-
out_features=4,
58-
rngs=rngs,
59-
kernel_init=nnx.with_partitioning(kernel_init_fn, ('fsdp', 'tp')),
60-
)
61-
self.w2 = nnx.Linear(
62-
in_features=4,
63-
out_features=2,
64-
rngs=rngs,
65-
kernel_init=nnx.with_partitioning(kernel_init_fn, ('tp', 'fsdp')),
66-
)
67-
68-
def __call__(self, x):
69-
h = nnx.relu(self.w1(x))
70-
h = self.w2(h) + x
71-
return h
72-
73-
74-
def create_sharded_model(model_ctor, rngs, mesh):
75-
@nnx.jit(static_argnums=(0,))
76-
def _create_sharded_model(model_ctor, rngs):
77-
model = model_ctor(rngs)
78-
state = nnx.state(model)
79-
pspecs = nnx.get_partition_spec(state)
80-
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
81-
nnx.update(model, sharded_state)
82-
return model, state
83-
84-
with mesh:
85-
model, state = _create_sharded_model(model_ctor, rngs)
86-
state_sharding = nnx.get_named_sharding(state, mesh)
87-
return model, state_sharding
88-
89-
90-
class CheckpointManagerTest(parameterized.TestCase):
91-
92-
def setUp(self):
93-
super().setUp()
94-
try:
95-
self.temp_path = self.create_tempdir().full_path
96-
except Exception:
97-
self.temp_path = tempfile.TemporaryDirectory().name
98-
self.device_count = jax.device_count()
99-
self.mesh = jax.sharding.Mesh(
100-
devices=np.array(jax.devices()).reshape(2, self.device_count // 2),
101-
axis_names=('fsdp', 'tp'),
102-
)
103-
104-
def test_empty_root_directory(self):
105-
cp_manager = checkpoint_manager.CheckpointManager(root_directory=None)
106-
self.assertIsNone(cp_manager.latest_step())
107-
self.assertFalse(cp_manager.save(1, None))
108-
self.assertEqual(cp_manager.maybe_restore(None), (0, {}))
109-
110-
def test_checkpoint_manager_options_none_sets_default(self):
111-
cp_path = f'{self.temp_path}/{self.id()}'
112-
cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None)
113-
self.assertIsNotNone(cp_manager._checkpoint_manager)
114-
self.assertEqual(
115-
cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error
116-
checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS,
117-
)
118-
119-
def test_save(self):
120-
cp_path = f'{self.temp_path}/{self.id()}'
121-
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
122-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
123-
124-
# Save the model state.
125-
self.assertTrue(cp_manager.save(1, model))
126-
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
127-
self.assertEqual(cp_manager.latest_step(), 1)
128-
129-
cp_manager.close()
130-
model_param_path = epath.Path(cp_path) / '1' / 'model_params'
131-
# Verify the model params are saved.
132-
self.assertTrue(model_param_path.exists())
133-
134-
def test_restore(self):
135-
cp_path = f'{self.temp_path}/{self.id()}'
136-
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
137-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
138-
expected_state = nnx.state(model)
139-
140-
# Save the model params.
141-
self.assertTrue(cp_manager.save(1, model))
142-
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
143-
144-
# Change the model state.
145-
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
146-
nnx.update(model, changed_state)
147-
148-
# Restore the model params.
149-
self.assertEqual(cp_manager.maybe_restore(model), (1, {}))
150-
# Check the model params are restored correctly.
151-
jax.tree.map_with_path(
152-
assert_close,
153-
expected_state,
154-
nnx.state(model),
155-
)
156-
157-
def test_restore_different_sharding(self):
158-
cp_path = f'{self.temp_path}/{self.id()}'
159-
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
160-
unsharded_model = TestModel(nnx.Rngs(0))
161-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
162-
163-
# Save the model params.
164-
self.assertTrue(cp_manager.save(1, unsharded_model))
165-
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
166-
167-
# Restore the model without shardings.
168-
self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {}))
169-
unsharded_variables = nnx.state(unsharded_model, nnx.Param)
170-
# Check the model shardings are restored correctly.
171-
self.assertIsInstance(
172-
unsharded_variables.w1.kernel.value.sharding,
173-
jax.sharding.SingleDeviceSharding,
174-
)
175-
self.assertIsInstance(
176-
unsharded_variables.w2.kernel.value.sharding,
177-
jax.sharding.SingleDeviceSharding,
178-
)
179-
180-
# Restore the model with shardings.
181-
self.assertEqual(cp_manager.maybe_restore(model), (1, {}))
182-
# Check the model shardings are restored correctly.
183-
variables = nnx.state(model, nnx.Param)
184-
185-
self.assertEqual(
186-
variables.w1.kernel.value.sharding.spec,
187-
shd.PartitionSpec('fsdp', 'tp'),
188-
)
189-
self.assertEqual(
190-
variables.w2.kernel.value.sharding.spec,
191-
shd.PartitionSpec('tp', 'fsdp'),
192-
)
193-
194-
def test_restore_with_lora(self):
195-
cp_path = f'{self.temp_path}/{self.id()}'
196-
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
197-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
198-
lora_provider = qwix.LoraProvider(
199-
module_path='.*w1',
200-
rank=4,
201-
alpha=2.0,
202-
)
203-
dummy_model_input = {
204-
'x': jnp.ones(2, dtype=jnp.int32),
205-
}
206-
model = qwix.apply_lora_to_model(model, lora_provider, **dummy_model_input)
207-
expected_lora_state = nnx.clone(nnx.state(model, nnx.LoRAParam))
208-
old_non_lora_state = nnx.clone(
209-
nnx.state(model, (nnx.filterlib.Not(nnx.LoRAParam)))
210-
)
211-
212-
# Save the model params.
213-
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
214-
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
215-
216-
# Change the model state.
217-
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
218-
nnx.update(model, changed_state)
219-
220-
# Restore the model lora params.
221-
self.assertEqual(
222-
cp_manager.maybe_restore(model, restore_only_lora_params=True),
223-
(1, {}),
224-
)
225-
# Check the model lora params are restored correctly.
226-
jax.tree.map_with_path(
227-
assert_close,
228-
expected_lora_state,
229-
nnx.state(model, nnx.LoRAParam),
230-
)
231-
# Check the rest of the params are not restored.
232-
jax.tree.map_with_path(
233-
assert_not_equal,
234-
old_non_lora_state,
235-
nnx.state(model, nnx.filterlib.Not(nnx.LoRAParam)),
236-
)
237-
238-
def test_save_and_restore_with_custom_metadata(self):
239-
cp_path = f'{self.temp_path}/{self.id()}'
240-
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
241-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
242-
custom_metadata = {'foo': 1, 'bar': 2}
243-
ckpt_manager.save(1, model, custom_metadata=custom_metadata)
244-
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
245-
restored_step, restored_metadata = ckpt_manager.maybe_restore(model)
246-
self.assertEqual(restored_step, 1)
247-
self.assertEqual(restored_metadata, custom_metadata)
248-
249-
def test_save_and_restore_with_optimizer_state(self):
250-
cp_path = f'{self.temp_path}/{self.id()}'
251-
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
252-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
253-
optimizer = nnx.Optimizer(
254-
model,
255-
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3),
256-
wrt=nnx.Param,
257-
)
258-
custom_metadata = {'foo': 1, 'bar': 2}
259-
ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata)
260-
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
261-
262-
new_optimizer = nnx.Optimizer(
263-
model,
264-
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5),
265-
wrt=nnx.Param,
266-
)
267-
self.assertEqual(
268-
new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-5
269-
)
270-
restored_step, restored_metadata = ckpt_manager.maybe_restore(
271-
model, new_optimizer
272-
)
273-
self.assertEqual(restored_step, 1)
274-
self.assertEqual(restored_metadata, custom_metadata)
275-
jax.tree.map_with_path(
276-
assert_close,
277-
nnx.state(new_optimizer, nnx.optimizer.OptState),
278-
nnx.state(optimizer, nnx.optimizer.OptState),
279-
)
280-
self.assertEqual(
281-
new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3
282-
)
283-
284-
def test_restore_without_optimizer(self):
285-
cp_path = f'{self.temp_path}/{self.id()}'
286-
ckpt_manager = checkpoint_manager.CheckpointManager(cp_path)
287-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
288-
optimizer = nnx.Optimizer(
289-
model,
290-
optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3),
291-
wrt=nnx.Param,
292-
)
293-
ckpt_manager.save(1, model, optimizer)
294-
ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error
295-
ckpt_manager.maybe_restore(model)
296-
297-
@parameterized.parameters(['test_data/checkpoints'])
298-
def test_restore_with_backward_compatibility(self, ckpt_path):
299-
# The checkpoints in test_data is saved with StandardSave. The test is to
300-
# verify the checkpoint manager with PyTreeRestore can still restore the
301-
# checkpoints saved with StandardSave.
302-
ckpt_manager = checkpoint_manager.CheckpointManager(
303-
os.path.join(os.path.dirname(__file__), ckpt_path)
304-
)
305-
model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh)
306-
expected_state = nnx.state(model)
307-
# Change the model state.
308-
changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model))
309-
nnx.update(model, changed_state)
310-
311-
# Restore the model params.
312-
self.assertEqual(ckpt_manager.maybe_restore(model), (1, {}))
313-
# Check the model params are restored correctly.
314-
jax.tree.map_with_path(
315-
assert_close,
316-
expected_state,
317-
nnx.state(model),
318-
)
25+
pass
31926

32027

32128
if __name__ == '__main__':

0 commit comments

Comments
 (0)