Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manager_integ_tests: added Python integration test with lighthouse #27

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[[linter]]
code = 'BLACK-ISORT'
include_patterns = [
'*.py',
'**/*.py',
]
exclude_patterns = []
Expand Down Expand Up @@ -46,6 +47,7 @@ command = [
[[linter]]
code = 'PYRE'
include_patterns = [
'*.py',
'**/*.py',
'**/*.pyi',
]
Expand Down
56 changes: 55 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod manager;

use core::time::Duration;
use std::env;
use std::sync::Arc;

use anyhow::Result;
use pyo3::exceptions::PyRuntimeError;
Expand All @@ -28,6 +29,8 @@ use pyo3::prelude::*;
#[pyclass]
struct Manager {
handle: JoinHandle<Result<()>>,
manager: Arc<manager::Manager>,
_runtime: Runtime,
}

#[pymethods]
Expand Down Expand Up @@ -55,10 +58,18 @@ impl Manager {
))
.unwrap();
let handle = runtime.spawn(manager.clone().run());
Self { handle: handle }
Self {
handle: handle,
manager: manager,
_runtime: runtime,
}
})
}

fn address(&self) -> PyResult<String> {
Ok(self.manager.address().to_string())
}

fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
Expand Down Expand Up @@ -200,6 +211,48 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
Ok(())
}

#[pyclass]
struct Lighthouse {
lighthouse: Arc<lighthouse::Lighthouse>,
handle: JoinHandle<Result<()>>,
_runtime: Runtime,
}

#[pymethods]
impl Lighthouse {
#[new]
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
py.allow_threads(move || {
let rt = Runtime::new().unwrap();

let lighthouse = rt
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
bind: bind,
min_replicas: min_replicas,
join_timeout_ms: 100,
quorum_tick_ms: 100,
}))
.unwrap();

Ok(Self {
handle: rt.spawn(lighthouse.clone().run()),
lighthouse: lighthouse,
_runtime: rt,
})
})
}

fn address(&self) -> PyResult<String> {
Ok(self.lighthouse.address().to_string())
}

fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
})
}
}

#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
Expand All @@ -212,6 +265,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {

m.add_class::<Manager>()?;
m.add_class::<ManagerClient>()?;
m.add_class::<Lighthouse>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;

Ok(())
Expand Down
5 changes: 1 addition & 4 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from torch.optim import Optimizer

from torchft.checkpointing import CheckpointServer

# pyre-fixme[21]: can't find rust module
from torchft.torchft import Manager as _Manager, ManagerClient

if TYPE_CHECKING:
Expand Down Expand Up @@ -121,7 +119,7 @@ def __init__(

store_addr = store_addr or os.environ["MASTER_ADDR"]
store_port = store_port or int(os.environ["MASTER_PORT"])
self._rank: int = rank or int(os.environ["RANK"])
self._rank: int = rank if rank is not None else int(os.environ["RANK"])
rank = self._rank
world_size = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size
Expand Down Expand Up @@ -151,7 +149,6 @@ def __init__(

if replica_id is None:
replica_id = str(uuid.uuid4())
# pyre-fixme[16]: can't find rust module
self._manager = _Manager(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
Expand Down
100 changes: 100 additions & 0 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest import TestCase

import torch
import torch.distributed as dist
from torch import nn, optim

from torchft.ddp import DistributedDataParallel
from torchft.manager import Manager
from torchft.optim import OptimizerWrapper
from torchft.process_group import ProcessGroupGloo
from torchft.torchft import Lighthouse


class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(3, 4),
nn.Sigmoid(),
)

def forward(self, x):
return self.model(x)


def train_loop(replica_id: int, lighthouse_address: str) -> None:
store = dist.TCPStore(
host_name="localhost",
port=0,
is_master=True,
wait_for_workers=False,
)

def load_state_dict(state_dict):
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict():
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}

pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(replica_id),
store_addr="localhost",
store_port=store.port,
rank=0,
world_size=1,
lighthouse_addr=lighthouse_address,
port=19530 + replica_id,
)
m = DistributedDataParallel(manager, MyModel())
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
criterion = nn.CrossEntropyLoss()

while True:
inputs = torch.rand(2, 3)
labels = torch.randint(4, (2,))

optimizer.zero_grad()
out = m(inputs)
loss = criterion(out, labels)

loss.backward()
optimizer.step()

# TODO: assert weights are equal across replicas

if manager.current_step() >= 5:
break

manager.shutdown()


class ManagerIntegTest(TestCase):
def test_ddp(self):
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
futures.append(
executor.submit(train_loop, replica_id, lighthouse.address())
)

for fut in as_completed(futures):
fut.result()

lighthouse.shutdown()
18 changes: 18 additions & 0 deletions torchft/torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,21 @@ class ManagerClient:
) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ...
def checkpoint_address(self, rank: int) -> str: ...
def should_commit(self, rank: int, step: int, should_commit: bool) -> bool: ...

class Manager:
def __init__(
self,
replica_id: str,
lighthouse_addr: str,
address: str,
bind: str,
store_addr: str,
world_size: int,
) -> None: ...
def address(self) -> str: ...
def shutdown(self) -> None: ...

class Lighthouse:
def __init__(self, bind: str, min_replicas: int) -> None: ...
def address(self) -> str: ...
def shutdown(self) -> None: ...
Loading