Skip to content

Commit

Permalink
Merge branch 'main' into change-987432
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai authored Feb 25, 2025
2 parents 73c995b + 745be4e commit 8a45bfa
Show file tree
Hide file tree
Showing 45 changed files with 536 additions and 273 deletions.
9 changes: 5 additions & 4 deletions .buckconfig
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
root = .
prelude = third-party/prelude
shim = shim
shim_et = shim_et

[repository_aliases]
config = prelude
ovr_config = prelude
toolchains = shim
fbcode = shim
fbcode_macros = shim
fbsource = shim
toolchains = shim_et
fbcode = shim_et
fbcode_macros = shim_et
fbsource = shim_et
buck = shim

[cxx]
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@
[submodule "third-party/pocketfft"]
path = third-party/pocketfft
url = https://github.com/mreineck/pocketfft
[submodule "shim"]
path = shim
url = https://github.com/facebook/buck2-shims-meta
4 changes: 2 additions & 2 deletions build/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ function(extract_sources sources_file)

if(ANDROID_ABI)
if("${ANDROID_ABI}" STREQUAL "arm64-v8a")
set(target_platforms_arg "--target-platforms=shim//:android-arm64")
set(target_platforms_arg "--target-platforms=shim_et//:android-arm64")
elseif("${ANDROID_ABI}" STREQUAL "x86_64")
set(target_platforms_arg "--target-platforms=shim//:android-x86_64")
set(target_platforms_arg "--target-platforms=shim_et//:android-x86_64")
else()
message(
FATAL_ERROR
Expand Down
68 changes: 48 additions & 20 deletions examples/llm_pte_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,43 @@ In this tutorial, we show how to fine-tune an LLM using executorch.

You will need to have a model's checkpoint, in the Hugging Face format. For example:

```
git clone https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
```console
git clone git clone https://huggingface.co/Qwen/Qwen2-0.5B-Instruct
```

You will need to install [torchtune](https://github.com/pytorch/torchtune) following [its installation instructions](https://github.com/pytorch/torchtune?tab=readme-ov-file#installation).

You might run into an issue with the `triton` package when installing `torchtune`. You can build `triton` locally following the [instructions in their repo](https://github.com/triton-lang/triton?tab=readme-ov-file#install-from-source).

## Config Files

The directory structure of the `llm_pte_finetuning` is:

```console
examples/llm_pte_finetuning
├── README.md
├── TARGETS
├── __init__.py
│ ├── model_loading_lib.cpython-312.pyc
│ └── training_lib.cpython-312.pyc
├── model_exporter.py
├── model_loading_lib.py
├── phi3_alpaca_code_config.yaml
├── phi3_config.yaml
├── qwen_05b_config.yaml
├── runner.py
└── training_lib.py
```

We already provide configs out of the box. The following sections explain how you can setup the config for your own model or dataset.

As mentioned in the previous section, we internally use `torchtune` APIs, and thus, we use config files that follow `torchtune`'s structure. Typically, in the following sections we go through a working example which can be found in the `phi3_config.yaml` config file.

### Tokenizer

We need to define the tokenizer. Let's suppose we would like to use [PHI3 Mini Instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) model from Microsoft. We need to define the tokenizer component:

```
```yaml
tokenizer:
_component_: torchtune.models.phi3.phi3_mini_tokenizer
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
Expand All @@ -33,7 +55,7 @@ This will load the tokenizer, and set the max sequence length to 1024. The class

In this example we use the [Alpaca-Cleaned dataset](https://huggingface.co/datasets/yahma/alpaca-cleaned). We need to define the following parameters:

```
```yaml
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
Expand All @@ -47,7 +69,7 @@ Torchtune supports datasets using huggingface dataloaders, so custom datasets co

For the loss function, we use PyTorch losses. In this example we use the `CrossEntropyLoss`:

```
```yaml
loss:
_component_: torch.nn.CrossEntropyLoss
```
Expand All @@ -56,7 +78,7 @@ loss:

Model parameters can be set, in this example we replicate the configuration for phi3 mini instruct benchmarks:

```
```yaml
model:
_component_: torchtune.models.phi3.lora_phi3_mini
lora_attn_modules: ['q_proj', 'v_proj']
Expand All @@ -70,7 +92,7 @@ model:

Depending on how your model is defined, you will need to instantiate different components. In these examples we use checkpoints from HF (hugging face format), and thus we will need to instantiate a `FullModelHFCheckpointer` object. We need to pass the checkpoint directory, the files with the tensors, the output directory for training and the model type:

```
```yaml
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
Expand All @@ -87,7 +109,7 @@ checkpointer:

Torchtune supports `cuda` and `bf16` tensors. However, for ExecuTorch training we only support `cpu` and `fp32`:

```
```yaml
device: cpu
dtype: fp32
```
Expand All @@ -101,28 +123,34 @@ The `model_exporter.py` exports the LLM checkpoint into an ExecuTorch checkpoint
* `cfg`: Configuration file
* `output_file`: The `.pte` output path

```
python model_exporter.py --cfg=phi3_config.yaml --output_file=phi3_mini_lora.pte
```console
python model_exporter.py \
--cfg=qwen_05b_config.yaml \
--output_file=qwen2_0_5B.pte
```

### Step 2: Run the fine-tuning job

To run the fine-tuning job:

```
python runner.py --cfg=phi3_config.yaml --model_file=phi3_mini_lora.pte
```console
python runner.py \
--cfg=qwen_05b_config.yaml \
--model_file=qwen2_0_5B.pte \
--num_training_steps=10 \
--num_eval_steps=5
```

You need to use **the same** config file from the previous step. The `model_file` arg is the `.pte` model from the previous step.

Example output:

```
Evaluating the model before training...
100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [31:23<00:00, 627.98s/it]
Eval loss: tensor(2.3778)
100%|██████████████████████████████████████████████████████████████████████████████████████| 5/5 [52:29<00:00, 629.84s/it]
Losses: [2.7152762413024902, 0.7890686988830566, 2.249271869659424, 1.4777560234069824, 0.8378427624702454]
100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [30:35<00:00, 611.90s/it]
Eval loss: tensor(0.8464)
```console
Evaluating the model before training
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:47<00:00, 9.45s/it]
Eval loss: tensor(0.9441)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:30<00:00, 9.09s/it]
Losses: [0.5646533966064453, 1.3464953899383545, 1.297974705696106, 1.2249481678009033, 0.6750457286834717, 0.7721152901649475, 1.0774847269058228, 0.7962403893470764, 0.8448256850242615, 0.8731598854064941]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:45<00:00, 9.18s/it]
Eval loss: tensor(0.7679)
```
5 changes: 3 additions & 2 deletions examples/llm_pte_finetuning/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python_library(
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
"fbcode//executorch/exir:lib",
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
"fbcode//executorch/extension/pybindings:portable_lib", # @manual For PTE loader
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
Expand All @@ -27,11 +27,12 @@ python_library(
],
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
"fbcode//executorch/extension/pybindings:portable_lib", # @manual For PTE loader
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
"fbsource//third-party/pypi/tqdm:tqdm",
"fbcode//executorch/backends/xnnpack/partition:xnnpack_partitioner",
],
)

Expand Down
19 changes: 19 additions & 0 deletions examples/llm_pte_finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model_loading_lib import export_model_lora_training, load_checkpoint, setup_model
from .training_lib import eval_model, get_dataloader, TrainingModule, update_function

__all__ = [
"eval_model",
"get_dataloader",
"update_function",
"TrainingModule",
"export_model_lora_training",
"load_checkpoint",
"setup_model",
]
59 changes: 57 additions & 2 deletions examples/llm_pte_finetuning/model_loading_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import Any, Dict, Tuple

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule
from executorch.exir import to_edge
from executorch.exir import EdgeCompileConfig, to_edge

from omegaconf import DictConfig
from torch.export import export, ExportedProgram
Expand Down Expand Up @@ -72,16 +73,70 @@ def export_model_lora_training(
exported_graph: ExportedProgram = export(model, example_args, strict=False)
print("Creating a joint forward-backwards graph for training")
joint_graph = _export_forward_backward(exported_graph)
ep = joint_graph

# Currently there is no implementation of empty_permuted for edge dialect.
# We manually make a pass to rewrite the empty_permuted to empty and permute.
for node in ep.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.empty_permuted.out
):
print("found empty_permute: ", node)
empty_permuted_node = node
with ep.graph.inserting_before(empty_permuted_node):
empty_node = ep.graph.create_node(
"call_function",
torch.ops.aten.empty.memory_format,
(node.args[0],),
empty_permuted_node.kwargs,
)
permute_node = ep.graph.create_node(
"call_function",
torch.ops.aten.permute,
(empty_node, node.args[1]),
)
for user in empty_permuted_node.users.copy():
user.replace_input_with(empty_permuted_node, permute_node)
if (
node.op == "call_function"
and node.target == torch.ops.aten.empty_permuted.default
):
print("found empty_permute default: ", node)
empty_permuted_node = node
with ep.graph.inserting_before(empty_permuted_node):
empty_node = ep.graph.create_node(
"call_function",
torch.ops.aten.empty.memory_format,
(node.args[0],),
empty_permuted_node.kwargs,
)
permute_node = ep.graph.create_node(
"call_function",
torch.ops.aten.permute.default,
(empty_node, node.args[1]),
)
for user in empty_permuted_node.users.copy():
user.replace_input_with(empty_permuted_node, permute_node)

# 2. to_edge: Make optimizations for Edge devices.
print("Lowering to edge dialect")
edge_program = to_edge(joint_graph)
edge_program = to_edge(
joint_graph,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten.empty_permuted.default]
),
)

print(edge_program._edge_programs["forward"].graph_module)

# 3. to_executorch: Convert the graph to an ExecuTorch program.
print("Exporting to executorch")
edge_program = edge_program.to_backend(
XnnpackPartitioner(force_fp32_dynamic_linear=True)
)
executorch_program = edge_program.to_executorch()

print(executorch_program.exported_program().graph_signature)
print(f"Saving to {output_file}")
with open(output_file, "wb") as file:
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_pte_finetuning/qwen_05b_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-0.5B-Instruct
output_dir: /tmp/qwen_0.5B_ft-output
model_type: QWEN2
resume_from_checkpoint: False
save_adapter_weights_only: False
Expand Down
25 changes: 19 additions & 6 deletions examples/llm_pte_finetuning/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
update_function,
)

from executorch.extension.pybindings.aten_lib import ( # @manual
from executorch.extension.pybindings.portable_lib import ( # @manual
_load_for_executorch_from_buffer,
)
from omegaconf import OmegaConf
Expand All @@ -30,6 +30,18 @@
)
parser.add_argument("--cfg", type=str, help="Path to the config file.")
parser.add_argument("--model_file", type=str, help="Path to the ET model file.")
parser.add_argument(
"--num_training_steps",
type=int,
help="Number of training steps, assuming 1 epoch.",
default=100,
)
parser.add_argument(
"--num_eval_steps",
type=int,
help="Number of eval steps, assuming 1 epoch.",
default=5,
)


def main() -> None:
Expand All @@ -47,10 +59,11 @@ def main() -> None:
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
val_dataloader = get_dataloader(cfg, val_set, tokenizer, loss_fn)
num_training_steps = args.num_training_steps
num_eval_steps = args.num_eval_steps

max_seq_len = cfg.tokenizer.max_seq_len
# Num of steps to run training. Assume 1 epoch
num_steps = 100
with open(file, "rb") as f:
model_bytes = f.read()
et_mod = _load_for_executorch_from_buffer(model_bytes)
Expand All @@ -62,7 +75,7 @@ def main() -> None:
dataloader=val_dataloader,
loss_fn=loss_fn,
max_seq_len=max_seq_len,
num_eval_steps=10,
num_eval_steps=num_eval_steps,
)
print("Eval loss: ", eval_loss)

Expand All @@ -74,9 +87,9 @@ def main() -> None:
learning_rate = 5e-3
f.seek(0)
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=num_steps):
for i, batch in tqdm(enumerate(train_dataloader), total=num_training_steps):
# Run for a limited number of steps.
if i >= num_steps:
if i >= num_training_steps:
break
tokens, labels = batch["tokens"], batch["labels"]
token_size = tokens.shape[1]
Expand Down Expand Up @@ -113,7 +126,7 @@ def main() -> None:
dataloader=val_dataloader,
loss_fn=loss_fn,
max_seq_len=max_seq_len,
num_eval_steps=10,
num_eval_steps=num_eval_steps,
)
print("Eval loss: ", eval_loss)

Expand Down
2 changes: 1 addition & 1 deletion examples/llm_pte_finetuning/training_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any

import torch
from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual
from executorch.extension.pybindings.portable_lib import ExecuTorchModule # @manual

from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, DistributedSampler
Expand Down
Loading

0 comments on commit 8a45bfa

Please sign in to comment.