Skip to content

[CheckpointServer] use streaming transfers #36

Open
@d4l3k

Description

The CheckpointServer currently uses torch.save/torch.load which requires allocating the entire buffer into memory. We want to instead use streaming transfers so we minimize the amount of CPU memory required.

It would also be nice to add checksums to these transfers to avoid any data corruption from the network.

Relevant existing code: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing.py#L72

The algorithm is described at: https://gist.github.com/d4l3k/b68094d649a076384967788c9b0a5f08

Existing tests: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing_test.py#L15

Overview of work:

  1. copy over the write_state_dict and read_state_dict implementations into checkpointing.py
  2. replace existing torch.save/torch.load with those
  3. add unit tests for write_state_dict/read_state_dict for all the different possible types of torch tensors (different data types, strided, offsets, scalars, nested structures, etc)
  4. optionally add in checksum to read/write_state_dict that uses zlib.crc32

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions