Skip to content

Commit b314c5a

Browse files
Merge pull request AI-Hypercomputer#2925 from CIeNET-International:user/sharony/mtxdoc
PiperOrigin-RevId: 857196217
2 parents 6bcd411 + 701ab96 commit b314c5a

File tree

1 file changed

+86
-51
lines changed

1 file changed

+86
-51
lines changed

docs/tutorials/posttraining/full_finetuning.md

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,88 +14,123 @@
1414
limitations under the License.
1515
-->
1616

17-
(full-finetuning)=
1817
# Full fine-tuning on single-host TPUs
1918

20-
MaxText can perform pre-training and full finetuning. To perform full fine
21-
tuning, you need to pass the checkpoint to the training script.
19+
Full Fine-Tuning (FFT) is a common technique used in post-training to adapt a pre-trained Large Language Model (LLM) to a specific downstream task or dataset. In this process, all the parameters (weights) of the original model are "unfrozen" and updated during training on the new task-specific data. This allows the entire model to adjust and specialize, potentially leading to the best performance on the new task.
2220

23-
Following is the parameter to assign a checkpoint to the training script.
21+
This tutorial demonstrates step-by-step instructions for setting up the environment, convert checkpoint and then training the model on a Hugging Face dataset using FFT.
2422

25-
- `load_parameters_path`: Path to the checkpoint directory
23+
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
2624

27-
The high level steps involve:
28-
- Converting the model checkpoints to MaxText formatted checkpoints
29-
- Preparing the dataset so that data can be fed into the training script.
30-
MaxText provides sample pipelines to load the data via tf.data or Pygrain from
31-
a disk or gcs bucket. Or it can also input data directly from the hugging face
32-
dataset.
33-
- Running the training script with the checkpoint
34-
- Note: You may need to change the training parameters to fit the model to the
35-
TPU or GPU shape and to obtain an optimized performance.
25+
## Install dependencies
3626

37-
## MaxText checkpoints
27+
```sh
28+
# 1. Clone the repository
29+
git clone https://github.com/AI-Hypercomputer/maxtext.git
30+
cd maxtext
3831

39-
MaxText checkpoints are in their own format. You can see the format in the script for llama conversion script.
32+
# 2. Create virtual environment
33+
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
34+
pip install uv
35+
uv venv --python 3.12 --seed $VENV_NAME
36+
source $VENV_NAME/bin/activate
4037

41-
The conversion scripts for LLama work with Meta’s original checkpoints and not with HuggingFace Checkpoint.
38+
# 3. Install dependencies in editable mode
39+
uv pip install -e .[tpu] --resolution=lowest
40+
install_maxtext_github_deps
41+
```
42+
## Setup environment variables
4243

43-
E.g.
44+
```sh
45+
# -- Model configuration --
46+
export MODEL_NAME=<model name> # e.g., 'llama3.1-8b'
47+
export MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
48+
export HF_TOKEN=<Hugging Face access token>
4449

45-
```bash
46-
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path <path/to/meta/ckpt> \
47-
--maxtext-model-path <GCS/path/to/save/new/maxtext/ckpt> --model-size llama2-7b
50+
# -- MaxText configuration --
51+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
52+
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
4853
```
4954

50-
The conversion scripts do not use accelerators but need large host memory to perform the conversion.
55+
## Hugging Face checkpoint to Maxtext checkpoint
56+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
5157

52-
- The base model checkpoints should be in the format `{name}.{chkpt_idx}.pth`
53-
- For example: `mistral-7b.00.pth`
54-
- For large size model (e.g. 70B model), this script requires large memory VM.
55-
- The script load and save weights in a single pass.
58+
### Option 1: Using an existing MaxText checkpoint
59+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
5660

57-
### Sample full fine tuning script
61+
```sh
62+
export MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
63+
```
5864

59-
Below is a sample training script for LLama2-7b.
65+
### Option 2: Converting a Hugging Face checkpoint
66+
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
6067

61-
```bash
62-
python3 -m MaxText.train \
63-
src/MaxText/configs/base.yml \
64-
run_name="llama2-finetune-maxtext" \
65-
base_output_directory=${output_directory} \
66-
load_parameters_path=${path_to_checkpoint} \
67-
model_name='llama2-7b' \
68-
dataset_path=${dataset_path} \
69-
async_checkpointing=False \
70-
model_name='llama2-7b' \
71-
steps=10 per_device_batch_size=.25
68+
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
69+
70+
```sh
71+
export MODEL_CKPT_DIRECTORY=${BASE_OUTPUT_DIRECTORY}/maxtext-checkpoint
7272
```
7373

74-
You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu).
75-
These scripts can provide a reference point for various scripts.
74+
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
75+
76+
```sh
77+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure torch is installed for the conversion script
7678

77-
### MaxText checkpoint to Hugging Face
79+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
80+
model_name=${MODEL_NAME} \
81+
hf_access_token=${HF_TOKEN} \
82+
base_output_directory=${MODEL_CKPT_DIRECTORY} \
83+
scan_layers=True skip_jax_distributed_system=True
84+
```
85+
86+
3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint:
7887

79-
Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py).
88+
```sh
89+
export MODEL_CKPT_PATH=${MODEL_CKPT_DIRECTORY}/0/items
90+
```
8091

81-
#### Dataset
92+
## Dataset
8293

8394
MaxText provides examples to work with [Common Crawl](https://commoncrawl.org/). The dataset is available in TFRecords format in a cloud bucket. MaxText provides scripts to copy the dataset to a Google Cloud Storage Bucket.
8495

85-
##### Common Crawl (c4) dataset setup
96+
### Common Crawl (c4) dataset setup
8697

87-
You need to run these steps once per project prior to any local development or cluster experiments.
98+
Run these steps once per project prior to any local development or cluster experiments.
8899

89100
1. Create two gcs buckets in your project, one for downloading and retrieving the dataset and the other for storing the logs.
90-
2. Download the dataset in your gcs bucket
101+
2. Download the dataset in your gcs bucket.
91102

92-
MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them:
103+
MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them.
93104

94-
```bash
95-
bash tools/data_generation/download_dataset.sh ${GCS_PROJECT?} ${GCS_BUCKET_NAME?}
105+
```sh
106+
export PROJECT=<Google Cloud Project ID>
107+
export DATASET_GCS_BUCKET=<GCS for dataset> # e.g., gs://my-bucket/my-dataset
108+
109+
bash tools/data_generation/download_dataset.sh ${PROJECT} ${DATASET_GCS_BUCKET}
96110
```
97111

98-
The above will download the c4 dataset to your GCS BUCKET.
112+
The above will download the c4 dataset to the GCS BUCKET.
113+
114+
## Sample Full Fine tuning script
115+
116+
Below is a sample training script.
117+
118+
```sh
119+
python3 -m MaxText.train \
120+
src/MaxText/configs/base.yml \
121+
run_name=${RUN_NAME} \
122+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
123+
load_parameters_path=${MODEL_CKPT_PATH} \
124+
model_name=${MODEL_NAME} \
125+
dataset_path=${DATASET_GCS_BUCKET} \
126+
async_checkpointing=False \
127+
tokenizer_path=${MODEL_TOKENIZER} \
128+
hf_access_token=${HF_TOKEN} \
129+
steps=10 per_device_batch_size=1
130+
```
131+
132+
You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu).
133+
These scripts can provide a reference point for various scripts.
99134

100135
## Parameters to achieve high MFU
101136

0 commit comments

Comments
 (0)