Skip to content

Commit 02847ec

Browse files
RdoubleAjoecummings
authored andcommitted
Llama3 tutorial updates (#800)
1 parent c5e4050 commit 02847ec

File tree

1 file changed

+70
-60
lines changed

1 file changed

+70
-60
lines changed

docs/source/tutorials/llama3.rst

Lines changed: 70 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Llama3-8B
2222
----------
2323

2424
`Llama3-8B <https://llama.meta.com/llama3>`_ is a new model released by Meta AI that improves upon the performance of the Llama2 family
25-
of models across a `range of different benchmarks <https://github.com/meta-llama/llama3/blob/main/eval_details.md>`_.
25+
of models across a `range of different benchmarks <https://huggingface.co/meta-llama/Meta-Llama-3-8B#base-pretrained-models>`_.
2626
There are a few main changes between Llama2-7B and Llama3-8B models:
2727

2828
- Llama3-8B uses `grouped-query attention <https://arxiv.org/abs/2305.13245>`_ instead of the standard multi-head attention from Llama2-7B
@@ -93,7 +93,7 @@ In our experiments, we observed a peak memory usage of 18.5 GB. The default conf
9393

9494
If you have multiple GPUs available, you can run the distributed version of the recipe.
9595
torchtune makes use of the `FSDP <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`_ APIs from PyTorch Distributed
96-
to shard the model, optimizer states, and gradients. This should enable you to increase your batch size, resulting in faster training.
96+
to shard the model, optimizer states, and gradients. This should enable you to increase your batch size, resulting in faster overall training.
9797
For example, on two devices:
9898

9999
.. code-block:: bash
@@ -140,28 +140,31 @@ Next, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoint
140140

141141
.. code-block:: yaml
142142
143+
model:
144+
_component_: torchtune.models.llama3.llama3_8b
145+
143146
checkpointer:
144-
_component_: torchtune.utils.FullModelMetaCheckpointer
147+
_component_: torchtune.utils.FullModelMetaCheckpointer
145148
146-
# directory with the checkpoint files
147-
# this should match the output_dir specified during
148-
# fine-tuning
149-
checkpoint_dir: <checkpoint_dir>
149+
# directory with the checkpoint files
150+
# this should match the output_dir specified during
151+
# fine-tuning
152+
checkpoint_dir: <checkpoint_dir>
150153
151-
# checkpoint files for the fine-tuned model. These will be logged
152-
# at the end of your fine-tune
153-
checkpoint_files: [
154-
consolidated.00.pth
155-
]
154+
# checkpoint files for the fine-tuned model. These will be logged
155+
# at the end of your fine-tune
156+
checkpoint_files: [
157+
consolidated.00.pth
158+
]
156159
157-
output_dir: <checkpoint_dir>
158-
model_type: LLAMA3
160+
output_dir: <checkpoint_dir>
161+
model_type: LLAMA3
159162
160163
# Make sure to update the tokenizer path to the right
161164
# checkpoint directory as well
162165
tokenizer:
163-
_component_: torchtune.models.llama3.llama3_tokenizer
164-
path: <checkpoint_dir>/tokenizer.model
166+
_component_: torchtune.models.llama3.llama3_tokenizer
167+
path: <checkpoint_dir>/tokenizer.model
165168
166169
Finally, we can run evaluation using our modified config.
167170

@@ -189,28 +192,31 @@ Now we modify ``custom_generation_config.yaml`` to point to our checkpoint and t
189192

190193
.. code-block:: yaml
191194
195+
model:
196+
_component_: torchtune.models.llama3.llama3_8b
197+
192198
checkpointer:
193-
_component_: torchtune.utils.FullModelMetaCheckpointer
199+
_component_: torchtune.utils.FullModelMetaCheckpointer
194200
195-
# directory with the checkpoint files
196-
# this should match the output_dir specified during
197-
# fine-tuning
198-
checkpoint_dir: <checkpoint_dir>
201+
# directory with the checkpoint files
202+
# this should match the output_dir specified during
203+
# fine-tuning
204+
checkpoint_dir: <checkpoint_dir>
199205
200-
# checkpoint files for the fine-tuned model. These will be logged
201-
# at the end of your fine-tune
202-
checkpoint_files: [
203-
consolidated.00.pth
204-
]
206+
# checkpoint files for the fine-tuned model. These will be logged
207+
# at the end of your fine-tune
208+
checkpoint_files: [
209+
consolidated.00.pth
210+
]
205211
206-
output_dir: <checkpoint_dir>
207-
model_type: LLAMA3
212+
output_dir: <checkpoint_dir>
213+
model_type: LLAMA3
208214
209215
# Make sure to update the tokenizer path to the right
210216
# checkpoint directory as well
211217
tokenizer:
212-
_component_: torchtune.models.llama3.llama3_tokenizer
213-
path: <checkpoint_dir>/tokenizer.model
218+
_component_: torchtune.models.llama3.llama3_tokenizer
219+
path: <checkpoint_dir>/tokenizer.model
214220
215221
Running generation with our LoRA-finetuned model, we see the following output:
216222

@@ -243,32 +249,36 @@ And update ``custom_quantization_config.yaml`` with the following:
243249

244250
.. code-block:: yaml
245251
252+
# Model arguments
253+
model:
254+
_component_: torchtune.models.llama3.llama3_8b
255+
246256
checkpointer:
247-
_component_: torchtune.utils.FullModelMetaCheckpointer
257+
_component_: torchtune.utils.FullModelMetaCheckpointer
248258
249-
# directory with the checkpoint files
250-
# this should match the output_dir specified during
251-
# fine-tuning
252-
checkpoint_dir: <checkpoint_dir>
259+
# directory with the checkpoint files
260+
# this should match the output_dir specified during
261+
# fine-tuning
262+
checkpoint_dir: <checkpoint_dir>
253263
254-
# checkpoint files for the fine-tuned model. These will be logged
255-
# at the end of your fine-tune
256-
checkpoint_files: [
257-
consolidated.00.pth
258-
]
264+
# checkpoint files for the fine-tuned model. These will be logged
265+
# at the end of your fine-tune
266+
checkpoint_files: [
267+
consolidated.00.pth
268+
]
259269
260-
output_dir: <checkpoint_dir>
261-
model_type: LLAMA3
270+
output_dir: <checkpoint_dir>
271+
model_type: LLAMA3
262272
263273
To quantize the model, we can now run:
264274

265275
.. code-block:: bash
266276
267-
tune run quantize ./custom_quantization_config.yaml
277+
tune run quantize --config ./custom_quantization_config.yaml
268278
269279
[quantize.py:90] Time for quantization: 2.93 sec
270280
[quantize.py:91] Memory used: 23.13 GB
271-
[quantize.py:104] Model checkpoint of size 4.92 GB saved to /tmp/Llama-3-8B-hf/meta_model_0-4w.pt
281+
[quantize.py:104] Model checkpoint of size 4.92 GB saved to /tmp/Llama-3-8B-hf/consolidated-4w.pt
272282
273283
We can see that the model is now under 5 GB, or just over four bits for each of the 8B parameters.
274284

@@ -286,29 +296,29 @@ First, we'll make one more change to our ``custom_generation_config.yaml``.
286296
.. code-block:: yaml
287297
288298
checkpointer:
289-
# we need to use the custom TorchTune checkpointer
290-
# instead of the HF checkpointer for loading
291-
# quantized models
292-
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
299+
# we need to use the custom TorchTune checkpointer
300+
# instead of the HF checkpointer for loading
301+
# quantized models
302+
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
293303
294-
# directory with the checkpoint files
295-
# this should match the output_dir specified during
296-
# fine-tuning
297-
checkpoint_dir: <checkpoint_dir>
304+
# directory with the checkpoint files
305+
# this should match the output_dir specified during
306+
# fine-tuning
307+
checkpoint_dir: <checkpoint_dir>
298308
299-
# checkpoint files point to the quantized model
300-
checkpoint_files: [
301-
meta_model_0-4w.pt,
302-
]
309+
# checkpoint files point to the quantized model
310+
checkpoint_files: [
311+
consolidated-4w.pt,
312+
]
303313
304-
output_dir: <checkpoint_dir>
305-
model_type: LLAMA3
314+
output_dir: <checkpoint_dir>
315+
model_type: LLAMA3
306316
307317
# we also need to update the quantizer to what was used during
308318
# quantization
309319
quantizer:
310-
_component_: torchtune.utils.quantization.Int4WeightOnlyQuantizer
311-
groupsize: 256
320+
_component_: torchtune.utils.quantization.Int4WeightOnlyQuantizer
321+
groupsize: 256
312322
313323
Let's re-run generation!
314324

0 commit comments

Comments
 (0)