Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples,docs: adjust ddp example timeout and docs #93

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Easy Per Step Fault Tolerance for PyTorch
| <a href="https://pytorch.org/torchft/"><b>Documentation</b></a>
| <a href="https://github.com/pytorch-labs/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a>
| <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a>
|
|
</p>
<p align="center">
<a href="https://pypi.org/project/torchft-nightly/"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/torchft-nightly"></a>
Expand Down Expand Up @@ -98,7 +98,7 @@ when using synchronous training.
You can start a lighthouse server by running:

```sh
$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000
$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
```

### Example Training Loop (DDP)
Expand All @@ -108,7 +108,7 @@ See [train_ddp.py](./train_ddp.py) for the full example.
Invoke with:

```sh
$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
$ TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
```

train.py:
Expand Down
6 changes: 3 additions & 3 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub struct LighthouseOpt {
#[structopt(
long = "join_timeout_ms",
default_value = "60000",
help = "How long to wait for new replicas to join before considering a quorum"
help = "How long to wait for heartbeating stragglers to join before issuing quorum"
)]
pub join_timeout_ms: u64,

Expand All @@ -90,14 +90,14 @@ pub struct LighthouseOpt {
#[structopt(
long = "quorum_tick_ms",
default_value = "100",
help = "How frequently to check for quorum when waiting for workers."
help = "How frequently to check for quorum when waiting for stragglers."
)]
pub quorum_tick_ms: u64,

#[structopt(
long = "heartbeat_timeout_ms",
default_value = "5000",
help = "how long to wait for a heartbeat before considering a replica dead."
help = "How long to wait for a heartbeat before considering a replica dead."
)]
pub heartbeat_timeout_ms: u64,
}
Expand Down
10 changes: 9 additions & 1 deletion train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import sys
from datetime import timedelta

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -70,14 +71,21 @@ def state_dict():
}

device = "cuda" if torch.cuda.is_available() else "cpu"
pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()
pg = (
ProcessGroupBabyNCCL(
timeout=timedelta(seconds=5),
)
if torch.cuda.is_available()
else ProcessGroupGloo(timeout=timedelta(seconds=5))
)

manager = Manager(
pg=pg,
min_replica_size=1,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
timeout=timedelta(seconds=10),
)

class Net(nn.Module):
Expand Down