diff --git a/README.md b/README.md index 2868eac..68d92bb 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ JMP is written in pure Python, but depends on C++ code via JAX and NumPy. Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in `requirements.txt`. -First, follow [these instructions](https://github.com/google/jax#installation) +First, follow [these instructions](https://github.com/jax-ml/jax#installation) to install JAX with the relevant accelerator support. Then, install JMP using pip: @@ -212,4 +212,4 @@ https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/. [0]: https://arxiv.org/abs/1710.03740 [1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html [Haiku]: https://github.com/deepmind/dm-haiku -[JAX]: https://github.com/google/jax +[JAX]: https://github.com/jax-ml/jax