Skip to content

Commit 37d65cf

Browse files
authored
Update README.md (#1069)
Reformatted README and updated models for maxtext
1 parent ef3fd66 commit 37d65cf

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

README.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
1-
# JAX Toolbox
1+
# **JAX Toolbox**
22

3+
[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md)
4+
[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status)
5+
6+
JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html).
7+
8+
## Frameworks and Supported Models
9+
We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs.
10+
11+
| Framework | Models | Use cases | Container |
12+
| :--- | :---: | :---: | :---: |
13+
| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` |
14+
| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` |
15+
| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` |
16+
| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
17+
| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
18+
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` |
19+
20+
# Build Pipeline Status
321
<table>
422
<thead>
523
<tr>
@@ -267,26 +285,9 @@ In all of the above cases, `ghcr.io/nvidia/jax:XXX` points to the most recent
267285
nightly build of the container for `XXX`. These containers are also tagged as
268286
`ghcr.io/nvidia/jax:XXX-YYYY-MM-DD`, if a stable reference is required.
269287

270-
## Note
271-
This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: [T5x](https://github.com/google-research/t5x), [PAXML](https://github.com/google/paxml), [Transformer Engine](https://github.com/NVIDIA/TransformerEngine), [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html) and others to come soon.
272-
273-
## Frameworks and Supported Models
274-
We currently support the following frameworks and models. More details about each model and the available containers can be found in their respective READMEs.
275-
276-
| Framework | Supported Models | Use-cases | Container |
277-
| :--- | :---: | :---: | :---: |
278-
| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` |
279-
| [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` |
280-
| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
281-
| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
282-
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` |
283-
| maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` |
284-
285-
We will update this table as new models become available, so stay tuned.
286-
287288
## Environment Variables
288289

289-
The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning:
290+
The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning of XLA and NCCL:
290291

291292
| XLA Flags | Value | Explanation |
292293
| --------- | ----- | ----------- |
@@ -302,10 +303,10 @@ There are various other XLA flags users can set to improve performance. For a de
302303

303304
For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.
304305

305-
## Profiling JAX programs on GPU
306+
## Profiling
306307
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.
307308

308-
## FAQ (Frequently Asked Questions)
309+
## Frequently asked questions (FAQ)
309310

310311
<details>
311312
<summary>`bus error` when running JAX in a docker container</summary>
@@ -340,7 +341,6 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
340341
</details>
341342

342343
## JAX on Public Clouds
343-
344344
* AWS
345345
* [Add EFA integration](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-efa.html)
346346
* [SageMaker code sample](https://github.com/aws-samples/aws-samples-for-ray/tree/main/sagemaker/jax_alpa_language_model)

0 commit comments

Comments
 (0)