28
28
29
29
30
30
def main () -> None :
31
+ REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
32
+ NUM_REPLICA_GROUPS = int (os .environ .get ("NUM_REPLICA_GROUPS" , 2 ))
33
+
31
34
transform = transforms .Compose (
32
35
[transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
33
36
)
@@ -40,8 +43,8 @@ def main() -> None:
40
43
# majority of groups will be available so few batches will be dropped.
41
44
sampler = DistributedSampler (
42
45
trainset ,
43
- replica_group = int ( os . environ . get ( " REPLICA_GROUP_ID" , 0 )) ,
44
- num_replica_groups = int ( os . environ . get ( " NUM_REPLICA_GROUPS" , 2 )) ,
46
+ replica_group = REPLICA_GROUP_ID ,
47
+ num_replica_groups = NUM_REPLICA_GROUPS ,
45
48
rank = 0 ,
46
49
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
47
50
num_replicas = 1 ,
@@ -50,7 +53,7 @@ def main() -> None:
50
53
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
51
54
# restore the per worker dataloader position.
52
55
trainloader = StatefulDataLoader (
53
- trainset , batch_size = 2 , shuffle = True , num_workers = 2
56
+ trainset , batch_size = 64 , shuffle = True , num_workers = 2
54
57
)
55
58
56
59
def load_state_dict (state_dict ):
@@ -68,9 +71,10 @@ def state_dict():
68
71
69
72
manager = Manager (
70
73
pg = pg ,
71
- min_replica_size = 2 ,
74
+ min_replica_size = 1 ,
72
75
load_state_dict = load_state_dict ,
73
76
state_dict = state_dict ,
77
+ replica_id = f"train_ddp_{ REPLICA_GROUP_ID } " ,
74
78
)
75
79
76
80
class Net (nn .Module ):
0 commit comments