Open
Description
The CheckpointServer currently uses torch.save/torch.load which requires allocating the entire buffer into memory. We want to instead use streaming transfers so we minimize the amount of CPU memory required.
It would also be nice to add checksums to these transfers to avoid any data corruption from the network.
Relevant existing code: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing.py#L72
The algorithm is described at: https://gist.github.com/d4l3k/b68094d649a076384967788c9b0a5f08
Existing tests: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing_test.py#L15
Overview of work:
- copy over the write_state_dict and read_state_dict implementations into checkpointing.py
- replace existing torch.save/torch.load with those
- add unit tests for write_state_dict/read_state_dict for all the different possible types of torch tensors (different data types, strided, offsets, scalars, nested structures, etc)
- optionally add in checksum to read/write_state_dict that uses zlib.crc32