Skip to content

Commit c4e58bb

Browse files
committed
use sampler in train_ddp.py
1 parent 8c8e4e3 commit c4e58bb

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

train_ddp.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def main() -> None:
4848
rank=0,
4949
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
5050
num_replicas=1,
51+
shuffle=True,
5152
)
5253

5354
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
5455
# restore the per worker dataloader position.
5556
trainloader = StatefulDataLoader(
56-
trainset, batch_size=64, shuffle=True, num_workers=2
57+
trainset, batch_size=64, num_workers=2, sampler=sampler
5758
)
5859

5960
def load_state_dict(state_dict):

0 commit comments

Comments
 (0)