You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html).
7
+
8
+
## Frameworks and Supported Models
9
+
We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs.
@@ -267,26 +285,9 @@ In all of the above cases, `ghcr.io/nvidia/jax:XXX` points to the most recent
267
285
nightly build of the container for `XXX`. These containers are also tagged as
268
286
`ghcr.io/nvidia/jax:XXX-YYYY-MM-DD`, if a stable reference is required.
269
287
270
-
## Note
271
-
This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: [T5x](https://github.com/google-research/t5x), [PAXML](https://github.com/google/paxml), [Transformer Engine](https://github.com/NVIDIA/TransformerEngine), [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html) and others to come soon.
272
-
273
-
## Frameworks and Supported Models
274
-
We currently support the following frameworks and models. More details about each model and the available containers can be found in their respective READMEs.
We will update this table as new models become available, so stay tuned.
286
-
287
288
## Environment Variables
288
289
289
-
The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning:
290
+
The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning of XLA and NCCL:
290
291
291
292
| XLA Flags | Value | Explanation |
292
293
| --------- | ----- | ----------- |
@@ -302,10 +303,10 @@ There are various other XLA flags users can set to improve performance. For a de
302
303
303
304
For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.
304
305
305
-
## Profiling JAX programs on GPU
306
+
## Profiling
306
307
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.
307
308
308
-
## FAQ (Frequently Asked Questions)
309
+
## Frequently asked questions (FAQ)
309
310
310
311
<details>
311
312
<summary>`bus error` when running JAX in a docker container</summary>
@@ -340,7 +341,6 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
0 commit comments