Skip to content

Commit 58a975d

Browse files
committed
Move src/MaxText/assets to src/maxtext/assets/tokenizers
1 parent 941d46a commit 58a975d

File tree

82 files changed

+198
-195
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+198
-195
lines changed

.vscode/launch.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"dataset_path=gs://test-maxtext-dataset",
1616
"model_name=llama2-7b",
1717
"load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_",
18-
"tokenizer_path=src/MaxText/assets/tokenizer.llama2",
18+
"tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2",
1919
"per_device_batch_size=8",
2020
"max_prefill_predict_length=8",
2121
"max_target_length=20",
@@ -70,7 +70,7 @@
7070
"args": [
7171
"src/MaxText/configs/base.yml",
7272
"model_name=llama2-7b",
73-
"tokenizer_path=src/MaxText/assets/tokenizer.llama2",
73+
"tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2",
7474
"weight_dtype=bfloat16",
7575
"scan_layers=false",
7676
"attention=dot_product",

benchmarks/globals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
r if os.path.isdir(os.path.join(r := os.path.dirname(os.path.dirname(__file__)), ".git")) else MAXTEXT_PKG_DIR,
2626
)
2727

28-
# This is the assets root: with "tokenizer.gemma3"; &etc.
29-
MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_PKG_DIR, "assets"))
28+
# This is the assets root: with "tokenizers/"; &etc.
29+
MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "assets"))
3030

3131
__all__ = ["MAXTEXT_ASSETS_ROOT", "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT"]

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@
544544
"profiler": "xplane",
545545
"dataset_path": "gs://max-datasets-rogue",
546546
"dataset_type": "tfds",
547-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
547+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
548548
"sa_block_q": 1024,
549549
"sa_block_q_dkv": 2048,
550550
"sa_block_q_dq": 2048,
@@ -1280,7 +1280,7 @@
12801280
"skip_first_n_steps_for_profiler": 10,
12811281
"profiler_steps": 5,
12821282
"tokenizer_type": "tiktoken",
1283-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
1283+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
12841284
},
12851285
xla_flags=(
12861286
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
@@ -1336,7 +1336,7 @@
13361336
"skip_first_n_steps_for_profiler": 10,
13371337
"profiler_steps": 5,
13381338
"tokenizer_type": "tiktoken",
1339-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
1339+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
13401340
},
13411341
xla_flags=(
13421342
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
@@ -1517,7 +1517,7 @@
15171517
"megablox": False,
15181518
"sparse_matmul": False,
15191519
"capacity_factor": 1.25,
1520-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
1520+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
15211521
},
15221522
xla_flags=(
15231523
xla_flags_library.MOE_VMEM_LIMIT_FLAG
@@ -1552,7 +1552,7 @@
15521552
"sparse_matmul": False,
15531553
"capacity_factor": 1.25,
15541554
"quantization": "int8",
1555-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
1555+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
15561556
},
15571557
xla_flags=(
15581558
xla_flags_library.MOE_VMEM_LIMIT_FLAG
@@ -1593,7 +1593,7 @@
15931593
"megablox": False,
15941594
"sparse_matmul": False,
15951595
"capacity_factor": 1.25,
1596-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v3"),
1596+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v3"),
15971597
"dtype": "bfloat16",
15981598
"weight_dtype": "bfloat16",
15991599
"allow_split_physical_axes": True,
@@ -1634,7 +1634,7 @@
16341634
"megablox": False,
16351635
"sparse_matmul": False,
16361636
"capacity_factor": 1.0,
1637-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v3"),
1637+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v3"),
16381638
"dtype": "bfloat16",
16391639
"opt_type": "sgd",
16401640
"weight_dtype": "bfloat16",
@@ -1667,7 +1667,7 @@
16671667
"reuse_example_batch": 1,
16681668
"enable_checkpointing": False,
16691669
"profiler": "xplane",
1670-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
1670+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
16711671
"sa_block_q": 2048,
16721672
"sa_block_q_dkv": 2048,
16731673
"sa_block_q_dq": 2048,
@@ -1700,7 +1700,7 @@
17001700
"reuse_example_batch": 1,
17011701
"enable_checkpointing": False,
17021702
"profiler": "xplane",
1703-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
1703+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
17041704
"sa_block_q": 2048,
17051705
"sa_block_q_dkv": 2048,
17061706
"sa_block_q_dq": 2048,
@@ -1739,7 +1739,7 @@
17391739
"profiler": "xplane",
17401740
"skip_first_n_steps_for_profiler": 10,
17411741
"profiler_steps": 2,
1742-
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1742+
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
17431743
"sa_block_q": 1024,
17441744
"sa_block_kv": 1024,
17451745
"sa_block_kv_compute": 1024,
@@ -1779,7 +1779,7 @@
17791779
"profiler": "xplane",
17801780
"skip_first_n_steps_for_profiler": 10,
17811781
"profiler_steps": 2,
1782-
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1782+
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
17831783
"sa_block_q": 1024,
17841784
"sa_block_kv": 1024,
17851785
"sa_block_kv_compute": 1024,
@@ -1819,7 +1819,7 @@
18191819
"profiler": "xplane",
18201820
"skip_first_n_steps_for_profiler": 10,
18211821
"profiler_steps": 2,
1822-
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1822+
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
18231823
"sa_block_q": 1024,
18241824
"sa_block_kv": 1024,
18251825
"sa_block_kv_compute": 1024,
@@ -1868,7 +1868,7 @@
18681868
"skip_first_n_steps_for_profiler": 10,
18691869
"profiler_steps": 5,
18701870
"tokenizer_type": "tiktoken",
1871-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
1871+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
18721872
"packing": False,
18731873
},
18741874
xla_flags=(
@@ -1933,7 +1933,7 @@
19331933
"sa_use_fused_bwd_kernel": True,
19341934
"sparse_matmul": False,
19351935
"capacity_factor": 1.5,
1936-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
1936+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
19371937
"dtype": "bfloat16",
19381938
"weight_dtype": "bfloat16",
19391939
"opt_type": "sgd",

benchmarks/maxtext_v5e_model_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
"remat_policy": "save_qkv_proj",
150150
"max_target_length": 2048,
151151
"use_iota_embed": True,
152-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
152+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
153153
"dataset_path": "gs://max-datasets-rogue",
154154
"dataset_type": "synthetic",
155155
"reuse_example_batch": 1,
@@ -171,7 +171,7 @@
171171
"remat_policy": "qkv_proj_offloaded",
172172
"max_target_length": 2048,
173173
"use_iota_embed": True,
174-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
174+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
175175
"dataset_path": "gs://max-datasets-rogue",
176176
"dataset_type": "synthetic",
177177
"reuse_example_batch": 1,

benchmarks/maxtext_v5p_model_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@
227227
"remat_policy": "minimal",
228228
"max_target_length": 4096,
229229
"use_iota_embed": True,
230-
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
230+
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
231231
"dataset_path": "gs://max-datasets-rogue",
232232
"dataset_type": "synthetic",
233233
"reuse_example_batch": 1,

docs/guides/data_input_pipeline/data_input_tfds.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ eval_interval: 10000
1616
eval_dataset_name: 'c4/en:3.0.1'
1717
eval_split: 'validation'
1818
# TFDS input pipeline only supports tokenizer in spm format
19-
tokenizer_path: 'src/MaxText/assets/tokenizer.llama2'
19+
tokenizer_path: 'src/maxtext/assets/tokenizers/tokenizer.llama2'
2020
```

docs/tutorials/posttraining/multimodal.md

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,34 @@
1-
2-
31
# Multimodal support
42

53
This document provides a guide to use the multimodal functionalities in MaxText including:
4+
65
- **Checkpoint Conversion**: Convert a MaxText-compatible orbax checkpoint from HuggingFace.
76
- **Multimodal Decode**: Inference with text+images as input.
87
- **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset.
98

109
We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
11-
| Models | Input Modalities | Output Modalities |
12-
| :---- | :---- | :---- |
13-
| - Gemma3-4B/12B/27B<br>- Llama4-Scout/Maverick | Text, images | Text |
10+
11+
| Models | Input Modalities | Output Modalities |
12+
| :--------------------------------------------- | :--------------- | :---------------- |
13+
| - Gemma3-4B/12B/27B<br>- Llama4-Scout/Maverick | Text, images | Text |
1414

1515
## Introduction
1616

17-
Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:
17+
Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:
18+
1819
- **Data Preprocessing**: We apply modality-specific preprocessing steps to prepare the raw input data (e.g., image resizing and normalization), transforming them into a format which neural networks can understand.
1920
- **Modality-Specific Encoders**: Modality-specific encoders will transform the preprocessed data into high-dimensional representations (e.g., vision transformers for images).
2021
- **Projection and Merge**: Projection layers will map these modality-specific embeddings into the shared embedding space of the language model, usually aligned with the dimension of text embeddings. These projected embeddings are then merged with text token embeddings, allowing the unified model to process and reason over multiple modalities simultaneously within a single coherent framework.
2122

2223
![Illustration of multimodal MaxText.](../../_static/multimodal_overview.png)
2324
*Figure 1: Overview of multimodal dataflow in MaxText.*
2425

25-
2626
## Checkpoint Conversion
2727

2828
Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md)).
2929

3030
Install pytorch:
31+
3132
```
3233
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
3334
```
@@ -58,7 +59,9 @@ python -m MaxText.utils.ckpt_scripts.llama4_ckpt_unscanned \
5859
```
5960

6061
## Multimodal Decode
62+
6163
MaxText supports multimodal decoding, allowing you to input text with multiple images to get a text output. To use this feature, you need three main settings:
64+
6265
- `use_multimodal=True`: Initializes the multimodal preprocessing steps and network components.
6366
- `prompt`: Specifies the position of image placeholder tokens in your input. If you don't manually place them, MaxText will automatically append the required placeholder (e.g., `<start_of_image>` for Gemma3, `<|image|>` for Llama4). The exact placeholder is listed under the `image_placeholder` field in each model's configuration file.
6467
- `image_path`: The path(s) to the image file(s) MaxText will load and process.
@@ -73,7 +76,7 @@ python -m MaxText.decode \
7376
MaxText/configs/base.yml \
7477
model_name=gemma3-4b \
7578
hf_access_token=$HF_ACCESS_TOKEN \
76-
tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
79+
tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \
7780
load_parameters_path=$MAXTEXT_CKPT_GCS_PATH/0/items \
7881
per_device_batch_size=1 \
7982
run_name=ht_test \
@@ -89,6 +92,7 @@ python -m MaxText.decode \
8992
```
9093

9194
The decoding results will look like this:
95+
9296
```
9397
Input `<start_of_turn>user
9498
Describe image <start_of_image><end_of_turn>
@@ -123,7 +127,6 @@ Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically
123127

124128
Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as an example to demonstrate SFT functionality:
125129

126-
127130
```shell
128131
export UNSCANNED_CKPT_PATH=... # either set to an already available MaxText ckpt or to the one we just converted in the previous step
129132
python -m MaxText.sft_trainer \
@@ -148,14 +151,16 @@ python -m MaxText.sft_trainer \
148151
```
149152

150153
## Other Recommendations
154+
151155
- **Setting appropriate prefill length**: To prevent truncation and ensure your full input (text + image) is processed, the prefill length should be set longer than the total combined length of your text tokens and image tokens. This combined length makes up the final sequence fed to the decoder. We recommend to estimate the combined sequence length from your full input and then add a buffer when setting your `max_prefill_predict_length` for decoding. Token estimation rules:
152-
- For text tokens, a good estimate is:
153-
154-
$\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.
155-
- For Gemma3, each image is resized to 896*896 and contributes 256 tokens:
156-
157-
$\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.
158-
- For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens:
159-
160-
$\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$.
156+
- For text tokens, a good estimate is:
157+
158+
$\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.
159+
160+
- For Gemma3, each image is resized to 896\*896 and contributes 256 tokens:
161+
162+
$\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.
163+
164+
- For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens:
161165

166+
$\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$.

end_to_end/gpu/a3/test_gemma3_logits.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ python3 -m MaxText.utils.ckpt_scripts.convert_gemma3_chkpt --base_model_path ${C
4444
export UNSCANNED_CKPT_PATH=gs://runner-maxtext-logs/unscanned_chkpt_2025-04-16-00-01/checkpoints/0/items
4545
export NVTE_FUSED_ATTN=1
4646
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
47-
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
47+
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
4848

0 commit comments

Comments
 (0)