Skip to content

Commit 03160ee

Browse files
authored
Fix typo and use sampler in train_ddp.py (#74)
* fix typo * use sampler in train_ddp.py
1 parent c58ed4c commit 03160ee

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ greatly improve efficiency by avoiding stop the world training on errors.
5252

5353
Before proceeding, ensure you have the following installed:
5454

55-
- Rust (with necessaray dependencies)
55+
- Rust (with necessary dependencies)
5656
- `protobuf-compiler` and the corresponding development package for Protobuf.
5757

5858
Note that the Rust versions available in many conda environments may be outdated. To install the latest version of Rust, we recommend downloading it directly from the official website as shown in the below command:

train_ddp.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def main() -> None:
4848
rank=0,
4949
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
5050
num_replicas=1,
51+
shuffle=True,
5152
)
5253

5354
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
5455
# restore the per worker dataloader position.
5556
trainloader = StatefulDataLoader(
56-
trainset, batch_size=64, shuffle=True, num_workers=2
57+
trainset, batch_size=64, num_workers=2, sampler=sampler
5758
)
5859

5960
def load_state_dict(state_dict):

0 commit comments

Comments
 (0)