Skip to content

TheBatmanofButler/GPT2-ZeRO1

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

GPT2-ZeRO1

This is a JAX implementation of GPT-2 with ZeRO-1 (see the paper for more details). In ZeRO-1, we shard the optimizer state while leaving parameters, gradients, and activations unsharded.

To run the training script:

cd GPT2-DDP/gpt2ddp/gpt2ddp
uv run scripts/train.py

To modify the model/training configuration, see gpt2ddp/core/config.py.

Here's a memory profile of 16 training steps. Compared to my experiments with GPT2-DDP, we see ~800 MB reduction in max memory use: image

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages