diff --git a/README.md b/README.md index c1a4ae3..aa67f59 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ greatly improve efficiency by avoiding stop the world training on errors. Before proceeding, ensure you have the following installed: -- Rust (with necessaray dependencies) +- Rust (with necessary dependencies) - `protobuf-compiler` and the corresponding development package for Protobuf. Note that the Rust versions available in many conda environments may be outdated. To install the latest version of Rust, we recommend downloading it directly from the official website as shown in the below command: diff --git a/train_ddp.py b/train_ddp.py index 741bb86..c15d7e7 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -48,12 +48,13 @@ def main() -> None: rank=0, # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. num_replicas=1, + shuffle=True, ) # This uses the torchdata StatefulDataLoader to be able to checkpoint and # restore the per worker dataloader position. trainloader = StatefulDataLoader( - trainset, batch_size=64, shuffle=True, num_workers=2 + trainset, batch_size=64, num_workers=2, sampler=sampler ) def load_state_dict(state_dict):