|
14 | 14 | limitations under the License. |
15 | 15 | --> |
16 | 16 |
|
17 | | -(full-finetuning)= |
18 | 17 | # Full fine-tuning on single-host TPUs |
19 | 18 |
|
20 | | -MaxText can perform pre-training and full finetuning. To perform full fine |
21 | | -tuning, you need to pass the checkpoint to the training script. |
| 19 | +Full Fine-Tuning (FFT) is a common technique used in post-training to adapt a pre-trained Large Language Model (LLM) to a specific downstream task or dataset. In this process, all the parameters (weights) of the original model are "unfrozen" and updated during training on the new task-specific data. This allows the entire model to adjust and specialize, potentially leading to the best performance on the new task. |
22 | 20 |
|
23 | | -Following is the parameter to assign a checkpoint to the training script. |
| 21 | +This tutorial demonstrates step-by-step instructions for setting up the environment, convert checkpoint and then training the model on a Hugging Face dataset using FFT. |
24 | 22 |
|
25 | | -- `load_parameters_path`: Path to the checkpoint directory |
| 23 | +In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! |
26 | 24 |
|
27 | | -The high level steps involve: |
28 | | -- Converting the model checkpoints to MaxText formatted checkpoints |
29 | | -- Preparing the dataset so that data can be fed into the training script. |
30 | | - MaxText provides sample pipelines to load the data via tf.data or Pygrain from |
31 | | - a disk or gcs bucket. Or it can also input data directly from the hugging face |
32 | | - dataset. |
33 | | -- Running the training script with the checkpoint |
34 | | -- Note: You may need to change the training parameters to fit the model to the |
35 | | - TPU or GPU shape and to obtain an optimized performance. |
| 25 | +## Install dependencies |
36 | 26 |
|
37 | | -## MaxText checkpoints |
| 27 | +```sh |
| 28 | +# 1. Clone the repository |
| 29 | +git clone https://github.com/AI-Hypercomputer/maxtext.git |
| 30 | +cd maxtext |
38 | 31 |
|
39 | | -MaxText checkpoints are in their own format. You can see the format in the script for llama conversion script. |
| 32 | +# 2. Create virtual environment |
| 33 | +export VENV_NAME=<your virtual env name> # e.g., maxtext_venv |
| 34 | +pip install uv |
| 35 | +uv venv --python 3.12 --seed $VENV_NAME |
| 36 | +source $VENV_NAME/bin/activate |
40 | 37 |
|
41 | | -The conversion scripts for LLama work with Meta’s original checkpoints and not with HuggingFace Checkpoint. |
| 38 | +# 3. Install dependencies in editable mode |
| 39 | +uv pip install -e .[tpu] --resolution=lowest |
| 40 | +install_maxtext_github_deps |
| 41 | +``` |
| 42 | +## Setup environment variables |
42 | 43 |
|
43 | | -E.g. |
| 44 | +```sh |
| 45 | +# -- Model configuration -- |
| 46 | +export MODEL_NAME=<model name> # e.g., 'llama3.1-8b' |
| 47 | +export MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct' |
| 48 | +export HF_TOKEN=<Hugging Face access token> |
44 | 49 |
|
45 | | -```bash |
46 | | -python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path <path/to/meta/ckpt> \ |
47 | | - --maxtext-model-path <GCS/path/to/save/new/maxtext/ckpt> --model-size llama2-7b |
| 50 | +# -- MaxText configuration -- |
| 51 | +export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory |
| 52 | +export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S) |
48 | 53 | ``` |
49 | 54 |
|
50 | | -The conversion scripts do not use accelerators but need large host memory to perform the conversion. |
| 55 | +## Hugging Face checkpoint to Maxtext checkpoint |
| 56 | +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. |
51 | 57 |
|
52 | | -- The base model checkpoints should be in the format `{name}.{chkpt_idx}.pth` |
53 | | - - For example: `mistral-7b.00.pth` |
54 | | -- For large size model (e.g. 70B model), this script requires large memory VM. |
55 | | -- The script load and save weights in a single pass. |
| 58 | +### Option 1: Using an existing MaxText checkpoint |
| 59 | +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. |
56 | 60 |
|
57 | | -### Sample full fine tuning script |
| 61 | +```sh |
| 62 | +export MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items |
| 63 | +``` |
58 | 64 |
|
59 | | -Below is a sample training script for LLama2-7b. |
| 65 | +### Option 2: Converting a Hugging Face checkpoint |
| 66 | +If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible. |
60 | 67 |
|
61 | | -```bash |
62 | | -python3 -m MaxText.train \ |
63 | | - src/MaxText/configs/base.yml \ |
64 | | - run_name="llama2-finetune-maxtext" \ |
65 | | - base_output_directory=${output_directory} \ |
66 | | - load_parameters_path=${path_to_checkpoint} \ |
67 | | - model_name='llama2-7b' \ |
68 | | - dataset_path=${dataset_path} \ |
69 | | - async_checkpointing=False \ |
70 | | - model_name='llama2-7b' \ |
71 | | - steps=10 per_device_batch_size=.25 |
| 68 | +1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example: |
| 69 | + |
| 70 | +```sh |
| 71 | +export MODEL_CKPT_DIRECTORY=${BASE_OUTPUT_DIRECTORY}/maxtext-checkpoint |
72 | 72 | ``` |
73 | 73 |
|
74 | | -You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu). |
75 | | -These scripts can provide a reference point for various scripts. |
| 74 | +2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py). |
| 75 | + |
| 76 | +```sh |
| 77 | +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure torch is installed for the conversion script |
76 | 78 |
|
77 | | -### MaxText checkpoint to Hugging Face |
| 79 | +python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ |
| 80 | + model_name=${MODEL_NAME} \ |
| 81 | + hf_access_token=${HF_TOKEN} \ |
| 82 | + base_output_directory=${MODEL_CKPT_DIRECTORY} \ |
| 83 | + scan_layers=True skip_jax_distributed_system=True |
| 84 | +``` |
| 85 | + |
| 86 | +3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint: |
78 | 87 |
|
79 | | -Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py). |
| 88 | +```sh |
| 89 | +export MODEL_CKPT_PATH=${MODEL_CKPT_DIRECTORY}/0/items |
| 90 | +``` |
80 | 91 |
|
81 | | -#### Dataset |
| 92 | +## Dataset |
82 | 93 |
|
83 | 94 | MaxText provides examples to work with [Common Crawl](https://commoncrawl.org/). The dataset is available in TFRecords format in a cloud bucket. MaxText provides scripts to copy the dataset to a Google Cloud Storage Bucket. |
84 | 95 |
|
85 | | -##### Common Crawl (c4) dataset setup |
| 96 | +### Common Crawl (c4) dataset setup |
86 | 97 |
|
87 | | -You need to run these steps once per project prior to any local development or cluster experiments. |
| 98 | +Run these steps once per project prior to any local development or cluster experiments. |
88 | 99 |
|
89 | 100 | 1. Create two gcs buckets in your project, one for downloading and retrieving the dataset and the other for storing the logs. |
90 | | -2. Download the dataset in your gcs bucket |
| 101 | +2. Download the dataset in your gcs bucket. |
91 | 102 |
|
92 | | -MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them: |
| 103 | +MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them. |
93 | 104 |
|
94 | | -```bash |
95 | | -bash tools/data_generation/download_dataset.sh ${GCS_PROJECT?} ${GCS_BUCKET_NAME?} |
| 105 | +```sh |
| 106 | +export PROJECT=<Google Cloud Project ID> |
| 107 | +export DATASET_GCS_BUCKET=<GCS for dataset> # e.g., gs://my-bucket/my-dataset |
| 108 | + |
| 109 | +bash tools/data_generation/download_dataset.sh ${PROJECT} ${DATASET_GCS_BUCKET} |
96 | 110 | ``` |
97 | 111 |
|
98 | | -The above will download the c4 dataset to your GCS BUCKET. |
| 112 | +The above will download the c4 dataset to the GCS BUCKET. |
| 113 | + |
| 114 | +## Sample Full Fine tuning script |
| 115 | + |
| 116 | +Below is a sample training script. |
| 117 | + |
| 118 | +```sh |
| 119 | +python3 -m MaxText.train \ |
| 120 | + src/MaxText/configs/base.yml \ |
| 121 | + run_name=${RUN_NAME} \ |
| 122 | + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ |
| 123 | + load_parameters_path=${MODEL_CKPT_PATH} \ |
| 124 | + model_name=${MODEL_NAME} \ |
| 125 | + dataset_path=${DATASET_GCS_BUCKET} \ |
| 126 | + async_checkpointing=False \ |
| 127 | + tokenizer_path=${MODEL_TOKENIZER} \ |
| 128 | + hf_access_token=${HF_TOKEN} \ |
| 129 | + steps=10 per_device_batch_size=1 |
| 130 | +``` |
| 131 | + |
| 132 | +You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu). |
| 133 | +These scripts can provide a reference point for various scripts. |
99 | 134 |
|
100 | 135 | ## Parameters to achieve high MFU |
101 | 136 |
|
|
0 commit comments