Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typo and use sampler in train_ddp.py #74

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ greatly improve efficiency by avoiding stop the world training on errors.

Before proceeding, ensure you have the following installed:

- Rust (with necessaray dependencies)
- Rust (with necessary dependencies)
- `protobuf-compiler` and the corresponding development package for Protobuf.

Note that the Rust versions available in many conda environments may be outdated. To install the latest version of Rust, we recommend downloading it directly from the official website as shown in the below command:
Expand Down
3 changes: 2 additions & 1 deletion train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def main() -> None:
rank=0,
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
num_replicas=1,
shuffle=True,
)

# This uses the torchdata StatefulDataLoader to be able to checkpoint and
# restore the per worker dataloader position.
trainloader = StatefulDataLoader(
trainset, batch_size=64, shuffle=True, num_workers=2
trainset, batch_size=64, num_workers=2, sampler=sampler
)

def load_state_dict(state_dict):
Expand Down
Loading