Skip to content

Commit 481d2ce

Browse files
committed
Update on "[WIP] Add DiLoCo"
Still WIP but open to feedback on the API ## API Usage ```python # LocalSGD example model = SimpleModel() optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) with LocalSGD(manager, model, optimizer, sync_every=2): for inp, label in dataloader: loss = model(inp).mean() loss.backward() optimizer.step() # DiLoCo example model = SimpleModel() inner_optimizer = torch.optim.AdamW( m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) ) outer_optimizer = torch.optim.SGD( m.parameters(), lr=0.7, momentum=0.9, nesterov=True ) manager = create_autospec(Manager) with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2): for inp, label in dataloader: loss = model(inp).mean() loss.backward() inner_optimizer.step() # outer_optimizer is used every 'sync_every' steps ``` ## Changes - Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum - Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared - TODO: should be working, but still validating some tests discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk [ghstack-poisoned]
2 parents cfbf7e2 + 68d4059 commit 481d2ce

25 files changed

+1713
-286
lines changed

.pyre_configuration

+4
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,9 @@
66
"import_root": ".",
77
"source": "torchft"
88
}
9+
],
10+
"search_path": [
11+
{"site-package": "torchx"},
12+
{"site-package": "parameterized"}
913
]
1014
}

.torchxconfig

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[cli:run]
2+
component=torchft/torchx.py:hsdp
3+
scheduler=local_cwd

CONTRIBUTING.md

+15
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ make livehtml
9696
The docs will be built in the `docs/build/html` directory and served at http://localhost:8000.
9797
The page will be automatically re-built as long as the process is kept running.
9898

99+
### Running Multiple Replica Local Job
100+
101+
We use torchx to run multiple worker local test jobs. You need to run the
102+
lighthouse first and then you can use torchx to launch as many replica groups as
103+
you want. This uses the [train_ddp.py](./train_ddp.py) script.
104+
105+
```sh
106+
$ torchft_lighthouse --min_replicas 2 --join_timeout_ms 10000 &
107+
$ torchx run -- --replicas 10
108+
```
109+
110+
Once the Lighthouse has started you can view the status of all the workers at the Lighthouse dashboard.
111+
112+
Default address is: http://localhost:29510
113+
99114
## Contributor License Agreement ("CLA")
100115

101116
In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ log = "0.4.22"
1212
prost = "0.13.3"
1313
prost-types = "0.13.3"
1414
pyo3 = {version="0.22.3", features = ["extension-module"]}
15+
rand = "0.8.5"
1516
slog = "2.7.0"
1617
slog-stdlog = "4.1.1"
1718
stderrlog = "0.6.0"

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ greatly improve efficiency by avoiding stop the world training on errors.
5252

5353
Before proceeding, ensure you have the following installed:
5454

55-
- Rust (with necessaray dependencies)
55+
- Rust (with necessary dependencies)
5656
- `protobuf-compiler` and the corresponding development package for Protobuf.
5757

5858
Note that the Rust versions available in many conda environments may be outdated. To install the latest version of Rust, we recommend downloading it directly from the official website as shown in the below command:

output.txt

+616
Large diffs are not rendered by default.

proto/torchft.proto

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ message QuorumMember {
4141
string store_address = 3;
4242
int64 step = 4;
4343
uint64 world_size = 5;
44+
bool shrink_only = 6;
4445
}
4546

4647
message Quorum {
@@ -72,6 +73,7 @@ message ManagerQuorumRequest {
7273
int64 rank = 1;
7374
int64 step = 2;
7475
string checkpoint_server_addr = 3;
76+
bool shrink_only = 4;
7577
}
7678

7779
message ManagerQuorumResponse {

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ dev = [
2727
"pyre-check",
2828
"parameterized",
2929
"expecttest",
30-
"numpy"
30+
"numpy",
31+
"torchx"
3132
]
3233

3334
[tool.maturin]

src/lib.rs

+39-27
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
pub mod lighthouse;
88
pub mod manager;
9+
mod net;
10+
mod retry;
11+
mod timeout;
912

1013
use core::time::Duration;
1114
use std::env;
@@ -46,6 +49,7 @@ impl Manager {
4649
store_addr: String,
4750
world_size: u64,
4851
heartbeat_interval: Duration,
52+
connect_timeout: Duration,
4953
) -> PyResult<Self> {
5054
py.allow_threads(move || {
5155
let runtime = Runtime::new()?;
@@ -58,6 +62,7 @@ impl Manager {
5862
store_addr,
5963
world_size,
6064
heartbeat_interval,
65+
connect_timeout,
6166
))
6267
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
6368
let handle = runtime.spawn(manager.clone().run());
@@ -84,47 +89,47 @@ impl Manager {
8489
struct ManagerClient {
8590
runtime: Runtime,
8691
client: ManagerServiceClient<Channel>,
87-
timeout: Duration,
8892
}
8993

9094
#[pymethods]
9195
impl ManagerClient {
9296
#[new]
93-
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
97+
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
9498
py.allow_threads(move || {
9599
let runtime = Runtime::new()?;
96100
let client = runtime
97-
.block_on(manager::manager_client_new(addr, timeout))
101+
.block_on(manager::manager_client_new(addr, connect_timeout))
98102
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
99103

100104
Ok(Self {
101105
runtime: runtime,
102106
client: client,
103-
timeout: timeout,
104107
})
105108
})
106109
}
107110

108-
#[pyo3(signature = (rank, step, checkpoint_server_addr, timeout=None))]
109111
fn quorum(
110-
&mut self,
112+
&self,
111113
py: Python<'_>,
112114
rank: i64,
113115
step: i64,
114116
checkpoint_server_addr: String,
115-
timeout: Option<Duration>,
117+
shrink_only: bool,
118+
timeout: Duration,
116119
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
117120
py.allow_threads(move || {
118121
let mut request = tonic::Request::new(ManagerQuorumRequest {
119122
rank: rank,
120123
step: step,
121124
checkpoint_server_addr: checkpoint_server_addr,
125+
shrink_only: shrink_only,
122126
});
123-
// This notifies the server about the timeout but doesn't affect the
124-
// endpoint timeout which we set on client creation.
125-
request.set_timeout(timeout.unwrap_or(self.timeout));
126127

127-
let response = self.runtime.block_on(self.client.quorum(request))?;
128+
// This timeout is processed on the server side so we also enable
129+
// keep alives to detect server health.
130+
request.set_timeout(timeout);
131+
132+
let response = self.runtime.block_on(self.client.clone().quorum(request))?;
128133
let resp = response.into_inner();
129134
Ok((
130135
resp.quorum_id,
@@ -140,47 +145,49 @@ impl ManagerClient {
140145
})
141146
}
142147

143-
#[pyo3(signature = (rank, timeout=None))]
144148
fn checkpoint_address(
145-
&mut self,
149+
&self,
146150
py: Python<'_>,
147151
rank: i64,
148-
timeout: Option<Duration>,
152+
timeout: Duration,
149153
) -> Result<String, StatusError> {
150154
py.allow_threads(move || {
151155
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
152-
// This notifies the server about the timeout but doesn't affect the
153-
// endpoint timeout which we set on client creation.
154-
request.set_timeout(timeout.unwrap_or(self.timeout));
156+
157+
// This timeout is processed on the server side so we also enable
158+
// keep alives to detect server health.
159+
request.set_timeout(timeout);
155160

156161
let response = self
157162
.runtime
158-
.block_on(self.client.checkpoint_address(request))?;
163+
.block_on(self.client.clone().checkpoint_address(request))?;
159164
let resp = response.into_inner();
160165
Ok(resp.checkpoint_server_address)
161166
})
162167
}
163168

164-
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
165169
fn should_commit(
166-
&mut self,
170+
&self,
167171
py: Python<'_>,
168172
rank: i64,
169173
step: i64,
170174
should_commit: bool,
171-
timeout: Option<Duration>,
175+
timeout: Duration,
172176
) -> Result<bool, StatusError> {
173177
py.allow_threads(move || {
174178
let mut request = tonic::Request::new(ShouldCommitRequest {
175179
rank: rank,
176180
step: step,
177181
should_commit: should_commit,
178182
});
183+
179184
// This notifies the server about the timeout but doesn't affect the
180185
// endpoint timeout which we set on client creation.
181-
request.set_timeout(timeout.unwrap_or(self.timeout));
186+
request.set_timeout(timeout);
182187

183-
let response = self.runtime.block_on(self.client.should_commit(request))?;
188+
let response = self
189+
.runtime
190+
.block_on(self.client.clone().should_commit(request))?;
184191
let resp = response.into_inner();
185192
Ok(resp.should_commit)
186193
})
@@ -297,11 +304,16 @@ impl From<Status> for StatusError {
297304
#[pymodule]
298305
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
299306
// setup logging on import
300-
stderrlog::new()
301-
.verbosity(2)
307+
let mut log = stderrlog::new();
308+
log.verbosity(2)
302309
.show_module_names(true)
303-
.timestamp(stderrlog::Timestamp::Millisecond)
304-
.init()
310+
.timestamp(stderrlog::Timestamp::Millisecond);
311+
312+
if env::var("CLICOLOR_FORCE").is_ok() {
313+
log.color(stderrlog::ColorChoice::AlwaysAnsi);
314+
}
315+
316+
log.init()
305317
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
306318

307319
m.add_class::<Manager>()?;

0 commit comments

Comments
 (0)