Skip to content

Commit aec61c7

Browse files
authored
train, manager, dashboard: show world size on dashboard, manual replica_id, convergence tweaks (#11)
1 parent 44d2148 commit aec61c7

File tree

6 files changed

+24
-7
lines changed

6 files changed

+24
-7
lines changed

proto/torchft.proto

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ message QuorumMember {
4040
string address = 2;
4141
string store_address = 3;
4242
int64 step = 4;
43+
uint64 world_size = 5;
4344
}
4445

4546
message Quorum {

src/lighthouse.rs

+7
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ mod tests {
463463
address: "".to_string(),
464464
store_address: "".to_string(),
465465
step: 1,
466+
world_size: 1,
466467
},
467468
},
468469
);
@@ -495,6 +496,7 @@ mod tests {
495496
address: "".to_string(),
496497
store_address: "".to_string(),
497498
step: 1,
499+
world_size: 1,
498500
},
499501
},
500502
);
@@ -511,6 +513,7 @@ mod tests {
511513
address: "".to_string(),
512514
store_address: "".to_string(),
513515
step: 1,
516+
world_size: 1,
514517
}],
515518
created: Some(SystemTime::now().into()),
516519
});
@@ -550,6 +553,7 @@ mod tests {
550553
address: "".to_string(),
551554
store_address: "".to_string(),
552555
step: 10,
556+
world_size: 1,
553557
}),
554558
});
555559

@@ -568,12 +572,14 @@ mod tests {
568572
address: "".to_string(),
569573
store_address: "".to_string(),
570574
step: 1,
575+
world_size: 1,
571576
}];
572577
let b = vec![QuorumMember {
573578
replica_id: "1".to_string(),
574579
address: "changed".to_string(),
575580
store_address: "changed".to_string(),
576581
step: 1000,
582+
world_size: 1,
577583
}];
578584

579585
// replica_id is the same
@@ -584,6 +590,7 @@ mod tests {
584590
address: "".to_string(),
585591
store_address: "".to_string(),
586592
step: 1,
593+
world_size: 1,
587594
}];
588595
// replica_id changed
589596
assert!(quorum_changed(&a, &c));

src/manager.rs

+1
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ impl ManagerService for Arc<Manager> {
192192
address: self.address.clone(),
193193
store_address: self.store_address.clone(),
194194
step: req.step,
195+
world_size: self.world_size,
195196
}),
196197
});
197198

templates/status.html

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ <h3>Previous Quorum</h3>
1919
<b>{{ member.replica_id }}</b> <br/>
2020
Step: {{ member.step }} <br/>
2121
Manager: {{ member.address }} <br/>
22-
TCPStore: {{ member.store_address }}
22+
TCPStore: {{ member.store_address }} <br/>
23+
World size: {{ member.world_size }} <br/>
2324

2425
<button hx-post="/replica/{{member.replica_id}}/kill"
2526
hx-trigger="click">

torchft/manager.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
store_addr: Optional[str] = None,
5050
store_port: Optional[int] = None,
5151
lighthouse_addr: Optional[str] = None,
52+
replica_id: Optional[str] = None,
5253
) -> None:
5354
"""
5455
Args:
@@ -62,7 +63,8 @@ def __init__(
6263
world_size: the replica group local world size
6364
store_addr: TCPStore address for this replica group
6465
store_port: TCPStore port for this replica group
65-
ligthouse_addr: if rank==0, the address of the lighthouse server
66+
lighthouse_addr: if rank==0, the address of the lighthouse server
67+
replica_id: if rank==0, the replica_id for this group
6668
"""
6769
self._load_state_dict = load_state_dict
6870
self._state_dict = state_dict
@@ -99,7 +101,8 @@ def __init__(
99101
bind = f"[::]:{port}"
100102
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
101103

102-
replica_id = str(uuid.uuid4())
104+
if replica_id is None:
105+
replica_id = str(uuid.uuid4())
103106
# pyre-fixme[16]: can't find rust module
104107
self._manager = _Manager(
105108
replica_id=replica_id,

train_ddp.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929

3030
def main() -> None:
31+
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
32+
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
33+
3134
transform = transforms.Compose(
3235
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
3336
)
@@ -40,8 +43,8 @@ def main() -> None:
4043
# majority of groups will be available so few batches will be dropped.
4144
sampler = DistributedSampler(
4245
trainset,
43-
replica_group=int(os.environ.get("REPLICA_GROUP_ID", 0)),
44-
num_replica_groups=int(os.environ.get("NUM_REPLICA_GROUPS", 2)),
46+
replica_group=REPLICA_GROUP_ID,
47+
num_replica_groups=NUM_REPLICA_GROUPS,
4548
rank=0,
4649
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
4750
num_replicas=1,
@@ -50,7 +53,7 @@ def main() -> None:
5053
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
5154
# restore the per worker dataloader position.
5255
trainloader = StatefulDataLoader(
53-
trainset, batch_size=2, shuffle=True, num_workers=2
56+
trainset, batch_size=64, shuffle=True, num_workers=2
5457
)
5558

5659
def load_state_dict(state_dict):
@@ -68,9 +71,10 @@ def state_dict():
6871

6972
manager = Manager(
7073
pg=pg,
71-
min_replica_size=2,
74+
min_replica_size=1,
7275
load_state_dict=load_state_dict,
7376
state_dict=state_dict,
77+
replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
7478
)
7579

7680
class Net(nn.Module):

0 commit comments

Comments
 (0)