Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PGTransport: add inplace transport (3x faster) #119

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 81 additions & 27 deletions torchft/checkpointing/pg_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Generator, List, Tuple, TypeVar, Union, cast
from typing import Callable, Generator, Optional, TypeVar, Union, cast

import torch
from torch.distributed import Work
from torch.distributed.tensor import DTensor, _DTensorSpec
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
from torch.utils._pytree import (
KeyPath,
TreeSpec,
tree_flatten_with_path,
tree_unflatten,
)

from torchft.checkpointing.transport import CheckpointTransport
from torchft.process_group import ProcessGroup
Expand All @@ -32,7 +37,7 @@ class _TensorMeta:
shape: torch.Size
dtype: torch.dtype
storage_offset: int
stride: Tuple[int, ...]
stride: tuple[int, ...]
nbytes: int


Expand Down Expand Up @@ -61,13 +66,15 @@ class _StateDictMeta:
Args:
step: the step of the checkpoint to verify consistency
treespec: the pytree spec of the state dict
paths: the path of each leaf in the state dict
non_tensor_leaves: the metadata for each tensor in the state dict and any
non-tensor leaves in the state dict
"""

step: int
treespec: TreeSpec
non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]]
paths: list[KeyPath]
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]]


@contextmanager
Expand All @@ -78,7 +85,7 @@ def _timeit(name: str) -> Generator[None, None, None]:
logger.info(f"{name} took {dur}s")


def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]:
return (
_cast_tensor(tensor, torch.uint8),
_TensorMeta(
Expand All @@ -95,12 +102,16 @@ def _prepare_state_dict(
state_dict: object,
step: int,
device: torch.device,
) -> Tuple[_StateDictMeta, List[torch.Tensor]]:
leaves, treespec = tree_flatten(state_dict)
) -> tuple[_StateDictMeta, list[torch.Tensor]]:
leaves: list[tuple[KeyPath, object]]
leaves, treespec = tree_flatten_with_path(state_dict)

paths: list[KeyPath] = []
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = []
tensors: list[torch.Tensor] = []
for key_path, v in leaves:
paths.append(key_path)

non_tensor_leaves = []
tensors = []
for v in leaves:
if isinstance(v, DTensor):
tensor, tensor_meta = _prepare_tensor(v._local_tensor)

Expand All @@ -123,6 +134,7 @@ def _prepare_state_dict(
_StateDictMeta(
step=step,
treespec=treespec,
paths=paths,
non_tensor_leaves=non_tensor_leaves,
),
tensors,
Expand All @@ -139,6 +151,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
caveat that the cast tensor may be larger than the original tensor due to
the differences in striding.
"""
assert (
type(tensor) is torch.Tensor
), f"can only cast standard tensors not {type(tensor)}"
storage = tensor.untyped_storage()
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
assert ret.untyped_storage() is storage, "storage should be the same"
Expand All @@ -150,17 +165,28 @@ class PGTransport(CheckpointTransport[T]):
This is a checkpoint transport that uses the process group to transfer checkpoints.
This allows for fast recovery of workers by fetching the current weights
from an existing worker.

Args:
state_dict: a callable that returns the state dict to be transferred
pg: the process group to use for communication
timeout: the timeout for communication
device: the device to use for tensors
state_dict: if specified this function will be called to do an inplace
receive into the returned state_dict. This is much faster than
having to allocate new tensors and transferring them to the CPU.
"""

def __init__(
self, pg: ProcessGroup, timeout: timedelta, device: torch.device
self,
pg: ProcessGroup,
timeout: timedelta,
device: torch.device,
state_dict: Optional[Callable[[], object]] = None,
) -> None:
self._work: List[Work] = []
self._work: list[Work] = []
self._pg = pg
self._timeout = timeout
self._device = device
self._state_dict = state_dict

def metadata(self) -> str:
return "<n/a>"
Expand All @@ -169,7 +195,7 @@ def disallow_checkpoint(self) -> None:
pass

def send_checkpoint(
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
) -> None:
with _timeit("preparing state_dict"):
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
Expand All @@ -186,20 +212,29 @@ def send_checkpoint(

with _timeit("send tensors"):
for i, t in enumerate(tensors):
original_device = t.device
t = t.to(self._device)
for dst_rank in dst_ranks:
work.append(self._pg.send([t], dst_rank, tag=3 + i))

# allow 3 concurrent transfers at a time to avoid OOMs
while len(work) > (3 * len(dst_ranks)):
work.pop(0).wait(timeout)
# if we did a copy we should wait for the work to complete so we
# can free the memory to avoid OOMs
if original_device == torch.device("cpu"):
for w in work:
w.wait(timeout)
work = []

for w in work:
w.wait(timeout)

def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
) -> T:
state_dict = self._state_dict() if self._state_dict else {}
state_dict_leaves, _ = tree_flatten_with_path(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is tree_flatten_with_path a new one? Is it going to give you the FQN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been there for a few versions of torch -- it gives a path like:

(MappingKey(key='user'), MappingKey(key='optimizer'), MappingKey(key='state.layers.7.feed_forward.w2.weight.step'))


dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves)

len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
length = cast(int, len_t.item())
Expand All @@ -213,18 +248,34 @@ def recv_checkpoint(
assert meta.step == step

i: int = 0
works: list[Work] = []

def recv(v: _TensorMeta) -> torch.Tensor:
def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor:
nonlocal i

t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
# TODO: parallelize receives
self._pg.recv([t], src_rank, tag=3 + i).wait(timeout)
inplace = dst_tensors.get(path)
if (
isinstance(inplace, torch.Tensor)
and inplace.device.type == self._device.type
):
if isinstance(inplace, DTensor):
inplace = inplace._local_tensor
t = _cast_tensor(inplace, torch.uint8)
assert (
t.nbytes == v.nbytes
), "inplace tensor storage must be the same size"
else:
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)

work = self._pg.recv([t], src_rank, tag=3 + i)
i += 1

# TODO: allow in place receives to avoid having to copy to cpu to
# avoid OOMs
t = t.cpu()
if inplace is None:
# if not inplace we need to copy it to CPU to avoid OOMing
work.wait(timeout)
t = t.cpu()
else:
works.append(work)

return torch.as_strided(
t.view(v.dtype),
Expand All @@ -234,14 +285,17 @@ def recv(v: _TensorMeta) -> torch.Tensor:
)

values = []
for v in meta.non_tensor_leaves:
for path, v in zip(meta.paths, meta.non_tensor_leaves):
if isinstance(v, _TensorMeta):
values.append(recv(v))
values.append(recv(path, v))
elif isinstance(v, _DTensorMeta):
tensor = recv(v.local)
tensor = recv(path, v.local)
# pyre-fixme[29]: DTensor is not a function
values.append(DTensor(tensor, v.spec, requires_grad=False))
else:
values.append(v)

for work in works:
work.wait(timeout)

return tree_unflatten(values, meta.treespec)
93 changes: 93 additions & 0 deletions torchft/checkpointing/pg_transport_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

import torch
import torch.distributed as dist

from torchft.checkpointing.pg_transport import PGTransport, _timeit
from torchft.process_group import ProcessGroupBabyNCCL

logger: logging.Logger = logging.getLogger(__name__)


def main(argv: list[str]) -> None:
import argparse

logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("--inplace", action="store_true")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
args = parser.parse_args(argv)

CHUNK_SIZE: int = args.chunk_size
TOTAL_SIZE: int = args.total_size
INPLACE: bool = args.inplace
DEVICE: str = args.device

timeout: timedelta = timedelta(seconds=10)

store = dist.TCPStore(
"localhost",
0,
is_master=True,
timeout=timeout,
wait_for_workers=False,
)
store_addr: str = f"localhost:{store.port}"

def run(rank: int) -> None:
torch.cuda.set_device(rank)

device = torch.device(DEVICE)

with _timeit("init_pg"):
pg = ProcessGroupBabyNCCL(timeout=timeout)
pg.configure(store_addr=store_addr, rank=rank, world_size=2)

t = torch.zeros(10, device=device, dtype=torch.float32)
pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout)

with _timeit("create state_dict"):
state_dict: dict[str, torch.Tensor] = {}
for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
state_dict[f"chunk/{i}"] = torch.zeros(
CHUNK_SIZE // 4, dtype=torch.float32, device=device
)

def get_state_dict() -> object:
return state_dict

transport = PGTransport(
pg=pg,
timeout=timeout,
device=device,
state_dict=get_state_dict if INPLACE else None,
)
metadata = transport.metadata()

if rank == 0:
with _timeit("send_checkpoint"):
transport.send_checkpoint(
dst_ranks=[1],
step=1,
state_dict=state_dict,
timeout=timedelta(seconds=60),
)
elif rank == 1:
with _timeit("recv_checkpoint"):
transport.recv_checkpoint(
src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=60)
)

with ThreadPoolExecutor(max_workers=2) as executor:
results = executor.map(run, range(2))
list(results)


if __name__ == "__main__":
main(sys.argv[1:])
Loading