Skip to content

Commit 83f52ce

Browse files
committed
Reorganize pre-training doc
* Moves pre-training documentation to be a sub-section of the Run MaxText section; * Adds more information beyond just dataset configuration for the pre-training guide; * Adds some extra content to data pipeline individual guides.
1 parent a4bd69d commit 83f52ce

13 files changed

Lines changed: 492 additions & 232 deletions

File tree

docs/development.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ hidden:
77
---
88
development/update_dependencies.md
99
development/contribute_docs.md
10+
development/hlo_diff_testing.md
1011
```

docs/development/hlo_diff_testing.md

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ ______________________________________________________________________
4444

4545
When intended architectures transformations alter graph lowering, reference file baselines require updates.
4646

47-
> [!IMPORTANT]\
48-
> While running the update script locally is not the end of the world, **relying on local execution can cause remote CI tests to fail.**
49-
> The PR verification pipelines run the tests in a strictly locked GitHub Actions environment. The smallest discrepancies in local library installations will introduce slight backend lowering graph deviations. If your local execution leads to a remote CI check failure, rely on the GitHub Action trigger described below to generate environment-matching baselines.
47+
```{important}
48+
49+
While running the update script locally is not the end of the world, **relying on local execution can cause remote CI tests to fail.**
50+
51+
The PR verification pipelines run the tests in a strictly locked GitHub Actions environment. The smallest discrepancies in local library installations will introduce slight backend lowering graph deviations. If your local execution leads to a remote CI check failure, rely on the GitHub Action trigger described below to generate environment-matching baselines.
52+
```
5053

5154
### Method 1: Run the manual GitHub Action Workflow (Highly Recommended)
5255

@@ -66,13 +69,14 @@ Alternatively, you can trigger the remote workflow via terminal CLI execution:
6669
gh workflow run update_reference_hlo.yml --ref <branch>
6770
```
6871

69-
> [!NOTE]
70-
> A successful run of the manual update workflow will add a new commit to your Pull Request branch. Once complete, you must:
71-
>
72-
> 1. Pull the new commit from remote.
73-
> 2. Squash the commits in your branch once again to keep your PR history clean.
74-
> 3. Push the squashed commit to remote.
75-
> 4. Retry the `tpu-integration` workflow to verify tests pass on your PR.
72+
```{note}
73+
A successful run of the manual update workflow will add a new commit to your Pull Request branch. Once complete, you must:
74+
75+
1. Pull the new commit from remote.
76+
2. Squash the commits in your branch once again to keep your PR history clean.
77+
3. Push the squashed commit to remote.
78+
4. Retry the `tpu-integration` workflow to verify tests pass on your PR.
79+
```
7680

7781
### Method 2: Local Execution
7882

docs/guides.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,59 @@
1818

1919
Explore our how-to guides for optimizing, debugging, and managing your MaxText workloads.
2020

21-
::::{grid} 1 2 2 2
22-
:gutter: 2
23-
24-
:::{grid-item-card} ⚡ Optimization
21+
````{grid} 1 2 2 2
22+
---
23+
gutter: 2
24+
---
25+
```{grid-item-card} ⚡ Optimization
2526
:link: guides/optimization
2627
:link-type: doc
2728
2829
Techniques for maximizing performance, including sharding strategies, Pallas kernels, and benchmarking.
29-
:::
30+
```
3031
31-
:::{grid-item-card} 💾 Data Pipelines
32+
```{grid-item-card} 💾 Data Pipelines
3233
:link: guides/data_input_pipeline
3334
:link-type: doc
3435
3536
Configure input pipelines using **Grain** (recommended for determinism), **HuggingFace**, or **TFDS**.
36-
:::
37+
```
3738
38-
:::{grid-item-card} 🔄 Checkpointing
39+
```{grid-item-card} 🔄 Checkpointing
3940
:link: guides/checkpointing_solutions
4041
:link-type: doc
4142
4243
Manage GCS checkpoints, handle preemption with emergency checkpointing, and configure multi-tier storage.
43-
:::
44+
```
4445
45-
:::{grid-item-card} 🔍 Monitoring & Debugging
46+
```{grid-item-card} 🔍 Monitoring & Debugging
4647
:link: guides/monitoring_and_debugging
4748
:link-type: doc
4849
4950
Tools for observability: goodput monitoring, hung job debugging, and Vertex AI TensorBoard integration.
50-
:::
51+
```
5152
52-
:::{grid-item-card} 🐍 Python Notebooks
53+
```{grid-item-card} 🐍 Python Notebooks
5354
:link: guides/run_python_notebook
5455
:link-type: doc
5556
5657
Interactive development guides for running MaxText on Google Colab or local JupyterLab environments.
57-
:::
58+
```
5859
59-
:::{grid-item-card} 🌱 Model Bringup
60+
```{grid-item-card} 🌱 Model Bringup
6061
:link: guides/model_bringup
6162
:link-type: doc
6263
6364
A step-by-step guide for the community to help expand MaxText's model library.
64-
:::
65+
```
6566
66-
:::{grid-item-card} 🎓 Distillation
67+
```{grid-item-card} 🎓 Distillation
6768
:link: guides/distillation
6869
:link-type: doc
6970
7071
How online distillation works in MaxText: loss anatomy, α / β / temperature schedule tuning, layer indices, monitoring metrics, and troubleshooting.
71-
:::
72-
::::
72+
```
73+
````
7374

7475
```{toctree}
7576
---

docs/guides/data_input_pipeline.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ Training in a multi-host environment presents unique challenges for data input p
3737

3838
### Random access dataset (Recommended)
3939

40-
Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.<br>
40+
Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.
41+
4142
In MaxText, this is best supported by the ArrayRecord format using the Grain input pipeline. This approach gracefully handles the key challenges:
4243

4344
- **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,14 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state
3232

3333
## Using Grain
3434

35-
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources/protocol.html) class.
36-
- **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
37-
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/scripts/setup_gcsfuse.sh). The script configures some parameters for the mount.
35+
Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources/protocol.html) class.
36+
37+
```{admonition} Community Resource
38+
39+
The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
40+
```
41+
42+
If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/scripts/setup_gcsfuse.sh). The script configures some parameters for the mount.
3843

3944
```sh
4045
bash src/dependencies/scripts/setup_gcsfuse.sh \
@@ -45,11 +50,13 @@ MOUNT_PATH=${MOUNT_PATH?} \
4550

4651
Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://docs.cloud.google.com/storage/docs/cloud-storage-fuse/performance)).
4752

53+
### Configuration
54+
4855
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
4956

5057
2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/scripts/setup_gcsfuse.sh) to avoid gcsfuse throttling.
5158

52-
3. ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
59+
3. *ArrayRecord Only*: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
5360

5461
```
5562
# Blend two data sources with 30% from first source and 70% from second source
@@ -120,17 +127,32 @@ grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
120127
grain_worker_count=2
121128
```
122129

123-
1. Using validation set for evaluation
130+
### Using validation set for evaluation
124131

125-
When setting eval_interval > 0, evaluation will be run with a specified eval dataset. Example config (set in [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml) or through command line):
132+
When setting `eval_interval > 0`, evaluation will be run with a specified eval dataset. Example config (set in [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml) or through command line):
126133

127134
```yaml
128135
eval_interval: 10000
129136
eval_steps: 50
130137
grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*'
131138
```
132139
133-
1. Experimental: resuming training with a different chip count
140+
### Tokenizer support
141+
142+
Grain pipeline supports three tokenizer types:
143+
144+
- `sentencepiece`: For SentencePiece tokenizers;
145+
- `huggingface`: For HuggingFace tokenizers (requires `hf_access_token` for gated models);
146+
- `tiktoken`: For OpenAI's tiktoken tokenizers.
147+
148+
Example with SentencePiece:
149+
150+
```bash
151+
tokenizer_type=sentencepiece \
152+
tokenizer_path=gs://<your-bucket>/tokenizers/c4_en_301_5Mexp2_spm.model
153+
```
154+
155+
### Experimental: resuming training with a different chip count
134156

135157
In Grain checkpoints, each data-loading host has a corresponding JSON file. For cases where a user wants to resume training with a different number of data-loading hosts, MaxText provides an experimental feature:
136158

docs/guides/data_input_pipeline/data_input_hf.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,40 @@ hf_eval_files: 'gs://<bucket>/<folder>/*-validation-*.parquet' # match the val
3939
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large
4040
```
4141

42+
## Tokenizer configuration
43+
44+
The Hugging Face pipeline only supports Hugging Face tokenizers and will ignore the `tokenizer_type` flag.
45+
46+
## Using gated datasets
47+
48+
For [gated datasets](https://huggingface.co/docs/hub/en/datasets-gated) or tokenizers from [gated models](https://huggingface.co/docs/hub/en/models-gated), you need to:
49+
50+
1. Request access on HuggingFace
51+
2. Generate an access token from your [HuggingFace settings](https://huggingface.co/settings/tokens)
52+
3. Provide the token in your command:
53+
54+
```bash
55+
hf_access_token=<YOUR_TOKEN>
56+
```
57+
58+
Example with gated model:
59+
60+
```bash
61+
python3 -m maxtext.trainers.pre_train.train \
62+
base_output_directory=gs://<your-bucket> \
63+
run_name=llama2_demo \
64+
model_name=llama2-7b \
65+
dataset_type=hf \
66+
hf_path=allenai/c4 \
67+
hf_data_dir=en \
68+
train_split=train \
69+
tokenizer_type=huggingface \
70+
tokenizer_path=meta-llama/Llama-2-7b \
71+
hf_access_token=hf_xxxxxxxxxxxxx \
72+
steps=1000 \
73+
per_device_batch_size=8
74+
```
75+
4276
## Limitations and Recommendations
4377

4478
1. Streaming data directly from Hugging Face Hub may be impacted by the traffic of the server. During peak hours you may encounter "504 Server Error: Gateway Time-out". It's recommended to download the Hugging Face dataset to a Cloud Storage bucket or disk for the most stable experience.

docs/guides/data_input_pipeline/data_input_tfds.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# TFDS pipeline
22

3+
The TensorFlow Datasets (TFDS) pipeline uses datasets in TFRecord format, which is performant and widely supported in the TensorFlow ecosystem.
4+
5+
## Example config for streaming from TFDS dataset in a Cloud Storage bucket
6+
37
1. Download the Allenai C4 dataset in TFRecord format to a Cloud Storage bucket. For information about cost, see [this discussion](https://github.com/allenai/allennlp/discussions/5056)
48

59
```shell
@@ -18,3 +22,11 @@ eval_split: 'validation'
1822
# TFDS input pipeline only supports tokenizer in spm format
1923
tokenizer_path: 'src/maxtext/assets/tokenizers/tokenizer.llama2'
2024
```
25+
26+
### Tokenizer support
27+
28+
TFDS pipeline supports three tokenizer types:
29+
30+
- `sentencepiece`: For SentencePiece tokenizers
31+
- `huggingface`: For HuggingFace tokenizers (requires `hf_access_token` for gated models)
32+
- `tiktoken`: For OpenAI's tiktoken tokenizers

docs/run_maxtext.md

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,59 @@
22

33
Choose your environment and orchestration method to run MaxText.
44

5-
::::{grid} 1 2 2 2
6-
:gutter: 2
5+
````{grid} 1 2 2 2
6+
---
7+
gutter: 2
8+
---
9+
```{grid-item-card} 🚀 Pre-training
10+
:link: run_maxtext/run_maxtext_pretraining
11+
:link-type: doc
12+
13+
Complete guide to pre-training language models from scratch. Covers model selection, hyperparameters, dataset configuration, deployment options, and monitoring.
14+
```
715
8-
:::{grid-item-card} 💻 Localhost / Single VM
16+
```{grid-item-card} 💻 Localhost / Single VM
917
:link: run_maxtext/run_maxtext_localhost
1018
:link-type: doc
1119
1220
Get started quickly on a single machine. Clone the repo, install dependencies, and run your first training job on a single TPU or GPU VM.
13-
:::
21+
```
1422
15-
:::{grid-item-card} 🎮 Single-host GPU
23+
```{grid-item-card} 🎮 Single-host GPU
1624
:link: run_maxtext/run_maxtext_single_host_gpu
1725
:link-type: doc
1826
1927
Run MaxText on single-host NVIDIA GPUs (e.g., A3 High/Mega). Includes Docker setup, NVIDIA Container Toolkit installation, and 1B/7B model training examples.
20-
:::
28+
```
2129
22-
:::{grid-item-card} 🏗️ At scale with XPK (GKE)
30+
```{grid-item-card} 🏗️ At scale with XPK (GKE)
2331
:link: run_maxtext/run_maxtext_via_xpk
2432
:link-type: doc
2533
2634
Deploy to Google Kubernetes Engine (GKE) using XPK. Orchestrate large-scale training jobs on TPU or GPU clusters with simple CLI commands.
27-
:::
35+
```
2836
29-
:::{grid-item-card} 🌐 Multi-host via Pathways
37+
```{grid-item-card} 🌐 Multi-host via Pathways
3038
:link: run_maxtext/run_maxtext_via_pathways
3139
:link-type: doc
3240
3341
Run large-scale JAX jobs on TPUs using Pathways. Supports batch and headless (interactive) workloads on GKE.
34-
:::
42+
```
3543
36-
:::{grid-item-card} 🔌 Decoupled Mode
44+
```{grid-item-card} 🔌 Decoupled Mode
3745
:link: run_maxtext/decoupled_mode
3846
:link-type: doc
3947
4048
Run tests and local development without Google Cloud dependencies (no `gcloud`, GCS, or Vertex AI required).
41-
:::
42-
::::
49+
```
50+
````
4351

4452
```{toctree}
4553
---
4654
hidden:
4755
maxdepth: 1
4856
---
57+
run_maxtext/run_maxtext_pretraining.md
4958
run_maxtext/run_maxtext_localhost.md
5059
run_maxtext/run_maxtext_single_host_gpu.md
5160
run_maxtext/run_maxtext_via_xpk.md

0 commit comments

Comments
 (0)