Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for allgather_into_tensor_coalesced and reduce_scatter_tensor_coalesced #114

Merged
merged 3 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 178 additions & 5 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ def allgather(
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
"""
Performs an allgather operation on coalesced tensors.

See torch.distributed.allgather_coalesced for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allreduce(
self,
Expand Down Expand Up @@ -212,6 +226,20 @@ def reduce_scatter(
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
"""
Performs a reduce-scatter operation on coalesced tensors.

See torch.distributed.reduce_scatter_tensor for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
"""
Expand Down Expand Up @@ -336,10 +364,20 @@ def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: object,
opts: AllgatherOptions,
) -> Work:
return self.parent.allgather(output_tensors, input_tensor, opts)

def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
return self.parent.allgather_into_tensor_coalesced(
output_tensors, input_tensors, opts
)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self.parent.allreduce(tensors, opts)

Expand Down Expand Up @@ -377,6 +415,16 @@ def reduce_scatter(
) -> Work:
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)

def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
return self.parent.reduce_scatter_tensor_coalesced(
output_tensors, input_tensors, opts
)

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
return self.parent.send(tensors, dst_rank, tag)

Expand All @@ -402,8 +450,15 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
self._timeout = timeout

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
backend_class._set_sequence_number_for_group()
pg._register_backend(
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-gloo"
Expand All @@ -427,6 +482,28 @@ def reduce_scatter(
"""
raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")

# pyre-fixme[15]: inconsistent override
def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> None:
"""
This function is a placeholder for the reduce_scatter_tensor_coalesced
operation in the ProcessGroupGloo class.
However, this operation is not supported by the
Gloo backend, and thus, calling this function will raise a
RuntimeError.

Raises:
RuntimeError: Always raised since reduce_scatter is not
supported by ProcessGroupGloo.
"""
raise RuntimeError(
"ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
)


class ProcessGroupNCCL(ProcessGroupWrapper):
"""
Expand All @@ -440,8 +517,15 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
"""

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
backend_class._set_sequence_number_for_group()
pg._register_backend(
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-nccl"
Expand Down Expand Up @@ -499,6 +583,19 @@ def allgather(
self._work.append(res)
return res

def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
for o, i in zip(output_tensors, input_tensors):
o.copy_(i)

res = _DummyWork(output_tensors)
self._work.append(res)
return res

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensors)
self._work.append(res)
Expand Down Expand Up @@ -548,6 +645,19 @@ def reduce_scatter(
self._work.append(res)
return res

def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
for o, i in zip(output_tensors, input_tensors):
o.copy_(i)

res = _DummyWork(output_tensors)
self._work.append(res)
return res

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
return _DummyWork(None)

Expand Down Expand Up @@ -1134,6 +1244,20 @@ def allgather(
_maybe_share_tensors(input_tensor)
return self._run_func("allgather", output_tensors, input_tensor, opts)

def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
_assert_list(output_tensors)
_assert_list(input_tensors)
_maybe_share_tensors(output_tensors)
_maybe_share_tensors(input_tensors)
return self._run_func(
"allgather_into_tensor_coalesced", output_tensors, input_tensors, opts
)

def allreduce(
self,
tensors: List[torch.Tensor],
Expand Down Expand Up @@ -1200,6 +1324,20 @@ def reduce_scatter(
_maybe_share_tensors(input_tensors)
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)

def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
_assert_list(output_tensors)
_assert_list(input_tensors)
_maybe_share_tensors(output_tensors)
_maybe_share_tensors(input_tensors)
return self._run_func(
"reduce_scatter_tensor_coalesced", output_tensors, input_tensors, opts
)

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
_assert_list(tensors)
_maybe_share_tensors(tensors)
Expand Down Expand Up @@ -1278,8 +1416,14 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size)
backend_class = BaseProcessGroupGloo(store, rank, world_size)
pg._register_backend(
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-baby-gloo"
Expand All @@ -1303,6 +1447,28 @@ def reduce_scatter(
"""
raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")

# pyre-fixme[15]: inconsistent override
def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> None:
"""
This function is a placeholder for the reduce_scatter_tensor_coalesced
operation in the ProcessGroupBabyGloo class.
However, this operation is not supported by the
Gloo backend, and thus, calling this function will raise a
RuntimeError.

Raises:
RuntimeError: Always raised since reduce_scatter is not
supported by ProcessGroupBabyGloo.
"""
raise RuntimeError(
"ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
)


class ProcessGroupBabyNCCL(ProcessGroupBaby):
"""
Expand All @@ -1322,8 +1488,15 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
backend_class._set_sequence_number_for_group()
pg._register_backend(
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-baby-nccl"
Expand Down
Loading