Skip to content

Commit 6a9ef6f

Browse files
author
The tunix Authors
committed
Merge pull request #1160 from rajasekharporeddy:qlora_llama3
PiperOrigin-RevId: 877423188
2 parents a3389dc + fdef654 commit 6a9ef6f

File tree

1 file changed

+44
-14
lines changed

1 file changed

+44
-14
lines changed

examples/qlora_llama3_gpu.ipynb

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,44 @@
3838
"!nvidia-smi"
3939
]
4040
},
41+
{
42+
"cell_type": "markdown",
43+
"metadata": {
44+
"id": "svAMl1BxH_nC"
45+
},
46+
"source": [
47+
"## Installing Libraries: RESTART RUNTIME AFTER INSTALLATION\n",
48+
"\n",
49+
"The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with all dependencies preinstalled. To install the dependencies manually:\n",
50+
"\n",
51+
"```bash\n",
52+
"pip install 'jax[cuda13]' flax optax transformers datasets qwix huggingface_hub wandb\n",
53+
"```\n",
54+
"\n",
55+
"On top of the installation (either container or manual), you will need Tunix:\n",
56+
"\n",
57+
"```bash\n",
58+
"pip install git+https://github.com/google/tunix\n",
59+
"```"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"metadata": {
66+
"id": "YIEvogq3Lvh2"
67+
},
68+
"outputs": [],
69+
"source": [
70+
"# Install necessary libraries\n",
71+
"import importlib\n",
72+
"\n",
73+
"if importlib.util.find_spec(\"tunix\") is None:\n",
74+
" print(\"Required packages not found. Running full installation...\")\n",
75+
" %pip install 'jax[cuda13]' flax optax transformers datasets qwix huggingface_hub wandb\n",
76+
" %pip install git+https://github.com/google/tunix"
77+
]
78+
},
4179
{
4280
"cell_type": "markdown",
4381
"metadata": {
@@ -118,7 +156,7 @@
118156
"source": [
119157
"# Prefer environment variable if already set\n",
120158
"\n",
121-
"from huggingface_hub.v1.hf_api import whoami\n",
159+
"from huggingface_hub import whoami\n",
122160
"HF_TOKEN = os.environ.get(\"HF_TOKEN\")\n",
123161
"\n",
124162
"if HF_TOKEN:\n",
@@ -160,19 +198,7 @@
160198
"- **Optax**: Gradient processing and optimization library for JAX\n",
161199
"- **Transformers**: Hugging Face library for tokenizers and model configurations\n",
162200
"- **Qwix**: Quantization and LoRA utilities for JAX models\n",
163-
"- **Tunix**: Training utilities including `PeftTrainer` and `AutoModel` for streamlined fine-tuning\n",
164-
"\n",
165-
"The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with all dependencies preinstalled. To install the dependencies manually:\n",
166-
"\n",
167-
"```bash\n",
168-
"pip install 'jax[cuda13]' flax optax transformers datasets qwix\n",
169-
"```\n",
170-
"\n",
171-
"On top of the installation (either container or manual), you will need Tunix:\n",
172-
"\n",
173-
"```bash\n",
174-
"pip install tunix\n",
175-
"```"
201+
"- **Tunix**: Training utilities including `PeftTrainer` and `AutoModel` for streamlined fine-tuning"
176202
]
177203
},
178204
{
@@ -197,6 +223,7 @@
197223
"import qwix\n",
198224
"from tunix.models.automodel import AutoModel\n",
199225
"from tunix.sft import peft_trainer, metrics_logger\n",
226+
"import wandb\n",
200227
"\n",
201228
"print(f\"JAX {jax.__version__} | Devices: {jax.devices()}\")"
202229
]
@@ -618,6 +645,9 @@
618645
},
619646
"outputs": [],
620647
"source": [
648+
"#Initialize wandb\n",
649+
"wandb.init()\n",
650+
"\n",
621651
"# Quick inference test with the fine-tuned LoRA model\n",
622652
"prompt = \"What is the capital of France?\"\n",
623653
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",

0 commit comments

Comments
 (0)