Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds more collectives to ProcessGroups #108

Merged
merged 21 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 213 additions & 36 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
)
from torch.distributed.distributed_c10d import (
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
ReduceOp,
ReduceScatterOptions,
Expand Down Expand Up @@ -107,40 +110,66 @@ def __init__(self, *args: object, **kwargs: object) -> None:

self._group_name: Optional[str] = None

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
# pyre-fixme[14]: inconsistent override
def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
"""
This reconfigures the ProcessGroup to use a new store, rank and world size.

Every time this is called it must be provided with a unique prefixed
store address. I.e. localhost:1234/my/prefix/1

This function will block until the underlying ProcessGroup is created.
If an error occurs this will throw.
Gathers tensors from the whole group in a list.

Args:
store_addr: address of the store to use
rank: rank of this process
world_size: world size of this process group
See torch.distributed.all_gather for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allreduce(
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
self,
tensors: List[torch.Tensor],
opts: Union[AllreduceOptions, ReduceOp],
) -> Work:
"""
Reduces the tensor data across all machines in such a way that all get the final result.

See torch.distributed.all_reduce for more details.
"""
raise NotImplementedError("not implemented")

def allreduce_coalesced(
self,
tensors: List[torch.Tensor],
opts: AllreduceCoalescedOptions,
) -> Work:
"""
Performs an all_reduce operation in a coalesced manner.

See torch.distributed.all_reduce_coalesced for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allgather(
def alltoall_base(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: AllgatherOptions,
output_buffer: torch.Tensor,
input_buffer: torch.Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
opts: AllToAllOptions,
) -> Work:
"""
Gathers tensors from the whole group in a list.
Performs an all_to_all operation.

See torch.distributed.all_gather for more details.
See torch.distributed.all_to_all_single for more details.
"""
raise NotImplementedError("not implemented")

def barrier(self, opts: BarrierOptions) -> Work:
"""
Synchronizes all processes.

See torch.distributed.barrier for more details.
"""
raise NotImplementedError("not implemented")

Expand All @@ -160,6 +189,15 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
opts.rootRank = root
return self.broadcast([tensor], opts)

# pyre-fixme[14]: inconsistent override
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
"""
Receives a list of tensors from the process with rank `rank`.

See torch.distributed.recv for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def reduce_scatter(
self,
Expand All @@ -174,6 +212,32 @@ def reduce_scatter(
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
"""
Sends a list of tensors to the process with rank `dst_rank`.

See torch.distributed.send for more details.
"""
raise NotImplementedError("not implemented")

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
"""
This reconfigures the ProcessGroup to use a new store, rank and world size.

Every time this is called it must be provided with a unique prefixed
store address. I.e. localhost:1234/my/prefix/1

This function will block until the underlying ProcessGroup is created.
If an error occurs this will throw.

Args:
store_addr: address of the store to use
rank: rank of this process
world_size: world size of this process group
"""
raise NotImplementedError("not implemented")

def size(self) -> int:
raise NotImplementedError("not implemented")

Expand Down Expand Up @@ -268,9 +332,6 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
raise NotImplementedError("not implemented")

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self.parent.allreduce(tensors, opts)

def allgather(
self,
output_tensors: List[List[torch.Tensor]],
Expand All @@ -279,9 +340,35 @@ def allgather(
) -> Work:
return self.parent.allgather(output_tensors, input_tensor, opts)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self.parent.allreduce(tensors, opts)

def allreduce_coalesced(
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
) -> Work:
return self.parent.allreduce_coalesced(tensors, opts)

def alltoall_base(
self,
output_buffer: torch.Tensor,
input_buffer: torch.Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
opts: AllToAllOptions,
) -> Work:
return self.parent.alltoall_base(
output_buffer, input_buffer, output_split_sizes, input_split_sizes, opts
)

def barrier(self, opts: BarrierOptions) -> Work:
return self.parent.barrier(opts)

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
return self.parent.broadcast(tensor_list, opts)

def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
return self.parent.recv(tensors, src_rank, tag)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
Expand All @@ -290,6 +377,9 @@ def reduce_scatter(
) -> Work:
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
return self.parent.send(tensors, dst_rank, tag)

def size(self) -> int:
return self.parent.size()

Expand Down Expand Up @@ -414,11 +504,37 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
self._work.append(res)
return res

def allreduce_coalesced(
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
) -> Work:
res = _DummyWork(tensors)
self._work.append(res)
return res

def alltoall_base(
self,
output_buffer: torch.Tensor,
input_buffer: torch.Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
opts: AllToAllOptions,
) -> Work:
output_buffer.copy_(input_buffer)
res = _DummyWork([output_buffer])
self._work.append(res)
return res

def barrier(self, opts: BarrierOptions) -> Work:
return _DummyWork(None)

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensor_list)
self._work.append(res)
return res

def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
return _DummyWork(None)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
Expand All @@ -432,6 +548,9 @@ def reduce_scatter(
self._work.append(res)
return res

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
return _DummyWork(None)

def size(self) -> int:
return self._world

Expand Down Expand Up @@ -979,19 +1098,6 @@ def _assert_alive(self) -> None:
if not p.is_alive():
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")

def allreduce(
self,
tensors: List[torch.Tensor],
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("allreduce", tensors, opts)

def allgather(
self,
output_tensors: List[List[torch.Tensor]],
Expand All @@ -1012,6 +1118,56 @@ def allgather(

return self._run_func("allgather", output_tensors, input_tensor, opts)

def allreduce(
self,
tensors: List[torch.Tensor],
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("allreduce", tensors, opts)

def allreduce_coalesced(
self,
tensors: List[torch.Tensor],
opts: Union[dist.AllreduceCoalescedOptions, dist.ReduceOp],
) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("allreduce_coalesced", tensors, opts)

def alltoall_base(
self,
output_buffer: torch.Tensor,
input_buffer: torch.Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
opts: AllToAllOptions,
) -> Work:
if not output_buffer.is_shared():
output_buffer.share_memory_()
if not input_buffer.is_shared():
input_buffer.share_memory_()
return self._run_func(
"alltoall_base",
output_buffer,
input_buffer,
output_split_sizes,
input_split_sizes,
opts,
)

def barrier(self, opts: BarrierOptions) -> Work:
return self._run_func("barrier", opts)

def broadcast(
self,
tensor_list: List[torch.Tensor],
Expand All @@ -1025,6 +1181,15 @@ def broadcast(

return self._run_func("broadcast", tensor_list, opts)

def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("recv", tensors, src_rank, tag)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
Expand All @@ -1044,6 +1209,15 @@ def reduce_scatter(
tensor.share_memory_()
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("send", tensors, dst_rank, tag)

def size(self) -> int:
return self._world_size

Expand All @@ -1069,8 +1243,11 @@ def safe_args(cls, args: T) -> T:
elif isinstance(
args,
(
AllreduceOptions,
AllgatherOptions,
AllreduceOptions,
AllreduceCoalescedOptions,
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
ReduceScatterOptions,
),
Expand Down
Loading