From 4ad6f1b532baaf248fdb7c5c0f906987393eb3a0 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 26 Nov 2024 17:27:45 -0800 Subject: [PATCH] docs: add sphinx documentation and add missing documentation (#18) --- .github/workflows/docs.yaml | 54 +++ docs/.gitignore | 8 + docs/Makefile | 36 ++ docs/requirements.txt | 8 + docs/source/_static/img/pytorch-logo-dark.svg | 33 ++ .../source/_static/img/pytorch-logo-flame.png | Bin 0 -> 1010 bytes docs/source/checkpointing.rst | 4 + docs/source/conf.py | 307 ++++++++++++++++++ docs/source/data.rst | 4 + docs/source/ddp.rst | 4 + docs/source/index.rst | 22 ++ docs/source/manager.rst | 4 + docs/source/optim.rst | 4 + docs/source/parameter_server.rst | 4 + docs/source/process_group.rst | 4 + torchft/checkpointing.py | 61 +++- torchft/data.py | 13 +- torchft/ddp.py | 25 +- torchft/manager.py | 102 +++++- torchft/optim.py | 20 +- torchft/parameter_server.py | 33 +- torchft/process_group.py | 62 +++- 22 files changed, 771 insertions(+), 41 deletions(-) create mode 100644 .github/workflows/docs.yaml create mode 100644 docs/.gitignore create mode 100644 docs/Makefile create mode 100644 docs/requirements.txt create mode 100644 docs/source/_static/img/pytorch-logo-dark.svg create mode 100644 docs/source/_static/img/pytorch-logo-flame.png create mode 100644 docs/source/checkpointing.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/data.rst create mode 100644 docs/source/ddp.rst create mode 100644 docs/source/index.rst create mode 100644 docs/source/manager.rst create mode 100644 docs/source/optim.rst create mode 100644 docs/source/parameter_server.rst create mode 100644 docs/source/process_group.rst diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 0000000..707489c --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,54 @@ +name: Docs + +on: + push: + branches: + - main + pull_request: + +jobs: + build: + runs-on: ubuntu-20.04 + steps: + - name: Setup Python + uses: actions/setup-python@v3 + with: + python-version: "3.10" + architecture: x64 + - name: Checkout + uses: actions/checkout@v3 + - name: Install Dependencies + run: | + set -eux + + sudo apt-get install -y protobuf-compiler + + pip install .[dev] -v + + pip install -r docs/requirements.txt + - name: Build Sphinx Docs + run: | + set -eux + + cd docs + make html + - name: Upload static files as artifact + id: deployment + uses: actions/upload-pages-artifact@v3 + with: + path: docs/build/html/ + + deploy: + runs-on: ubuntu-latest + needs: build + if: ${{ github.ref == 'refs/heads/main' }} + permissions: + pages: write # to deploy to Pages + id-token: write # to verify the deployment originates from an appropriate source + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..670f0b9 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,8 @@ +src/pytorch-sphinx-theme +source/examples_*/ +jupyter_execute/ +build/ + +Dockerfile +my_component.py +my_app.py diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..9bc06cd --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,36 @@ + +# Minimal makefile for Sphinx documentation +# Usage: +# make html +# + +# You can set these variables from the command line. +SPHINXOPTS = -W +SPHINXBUILD = sphinx-build +SPHINXPROJ = torchft +SOURCEDIR = source +BUILDDIR = build +VERSION := $(shell python -c "from importlib.metadata import version; print(version('torchft'))") + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + @echo "Deleting build directory" + rm -rf "$(BUILDDIR)" + rm -rf "$(SOURCEDIR)/examples_apps" "$(SOURCEDIR)/examples_pipelines" + +.PHONY: help Makefile clean livehtml papermill + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +# optional live version +livehtml: + sphinx-autobuild --watch ../torchft --re-ignore ".*(examples_.*|.new|source/.*(Dockerfile|.py))" "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +papermill: html + ./papermill.sh diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..9008607 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,8 @@ +sphinx==5.0.1 +-e git+http://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +sphinxcontrib.katex +matplotlib +papermill +ipykernel +ipython_genutils +jinja2<=3.0.3 diff --git a/docs/source/_static/img/pytorch-logo-dark.svg b/docs/source/_static/img/pytorch-logo-dark.svg new file mode 100644 index 0000000..5e53000 --- /dev/null +++ b/docs/source/_static/img/pytorch-logo-dark.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/img/pytorch-logo-flame.png b/docs/source/_static/img/pytorch-logo-flame.png new file mode 100644 index 0000000000000000000000000000000000000000..370633f2ec2b7cf35a37a283095229de337f46e4 GIT binary patch literal 1010 zcmV3glQ%A$r$!ato85yx2xcuN% zU{ge`&4p|qlJ;swtf%e@ zz;DiuCXiGGsD7_D<0S3;U{}qPsR^JZGPf6Kt+3T~BeSj3Wf;+SQMowdBp+KffJ);p zMZ`p2D?R5k-A|xA3q&nT$F#2_8t5IAa6Of5Zf@3TV4Y8jL3kc0p!y!j`W16|!OrH6 z-w{B5ATfZD4a%z&s!aZ=(f0c%zPja?Q^QXnPCdHVXBXQeMF7(i7jjn9CbP_YuWX>YWrVvU07QQ<=$m^IO+ zTf6*v_#Q?#h8~TvFyH5q7eO-dY;Iy{`CUo>7C1CJm^dfONjxvNDoM%?uJZ7mK*PfA zeZX3e{Ps%o|HeJzmre)^*H-L!C3O)@D5R368;2NExZx+*u z{eCrtbYe*1C0C6yX=}%MznZG2;mklHWeJ~n~p!2hn)}dc`_. + +.. toctree:: + :maxdepth: 1 + :caption: Reference + + process_group + manager + optim + ddp + data + checkpointing + parameter_server diff --git a/docs/source/manager.rst b/docs/source/manager.rst new file mode 100644 index 0000000..9f79149 --- /dev/null +++ b/docs/source/manager.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.manager + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/optim.rst b/docs/source/optim.rst new file mode 100644 index 0000000..00eb587 --- /dev/null +++ b/docs/source/optim.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/parameter_server.rst b/docs/source/parameter_server.rst new file mode 100644 index 0000000..03c2a37 --- /dev/null +++ b/docs/source/parameter_server.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.parameter_server + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/process_group.rst b/docs/source/process_group.rst new file mode 100644 index 0000000..8e6cfe1 --- /dev/null +++ b/docs/source/process_group.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.process_group + :members: + :undoc-members: + :show-inheritance: diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 6c897c2..becd57c 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -4,25 +4,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import urllib.request -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +""" +Checkpointing +============== + +This module implements methods for checkpointing and resuming training from a checkpoint. +""" + +import io +import logging import socket import threading -import logging -import io +import urllib.request +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Callable import torch logger: logging.Logger = logging.getLogger(__name__) -class IPv6HTTPServer(ThreadingHTTPServer): +class _IPv6HTTPServer(ThreadingHTTPServer): address_family = socket.AF_INET6 request_queue_size = 1024 class CheckpointServer: - def __init__(self, state_dict) -> None: + """ + This is an HTTP server that can be used to transfer checkpoints + between workers. + + This allows for fast recovery of workers by fetching the current weights + from an existing worker. + + Args: + state_dict: a callable that returns the state dict to be transferred + """ + + def __init__(self, state_dict: Callable[[], object]) -> None: self._checkpoint_lock = threading.Lock() self._disallowed = False self._step = -1 @@ -58,7 +77,7 @@ def err(self, msg: str) -> None: self.wfile.write(msg.encode()) server_address = ("", 0) - self._server = IPv6HTTPServer(server_address, RequestHandler) + self._server = _IPv6HTTPServer(server_address, RequestHandler) logger.info(f"Started CheckpointServer on {self.address()}...") self._thread = threading.Thread( @@ -70,6 +89,12 @@ def err(self, msg: str) -> None: @classmethod def load_from_address(cls, address: str) -> object: + """ + Loads a checkpoint from the given address. + + Args: + address: the HTTP address to load the checkpoint from + """ logger.info(f"fetching checkpoint from {address}") with urllib.request.urlopen(address) as f: @@ -79,6 +104,14 @@ def load_from_address(cls, address: str) -> object: return torch.load(reader, weights_only=True) def address(self) -> str: + """ + Returns the HTTP address to fetch a checkpoint from this server at the current step. + + Format: http://host:port/checkpoint/1234 + + Returns: + an HTTP address + """ port = self._server.socket.getsockname()[1] return f"http://{socket.gethostname()}:{port}/checkpoint/{self._step}" @@ -89,11 +122,22 @@ def _serve(self) -> None: logger.exception("got exception in checkpoint server") def disallow_checkpoint(self) -> None: + """ + Disallows serving the checkpoint. + + All requests will block until allow_checkpoint is called. + """ if not self._disallowed: self._disallowed = True self._checkpoint_lock.acquire() def allow_checkpoint(self, step: int) -> None: + """ + Allows serving the checkpoint with the specified step number. + + Args: + step: the step number to serve + """ self._step = step if self._disallowed: @@ -101,4 +145,7 @@ def allow_checkpoint(self, step: int) -> None: self._checkpoint_lock.release() def shutdown(self) -> None: + """ + Shutdown the server. + """ self._server.shutdown() diff --git a/torchft/data.py b/torchft/data.py index 8b8869c..f7a26f3 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -4,11 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Data +==== + +This module provides helper classes to implement fault tolerant data loaders. + +We recommend using torchdata's StatefulDataLoader to checkpoint each replica's +dataloader frequently to avoid duplicate batches. +""" + from typing import Optional -from torch.utils import data import torch.distributed as dist +from torch.utils import data + class DistributedSampler(data.distributed.DistributedSampler): """ diff --git a/torchft/ddp.py b/torchft/ddp.py index 18b71a3..c10ad78 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -4,19 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Distributed Data Parallel +========================== + +This module implements a DistributedDataParallel wrapper that works with the +Manager to provide fault tolerance. +""" + import os -from typing import Optional, TYPE_CHECKING import sys +from typing import Optional, TYPE_CHECKING from unittest.mock import patch -from torch.nn import parallel import torch +import torch.distributed as dist from torch import nn from torch.distributed.algorithms.join import Joinable -import torch.distributed as dist -from torchft.process_group import ProcessGroup -from torchft.process_group import ProcessGroupGloo -from torchft.process_group import ProcessGroupDummy + +from torch.nn import parallel +from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo if TYPE_CHECKING: from torchft.manager import Manager @@ -28,6 +35,7 @@ class DistributedDataParallel(parallel.DistributedDataParallel): compatible with torchft. Important notes: + * This requires states to be synced on step 0 using an external mechanism rather than an internal broadcast (torchft.Manager will do this). * Using non-basic features of the DDP may cause your model to catch fire as @@ -55,6 +63,11 @@ def _comm_hook( class PureDistributedDataParallel(nn.Module): """ A pure Python reimplementation of the DDP wrapper. + + We recommend using DistributedDataParallel instead of this class. + + This calls one allreduce per gradient tensor and doesn't use a reducer. This + may be very slow for real models. """ def __init__(self, manager: "Manager", module: nn.Module): diff --git a/torchft/manager.py b/torchft/manager.py index c96834d..ae2f15a 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -4,22 +4,43 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Manager +========= + +This module implements the Manager that manages the full fault tolerant training +loop. + +The Manager is responsible for managing the +full training loop, communicating with the Lighthouse server to figure out +quorum, reconfiguring the ProcessGroups and restoring checkpoint state when +recovering. + +This uses wrapper classes to wrap the standard PyTorch Optimizer and Module +classes to provide fault tolerance. These wrappers indented to add fault +tolerance with minimal changes to the users modeling code and training loop. + +This is designed to work with the standard PyTorch DistributedDataParallel module +and Hybrid FSDP. + +""" + +import logging import os -import uuid import socket -from typing import Dict, Optional, List import time -import logging +import uuid from concurrent.futures import ThreadPoolExecutor from datetime import timedelta +from typing import Dict, List, Optional import torch -from torch.distributed import TCPStore, PrefixStore, Work, ReduceOp +from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work from torch.optim import Optimizer +from torchft.checkpointing import CheckpointServer # pyre-fixme[21]: can't find rust module from torchft.torchft import Manager as _Manager, ManagerClient -from torchft.checkpointing import CheckpointServer logger: logging.Logger = logging.getLogger(__name__) @@ -131,9 +152,29 @@ def __init__( self._should_step = True def shutdown(self) -> None: + """ + Shutdown the manager and checkpoint server. + """ self._ckpt_server.shutdown() def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]: + """ + Allreduce the gradient and return a Future that will be completed when + the gradient is ready. + + This will automatically scale the gradient by 1 / world_size. + + If an error occurs during the allreduce: + + * The Future will be completed with no error and instead tracked asynchronously. + * After the first error, all subsequent allreduce_grad calls will be noops and immediately return. + * The grad tensor must be zeroed before being used as it may be corrupted. + + Args: + grad: the gradient to allreduce + Returns: + a Future that will be completed with the allreduced gradient + """ if self._errored: fut = torch.futures.Future() fut.set_result(grad) @@ -185,6 +226,16 @@ def callback( return fut def step(self) -> None: + """ + .. note:: + We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly. + + Must be called before the forwards pass of each step. + + Computes a new quorum (potentially asynchronously) and readies the + manager for a new step. + """ + if self._should_step: self._step += 1 self._batches_committed += self._participating_replicas @@ -264,6 +315,24 @@ def _apply_pending_state_dict(self) -> None: self._state_dict = None def should_commit(self) -> bool: + """ + .. note:: + We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly. + + Must be called after the backwards pass completes but before stepping the optimizer. + + The optimizer must only be stepped if this returns True. + + This must be called on all workers within a replica group. This uses a + collective to ensure all workers within a replica return the same value. + If an error occurs on any worker, all workers will return False. + Different replica groups may return different values. + + This should only be called once per step. + + Returns: + True if the optimizer should be stepped, False otherwise + """ for work in self._pending_work: # check at the beginning of since .wait() may trigger errors if self._errored: @@ -296,10 +365,27 @@ def should_commit(self) -> bool: return should_commit def load_state_dict(self, state_dict: Dict[str, int]) -> None: + """ + Load the state dict from a previous checkpoint. + + This will restore the step count and internal metadata. + + Args: + state_dict: the state dict to load + """ self._step = state_dict["step"] self._batches_committed = state_dict["batches_committed"] def state_dict(self) -> Dict[str, int]: + """ + Get the state dict for this manager. + + This can be used to checkpoint the state of the manager to restore + from a previous checkpoint. + + Returns: + the state dict for this manager + """ return {"step": self._step, "batches_committed": self._batches_committed} def current_step(self) -> int: @@ -307,6 +393,9 @@ def current_step(self) -> int: Get the current step count. This number is incremented on .step() + + Returns: + the current step count """ return self._step @@ -317,5 +406,8 @@ def batches_committed(self) -> int: 10 examples depending on batch size. This number is incremented on .step() + + Returns: + the total number of batches committed """ return self._batches_committed diff --git a/torchft/optim.py b/torchft/optim.py index 8e31ecc..c88782e 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -4,7 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING, Optional +""" +Optimizers +============ + +This module implements an optimizer wrapper that works with the Manager to provide fault tolerance. + +""" + +from typing import Optional, TYPE_CHECKING from torch.optim import Optimizer @@ -13,6 +21,16 @@ class OptimizerWrapper(Optimizer): + """ + This wraps any provided optimizer and in conjunction with the manager will provide fault tolerance. + + zero_grad() must be called at the start of the forwards pass and step() must + be called at the end of the backwards pass. + + Depending on the state of the manager, the optimizer will either commit the + gradients to the wrapped optimizer or ignore them. + """ + def __init__(self, manager: "Manager", optim: Optimizer) -> None: self.optim = optim self.manager = manager diff --git a/torchft/parameter_server.py b/torchft/parameter_server.py index a1e5424..475f1c7 100644 --- a/torchft/parameter_server.py +++ b/torchft/parameter_server.py @@ -4,14 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from abc import ABC, abstractmethod -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +""" +Parameter Servers +================== + +This module provides a prototype implementation of a fault tolerant parameter server bulit on the reconfigurable ProcessGroups. +""" + +import json +import logging import socket import threading -import uuid -import logging import urllib.request -import json +import uuid +from abc import ABC, abstractmethod +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from torch.distributed import TCPStore @@ -87,7 +94,7 @@ def do_GET(self): self.connection.close() # hijack thread for the session - ps.handle_session(session_id, store_addr) + ps._handle_session(session_id, store_addr) except Exception: logger.exception( f"got exception in request handler for {self.path}" @@ -109,6 +116,11 @@ def do_GET(self): def address(self) -> str: """ Returns the HTTP address to create a new session on this server. + + Format: http://host:port/new_session + + Returns: + an HTTP address """ port = self._server.socket.getsockname()[1] return f"http://{socket.gethostname()}:{port}/new_session" @@ -125,6 +137,11 @@ def new_process_group(cls) -> ProcessGroup: """ Create a new non-configured ProcessGroup for the ParameterServer to configure when setting up server and client connections. + + Must be implemented by subclasses. + + Returns: + a new ProcessGroup """ ... @@ -150,7 +167,7 @@ def new_session(cls, address: str) -> ProcessGroup: return pg - def handle_session(self, session_id: str, store_addr: str) -> None: + def _handle_session(self, session_id: str, store_addr: str) -> None: pg = self.new_process_group() # paramter server is always rank 0 pg.configure(store_addr, rank=0, world_size=2) @@ -169,6 +186,8 @@ def forward(self, session_id: str, pg: ProcessGroup) -> None: The server rank is 0 and the client rank is 1. + Must be implemented by subclasses. + Args: session_id: a unique uuid for this session pg: the ProcessGroup that's configured for the client. diff --git a/torchft/process_group.py b/torchft/process_group.py index 52aa0d1..e93d0b1 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -4,6 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Process Groups +========================= + +This module implements fault tolerant process groups that can be reconfigured +and resized at runtime. + +These extend the standard PyTorch ProcessGroup API and can be used in most +places that would accept a standard process group. As these can change size at +runtime users need to take care to not assume a static rank or world size. +""" + import logging import threading from abc import ABC @@ -47,9 +59,9 @@ def _get(queue: mp.Queue, timeout) -> object: return v -def create_store(store_addr: str) -> Store: +def create_store_client(store_addr: str) -> Store: """ - Creates a PrefixStore(TCPStore(...)) from an address in the format: + Creates a PrefixStore(TCPStore(...)) client from an address in the format: host:port/prefix @@ -75,6 +87,17 @@ def __init__(self, *args, **kwargs) -> None: self._group_name = None def configure(self, store_addr: str, rank: int, world_size: int) -> None: + """ + This reconfigures the ProcessGroup to use a new store, rank and world size. + + Every time this is called it must be provided with a unique prefixed + store address. I.e. localhost:1234/my/prefix/1 + + Args: + store_addr: address of the store to use + rank: rank of this process + world_size: world size of this process group + """ raise NotImplementedError("not implemented") def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: @@ -171,7 +194,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._pg.abort() self._pg = None - store = create_store(store_addr) + store = create_store_client(store_addr) # TODO: set global timeout self._pg = self.PG_CLASS(store, rank, world_size) @@ -222,7 +245,7 @@ def getBackendName(self) -> str: return "torchft-nccl" -class DummyWork(dist._Work): +class _DummyWork(dist._Work): def __init__(self, result): super().__init__() self.result_ = result @@ -238,10 +261,14 @@ def get_future(self): class ProcessGroupDummy(ProcessGroup): """ - This PG only supports world_size of 1 + This process group discards all data passed to it and returns success. This + is intended for rare cases where we want to discard certain operations + without modifying the underlying library. + + This PG only supports world_size of 1. """ - def __init__(self, rank, world): + def __init__(self, rank: int, world: int) -> None: super().__init__(rank, world) assert rank == 0 assert world == 1 @@ -253,7 +280,7 @@ def __init__(self, rank, world): self._work = [] def broadcast(self, tensor_list, opts): - res = DummyWork(tensor_list) + res = _DummyWork(tensor_list) self._work.append(res) return res @@ -261,12 +288,12 @@ def allgather(self, output_tensors, input_tensor, opts): for o, i in zip(output_tensors[0], input_tensor): o.copy_(i) - res = DummyWork(output_tensors) + res = _DummyWork(output_tensors) self._work.append(res) return res def allreduce(self, tensors, opts): - res = DummyWork(tensors) + res = _DummyWork(tensors) self._work.append(res) return res @@ -277,7 +304,7 @@ def getBackendName(self): return "torchft-dummy" -class BabyWork(Work): +class _BabyWork(Work): def __init__( self, pg: "ProcessGroupBaby", @@ -303,7 +330,7 @@ def get_future(self) -> Future: return self._pg._get_future(self._op_id) -class BabyWorkNCCL(BabyWork): +class _BabyWorkNCCL(_BabyWork): def wait(self) -> bool: self._tx.put(("synchronize", self._op_id), timeout=self._timeout) op_id, event = _get(self._rx, self._timeout) @@ -326,7 +353,7 @@ class ProcessGroupBaby(ProcessGroup): """ PG_CLASS: Type[BaseProcessGroup] - WORK_CLASS: Type[BabyWork] = BabyWork + WORK_CLASS: Type[_BabyWork] = _BabyWork def __init__(self, timeout: float = 60.0) -> None: super().__init__(0, 1) @@ -389,7 +416,7 @@ def _worker( future_queue: mp.Queue, ) -> None: try: - store = create_store(store_addr) + store = create_store_client(store_addr) pg = cls.PG_CLASS(store, rank, world_size) @@ -495,6 +522,13 @@ def size(self) -> int: class ProcessGroupBabyGloo(ProcessGroupBaby): + """ + This is a ProcessGroup that runs Gloo in a subprocess. + + For most use cases you should prefer ProcessGroupGloo or + ProcessGroupBabyNCCL. + """ + PG_CLASS = BaseProcessGroupGloo def getBackendName(self): @@ -518,7 +552,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): """ PG_CLASS = BaseProcessGroupNCCL - WORK_CLASS = BabyWorkNCCL + WORK_CLASS = _BabyWorkNCCL def getBackendName(self): return "torchft-baby-nccl"