Skip to content

Commit 2fa50b7

Browse files
Evaan2001aaronmondal
authored andcommitted
Add files for finetuning LLMs on CPUs blog
1 parent ee0ecec commit 2fa50b7

7 files changed

Lines changed: 560 additions & 0 deletions

File tree

finetuning_on_cpu/.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.DS_Store # only for macs
2+
bazel-finetune-repo
3+
bazel-out
4+
bazel-bin
5+
bazel-testlogs
6+
.bazelrc
7+
.env
8+
MODULE.bazel.lock

finetuning_on_cpu/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
exports_files(
2+
["python_provider.sh"],
3+
visibility = ["//visibility:public"], # Makes it visible to all packages
4+
)

finetuning_on_cpu/MODULE.bazel

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
bazel_dep(name = "rules_python", version = "1.2.0")
2+
bazel_dep(name = "rules_shell", version = "0.4.1")
3+
4+
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
5+
python.toolchain(
6+
python_version = "3.13",
7+
)
8+
9+
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
10+
pip.parse(
11+
hub_name = "pypi",
12+
python_version = "3.13",
13+
requirements_lock = "//:bazel_requirements_lock.txt",
14+
)
15+
16+
use_repo(pip, "pypi")
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
2+
aiohappyeyeballs==2.6.1
3+
# via aiohttp
4+
aiohttp==3.11.14
5+
# via
6+
# datasets
7+
# fsspec
8+
aiosignal==1.3.2
9+
# via aiohttp
10+
async-timeout==5.0.1
11+
# via aiohttp
12+
attrs==25.3.0
13+
# via aiohttp
14+
certifi==2025.1.31
15+
# via requests
16+
charset-normalizer==3.4.1
17+
# via requests
18+
datasets==3.4.1
19+
# via -r requirements.txt
20+
dill==0.3.7
21+
# via
22+
# datasets
23+
# multiprocess
24+
filelock==3.18.0
25+
# via
26+
# huggingface-hub
27+
# torch
28+
# transformers
29+
frozenlist==1.5.0
30+
# via
31+
# aiohttp
32+
# aiosignal
33+
fsspec[http]==2024.12.0
34+
# via
35+
# datasets
36+
# huggingface-hub
37+
# torch
38+
huggingface-hub==0.29.3
39+
# via
40+
# datasets
41+
# tokenizers
42+
# transformers
43+
idna==3.10
44+
# via
45+
# requests
46+
# yarl
47+
jinja2==3.1.6
48+
# via torch
49+
markupsafe==3.0.2
50+
# via jinja2
51+
mpmath==1.3.0
52+
# via sympy
53+
multidict==6.2.0
54+
# via
55+
# aiohttp
56+
# yarl
57+
multiprocess==0.70.15
58+
# via datasets
59+
networkx==3.2.1
60+
# via torch
61+
numpy==1.26.4
62+
# via
63+
# datasets
64+
# pandas
65+
# transformers
66+
packaging==24.2
67+
# via
68+
# datasets
69+
# huggingface-hub
70+
# transformers
71+
pandas==2.2.3
72+
# via datasets
73+
propcache==0.3.0
74+
# via
75+
# aiohttp
76+
# yarl
77+
pyarrow==19.0.1
78+
# via datasets
79+
python-dateutil==2.9.0.post0
80+
# via pandas
81+
pytz==2025.2
82+
# via pandas
83+
pyyaml==6.0.2
84+
# via
85+
# datasets
86+
# huggingface-hub
87+
# transformers
88+
regex==2024.11.6
89+
# via transformers
90+
requests==2.32.3
91+
# via
92+
# datasets
93+
# huggingface-hub
94+
# transformers
95+
safetensors==0.5.3
96+
# via transformers
97+
six==1.17.0
98+
# via python-dateutil
99+
sympy==1.13.1
100+
# via torch
101+
tokenizers==0.21.1
102+
# via transformers
103+
setuptools==69.0.2
104+
# via torch
105+
torch==2.6.0
106+
# via -r requirements.txt
107+
tqdm==4.67.1
108+
# via
109+
# datasets
110+
# huggingface-hub
111+
# transformers
112+
transformers==4.50.1
113+
# via -r requirements.txt
114+
typing-extensions==4.12.2
115+
# via
116+
# huggingface-hub
117+
# multidict
118+
# torch
119+
tzdata==2025.2
120+
# via pandas
121+
urllib3==2.3.0
122+
# via requests
123+
xxhash==3.5.0
124+
# via datasets
125+
yarl==1.18.3
126+
# via aiohttp
127+
accelerate==1.5.2
128+
psutil==5.9.8
129+
# for accelerate
130+
anyio==4.9.0
131+
# for transformers
132+
distro==1.9.0
133+
# for transformers
134+
httpx==0.28.1
135+
# for transformers
136+
jiter==0.9.0
137+
# for transformers
138+
pydantic-core==2.33.0
139+
# for transformers
140+
pydantic==2.11.0
141+
# for transformers
142+
sniffio==1.3.1
143+
# for httpx
144+
httpcore==1.0.7
145+
# for httpx
146+
h11==0.14.0
147+
# for httpx
148+
annotated-types==0.7.0
149+
# for transformers
150+
typing-inspection==0.4.0
151+
# for transformers
152+
exceptiongroup==1.2.2
153+
# for transformers
154+
python-dotenv==1.1.0
155+
# for loading env variables
156+
scipy==1.15.2
157+
# for torch to get embeddings
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/usr/bin/env bash
2+
# Generic Python test wrapper that handles cross-architecture execution
3+
4+
# Get the directory where this script is located
5+
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
6+
7+
# First argument is the Python module or script to run
8+
PYTHON_SCRIPT="$1"
9+
# Remaining arguments are passed to the Python script
10+
shift
11+
12+
echo "Running Python test: $PYTHON_SCRIPT"
13+
echo "Python version:"
14+
python3 --version
15+
16+
# Check if the argument is a file or a module
17+
if [[ -f "$PYTHON_SCRIPT" ]]; then
18+
# Run as a script if it's a file
19+
echo "Running as script"
20+
python3 "$PYTHON_SCRIPT" "$@"
21+
else
22+
# Run as a module if it's not a file
23+
echo "Running as module"
24+
python3 -m "$PYTHON_SCRIPT" "$@"
25+
fi
26+
27+
# Capture and return the exit code
28+
exit_code=$?
29+
echo "Python test completed with exit code: $exit_code"
30+
exit $exit_code
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
load("@rules_python//python:py_binary.bzl", "py_binary")
2+
load("@rules_shell//shell:sh_test.bzl", "sh_test")
3+
4+
exports_files(
5+
["train_model.py"],
6+
visibility=["//visibility:public"], # Makes it visible to all packages
7+
)
8+
9+
py_binary(
10+
name="train_model",
11+
srcs=["train_model.py"],
12+
deps=[
13+
"@pypi//torch",
14+
"@pypi//transformers",
15+
"@pypi//datasets",
16+
"@pypi//accelerate",
17+
"@pypi//psutil",
18+
],
19+
visibility=["//visibility:public"],
20+
)
21+
22+
sh_test(
23+
name="training_test",
24+
srcs=["//:python_provider.sh"], # Use the python provider script from root
25+
args=["src.training.train_model"], # Pass the Python module path as an argument
26+
data=[ # The Python script as a data dependency
27+
"//src/training:train_model.py",
28+
],
29+
# Optional: Set appropriate test size/timeout
30+
size="large",
31+
)

0 commit comments

Comments
 (0)