|
31 | 31 | ) |
32 | 32 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
33 | 33 |
|
34 | | -from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest |
| 34 | +from nvidia_resiliency_ext.checkpointing.async_ckpt.core import ( |
| 35 | + AsyncCallsQueue, |
| 36 | + AsyncRequest, |
| 37 | + abort_nvrx_checkpoint, |
| 38 | +) |
35 | 39 | from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync |
36 | 40 | from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import ( |
37 | 41 | save_state_dict_async_finalize, |
38 | 42 | save_state_dict_async_plan, |
39 | 43 | ) |
| 44 | +from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint |
40 | 45 | from nvidia_resiliency_ext.checkpointing.utils import diff |
41 | 46 | from tests.checkpointing.unit import TempNamedDir |
42 | 47 | from tests.checkpointing.unit.test_utilities import Model, Utils |
@@ -92,6 +97,10 @@ def sync_save_checkpoint(self, checkpoint_dir, state_dict, planner): |
92 | 97 | planner=planner, |
93 | 98 | ) |
94 | 99 |
|
| 100 | + def async_save_checkpoint_on_rank0(self, checkpoint_dir, state_dict, torch_ckpt_impl): |
| 101 | + if torch.distributed.get_rank() == 0: |
| 102 | + torch_ckpt_impl.async_save(state_dict, checkpoint_dir / 'test') |
| 103 | + |
95 | 104 | def load_checkpoint(self, checkpoint_dir, state_dict): |
96 | 105 | """Loads a checkpoint into the given state_dict.""" |
97 | 106 | load( |
@@ -219,3 +228,83 @@ def test_cached_metadata(self, tmp_path_dist_ckpt, async_queue): |
219 | 228 | ), f'{field.name} is different in metadata from non-cached, cached metadata impls' |
220 | 229 | ckpt_dir.cleanup() |
221 | 230 | async_queue.close() |
| 231 | + |
| 232 | + def test_async_cp_with_multiple_queue_and_abort(self, tmp_path_dist_ckpt): |
| 233 | + """ |
| 234 | + Verifies that async checkpointing backend can be used with multiple async queues. |
| 235 | + For example, user may want to save 2 checkpoints i.e. one sharded state and one only on rank-0. |
| 236 | + Verify the abort CP functionality and the ability to resume after an abort operation |
| 237 | + """ |
| 238 | + Utils.initialize_distributed() |
| 239 | + model = FSDP(Model((1024, 1024), 8)) |
| 240 | + async_queue_dist = AsyncCallsQueue() |
| 241 | + ckpt_impl = TorchAsyncCheckpoint(persistent_queue=True) |
| 242 | + with ( |
| 243 | + TempNamedDir( |
| 244 | + tmp_path_dist_ckpt / 'async_checkpoint_dist', sync=True |
| 245 | + ) as async_ckpt_dir_dist, |
| 246 | + TempNamedDir( |
| 247 | + tmp_path_dist_ckpt / 'async_checkpoint_no_dist', sync=True |
| 248 | + ) as async_ckpt_dir_no_dist, |
| 249 | + ): |
| 250 | + state_dict = model.state_dict() |
| 251 | + planner = DefaultSavePlanner() |
| 252 | + |
| 253 | + # Perform async saves for both dist CP and non-dict CP use cases. |
| 254 | + self.async_save_checkpoint(async_ckpt_dir_dist, state_dict, planner, async_queue_dist) |
| 255 | + self.async_save_checkpoint_on_rank0(async_ckpt_dir_no_dist, state_dict, ckpt_impl) |
| 256 | + async_queue_dist.maybe_finalize_async_calls(blocking=True, no_dist=False) |
| 257 | + ckpt_impl.finalize_async_save(blocking=True, no_dist=True) |
| 258 | + |
| 259 | + # Abort the CP workers to mock the action of inprocess restarts |
| 260 | + abort_nvrx_checkpoint() |
| 261 | + |
| 262 | + # validate state of the Async CP workers after abort operation |
| 263 | + async_calls_queue_no_dist = ckpt_impl._get_async_calls_queue() |
| 264 | + assert ( |
| 265 | + async_calls_queue_no_dist is not None |
| 266 | + ), "We expect a valid state of AsyncCallsQueue" |
| 267 | + async_process_no_dist = async_calls_queue_no_dist._get_async_caller() |
| 268 | + if async_process_no_dist is not None: |
| 269 | + assert ( |
| 270 | + async_process_no_dist._debug_is_async_process_running() is False |
| 271 | + ), "After abort async process must stop" |
| 272 | + |
| 273 | + async_process_dist = async_queue_dist._get_async_caller() |
| 274 | + if async_process_dist is not None: |
| 275 | + assert ( |
| 276 | + async_process_dist._debug_is_async_process_running() is False |
| 277 | + ), "After abort async process must stop" |
| 278 | + |
| 279 | + # Perform async saves for both dist CP and non-dist CP use cases. |
| 280 | + # Validate that operations seamlessly resume after an abort operation |
| 281 | + self.async_save_checkpoint(async_ckpt_dir_dist, state_dict, planner, async_queue_dist) |
| 282 | + self.async_save_checkpoint_on_rank0(async_ckpt_dir_no_dist, state_dict, ckpt_impl) |
| 283 | + async_queue_dist.maybe_finalize_async_calls(blocking=True, no_dist=False) |
| 284 | + ckpt_impl.finalize_async_save(blocking=True, no_dist=True) |
| 285 | + |
| 286 | + # validate state of the Async CP workers after resume operation |
| 287 | + async_calls_queue_no_dist = ckpt_impl._get_async_calls_queue() |
| 288 | + assert ( |
| 289 | + async_calls_queue_no_dist is not None |
| 290 | + ), "We expect a valid state of AsyncCallsQueue object in TorchAsyncCheckpoint after a CP event" |
| 291 | + async_process_no_dist = async_calls_queue_no_dist._get_async_caller() |
| 292 | + # for the non_dist CP use case, only rank-0 is expected to trigger an async process |
| 293 | + if torch.distributed.get_rank() == 0: |
| 294 | + assert ( |
| 295 | + async_process_no_dist is not None |
| 296 | + ), "We expect a valid state of AsyncCaller after a CP event" |
| 297 | + assert ( |
| 298 | + async_process_no_dist._debug_is_async_process_running() is True |
| 299 | + ), "After resume, we expect async process to be running on rank 0 for non dist async save" |
| 300 | + |
| 301 | + async_process_dist = async_queue_dist._get_async_caller() |
| 302 | + assert ( |
| 303 | + async_process_dist is not None |
| 304 | + ), "We expect a valid state of AsyncCaller after a CP event" |
| 305 | + assert ( |
| 306 | + async_process_dist._debug_is_async_process_running() is True |
| 307 | + ), "After resume, we expect async process to be running on all ranks for dist async save" |
| 308 | + |
| 309 | + async_queue_dist.close() |
| 310 | + ckpt_impl.close() |
0 commit comments