"From scratch" implementations of image processing algorithms using Jax
- Learn about the algorithms (my background is NLP, learning the new area)
- Refresh my knowledge of the latest changes in JAX
- Have fun
- This is not production quality implementations, the focus is on simplicity and algorithm understanding
Everything was tried using Ubuntu 24.04 CUDA 12.7
Setting up and running
python3 -m venv .venv:wa
source .venv/bin/activate
pip install -r requirements.txt
jupyter labdocker run --runtime=nvidia -p 8888:8888 -it jax-cuda