A minimal JAX-based autograd engine implementation, inspired by Andrej Karpathy's micrograd. This is a learning project to understand the internals of automatic differentiation and neural network frameworks.
- JAX-based tensor operations
- Automatic differentiation
- Basic neural network operations (relu, sigmoid, tanh)
- Memory-efficient operations with chunking and checkpointing
- Multi-device support
Built this to learn about:
- How autograd engines work
- JAX's approach to automatic differentiation
- Efficient tensor operations
- GPU acceleration and parallel processing
This is a learning project and I'm always looking to improve. Feel free to suggest improvements or share your thoughts:
- Twitter/X: @OccupyingM
- Inspired by micrograd by Andrej Karpathy