Skip to content

Commit

Permalink
train, manager, dashboard: show world size on dashboard, manual repli…
Browse files Browse the repository at this point in the history
…ca_id, convergence tweaks (#11)
  • Loading branch information
d4l3k authored Nov 11, 2024
1 parent 44d2148 commit aec61c7
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 7 deletions.
1 change: 1 addition & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message QuorumMember {
string address = 2;
string store_address = 3;
int64 step = 4;
uint64 world_size = 5;
}

message Quorum {
Expand Down
7 changes: 7 additions & 0 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
},
},
);
Expand Down Expand Up @@ -495,6 +496,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
},
},
);
Expand All @@ -511,6 +513,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
}],
created: Some(SystemTime::now().into()),
});
Expand Down Expand Up @@ -550,6 +553,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 10,
world_size: 1,
}),
});

Expand All @@ -568,12 +572,14 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
}];
let b = vec![QuorumMember {
replica_id: "1".to_string(),
address: "changed".to_string(),
store_address: "changed".to_string(),
step: 1000,
world_size: 1,
}];

// replica_id is the same
Expand All @@ -584,6 +590,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
}];
// replica_id changed
assert!(quorum_changed(&a, &c));
Expand Down
1 change: 1 addition & 0 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ impl ManagerService for Arc<Manager> {
address: self.address.clone(),
store_address: self.store_address.clone(),
step: req.step,
world_size: self.world_size,
}),
});

Expand Down
3 changes: 2 additions & 1 deletion templates/status.html
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ <h3>Previous Quorum</h3>
<b>{{ member.replica_id }}</b> <br/>
Step: {{ member.step }} <br/>
Manager: {{ member.address }} <br/>
TCPStore: {{ member.store_address }}
TCPStore: {{ member.store_address }} <br/>
World size: {{ member.world_size }} <br/>

<button hx-post="/replica/{{member.replica_id}}/kill"
hx-trigger="click">
Expand Down
7 changes: 5 additions & 2 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
store_addr: Optional[str] = None,
store_port: Optional[int] = None,
lighthouse_addr: Optional[str] = None,
replica_id: Optional[str] = None,
) -> None:
"""
Args:
Expand All @@ -62,7 +63,8 @@ def __init__(
world_size: the replica group local world size
store_addr: TCPStore address for this replica group
store_port: TCPStore port for this replica group
ligthouse_addr: if rank==0, the address of the lighthouse server
lighthouse_addr: if rank==0, the address of the lighthouse server
replica_id: if rank==0, the replica_id for this group
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
Expand Down Expand Up @@ -99,7 +101,8 @@ def __init__(
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

replica_id = str(uuid.uuid4())
if replica_id is None:
replica_id = str(uuid.uuid4())
# pyre-fixme[16]: can't find rust module
self._manager = _Manager(
replica_id=replica_id,
Expand Down
12 changes: 8 additions & 4 deletions train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@


def main() -> None:
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
Expand All @@ -40,8 +43,8 @@ def main() -> None:
# majority of groups will be available so few batches will be dropped.
sampler = DistributedSampler(
trainset,
replica_group=int(os.environ.get("REPLICA_GROUP_ID", 0)),
num_replica_groups=int(os.environ.get("NUM_REPLICA_GROUPS", 2)),
replica_group=REPLICA_GROUP_ID,
num_replica_groups=NUM_REPLICA_GROUPS,
rank=0,
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
num_replicas=1,
Expand All @@ -50,7 +53,7 @@ def main() -> None:
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
# restore the per worker dataloader position.
trainloader = StatefulDataLoader(
trainset, batch_size=2, shuffle=True, num_workers=2
trainset, batch_size=64, shuffle=True, num_workers=2
)

def load_state_dict(state_dict):
Expand All @@ -68,9 +71,10 @@ def state_dict():

manager = Manager(
pg=pg,
min_replica_size=2,
min_replica_size=1,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
)

class Net(nn.Module):
Expand Down

0 comments on commit aec61c7

Please sign in to comment.