diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 0000000..707489c --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,54 @@ +name: Docs + +on: + push: + branches: + - main + pull_request: + +jobs: + build: + runs-on: ubuntu-20.04 + steps: + - name: Setup Python + uses: actions/setup-python@v3 + with: + python-version: "3.10" + architecture: x64 + - name: Checkout + uses: actions/checkout@v3 + - name: Install Dependencies + run: | + set -eux + + sudo apt-get install -y protobuf-compiler + + pip install .[dev] -v + + pip install -r docs/requirements.txt + - name: Build Sphinx Docs + run: | + set -eux + + cd docs + make html + - name: Upload static files as artifact + id: deployment + uses: actions/upload-pages-artifact@v3 + with: + path: docs/build/html/ + + deploy: + runs-on: ubuntu-latest + needs: build + if: ${{ github.ref == 'refs/heads/main' }} + permissions: + pages: write # to deploy to Pages + id-token: write # to verify the deployment originates from an appropriate source + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..670f0b9 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,8 @@ +src/pytorch-sphinx-theme +source/examples_*/ +jupyter_execute/ +build/ + +Dockerfile +my_component.py +my_app.py diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..9bc06cd --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,36 @@ + +# Minimal makefile for Sphinx documentation +# Usage: +# make html +# + +# You can set these variables from the command line. +SPHINXOPTS = -W +SPHINXBUILD = sphinx-build +SPHINXPROJ = torchft +SOURCEDIR = source +BUILDDIR = build +VERSION := $(shell python -c "from importlib.metadata import version; print(version('torchft'))") + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + @echo "Deleting build directory" + rm -rf "$(BUILDDIR)" + rm -rf "$(SOURCEDIR)/examples_apps" "$(SOURCEDIR)/examples_pipelines" + +.PHONY: help Makefile clean livehtml papermill + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +# optional live version +livehtml: + sphinx-autobuild --watch ../torchft --re-ignore ".*(examples_.*|.new|source/.*(Dockerfile|.py))" "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +papermill: html + ./papermill.sh diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..9008607 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,8 @@ +sphinx==5.0.1 +-e git+http://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +sphinxcontrib.katex +matplotlib +papermill +ipykernel +ipython_genutils +jinja2<=3.0.3 diff --git a/docs/source/_static/img/pytorch-logo-dark.svg b/docs/source/_static/img/pytorch-logo-dark.svg new file mode 100644 index 0000000..5e53000 --- /dev/null +++ b/docs/source/_static/img/pytorch-logo-dark.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/img/pytorch-logo-flame.png b/docs/source/_static/img/pytorch-logo-flame.png new file mode 100644 index 0000000..370633f Binary files /dev/null and b/docs/source/_static/img/pytorch-logo-flame.png differ diff --git a/docs/source/checkpointing.rst b/docs/source/checkpointing.rst new file mode 100644 index 0000000..d7d309a --- /dev/null +++ b/docs/source/checkpointing.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.checkpointing + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..b7fe578 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# PyTorch documentation build configuration file, created by +# sphinx-quickstart on Fri Dec 23 13:31:47 2016. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import subprocess +import sys +from importlib.metadata import version + +import pytorch_sphinx_theme +from docutils import nodes +from sphinx import addnodes +from sphinx.util.docfields import TypedField + +FBCODE = "fbcode" in os.getcwd() + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +needs_sphinx = "1.6" + +user_agent = "Mozilla/5.0 (X11; Linux x86_64; rv:25.0) Gecko/20100101 Firefox/25.0 github.com/pytorch-labs/torchft" + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", +] + +html_context = {} + +# coverage options + +coverage_ignore_modules = [] + +# katex options +# +# + +katex_options = r""" +delimiters : [ + {left: "$$", right: "$$", display: true}, + {left: "\\(", right: "\\)", display: false}, + {left: "\\[", right: "\\]", display: true} +] +""" + +napoleon_use_ivar = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "pytorch/torchft" +copyright = "2024, PyTorch Contributors" +author = "PyTorch Contributors" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +# TODO: change to [:2] at v1.0 +version = "v" + version("torchft") +# The full version, including alpha/beta/rc tags. +# TODO: verify this works as expected +release = "main" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = "en" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = [ + "examples_*/**/*.ipynb", +] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "pytorch_sphinx_theme" +html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "pytorch_project": "torchft", + "collapse_navigation": False, + "display_version": True, + "logo_only": True, + "analytics_id": "UA-117752657-2", +} + +html_logo = "_static/img/pytorch-logo-dark.svg" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + + +html_css_files = [ + # "https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css", + # "css/torchx.css", +] +html_js_files = [ + # "js/torchx.js", +] + + +def setup(app): + # NOTE: in Sphinx 1.8+ `html_css_files` is an official configuration value + # and can be moved outside of this function (and the setup(app) function + # can be deleted). + + # In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is + # `add_stylesheet` (deprecated in 1.8). + add_css = getattr( + app, "add_css_file", getattr(app, "add_stylesheet", None) + ) # noqa B009 + for css_file in html_css_files: + add_css(css_file) + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = "torchft-doc" + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + ( + master_doc, + "pytorch.tex", + "torchft Documentation", + "Torch Contributors", + "manual", + ) +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, "torchft", "torchft Documentation", [author], 1)] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "torchft", + "torchft Documentation", + author, + "torchft", + "Miscellaneous", + ) +] + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "python": ("https://docs.python.org/", None), + "numpy": ("https://docs.scipy.org/doc/numpy/", None), + "torch": ("https://pytorch.org/docs/stable/", None), +} + +# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- +# See http://stackoverflow.com/a/41184353/3343043 + + +def patched_make_field(self, types, domain, items, **kw): + # `kw` catches `env=None` needed for newer sphinx while maintaining + # backwards compatibility when passed along further down! + + def handle_item(fieldarg, content): + par = nodes.paragraph() + par += addnodes.literal_strong("", fieldarg) # Patch: this line added + # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, + # addnodes.literal_strong)) + if fieldarg in types: + par += nodes.Text(" (") + # NOTE: using .pop() here to prevent a single type node to be + # inserted twice into the doctree, which leads to + # inconsistencies later when references are resolved + fieldtype = types.pop(fieldarg) + if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): + typename = "".join(n.astext() for n in fieldtype) + typename = typename.replace("int", "python:int") + typename = typename.replace("long", "python:long") + typename = typename.replace("float", "python:float") + typename = typename.replace("type", "python:type") + par.extend( + self.make_xrefs( + self.typerolename, + domain, + typename, + addnodes.literal_emphasis, + **kw, + ) + ) + else: + par += fieldtype + par += nodes.Text(")") + par += nodes.Text(" -- ") + par += content + return par + + fieldname = nodes.field_name("", self.label) + if len(items) == 1 and self.can_collapse: + fieldarg, content = items[0] + bodynode = handle_item(fieldarg, content) + else: + bodynode = self.list_type() + for fieldarg, content in items: + bodynode += nodes.list_item("", handle_item(fieldarg, content)) + fieldbody = nodes.field_body("", bodynode) + return nodes.field("", fieldname, fieldbody) + + +TypedField.make_field = patched_make_field + + +# -- Options for autosectionlabel + +# add the document to avoid collisions for common titles +autosectionlabel_prefix_document = True diff --git a/docs/source/data.rst b/docs/source/data.rst new file mode 100644 index 0000000..f0bbed0 --- /dev/null +++ b/docs/source/data.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.data + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/ddp.rst b/docs/source/ddp.rst new file mode 100644 index 0000000..84eead1 --- /dev/null +++ b/docs/source/ddp.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.ddp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..f97024f --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,22 @@ +:github_url: https://github.com/pytorch-labs/torchft + +torchft +======== + +This repository implements primitives and E2E solutions for doing a per-step +fault tolerance so you can keep training if errors occur without interrupting +the entire training job. + +**GETTING STARTED?** See Install and Usage in `the README `_. + +.. toctree:: + :maxdepth: 1 + :caption: Reference + + process_group + manager + optim + ddp + data + checkpointing + parameter_server diff --git a/docs/source/manager.rst b/docs/source/manager.rst new file mode 100644 index 0000000..9f79149 --- /dev/null +++ b/docs/source/manager.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.manager + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/optim.rst b/docs/source/optim.rst new file mode 100644 index 0000000..00eb587 --- /dev/null +++ b/docs/source/optim.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/parameter_server.rst b/docs/source/parameter_server.rst new file mode 100644 index 0000000..03c2a37 --- /dev/null +++ b/docs/source/parameter_server.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.parameter_server + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/process_group.rst b/docs/source/process_group.rst new file mode 100644 index 0000000..8e6cfe1 --- /dev/null +++ b/docs/source/process_group.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.process_group + :members: + :undoc-members: + :show-inheritance: diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 6c897c2..becd57c 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -4,25 +4,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import urllib.request -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +""" +Checkpointing +============== + +This module implements methods for checkpointing and resuming training from a checkpoint. +""" + +import io +import logging import socket import threading -import logging -import io +import urllib.request +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Callable import torch logger: logging.Logger = logging.getLogger(__name__) -class IPv6HTTPServer(ThreadingHTTPServer): +class _IPv6HTTPServer(ThreadingHTTPServer): address_family = socket.AF_INET6 request_queue_size = 1024 class CheckpointServer: - def __init__(self, state_dict) -> None: + """ + This is an HTTP server that can be used to transfer checkpoints + between workers. + + This allows for fast recovery of workers by fetching the current weights + from an existing worker. + + Args: + state_dict: a callable that returns the state dict to be transferred + """ + + def __init__(self, state_dict: Callable[[], object]) -> None: self._checkpoint_lock = threading.Lock() self._disallowed = False self._step = -1 @@ -58,7 +77,7 @@ def err(self, msg: str) -> None: self.wfile.write(msg.encode()) server_address = ("", 0) - self._server = IPv6HTTPServer(server_address, RequestHandler) + self._server = _IPv6HTTPServer(server_address, RequestHandler) logger.info(f"Started CheckpointServer on {self.address()}...") self._thread = threading.Thread( @@ -70,6 +89,12 @@ def err(self, msg: str) -> None: @classmethod def load_from_address(cls, address: str) -> object: + """ + Loads a checkpoint from the given address. + + Args: + address: the HTTP address to load the checkpoint from + """ logger.info(f"fetching checkpoint from {address}") with urllib.request.urlopen(address) as f: @@ -79,6 +104,14 @@ def load_from_address(cls, address: str) -> object: return torch.load(reader, weights_only=True) def address(self) -> str: + """ + Returns the HTTP address to fetch a checkpoint from this server at the current step. + + Format: http://host:port/checkpoint/1234 + + Returns: + an HTTP address + """ port = self._server.socket.getsockname()[1] return f"http://{socket.gethostname()}:{port}/checkpoint/{self._step}" @@ -89,11 +122,22 @@ def _serve(self) -> None: logger.exception("got exception in checkpoint server") def disallow_checkpoint(self) -> None: + """ + Disallows serving the checkpoint. + + All requests will block until allow_checkpoint is called. + """ if not self._disallowed: self._disallowed = True self._checkpoint_lock.acquire() def allow_checkpoint(self, step: int) -> None: + """ + Allows serving the checkpoint with the specified step number. + + Args: + step: the step number to serve + """ self._step = step if self._disallowed: @@ -101,4 +145,7 @@ def allow_checkpoint(self, step: int) -> None: self._checkpoint_lock.release() def shutdown(self) -> None: + """ + Shutdown the server. + """ self._server.shutdown() diff --git a/torchft/data.py b/torchft/data.py index 8b8869c..f7a26f3 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -4,11 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Data +==== + +This module provides helper classes to implement fault tolerant data loaders. + +We recommend using torchdata's StatefulDataLoader to checkpoint each replica's +dataloader frequently to avoid duplicate batches. +""" + from typing import Optional -from torch.utils import data import torch.distributed as dist +from torch.utils import data + class DistributedSampler(data.distributed.DistributedSampler): """ diff --git a/torchft/ddp.py b/torchft/ddp.py index 18b71a3..c10ad78 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -4,19 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Distributed Data Parallel +========================== + +This module implements a DistributedDataParallel wrapper that works with the +Manager to provide fault tolerance. +""" + import os -from typing import Optional, TYPE_CHECKING import sys +from typing import Optional, TYPE_CHECKING from unittest.mock import patch -from torch.nn import parallel import torch +import torch.distributed as dist from torch import nn from torch.distributed.algorithms.join import Joinable -import torch.distributed as dist -from torchft.process_group import ProcessGroup -from torchft.process_group import ProcessGroupGloo -from torchft.process_group import ProcessGroupDummy + +from torch.nn import parallel +from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo if TYPE_CHECKING: from torchft.manager import Manager @@ -28,6 +35,7 @@ class DistributedDataParallel(parallel.DistributedDataParallel): compatible with torchft. Important notes: + * This requires states to be synced on step 0 using an external mechanism rather than an internal broadcast (torchft.Manager will do this). * Using non-basic features of the DDP may cause your model to catch fire as @@ -55,6 +63,11 @@ def _comm_hook( class PureDistributedDataParallel(nn.Module): """ A pure Python reimplementation of the DDP wrapper. + + We recommend using DistributedDataParallel instead of this class. + + This calls one allreduce per gradient tensor and doesn't use a reducer. This + may be very slow for real models. """ def __init__(self, manager: "Manager", module: nn.Module): diff --git a/torchft/manager.py b/torchft/manager.py index c96834d..ae2f15a 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -4,22 +4,43 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Manager +========= + +This module implements the Manager that manages the full fault tolerant training +loop. + +The Manager is responsible for managing the +full training loop, communicating with the Lighthouse server to figure out +quorum, reconfiguring the ProcessGroups and restoring checkpoint state when +recovering. + +This uses wrapper classes to wrap the standard PyTorch Optimizer and Module +classes to provide fault tolerance. These wrappers indented to add fault +tolerance with minimal changes to the users modeling code and training loop. + +This is designed to work with the standard PyTorch DistributedDataParallel module +and Hybrid FSDP. + +""" + +import logging import os -import uuid import socket -from typing import Dict, Optional, List import time -import logging +import uuid from concurrent.futures import ThreadPoolExecutor from datetime import timedelta +from typing import Dict, List, Optional import torch -from torch.distributed import TCPStore, PrefixStore, Work, ReduceOp +from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work from torch.optim import Optimizer +from torchft.checkpointing import CheckpointServer # pyre-fixme[21]: can't find rust module from torchft.torchft import Manager as _Manager, ManagerClient -from torchft.checkpointing import CheckpointServer logger: logging.Logger = logging.getLogger(__name__) @@ -131,9 +152,29 @@ def __init__( self._should_step = True def shutdown(self) -> None: + """ + Shutdown the manager and checkpoint server. + """ self._ckpt_server.shutdown() def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]: + """ + Allreduce the gradient and return a Future that will be completed when + the gradient is ready. + + This will automatically scale the gradient by 1 / world_size. + + If an error occurs during the allreduce: + + * The Future will be completed with no error and instead tracked asynchronously. + * After the first error, all subsequent allreduce_grad calls will be noops and immediately return. + * The grad tensor must be zeroed before being used as it may be corrupted. + + Args: + grad: the gradient to allreduce + Returns: + a Future that will be completed with the allreduced gradient + """ if self._errored: fut = torch.futures.Future() fut.set_result(grad) @@ -185,6 +226,16 @@ def callback( return fut def step(self) -> None: + """ + .. note:: + We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly. + + Must be called before the forwards pass of each step. + + Computes a new quorum (potentially asynchronously) and readies the + manager for a new step. + """ + if self._should_step: self._step += 1 self._batches_committed += self._participating_replicas @@ -264,6 +315,24 @@ def _apply_pending_state_dict(self) -> None: self._state_dict = None def should_commit(self) -> bool: + """ + .. note:: + We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly. + + Must be called after the backwards pass completes but before stepping the optimizer. + + The optimizer must only be stepped if this returns True. + + This must be called on all workers within a replica group. This uses a + collective to ensure all workers within a replica return the same value. + If an error occurs on any worker, all workers will return False. + Different replica groups may return different values. + + This should only be called once per step. + + Returns: + True if the optimizer should be stepped, False otherwise + """ for work in self._pending_work: # check at the beginning of since .wait() may trigger errors if self._errored: @@ -296,10 +365,27 @@ def should_commit(self) -> bool: return should_commit def load_state_dict(self, state_dict: Dict[str, int]) -> None: + """ + Load the state dict from a previous checkpoint. + + This will restore the step count and internal metadata. + + Args: + state_dict: the state dict to load + """ self._step = state_dict["step"] self._batches_committed = state_dict["batches_committed"] def state_dict(self) -> Dict[str, int]: + """ + Get the state dict for this manager. + + This can be used to checkpoint the state of the manager to restore + from a previous checkpoint. + + Returns: + the state dict for this manager + """ return {"step": self._step, "batches_committed": self._batches_committed} def current_step(self) -> int: @@ -307,6 +393,9 @@ def current_step(self) -> int: Get the current step count. This number is incremented on .step() + + Returns: + the current step count """ return self._step @@ -317,5 +406,8 @@ def batches_committed(self) -> int: 10 examples depending on batch size. This number is incremented on .step() + + Returns: + the total number of batches committed """ return self._batches_committed diff --git a/torchft/optim.py b/torchft/optim.py index 8e31ecc..c88782e 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -4,7 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING, Optional +""" +Optimizers +============ + +This module implements an optimizer wrapper that works with the Manager to provide fault tolerance. + +""" + +from typing import Optional, TYPE_CHECKING from torch.optim import Optimizer @@ -13,6 +21,16 @@ class OptimizerWrapper(Optimizer): + """ + This wraps any provided optimizer and in conjunction with the manager will provide fault tolerance. + + zero_grad() must be called at the start of the forwards pass and step() must + be called at the end of the backwards pass. + + Depending on the state of the manager, the optimizer will either commit the + gradients to the wrapped optimizer or ignore them. + """ + def __init__(self, manager: "Manager", optim: Optimizer) -> None: self.optim = optim self.manager = manager diff --git a/torchft/parameter_server.py b/torchft/parameter_server.py index a1e5424..475f1c7 100644 --- a/torchft/parameter_server.py +++ b/torchft/parameter_server.py @@ -4,14 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from abc import ABC, abstractmethod -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +""" +Parameter Servers +================== + +This module provides a prototype implementation of a fault tolerant parameter server bulit on the reconfigurable ProcessGroups. +""" + +import json +import logging import socket import threading -import uuid -import logging import urllib.request -import json +import uuid +from abc import ABC, abstractmethod +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from torch.distributed import TCPStore @@ -87,7 +94,7 @@ def do_GET(self): self.connection.close() # hijack thread for the session - ps.handle_session(session_id, store_addr) + ps._handle_session(session_id, store_addr) except Exception: logger.exception( f"got exception in request handler for {self.path}" @@ -109,6 +116,11 @@ def do_GET(self): def address(self) -> str: """ Returns the HTTP address to create a new session on this server. + + Format: http://host:port/new_session + + Returns: + an HTTP address """ port = self._server.socket.getsockname()[1] return f"http://{socket.gethostname()}:{port}/new_session" @@ -125,6 +137,11 @@ def new_process_group(cls) -> ProcessGroup: """ Create a new non-configured ProcessGroup for the ParameterServer to configure when setting up server and client connections. + + Must be implemented by subclasses. + + Returns: + a new ProcessGroup """ ... @@ -150,7 +167,7 @@ def new_session(cls, address: str) -> ProcessGroup: return pg - def handle_session(self, session_id: str, store_addr: str) -> None: + def _handle_session(self, session_id: str, store_addr: str) -> None: pg = self.new_process_group() # paramter server is always rank 0 pg.configure(store_addr, rank=0, world_size=2) @@ -169,6 +186,8 @@ def forward(self, session_id: str, pg: ProcessGroup) -> None: The server rank is 0 and the client rank is 1. + Must be implemented by subclasses. + Args: session_id: a unique uuid for this session pg: the ProcessGroup that's configured for the client. diff --git a/torchft/process_group.py b/torchft/process_group.py index 52aa0d1..e93d0b1 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -4,6 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Process Groups +========================= + +This module implements fault tolerant process groups that can be reconfigured +and resized at runtime. + +These extend the standard PyTorch ProcessGroup API and can be used in most +places that would accept a standard process group. As these can change size at +runtime users need to take care to not assume a static rank or world size. +""" + import logging import threading from abc import ABC @@ -47,9 +59,9 @@ def _get(queue: mp.Queue, timeout) -> object: return v -def create_store(store_addr: str) -> Store: +def create_store_client(store_addr: str) -> Store: """ - Creates a PrefixStore(TCPStore(...)) from an address in the format: + Creates a PrefixStore(TCPStore(...)) client from an address in the format: host:port/prefix @@ -75,6 +87,17 @@ def __init__(self, *args, **kwargs) -> None: self._group_name = None def configure(self, store_addr: str, rank: int, world_size: int) -> None: + """ + This reconfigures the ProcessGroup to use a new store, rank and world size. + + Every time this is called it must be provided with a unique prefixed + store address. I.e. localhost:1234/my/prefix/1 + + Args: + store_addr: address of the store to use + rank: rank of this process + world_size: world size of this process group + """ raise NotImplementedError("not implemented") def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: @@ -171,7 +194,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._pg.abort() self._pg = None - store = create_store(store_addr) + store = create_store_client(store_addr) # TODO: set global timeout self._pg = self.PG_CLASS(store, rank, world_size) @@ -222,7 +245,7 @@ def getBackendName(self) -> str: return "torchft-nccl" -class DummyWork(dist._Work): +class _DummyWork(dist._Work): def __init__(self, result): super().__init__() self.result_ = result @@ -238,10 +261,14 @@ def get_future(self): class ProcessGroupDummy(ProcessGroup): """ - This PG only supports world_size of 1 + This process group discards all data passed to it and returns success. This + is intended for rare cases where we want to discard certain operations + without modifying the underlying library. + + This PG only supports world_size of 1. """ - def __init__(self, rank, world): + def __init__(self, rank: int, world: int) -> None: super().__init__(rank, world) assert rank == 0 assert world == 1 @@ -253,7 +280,7 @@ def __init__(self, rank, world): self._work = [] def broadcast(self, tensor_list, opts): - res = DummyWork(tensor_list) + res = _DummyWork(tensor_list) self._work.append(res) return res @@ -261,12 +288,12 @@ def allgather(self, output_tensors, input_tensor, opts): for o, i in zip(output_tensors[0], input_tensor): o.copy_(i) - res = DummyWork(output_tensors) + res = _DummyWork(output_tensors) self._work.append(res) return res def allreduce(self, tensors, opts): - res = DummyWork(tensors) + res = _DummyWork(tensors) self._work.append(res) return res @@ -277,7 +304,7 @@ def getBackendName(self): return "torchft-dummy" -class BabyWork(Work): +class _BabyWork(Work): def __init__( self, pg: "ProcessGroupBaby", @@ -303,7 +330,7 @@ def get_future(self) -> Future: return self._pg._get_future(self._op_id) -class BabyWorkNCCL(BabyWork): +class _BabyWorkNCCL(_BabyWork): def wait(self) -> bool: self._tx.put(("synchronize", self._op_id), timeout=self._timeout) op_id, event = _get(self._rx, self._timeout) @@ -326,7 +353,7 @@ class ProcessGroupBaby(ProcessGroup): """ PG_CLASS: Type[BaseProcessGroup] - WORK_CLASS: Type[BabyWork] = BabyWork + WORK_CLASS: Type[_BabyWork] = _BabyWork def __init__(self, timeout: float = 60.0) -> None: super().__init__(0, 1) @@ -389,7 +416,7 @@ def _worker( future_queue: mp.Queue, ) -> None: try: - store = create_store(store_addr) + store = create_store_client(store_addr) pg = cls.PG_CLASS(store, rank, world_size) @@ -495,6 +522,13 @@ def size(self) -> int: class ProcessGroupBabyGloo(ProcessGroupBaby): + """ + This is a ProcessGroup that runs Gloo in a subprocess. + + For most use cases you should prefer ProcessGroupGloo or + ProcessGroupBabyNCCL. + """ + PG_CLASS = BaseProcessGroupGloo def getBackendName(self): @@ -518,7 +552,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): """ PG_CLASS = BaseProcessGroupNCCL - WORK_CLASS = BabyWorkNCCL + WORK_CLASS = _BabyWorkNCCL def getBackendName(self): return "torchft-baby-nccl"