Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpointing/PGTransport: add NCCL/Gloo transport for checkpoints #110

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions torchft/checkpointing/http_transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,23 @@

import urllib.error
from datetime import timedelta
from typing import Any, Dict
from unittest import TestCase
from typing import Dict
from unittest import TestCase, skipUnless
from unittest.mock import MagicMock

import torch
from parameterized import parameterized

from torchft.checkpointing.http_transport import HTTPTransport
from torchft.checkpointing.http_transport_bench import main as bench_main
from torchft.checkpointing.transport import CheckpointTransport
from torchft.checkpointing.transport_test import (
assertStateDictEqual,
run_multi_recovery_test,
)


class TestHTTPTransport(TestCase):
def assertStateDictEqual(self, a: Dict[str, object], b: Dict[str, object]) -> None:
for k, v1 in a.items():
v2 = b[k]
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
torch.testing.assert_close(v1.cpu(), v2.cpu())
else:
self.assertEqual(v1, v2)

@parameterized.expand(
[
("no chunks", 0),
Expand Down Expand Up @@ -59,7 +56,7 @@ def test_checkpoint_server(self, name: str, num_chunks: int) -> None:
out = server.recv_checkpoint(
src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
)
self.assertStateDictEqual(out, expected)
assertStateDictEqual(self, out, expected)

# test timeout
with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"):
Expand Down Expand Up @@ -114,6 +111,30 @@ def test_checkpoint_server_locking(self) -> None:

server.shutdown()

def test_multi_http_transport_cpu(self) -> None:
device = torch.device("cpu")

def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
return HTTPTransport(
timeout=timedelta(seconds=10),
num_chunks=0,
)

run_multi_recovery_test(self, init, device=device)

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
@skipUnless(torch.cuda.is_available(), "CUDA is not available")
def test_multi_http_transport_cuda(self) -> None:
device = torch.device("cuda")

def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
return HTTPTransport(
timeout=timedelta(seconds=10),
num_chunks=0,
)

run_multi_recovery_test(self, init, device=device)

def test_benchmark(self) -> None:
bench_main(
[
Expand Down
247 changes: 247 additions & 0 deletions torchft/checkpointing/pg_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import logging
import pickle
import time
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Generator, List, Tuple, TypeVar, Union, cast

import torch
from torch.distributed import Work
from torch.distributed.tensor import DTensor, _DTensorSpec
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten

from torchft.checkpointing.transport import CheckpointTransport
from torchft.process_group import ProcessGroup

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")


@dataclass
class _TensorMeta:
"""
This is the metadata for a tensor that is used to transfer checkpoints.
It contains the shape, the dtype, the storage offset and the stride of the
tensor.

This must be pickleable so that it can be sent over the wire.
"""

shape: torch.Size
dtype: torch.dtype
storage_offset: int
stride: Tuple[int, ...]
nbytes: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not important now, but I wonder if we need to store quantization information and also wondering how thats handled in dtensor if you know

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang do you know how quantized information is stored? Is it a different tensor subclass or just packed into the storage?



@dataclass
class _DTensorMeta:
"""
This is the metadata for a DTensor that is used to transfer checkpoints.
It contains the metadata for the local tensor and the spec of the DTensor.

This must be pickleable so that it can be sent over the wire.
"""

local: _TensorMeta
spec: _DTensorSpec


@dataclass
class _StateDictMeta:
Copy link
Member

@H-Huang H-Huang Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the _StateDictMeta and dataclasses could probably use some more comments since that's pretty important in determining how we serialize / deserialize and being able to update them. I am kinda curious how DCP handles this metadata when it transfers and if we have existing structure we can use? @fegin @LucasLLC do you know?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments and updated field names to make it clearer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang similar to what is does here. Some logic is the same and we can see if we want to unify the code in the future.

"""
This is the metadata for a state dict that is used to transfer checkpoints.
It contains the step, the pytree spec of the state dict and the metadata for
each tensor in the state dict.

This must be pickleable so that it can be sent over the wire.

Args:
step: the step of the checkpoint to verify consistency
treespec: the pytree spec of the state dict
non_tensor_leaves: the metadata for each tensor in the state dict and any
non-tensor leaves in the state dict
"""

step: int
treespec: TreeSpec
non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]]


@contextmanager
def _timeit(name: str) -> Generator[None, None, None]:
start = time.perf_counter()
yield
dur = time.perf_counter() - start
logger.info(f"{name} took {dur}s")


def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
return (
_cast_tensor(tensor, torch.uint8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we're casting to uint8 to reduce memory pressure / speed up the transfer, but should we be concerned about any precision loss?

I see that transport_test.py verifies closeness/correctness through run_multi_recovery_test, but it isn't making sense to me!

Copy link

@daulet-askarov daulet-askarov Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just reinterprets the tensor as a bunch of bytes (hence the uint8) backed by the same UntypedStorage range of bytes. No bytes modified, so no loss of precision:
https://pytorch.org/docs/stable/storage.html#untyped-storage-api
@d4l3k I presume if you don't do this cast and just pass the original tensor object with its original dtype, then you do lose precision? Or is it just inconvenient on the recv side to interpret the tensor with its original dtype right away?

Copy link
Member Author

@d4l3k d4l3k Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most cases it doesn't matter and should result in a byte identical output.

There's a few trade-offs:

  • Pro: not all tensor types are supported by NCCL torch.uint16 for instance can't be sent via nccl so doing this cast to uint8 allows us to support any dtype
  • Con: arguably it's better to use the non-storage option to avoid sending duplicate/extra bytes for strided/offset tensors. If you have two tensors sharing the same underlying storage or a tensor that's strided in this implementation we end up sending twice as much data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do tensor.view(torch.uint8) instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.view doesn't work for strided tensors -- I'm not sure we need to support those but I think I'll leave it as is for now

_TensorMeta(
shape=tensor.shape,
dtype=tensor.dtype,
storage_offset=cast(int, tensor.storage_offset()),
stride=tensor.stride(),
nbytes=tensor.untyped_storage().nbytes(),
),
)


def _prepare_state_dict(
state_dict: object,
step: int,
device: torch.device,
) -> Tuple[_StateDictMeta, List[torch.Tensor]]:
leaves, treespec = tree_flatten(state_dict)

non_tensor_leaves = []
tensors = []
for v in leaves:
if isinstance(v, DTensor):
tensor, tensor_meta = _prepare_tensor(v._local_tensor)

tensors.append(tensor)

non_tensor_leaves.append(
_DTensorMeta(
local=tensor_meta,
spec=v._spec,
)
)
elif isinstance(v, torch.Tensor):
tensor, tensor_meta = _prepare_tensor(v)
tensors.append(tensor)
non_tensor_leaves.append(tensor_meta)
else:
non_tensor_leaves.append(v)

return (
_StateDictMeta(
step=step,
treespec=treespec,
non_tensor_leaves=non_tensor_leaves,
),
tensors,
)


def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Casts the underlying storage to a tensor of the given dtype.

The returned tensor will be of size ``storage.nbytes``.

This works for all datatypes and supports strided/offset tensors with the
caveat that the cast tensor may be larger than the original tensor due to
the differences in striding.
"""
storage = tensor.untyped_storage()
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
assert ret.untyped_storage() is storage, "storage should be the same"
return ret


class PGTransport(CheckpointTransport[T]):
"""
This is a checkpoint transport that uses the process group to transfer checkpoints.
This allows for fast recovery of workers by fetching the current weights
from an existing worker.
Args:
state_dict: a callable that returns the state dict to be transferred
"""

def __init__(
self, pg: ProcessGroup, timeout: timedelta, device: torch.device
) -> None:
self._work: List[Work] = []
self._pg = pg
self._timeout = timeout
self._device = device

def metadata(self) -> str:
return "<n/a>"

def disallow_checkpoint(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be implemented?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support any async/out of band calls via PG so nothing needs to be done here

For HTTP we need to this to avoid serving a checkpoint during optimizer step but since send is synchronous in PG we don't need any additional synchronization

pass

def send_checkpoint(
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
) -> None:
with _timeit("preparing state_dict"):
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)

work = []

with _timeit("send pickle"):
Copy link
Member

@H-Huang H-Huang Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we have send_object_list and recv_object_list https://github.com/pytorch/pytorch/blob/8b5ee275fb455156a944445fb92c43731369ace3/torch/distributed/distributed_c10d.py#L3181 which is what we use in PP to exchange shape metadata between stages to preallocate the recv buffers.

It is pretty similar since it pickles then sends object sizes, then the object data. I think yours may be more efficient since there are only 2 additional sends of metadata and the rest are the actual data. But wanted to flag in case we wanted to somehow consolidate some logic!

Copy link
Member Author

@d4l3k d4l3k Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yeah, I missed that before -- this is good to know though for right now I think this implementation is a bit more performant since it sends the same data to multiple receivers

If we wrap the underlying PG we also should be able to use the broadcast_object_list variant which should give us best of both worlds

I'm planning a follow up PR since to make a subworld we need to do some underlying improvements in how we calculate the recovering workers

buf = pickle.dumps(meta)
len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device)
buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device)
for dst_rank in dst_ranks:
work.append(self._pg.send([len_t], dst_rank, tag=1))
work.append(self._pg.send([buf_t], dst_rank, tag=2))

with _timeit("send tensors"):
for i, t in enumerate(tensors):
t = t.to(self._device)
for dst_rank in dst_ranks:
work.append(self._pg.send([t], dst_rank, tag=3 + i))

# allow 3 concurrent transfers at a time to avoid OOMs
while len(work) > (3 * len(dst_ranks)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used to avoid OOM?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's to avoid OOM when transferring between devices

work.pop(0).wait(timeout)

for w in work:
w.wait(timeout)

def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
) -> T:
len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
length = cast(int, len_t.item())

assert length > 0, f"invalid metadata length {length=}"

buf = torch.empty(length, dtype=torch.uint8, device=self._device)
self._pg.recv([buf], src_rank, tag=2).wait(timeout)

meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes())
assert meta.step == step

i: int = 0

def recv(v: _TensorMeta) -> torch.Tensor:
nonlocal i

t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
# TODO: parallelize receives
self._pg.recv([t], src_rank, tag=3 + i).wait(timeout)
i += 1

# TODO: allow in place receives to avoid having to copy to cpu to
# avoid OOMs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to avoid transferring TensorMeta and DTensorMeta and avoid to(cpu) if we can first call state_dict() to get the state_dict structure and traverse the state_dict and send/recv the tensor directly.

Copy link
Member Author

@d4l3k d4l3k Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we totally can do that in most cases -- I wanted to make that refactor in a follow up PR since I still need to figure out how to do that and this may be a decent fallback if we have some weird state dict

t = t.cpu()

return torch.as_strided(
t.view(v.dtype),
size=v.shape,
stride=v.stride,
storage_offset=v.storage_offset,
)

values = []
for v in meta.non_tensor_leaves:
if isinstance(v, _TensorMeta):
values.append(recv(v))
elif isinstance(v, _DTensorMeta):
tensor = recv(v.local)
# pyre-fixme[29]: DTensor is not a function
values.append(DTensor(tensor, v.spec, requires_grad=False))
else:
values.append(v)

return tree_unflatten(values, meta.treespec)
57 changes: 57 additions & 0 deletions torchft/checkpointing/pg_transport_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from datetime import timedelta
from typing import Dict
from unittest import TestCase, skipUnless

import torch
from torch.distributed import TCPStore

from torchft.checkpointing.pg_transport import PGTransport
from torchft.checkpointing.transport import CheckpointTransport
from torchft.checkpointing.transport_test import run_multi_recovery_test
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo


class PGTransportTest(TestCase):
def test_pg_transport_gloo(self) -> None:
store: TCPStore = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
device: torch.device = torch.device("cpu")

def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
pg = ProcessGroupGloo()
pg.configure(
store_addr=f"localhost:{store.port}/prefix",
rank=rank,
world_size=world_size,
)

return PGTransport[Dict[str, object]](
pg, timeout=timedelta(seconds=10), device=device
)

run_multi_recovery_test(self, init, device=device)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
def test_pg_transport_baby_nccl(self) -> None:
store: TCPStore = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
device: torch.device = torch.device("cuda")

def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
torch.cuda.set_device(rank)

pg = ProcessGroupBabyNCCL()
pg.configure(
store_addr=f"localhost:{store.port}/prefix",
rank=rank,
world_size=world_size,
)

return PGTransport[Dict[str, object]](
pg, timeout=timedelta(seconds=10), device=device
)

run_multi_recovery_test(self, init, device=device)
Loading