Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manager_integ_tests: added recovery test #28

Merged
merged 1 commit into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ impl Manager {
bind: String,
store_addr: String,
world_size: u64,
) -> Self {
) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = Runtime::new().unwrap();
let runtime = Runtime::new()?;
let manager = runtime
.block_on(manager::Manager::new(
replica_id,
Expand All @@ -56,13 +56,13 @@ impl Manager {
store_addr,
world_size,
))
.unwrap();
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let handle = runtime.spawn(manager.clone().run());
Self {
Ok(Self {
handle: handle,
manager: manager,
_runtime: runtime,
}
})
})
}

Expand All @@ -89,7 +89,7 @@ impl ManagerClient {
#[new]
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = Runtime::new().unwrap();
let runtime = Runtime::new()?;
let client = runtime
.block_on(manager::manager_client_new(addr, timeout))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Expand Down Expand Up @@ -193,14 +193,16 @@ fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
}

#[pyfunction]
fn lighthouse_main(py: Python<'_>) {
reset_python_signals(py).unwrap();
fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
reset_python_signals(py)?;

let mut args = env::args();
args.next(); // discard binary arg
let opt = lighthouse::LighthouseOpt::from_iter(args);
let rt = Runtime::new().unwrap();
rt.block_on(lighthouse_main_async(opt)).unwrap();
let rt = Runtime::new()?;
rt.block_on(lighthouse_main_async(opt))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}

async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
Expand All @@ -223,7 +225,7 @@ impl Lighthouse {
#[new]
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
py.allow_threads(move || {
let rt = Runtime::new().unwrap();
let rt = Runtime::new()?;

let lighthouse = rt
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
Expand All @@ -232,7 +234,7 @@ impl Lighthouse {
join_timeout_ms: 100,
quorum_tick_ms: 100,
}))
.unwrap();
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

Ok(Self {
handle: rt.spawn(lighthouse.clone().run()),
Expand Down Expand Up @@ -261,7 +263,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond)
.init()
.unwrap();
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<Manager>()?;
m.add_class::<ManagerClient>()?;
Expand Down
10 changes: 7 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
wait_for_workers=False,
)
self._pg = pg
self._manager = None

if rank == 0:
hostname = socket.gethostname()
Expand All @@ -148,7 +149,8 @@ def __init__(
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

if replica_id is None:
replica_id = str(uuid.uuid4())
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
self._manager = _Manager(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
Expand Down Expand Up @@ -180,6 +182,8 @@ def shutdown(self) -> None:
Shutdown the manager and checkpoint server.
"""
self._ckpt_server.shutdown()
if self._manager is not None:
self._manager.shutdown()

def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
"""
Expand Down Expand Up @@ -364,7 +368,7 @@ def _async_quorum(self) -> None:
self._participating_rank = None

if quorum_id != self._quorum_id:
logger.info(f"reconfiguring for quorum_id {quorum_id}")
logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}")
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
# We use the replica rank and world as we want all replicas in the PG.
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
Expand All @@ -373,7 +377,7 @@ def _async_quorum(self) -> None:
# See manager.rs for healing conditions
if heal:
self._healing = True
logger.info("healing required")
logger.info(f"{replica_rank}= healing required")

logger.info(f"fetching checkpoint server address from {address}")
primary_client = ManagerClient(address, timeout=self._timeout)
Expand Down
208 changes: 152 additions & 56 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from typing import Set, Tuple
from unittest import TestCase

import torch
Expand All @@ -24,63 +26,108 @@ def forward(self, x):
return self.model(x)


def train_loop(replica_id: int, lighthouse_address: str) -> None:
store = dist.TCPStore(
host_name="localhost",
port=0,
is_master=True,
wait_for_workers=False,
)

def load_state_dict(state_dict):
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict():
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}

pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(replica_id),
store_addr="localhost",
store_port=store.port,
rank=0,
world_size=1,
lighthouse_addr=lighthouse_address,
port=19530 + replica_id,
)
m = DistributedDataParallel(manager, MyModel())
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
criterion = nn.CrossEntropyLoss()

while True:
inputs = torch.rand(2, 3)
labels = torch.randint(4, (2,))

optimizer.zero_grad()
out = m(inputs)
loss = criterion(out, labels)

loss.backward()
optimizer.step()

# TODO: assert weights are equal across replicas

if manager.current_step() >= 5:
break

manager.shutdown()
class InjectedFailure(Exception):
pass


class FailureInjector:
def __init__(self) -> None:
self._failures: Set[int] = set()
self.count = 0

def fail_at(self, step: int) -> "FailureInjector":
self._failures.add(step)
return self

def check(self, step: int) -> None:
if step in self._failures:
self.count += 1
self._failures.remove(step)
print(f"injecting failure {step=}")
raise InjectedFailure(f"injected failure {step=}")


def worker_manager(
replica_id: int,
lighthouse_address: str,
failure_injector: FailureInjector,
attempts: int = 3,
) -> None:
for i in range(attempts):
try:
print(f"starting worker {replica_id} attempt {i}")
return train_loop(
replica_id, lighthouse_address, failure_injector=failure_injector
)
except InjectedFailure as e:
print("got injected failure", i, e)
if i == attempts - 1:
raise
continue


def train_loop(
replica_id: int, lighthouse_address: str, failure_injector: FailureInjector
) -> None:
with ExitStack() as stack:
store = dist.TCPStore(
host_name="localhost",
port=0,
is_master=True,
wait_for_workers=False,
)

def load_state_dict(state_dict):
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict():
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}

pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(replica_id),
store_addr="localhost",
store_port=store.port,
rank=0,
world_size=1,
lighthouse_addr=lighthouse_address,
port=19530 + replica_id,
)
stack.callback(manager.shutdown)

m = DistributedDataParallel(manager, MyModel())
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
criterion = nn.CrossEntropyLoss()

while True:
print(f"worker {replica_id} starting step {manager.current_step()}")
inputs = torch.rand(2, 3)
labels = torch.randint(4, (2,))

optimizer.zero_grad()
out = m(inputs)
loss = criterion(out, labels)

loss.backward()
optimizer.step()

if manager.current_step() >= 5:
# return state_dict so we can check consistency
return state_dict()

failure_injector.check(manager.current_step())


class ManagerIntegTest(TestCase):
def test_ddp(self):
def test_ddp_healthy(self):
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
Expand All @@ -90,11 +137,60 @@ def test_ddp(self):

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
failure_injector = FailureInjector()
futures.append(
executor.submit(
worker_manager,
replica_id,
lighthouse.address(),
failure_injector=failure_injector,
)
)

state_dicts = []

for fut in as_completed(futures):
state_dicts.append(fut.result())

lighthouse.shutdown()

for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

def test_ddp_recovery(self):
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []

failure_injectors = [
FailureInjector(),
FailureInjector().fail_at(2),
]

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id, failure_injector in zip(
range(num_replicas), failure_injectors
):
futures.append(
executor.submit(train_loop, replica_id, lighthouse.address())
executor.submit(
worker_manager,
replica_id,
lighthouse.address(),
failure_injector=failure_injector,
)
)

state_dicts = []

for fut in as_completed(futures):
fut.result()
state_dicts.append(fut.result())

lighthouse.shutdown()

for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

self.assertEqual(failure_injectors[1].count, 1)
Loading