Skip to content

Commit

Permalink
Adds reduce_scatter into torchft (#102)
Browse files Browse the repository at this point in the history
* initial commit for reduce_scatter

* fixes reduce_scatter function signature, refactors test and adds reduce_scatter test

* fixes test

* adds explicit NotImplementedError to reduce_scatter in gloo, simplify the test suite

* fix tests after merge

* add explicit error for ProcessGroupGloo

* notimplementederror->runtimeerror
  • Loading branch information
allenwang28 authored Feb 10, 2025
1 parent 8f0d125 commit e55542a
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 23 deletions.
113 changes: 107 additions & 6 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
AllreduceOptions,
BroadcastOptions,
ReduceOp,
ReduceScatterOptions,
Work,
)
from torch.futures import Future
Expand Down Expand Up @@ -159,6 +160,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
opts.rootRank = root
return self.broadcast([tensor], opts)

# pyre-fixme[14]: inconsistent override
def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: ReduceScatterOptions,
) -> Work:
"""
Reduces, then scatters a list of tensors to all processes in a group.
See torch.distributed.reduce_scatter for more details.
"""
raise NotImplementedError("not implemented")

def size(self) -> int:
raise NotImplementedError("not implemented")

Expand Down Expand Up @@ -267,6 +282,14 @@ def allgather(
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
return self.parent.broadcast(tensor_list, opts)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: object,
) -> Work:
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)

def size(self) -> int:
return self.parent.size()

Expand Down Expand Up @@ -295,6 +318,25 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
def getBackendName(self) -> str:
return "torchft-gloo"

# pyre-fixme[14,15]: inconsistent override
def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: ReduceScatterOptions,
) -> None:
"""
This function is a placeholder for the reduce_scatter operation in the
ProcessGroupGloo class. However, this operation is not supported by the
Gloo backend, and thus, calling this function will raise a
RuntimeError.
Raises:
RuntimeError: Always raised since reduce_scatter is not
supported by ProcessGroupGloo.
"""
raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")


class ProcessGroupNCCL(ProcessGroupWrapper):
"""
Expand Down Expand Up @@ -354,11 +396,6 @@ def __init__(self, rank: int, world: int) -> None:
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
self.configure_count += 1

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensor_list)
self._work.append(res)
return res

def allgather(
self,
output_tensors: List[List[torch.Tensor]],
Expand All @@ -377,6 +414,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
self._work.append(res)
return res

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensor_list)
self._work.append(res)
return res

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: object,
) -> Work:
for o, i in zip(output_tensors, input_tensors[0]):
o.copy_(i)

res = _DummyWork(output_tensors)
self._work.append(res)
return res

def size(self) -> int:
return self._world

Expand Down Expand Up @@ -970,6 +1025,25 @@ def broadcast(

return self._run_func("broadcast", tensor_list, opts)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: ReduceScatterOptions,
) -> Work:
assert isinstance(output_tensors, list), "input must be list"
assert isinstance(input_tensors, list), "input must be list"

for tensor in output_tensors:
if not tensor.is_shared():
tensor.share_memory_()

for tensor_list in input_tensors:
for tensor in tensor_list:
if not tensor.is_shared():
tensor.share_memory_()
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)

def size(self) -> int:
return self._world_size

Expand All @@ -992,7 +1066,15 @@ def safe_args(cls, args: T) -> T:
return tuple(cls.safe_args(arg) for arg in args)
elif isinstance(args, list):
return [cls.safe_args(arg) for arg in args]
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
elif isinstance(
args,
(
AllreduceOptions,
AllgatherOptions,
BroadcastOptions,
ReduceScatterOptions,
),
):
return cls.from_torch(args)
else:
return args
Expand Down Expand Up @@ -1038,6 +1120,25 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou
def getBackendName(self) -> str:
return "torchft-baby-gloo"

# pyre-fixme[15]: inconsistent override
def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: ReduceScatterOptions,
) -> None:
"""
This function is a placeholder for the reduce_scatter operation in the
ProcessGroupGloo class. However, this operation is not supported by the
Gloo backend, and thus, calling this function will raise a
RuntimeError.
Raises:
RuntimeError: Always raised since reduce_scatter is not
supported by ProcessGroupGloo.
"""
raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")


class ProcessGroupBabyNCCL(ProcessGroupBaby):
"""
Expand Down
55 changes: 38 additions & 17 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AllreduceOptions,
BroadcastOptions,
ReduceOp,
ReduceScatterOptions,
_resolve_process_group,
)
from torch.distributed import (
Expand Down Expand Up @@ -94,18 +95,28 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
("broadcast", (tensor_list, BroadcastOptions())),
("broadcast_one", (input_tensor, 0)),
(
"reduce_scatter",
(output_tensors[0], [[input_tensor]], ReduceScatterOptions()),
),
]
works: Dict[str, dist._Work] = {}
for coll_str, args in collectives:
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
work.wait()
fut = work.get_future()
fut.wait()

# Check that all tensor arguments have the expected shapes and dtypes
check_tensors(args)
for coll_str, args in collectives:
try:
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
work.wait()
fut = work.get_future()
fut.wait()
# Check that all tensor arguments have the expected shapes and dtypes
check_tensors(args)
except RuntimeError as e:
if f"does not support {coll_str}" in str(e):
# Skip collectives that are not supported by the backend.
continue
raise e

print(works)
return works
Expand Down Expand Up @@ -306,7 +317,7 @@ def test_baby_nccl_2gpu(self) -> None:

store_addr: str = f"localhost:{store.port}/prefix"

def run(rank: int) -> Tuple[torch.Tensor, Work]:
def run(rank: int) -> Tuple[ProcessGroupBabyNCCL, torch.Tensor, Work]:
a = ProcessGroupBabyNCCL(
timeout=timedelta(seconds=10.0),
)
Expand All @@ -318,19 +329,29 @@ def run(rank: int) -> Tuple[torch.Tensor, Work]:
at = torch.tensor([rank + 1], device="cuda")

a_work = a.allreduce([at], ReduceOp.SUM)
return at, a_work
return a, at, a_work

with ThreadPoolExecutor(max_workers=2) as executor:
a_fut = executor.submit(run, 0)
b_fut = executor.submit(run, 1)

at, a_work = a_fut.result()
bt, b_work = b_fut.result()

a_work.wait()
b_work.get_future().wait()
a, at, a_work = a_fut.result()
b, bt, b_work = b_fut.result()

torch.testing.assert_close(at.cpu(), bt.cpu())
try:
a_work.wait()
b_work.get_future().wait()
torch.testing.assert_close(at.cpu(), bt.cpu())
finally:
# cleanup - first ensure that babywork is deleted before shutting down PGs
# note futures must be deleted as they hold references to babywork
del a_fut
del b_fut
del a_work
del b_work
gc.collect()
b.shutdown()
a.shutdown()

def test_device_mesh(self) -> None:
os.environ["MASTER_ADDR"] = "localhost"
Expand Down

0 comments on commit e55542a

Please sign in to comment.