Skip to content

Commit

Permalink
PGTransport: add inplace transport (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Feb 25, 2025
1 parent 5e65330 commit 6fe4c8e
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 81 deletions.
108 changes: 81 additions & 27 deletions torchft/checkpointing/pg_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Generator, List, Tuple, TypeVar, Union, cast
from typing import Callable, Generator, Optional, TypeVar, Union, cast

import torch
from torch.distributed import Work
from torch.distributed.tensor import DTensor, _DTensorSpec
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
from torch.utils._pytree import (
KeyPath,
TreeSpec,
tree_flatten_with_path,
tree_unflatten,
)

from torchft.checkpointing.transport import CheckpointTransport
from torchft.process_group import ProcessGroup
Expand All @@ -32,7 +37,7 @@ class _TensorMeta:
shape: torch.Size
dtype: torch.dtype
storage_offset: int
stride: Tuple[int, ...]
stride: tuple[int, ...]
nbytes: int


Expand Down Expand Up @@ -61,13 +66,15 @@ class _StateDictMeta:
Args:
step: the step of the checkpoint to verify consistency
treespec: the pytree spec of the state dict
paths: the path of each leaf in the state dict
non_tensor_leaves: the metadata for each tensor in the state dict and any
non-tensor leaves in the state dict
"""

step: int
treespec: TreeSpec
non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]]
paths: list[KeyPath]
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]]


@contextmanager
Expand All @@ -78,7 +85,7 @@ def _timeit(name: str) -> Generator[None, None, None]:
logger.info(f"{name} took {dur}s")


def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]:
return (
_cast_tensor(tensor, torch.uint8),
_TensorMeta(
Expand All @@ -95,12 +102,16 @@ def _prepare_state_dict(
state_dict: object,
step: int,
device: torch.device,
) -> Tuple[_StateDictMeta, List[torch.Tensor]]:
leaves, treespec = tree_flatten(state_dict)
) -> tuple[_StateDictMeta, list[torch.Tensor]]:
leaves: list[tuple[KeyPath, object]]
leaves, treespec = tree_flatten_with_path(state_dict)

paths: list[KeyPath] = []
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = []
tensors: list[torch.Tensor] = []
for key_path, v in leaves:
paths.append(key_path)

non_tensor_leaves = []
tensors = []
for v in leaves:
if isinstance(v, DTensor):
tensor, tensor_meta = _prepare_tensor(v._local_tensor)

Expand All @@ -123,6 +134,7 @@ def _prepare_state_dict(
_StateDictMeta(
step=step,
treespec=treespec,
paths=paths,
non_tensor_leaves=non_tensor_leaves,
),
tensors,
Expand All @@ -139,6 +151,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
caveat that the cast tensor may be larger than the original tensor due to
the differences in striding.
"""
assert (
type(tensor) is torch.Tensor
), f"can only cast standard tensors not {type(tensor)}"
storage = tensor.untyped_storage()
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
assert ret.untyped_storage() is storage, "storage should be the same"
Expand All @@ -150,17 +165,28 @@ class PGTransport(CheckpointTransport[T]):
This is a checkpoint transport that uses the process group to transfer checkpoints.
This allows for fast recovery of workers by fetching the current weights
from an existing worker.
Args:
state_dict: a callable that returns the state dict to be transferred
pg: the process group to use for communication
timeout: the timeout for communication
device: the device to use for tensors
state_dict: if specified this function will be called to do an inplace
receive into the returned state_dict. This is much faster than
having to allocate new tensors and transferring them to the CPU.
"""

def __init__(
self, pg: ProcessGroup, timeout: timedelta, device: torch.device
self,
pg: ProcessGroup,
timeout: timedelta,
device: torch.device,
state_dict: Optional[Callable[[], object]] = None,
) -> None:
self._work: List[Work] = []
self._work: list[Work] = []
self._pg = pg
self._timeout = timeout
self._device = device
self._state_dict = state_dict

def metadata(self) -> str:
return "<n/a>"
Expand All @@ -169,7 +195,7 @@ def disallow_checkpoint(self) -> None:
pass

def send_checkpoint(
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
) -> None:
with _timeit("preparing state_dict"):
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
Expand All @@ -186,20 +212,29 @@ def send_checkpoint(

with _timeit("send tensors"):
for i, t in enumerate(tensors):
original_device = t.device
t = t.to(self._device)
for dst_rank in dst_ranks:
work.append(self._pg.send([t], dst_rank, tag=3 + i))

# allow 3 concurrent transfers at a time to avoid OOMs
while len(work) > (3 * len(dst_ranks)):
work.pop(0).wait(timeout)
# if we did a copy we should wait for the work to complete so we
# can free the memory to avoid OOMs
if original_device == torch.device("cpu"):
for w in work:
w.wait(timeout)
work = []

for w in work:
w.wait(timeout)

def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
) -> T:
state_dict = self._state_dict() if self._state_dict else {}
state_dict_leaves, _ = tree_flatten_with_path(state_dict)

dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves)

len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
length = cast(int, len_t.item())
Expand All @@ -213,18 +248,34 @@ def recv_checkpoint(
assert meta.step == step

i: int = 0
works: list[Work] = []

def recv(v: _TensorMeta) -> torch.Tensor:
def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor:
nonlocal i

t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
# TODO: parallelize receives
self._pg.recv([t], src_rank, tag=3 + i).wait(timeout)
inplace = dst_tensors.get(path)
if (
isinstance(inplace, torch.Tensor)
and inplace.device.type == self._device.type
):
if isinstance(inplace, DTensor):
inplace = inplace._local_tensor
t = _cast_tensor(inplace, torch.uint8)
assert (
t.nbytes == v.nbytes
), "inplace tensor storage must be the same size"
else:
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)

work = self._pg.recv([t], src_rank, tag=3 + i)
i += 1

# TODO: allow in place receives to avoid having to copy to cpu to
# avoid OOMs
t = t.cpu()
if inplace is None:
# if not inplace we need to copy it to CPU to avoid OOMing
work.wait(timeout)
t = t.cpu()
else:
works.append(work)

return torch.as_strided(
t.view(v.dtype),
Expand All @@ -234,14 +285,17 @@ def recv(v: _TensorMeta) -> torch.Tensor:
)

values = []
for v in meta.non_tensor_leaves:
for path, v in zip(meta.paths, meta.non_tensor_leaves):
if isinstance(v, _TensorMeta):
values.append(recv(v))
values.append(recv(path, v))
elif isinstance(v, _DTensorMeta):
tensor = recv(v.local)
tensor = recv(path, v.local)
# pyre-fixme[29]: DTensor is not a function
values.append(DTensor(tensor, v.spec, requires_grad=False))
else:
values.append(v)

for work in works:
work.wait(timeout)

return tree_unflatten(values, meta.treespec)
93 changes: 93 additions & 0 deletions torchft/checkpointing/pg_transport_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

import torch
import torch.distributed as dist

from torchft.checkpointing.pg_transport import PGTransport, _timeit
from torchft.process_group import ProcessGroupBabyNCCL

logger: logging.Logger = logging.getLogger(__name__)


def main(argv: list[str]) -> None:
import argparse

logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("--inplace", action="store_true")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
args = parser.parse_args(argv)

CHUNK_SIZE: int = args.chunk_size
TOTAL_SIZE: int = args.total_size
INPLACE: bool = args.inplace
DEVICE: str = args.device

timeout: timedelta = timedelta(seconds=10)

store = dist.TCPStore(
"localhost",
0,
is_master=True,
timeout=timeout,
wait_for_workers=False,
)
store_addr: str = f"localhost:{store.port}"

def run(rank: int) -> None:
torch.cuda.set_device(rank)

device = torch.device(DEVICE)

with _timeit("init_pg"):
pg = ProcessGroupBabyNCCL(timeout=timeout)
pg.configure(store_addr=store_addr, rank=rank, world_size=2)

t = torch.zeros(10, device=device, dtype=torch.float32)
pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout)

with _timeit("create state_dict"):
state_dict: dict[str, torch.Tensor] = {}
for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
state_dict[f"chunk/{i}"] = torch.zeros(
CHUNK_SIZE // 4, dtype=torch.float32, device=device
)

def get_state_dict() -> object:
return state_dict

transport = PGTransport(
pg=pg,
timeout=timeout,
device=device,
state_dict=get_state_dict if INPLACE else None,
)
metadata = transport.metadata()

if rank == 0:
with _timeit("send_checkpoint"):
transport.send_checkpoint(
dst_ranks=[1],
step=1,
state_dict=state_dict,
timeout=timedelta(seconds=60),
)
elif rank == 1:
with _timeit("recv_checkpoint"):
transport.recv_checkpoint(
src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=60)
)

with ThreadPoolExecutor(max_workers=2) as executor:
results = executor.map(run, range(2))
list(results)


if __name__ == "__main__":
main(sys.argv[1:])
Loading

0 comments on commit 6fe4c8e

Please sign in to comment.