This is a JAX implementation of GPT-2 with ZeRO-1 (see the paper for more details). In ZeRO-1, we shard the optimizer state while leaving parameters, gradients, and activations unsharded.
To run the training script:
cd GPT2-DDP/gpt2ddp/gpt2ddp
uv run scripts/train.py
To modify the model/training configuration, see gpt2ddp/core/config.py.
Here's a memory profile of 16 training steps. Compared to my experiments with GPT2-DDP, we see ~800 MB reduction in max memory use:
