|
14 | 14 |
|
15 | 15 | """Peft Checkpoint manager unittest.""" |
16 | 16 |
|
17 | | -import os |
18 | | -import tempfile |
19 | 17 | from absl.testing import absltest |
20 | | -from absl.testing import parameterized |
21 | | -from etils import epath |
22 | | -from flax import config as flax_config |
23 | | -from flax import nnx |
24 | | -import jax |
25 | | -import jax.numpy as jnp |
26 | | -import jax.sharding as shd |
27 | | -import numpy as np |
28 | | -import optax |
29 | | -import qwix |
30 | | -from tunix.sft import checkpoint_manager |
| 18 | +from tests.sft import checkpoint_manager_test_lib |
31 | 19 |
|
32 | | -os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' |
33 | 20 |
|
34 | | - |
35 | | -if hasattr(flax_config, 'flax_always_shard_variable'): |
36 | | - flax_config.update('flax_always_shard_variable', False) |
37 | | - |
38 | | - |
39 | | -def assert_close(path, x, y, atol=1e-5, rtol=1e-5): |
40 | | - np.testing.assert_allclose( |
41 | | - x, y, atol, rtol, err_msg=f'Mismatch at path: {path}' |
42 | | - ) |
43 | | - |
44 | | - |
45 | | -def assert_not_equal(path, x, y): |
46 | | - np.testing.assert_( |
47 | | - np.any(np.not_equal(x, y)), msg=f'Unexpected match at path: {path}' |
48 | | - ) |
49 | | - |
50 | | - |
51 | | -class TestModel(nnx.Module): |
52 | | - |
53 | | - def __init__(self, rngs: nnx.Rngs): |
54 | | - kernel_init_fn = nnx.initializers.lecun_normal() |
55 | | - self.w1 = nnx.Linear( |
56 | | - in_features=2, |
57 | | - out_features=4, |
58 | | - rngs=rngs, |
59 | | - kernel_init=nnx.with_partitioning(kernel_init_fn, ('fsdp', 'tp')), |
60 | | - ) |
61 | | - self.w2 = nnx.Linear( |
62 | | - in_features=4, |
63 | | - out_features=2, |
64 | | - rngs=rngs, |
65 | | - kernel_init=nnx.with_partitioning(kernel_init_fn, ('tp', 'fsdp')), |
66 | | - ) |
67 | | - |
68 | | - def __call__(self, x): |
69 | | - h = nnx.relu(self.w1(x)) |
70 | | - h = self.w2(h) + x |
71 | | - return h |
72 | | - |
73 | | - |
74 | | -def create_sharded_model(model_ctor, rngs, mesh): |
75 | | - @nnx.jit(static_argnums=(0,)) |
76 | | - def _create_sharded_model(model_ctor, rngs): |
77 | | - model = model_ctor(rngs) |
78 | | - state = nnx.state(model) |
79 | | - pspecs = nnx.get_partition_spec(state) |
80 | | - sharded_state = jax.lax.with_sharding_constraint(state, pspecs) |
81 | | - nnx.update(model, sharded_state) |
82 | | - return model, state |
83 | | - |
84 | | - with mesh: |
85 | | - model, state = _create_sharded_model(model_ctor, rngs) |
86 | | - state_sharding = nnx.get_named_sharding(state, mesh) |
87 | | - return model, state_sharding |
88 | | - |
89 | | - |
90 | | -class CheckpointManagerTest(parameterized.TestCase): |
91 | | - |
92 | | - def setUp(self): |
93 | | - super().setUp() |
94 | | - try: |
95 | | - self.temp_path = self.create_tempdir().full_path |
96 | | - except Exception: |
97 | | - self.temp_path = tempfile.TemporaryDirectory().name |
98 | | - self.device_count = jax.device_count() |
99 | | - self.mesh = jax.sharding.Mesh( |
100 | | - devices=np.array(jax.devices()).reshape(2, self.device_count // 2), |
101 | | - axis_names=('fsdp', 'tp'), |
102 | | - ) |
103 | | - |
104 | | - def test_empty_root_directory(self): |
105 | | - cp_manager = checkpoint_manager.CheckpointManager(root_directory=None) |
106 | | - self.assertIsNone(cp_manager.latest_step()) |
107 | | - self.assertFalse(cp_manager.save(1, None)) |
108 | | - self.assertEqual(cp_manager.maybe_restore(None), (0, {})) |
109 | | - |
110 | | - def test_checkpoint_manager_options_none_sets_default(self): |
111 | | - cp_path = f'{self.temp_path}/{self.id()}' |
112 | | - cp_manager = checkpoint_manager.CheckpointManager(cp_path, options=None) |
113 | | - self.assertIsNotNone(cp_manager._checkpoint_manager) |
114 | | - self.assertEqual( |
115 | | - cp_manager._checkpoint_manager._options, # pytype: disable=attribute-error |
116 | | - checkpoint_manager._DEFAULT_CHECKPOINTING_OPTIONS, |
117 | | - ) |
118 | | - |
119 | | - def test_save(self): |
120 | | - cp_path = f'{self.temp_path}/{self.id()}' |
121 | | - cp_manager = checkpoint_manager.CheckpointManager(cp_path) |
122 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
123 | | - |
124 | | - # Save the model state. |
125 | | - self.assertTrue(cp_manager.save(1, model)) |
126 | | - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
127 | | - self.assertEqual(cp_manager.latest_step(), 1) |
128 | | - |
129 | | - cp_manager.close() |
130 | | - model_param_path = epath.Path(cp_path) / '1' / 'model_params' |
131 | | - # Verify the model params are saved. |
132 | | - self.assertTrue(model_param_path.exists()) |
133 | | - |
134 | | - def test_restore(self): |
135 | | - cp_path = f'{self.temp_path}/{self.id()}' |
136 | | - cp_manager = checkpoint_manager.CheckpointManager(cp_path) |
137 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
138 | | - expected_state = nnx.state(model) |
139 | | - |
140 | | - # Save the model params. |
141 | | - self.assertTrue(cp_manager.save(1, model)) |
142 | | - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
143 | | - |
144 | | - # Change the model state. |
145 | | - changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model)) |
146 | | - nnx.update(model, changed_state) |
147 | | - |
148 | | - # Restore the model params. |
149 | | - self.assertEqual(cp_manager.maybe_restore(model), (1, {})) |
150 | | - # Check the model params are restored correctly. |
151 | | - jax.tree.map_with_path( |
152 | | - assert_close, |
153 | | - expected_state, |
154 | | - nnx.state(model), |
155 | | - ) |
156 | | - |
157 | | - def test_restore_different_sharding(self): |
158 | | - cp_path = f'{self.temp_path}/{self.id()}' |
159 | | - cp_manager = checkpoint_manager.CheckpointManager(cp_path) |
160 | | - unsharded_model = TestModel(nnx.Rngs(0)) |
161 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
162 | | - |
163 | | - # Save the model params. |
164 | | - self.assertTrue(cp_manager.save(1, unsharded_model)) |
165 | | - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
166 | | - |
167 | | - # Restore the model without shardings. |
168 | | - self.assertEqual(cp_manager.maybe_restore(unsharded_model), (1, {})) |
169 | | - unsharded_variables = nnx.state(unsharded_model, nnx.Param) |
170 | | - # Check the model shardings are restored correctly. |
171 | | - self.assertIsInstance( |
172 | | - unsharded_variables.w1.kernel.value.sharding, |
173 | | - jax.sharding.SingleDeviceSharding, |
174 | | - ) |
175 | | - self.assertIsInstance( |
176 | | - unsharded_variables.w2.kernel.value.sharding, |
177 | | - jax.sharding.SingleDeviceSharding, |
178 | | - ) |
179 | | - |
180 | | - # Restore the model with shardings. |
181 | | - self.assertEqual(cp_manager.maybe_restore(model), (1, {})) |
182 | | - # Check the model shardings are restored correctly. |
183 | | - variables = nnx.state(model, nnx.Param) |
184 | | - |
185 | | - self.assertEqual( |
186 | | - variables.w1.kernel.value.sharding.spec, |
187 | | - shd.PartitionSpec('fsdp', 'tp'), |
188 | | - ) |
189 | | - self.assertEqual( |
190 | | - variables.w2.kernel.value.sharding.spec, |
191 | | - shd.PartitionSpec('tp', 'fsdp'), |
192 | | - ) |
193 | | - |
194 | | - def test_restore_with_lora(self): |
195 | | - cp_path = f'{self.temp_path}/{self.id()}' |
196 | | - cp_manager = checkpoint_manager.CheckpointManager(cp_path) |
197 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
198 | | - lora_provider = qwix.LoraProvider( |
199 | | - module_path='.*w1', |
200 | | - rank=4, |
201 | | - alpha=2.0, |
202 | | - ) |
203 | | - dummy_model_input = { |
204 | | - 'x': jnp.ones(2, dtype=jnp.int32), |
205 | | - } |
206 | | - model = qwix.apply_lora_to_model(model, lora_provider, **dummy_model_input) |
207 | | - expected_lora_state = nnx.clone(nnx.state(model, nnx.LoRAParam)) |
208 | | - old_non_lora_state = nnx.clone( |
209 | | - nnx.state(model, (nnx.filterlib.Not(nnx.LoRAParam))) |
210 | | - ) |
211 | | - |
212 | | - # Save the model params. |
213 | | - self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True)) |
214 | | - cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
215 | | - |
216 | | - # Change the model state. |
217 | | - changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model)) |
218 | | - nnx.update(model, changed_state) |
219 | | - |
220 | | - # Restore the model lora params. |
221 | | - self.assertEqual( |
222 | | - cp_manager.maybe_restore(model, restore_only_lora_params=True), |
223 | | - (1, {}), |
224 | | - ) |
225 | | - # Check the model lora params are restored correctly. |
226 | | - jax.tree.map_with_path( |
227 | | - assert_close, |
228 | | - expected_lora_state, |
229 | | - nnx.state(model, nnx.LoRAParam), |
230 | | - ) |
231 | | - # Check the rest of the params are not restored. |
232 | | - jax.tree.map_with_path( |
233 | | - assert_not_equal, |
234 | | - old_non_lora_state, |
235 | | - nnx.state(model, nnx.filterlib.Not(nnx.LoRAParam)), |
236 | | - ) |
237 | | - |
238 | | - def test_save_and_restore_with_custom_metadata(self): |
239 | | - cp_path = f'{self.temp_path}/{self.id()}' |
240 | | - ckpt_manager = checkpoint_manager.CheckpointManager(cp_path) |
241 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
242 | | - custom_metadata = {'foo': 1, 'bar': 2} |
243 | | - ckpt_manager.save(1, model, custom_metadata=custom_metadata) |
244 | | - ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
245 | | - restored_step, restored_metadata = ckpt_manager.maybe_restore(model) |
246 | | - self.assertEqual(restored_step, 1) |
247 | | - self.assertEqual(restored_metadata, custom_metadata) |
248 | | - |
249 | | - def test_save_and_restore_with_optimizer_state(self): |
250 | | - cp_path = f'{self.temp_path}/{self.id()}' |
251 | | - ckpt_manager = checkpoint_manager.CheckpointManager(cp_path) |
252 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
253 | | - optimizer = nnx.Optimizer( |
254 | | - model, |
255 | | - optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3), |
256 | | - wrt=nnx.Param, |
257 | | - ) |
258 | | - custom_metadata = {'foo': 1, 'bar': 2} |
259 | | - ckpt_manager.save(1, model, optimizer, custom_metadata=custom_metadata) |
260 | | - ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
261 | | - |
262 | | - new_optimizer = nnx.Optimizer( |
263 | | - model, |
264 | | - optax.inject_hyperparams(optax.adamw)(learning_rate=1e-5), |
265 | | - wrt=nnx.Param, |
266 | | - ) |
267 | | - self.assertEqual( |
268 | | - new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-5 |
269 | | - ) |
270 | | - restored_step, restored_metadata = ckpt_manager.maybe_restore( |
271 | | - model, new_optimizer |
272 | | - ) |
273 | | - self.assertEqual(restored_step, 1) |
274 | | - self.assertEqual(restored_metadata, custom_metadata) |
275 | | - jax.tree.map_with_path( |
276 | | - assert_close, |
277 | | - nnx.state(new_optimizer, nnx.optimizer.OptState), |
278 | | - nnx.state(optimizer, nnx.optimizer.OptState), |
279 | | - ) |
280 | | - self.assertEqual( |
281 | | - new_optimizer.opt_state.hyperparams['learning_rate'].value, 1e-3 |
282 | | - ) |
283 | | - |
284 | | - def test_restore_without_optimizer(self): |
285 | | - cp_path = f'{self.temp_path}/{self.id()}' |
286 | | - ckpt_manager = checkpoint_manager.CheckpointManager(cp_path) |
287 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
288 | | - optimizer = nnx.Optimizer( |
289 | | - model, |
290 | | - optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3), |
291 | | - wrt=nnx.Param, |
292 | | - ) |
293 | | - ckpt_manager.save(1, model, optimizer) |
294 | | - ckpt_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error |
295 | | - ckpt_manager.maybe_restore(model) |
296 | | - |
297 | | - @parameterized.parameters(['test_data/checkpoints']) |
298 | | - def test_restore_with_backward_compatibility(self, ckpt_path): |
299 | | - # The checkpoints in test_data is saved with StandardSave. The test is to |
300 | | - # verify the checkpoint manager with PyTreeRestore can still restore the |
301 | | - # checkpoints saved with StandardSave. |
302 | | - ckpt_manager = checkpoint_manager.CheckpointManager( |
303 | | - os.path.join(os.path.dirname(__file__), ckpt_path) |
304 | | - ) |
305 | | - model, _ = create_sharded_model(TestModel, nnx.Rngs(0), self.mesh) |
306 | | - expected_state = nnx.state(model) |
307 | | - # Change the model state. |
308 | | - changed_state = jax.tree.map(lambda x: x + 1, nnx.state(model)) |
309 | | - nnx.update(model, changed_state) |
310 | | - |
311 | | - # Restore the model params. |
312 | | - self.assertEqual(ckpt_manager.maybe_restore(model), (1, {})) |
313 | | - # Check the model params are restored correctly. |
314 | | - jax.tree.map_with_path( |
315 | | - assert_close, |
316 | | - expected_state, |
317 | | - nnx.state(model), |
318 | | - ) |
| 21 | +class CheckpointManagerTest( |
| 22 | + checkpoint_manager_test_lib.BaseCheckpointManagerTest |
| 23 | +): |
| 24 | + pass |
319 | 25 |
|
320 | 26 |
|
321 | 27 | if __name__ == '__main__': |
|
0 commit comments