Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_pathways_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ jobs:
FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only"
fi
export MAXTEXT_REPO_ROOT=$(pwd)
export MAXTEXT_ASSETS_ROOT=$(pwd)/src/MaxText/assets
export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ jobs:
fi
# TODO: Use package data for testing and remove the env vars
export MAXTEXT_REPO_ROOT=$(pwd)
export MAXTEXT_ASSETS_ROOT=$(pwd)/src/MaxText/assets
export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets
export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText
# omit this libtpu init args for gpu tests
Expand Down
4 changes: 2 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"dataset_path=gs://test-maxtext-dataset",
"model_name=llama2-7b",
"load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_",
"tokenizer_path=src/MaxText/assets/tokenizer.llama2",
"tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2",
"per_device_batch_size=8",
"max_prefill_predict_length=8",
"max_target_length=20",
Expand Down Expand Up @@ -70,7 +70,7 @@
"args": [
"src/MaxText/configs/base.yml",
"model_name=llama2-7b",
"tokenizer_path=src/MaxText/assets/tokenizer.llama2",
"tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2",
"weight_dtype=bfloat16",
"scan_layers=false",
"attention=dot_product",
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
r if os.path.isdir(os.path.join(r := os.path.dirname(os.path.dirname(__file__)), ".git")) else MAXTEXT_PKG_DIR,
)

# This is the assets root: with "tokenizer.gemma3"; &etc.
MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_PKG_DIR, "assets"))
# This is the assets root: with "tokenizers/"; &etc.
MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "assets"))

__all__ = ["MAXTEXT_ASSETS_ROOT", "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT"]
28 changes: 14 additions & 14 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
Expand Down Expand Up @@ -1280,7 +1280,7 @@
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"tokenizer_type": "tiktoken",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
Expand Down Expand Up @@ -1336,7 +1336,7 @@
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"tokenizer_type": "tiktoken",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
Expand Down Expand Up @@ -1517,7 +1517,7 @@
"megablox": False,
"sparse_matmul": False,
"capacity_factor": 1.25,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
},
xla_flags=(
xla_flags_library.MOE_VMEM_LIMIT_FLAG
Expand Down Expand Up @@ -1552,7 +1552,7 @@
"sparse_matmul": False,
"capacity_factor": 1.25,
"quantization": "int8",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
},
xla_flags=(
xla_flags_library.MOE_VMEM_LIMIT_FLAG
Expand Down Expand Up @@ -1593,7 +1593,7 @@
"megablox": False,
"sparse_matmul": False,
"capacity_factor": 1.25,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v3"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v3"),
"dtype": "bfloat16",
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
Expand Down Expand Up @@ -1634,7 +1634,7 @@
"megablox": False,
"sparse_matmul": False,
"capacity_factor": 1.0,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v3"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v3"),
"dtype": "bfloat16",
"opt_type": "sgd",
"weight_dtype": "bfloat16",
Expand Down Expand Up @@ -1667,7 +1667,7 @@
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"sa_block_q": 2048,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
Expand Down Expand Up @@ -1700,7 +1700,7 @@
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"sa_block_q": 2048,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
Expand Down Expand Up @@ -1739,7 +1739,7 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand Down Expand Up @@ -1779,7 +1779,7 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand Down Expand Up @@ -1819,7 +1819,7 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand Down Expand Up @@ -1868,7 +1868,7 @@
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"tokenizer_type": "tiktoken",
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
"packing": False,
},
xla_flags=(
Expand Down Expand Up @@ -1933,7 +1933,7 @@
"sa_use_fused_bwd_kernel": True,
"sparse_matmul": False,
"capacity_factor": 1.5,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.mistral-v1"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.mistral-v1"),
"dtype": "bfloat16",
"weight_dtype": "bfloat16",
"opt_type": "sgd",
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/maxtext_v5e_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"remat_policy": "save_qkv_proj",
"max_target_length": 2048,
"use_iota_embed": True,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
Expand All @@ -171,7 +171,7 @@
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 2048,
"use_iota_embed": True,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/maxtext_v5p_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@
"remat_policy": "minimal",
"max_target_length": 4096,
"use_iota_embed": True,
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer.llama2"),
"tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"),
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/data_input_pipeline/data_input_tfds.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ eval_interval: 10000
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
# TFDS input pipeline only supports tokenizer in spm format
tokenizer_path: 'src/MaxText/assets/tokenizer.llama2'
tokenizer_path: 'src/maxtext/assets/tokenizers/tokenizer.llama2'
```
41 changes: 23 additions & 18 deletions docs/tutorials/posttraining/multimodal.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@


# Multimodal support

This document provides a guide to use the multimodal functionalities in MaxText including:

- **Checkpoint Conversion**: Convert a MaxText-compatible orbax checkpoint from HuggingFace.
- **Multimodal Decode**: Inference with text+images as input.
- **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset.

We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
| Models | Input Modalities | Output Modalities |
| :---- | :---- | :---- |
| - Gemma3-4B/12B/27B<br>- Llama4-Scout/Maverick | Text, images | Text |

| Models | Input Modalities | Output Modalities |
| :--------------------------------------------- | :--------------- | :---------------- |
| - Gemma3-4B/12B/27B<br>- Llama4-Scout/Maverick | Text, images | Text |

## Introduction

Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:
Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:

- **Data Preprocessing**: We apply modality-specific preprocessing steps to prepare the raw input data (e.g., image resizing and normalization), transforming them into a format which neural networks can understand.
- **Modality-Specific Encoders**: Modality-specific encoders will transform the preprocessed data into high-dimensional representations (e.g., vision transformers for images).
- **Projection and Merge**: Projection layers will map these modality-specific embeddings into the shared embedding space of the language model, usually aligned with the dimension of text embeddings. These projected embeddings are then merged with text token embeddings, allowing the unified model to process and reason over multiple modalities simultaneously within a single coherent framework.

![Illustration of multimodal MaxText.](../../_static/multimodal_overview.png)
*Figure 1: Overview of multimodal dataflow in MaxText.*


## Checkpoint Conversion

Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md)).

Install pytorch:

```
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
```
Expand Down Expand Up @@ -58,7 +59,9 @@ python -m MaxText.utils.ckpt_scripts.llama4_ckpt_unscanned \
```

## Multimodal Decode

MaxText supports multimodal decoding, allowing you to input text with multiple images to get a text output. To use this feature, you need three main settings:

- `use_multimodal=True`: Initializes the multimodal preprocessing steps and network components.
- `prompt`: Specifies the position of image placeholder tokens in your input. If you don't manually place them, MaxText will automatically append the required placeholder (e.g., `<start_of_image>` for Gemma3, `<|image|>` for Llama4). The exact placeholder is listed under the `image_placeholder` field in each model's configuration file.
- `image_path`: The path(s) to the image file(s) MaxText will load and process.
Expand All @@ -73,7 +76,7 @@ python -m MaxText.decode \
MaxText/configs/base.yml \
model_name=gemma3-4b \
hf_access_token=$HF_ACCESS_TOKEN \
tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \
load_parameters_path=$MAXTEXT_CKPT_GCS_PATH/0/items \
per_device_batch_size=1 \
run_name=ht_test \
Expand All @@ -89,6 +92,7 @@ python -m MaxText.decode \
```

The decoding results will look like this:

```
Input `<start_of_turn>user
Describe image <start_of_image><end_of_turn>
Expand Down Expand Up @@ -123,7 +127,6 @@ Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically

Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as an example to demonstrate SFT functionality:


```shell
export UNSCANNED_CKPT_PATH=... # either set to an already available MaxText ckpt or to the one we just converted in the previous step
python -m MaxText.sft_trainer \
Expand All @@ -148,14 +151,16 @@ python -m MaxText.sft_trainer \
```

## Other Recommendations

- **Setting appropriate prefill length**: To prevent truncation and ensure your full input (text + image) is processed, the prefill length should be set longer than the total combined length of your text tokens and image tokens. This combined length makes up the final sequence fed to the decoder. We recommend to estimate the combined sequence length from your full input and then add a buffer when setting your `max_prefill_predict_length` for decoding. Token estimation rules:
- For text tokens, a good estimate is:
$\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.
- For Gemma3, each image is resized to 896*896 and contributes 256 tokens:
$\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.
- For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens:
$\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$.
- For text tokens, a good estimate is:

$\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.

- For Gemma3, each image is resized to 896\*896 and contributes 256 tokens:

$\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.

- For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens:

$\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$.
2 changes: 1 addition & 1 deletion end_to_end/gpu/a3/test_gemma3_logits.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ python3 -m MaxText.utils.ckpt_scripts.convert_gemma3_chkpt --base_model_path ${C
export UNSCANNED_CKPT_PATH=gs://runner-maxtext-logs/unscanned_chkpt_2025-04-16-00-01/checkpoints/0/items
export NVTE_FUSED_ATTN=1
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0

6 changes: 3 additions & 3 deletions end_to_end/gpu/mixtral/test_8x7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT
attention=cudnn_flash_te capacity_factor=1.25 dtype=bfloat16 \
enable_checkpointing=false ici_expert_parallelism=-1 ici_fsdp_parallelism=1 \
max_target_length=1024 megablox=False per_device_batch_size=1 \
reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 \
reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 \
weight_dtype=bfloat16 sparse_matmul=False packing=False
echo "Finished pre-training"

Expand All @@ -43,7 +43,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT
attention=cudnn_flash_te capacity_factor=1.25 dtype=bfloat16 \
ici_expert_parallelism=-1 ici_fsdp_parallelism=1 \
max_target_length=1024 megablox=False per_device_batch_size=1 \
reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 \
reuse_example_batch=1 steps=5 tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 \
weight_dtype=bfloat16 sparse_matmul=False packing=False
echo "Finished fine-tuning"

Expand All @@ -55,7 +55,7 @@ echo "Finished fine-tuning"
# ici_expert_parallelism=8 ici_fsdp_parallelism=1 max_prefill_predict_length=11 \
# max_target_length=24 megablox=False per_device_batch_size=1 \
# prompt='"[INST] I love to [/INST]"' scan_layers=false \
# tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1
# tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1
# echo "Finished decoding"


Loading
Loading