Skip to content

TheBatmanofButler/GPT2-DDP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 

Repository files navigation

GPT2-DDP

This is a JAX implementation of GPT-2 with data parallelism (similar to PyTorch's Distributed Data Parallel). See my blog post for more details on how data parallelism works.

To run the training script:

cd GPT2-DDP/gpt2ddp/gpt2ddp
uv run scripts/train.py

To modify the model/training configuration, see gpt2ddp/core/config.py.

Here's a memory profile of 16 training steps: 73a65020-455d-4831-b825-84f72c185596_1554x924

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages