This is a JAX implementation of GPT-2 with data parallelism (similar to PyTorch's Distributed Data Parallel). See my blog post for more details on how data parallelism works.
To run the training script:
cd GPT2-DDP/gpt2ddp/gpt2ddp
uv run scripts/train.py
To modify the model/training configuration, see gpt2ddp/core/config.py.
