|
38 | 38 | "!nvidia-smi" |
39 | 39 | ] |
40 | 40 | }, |
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "metadata": { |
| 44 | + "id": "svAMl1BxH_nC" |
| 45 | + }, |
| 46 | + "source": [ |
| 47 | + "## Installing Libraries: RESTART RUNTIME AFTER INSTALLATION\n", |
| 48 | + "\n", |
| 49 | + "The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with all dependencies preinstalled. To install the dependencies manually:\n", |
| 50 | + "\n", |
| 51 | + "```bash\n", |
| 52 | + "pip install 'jax[cuda13]' flax optax transformers datasets qwix huggingface_hub wandb\n", |
| 53 | + "```\n", |
| 54 | + "\n", |
| 55 | + "On top of the installation (either container or manual), you will need Tunix:\n", |
| 56 | + "\n", |
| 57 | + "```bash\n", |
| 58 | + "pip install git+https://github.com/google/tunix\n", |
| 59 | + "```" |
| 60 | + ] |
| 61 | + }, |
| 62 | + { |
| 63 | + "cell_type": "code", |
| 64 | + "execution_count": null, |
| 65 | + "metadata": { |
| 66 | + "id": "YIEvogq3Lvh2" |
| 67 | + }, |
| 68 | + "outputs": [], |
| 69 | + "source": [ |
| 70 | + "# Install necessary libraries\n", |
| 71 | + "import importlib\n", |
| 72 | + "\n", |
| 73 | + "if importlib.util.find_spec(\"tunix\") is None:\n", |
| 74 | + " print(\"Required packages not found. Running full installation...\")\n", |
| 75 | + " %pip install 'jax[cuda13]' flax optax transformers datasets qwix huggingface_hub wandb\n", |
| 76 | + " %pip install git+https://github.com/google/tunix" |
| 77 | + ] |
| 78 | + }, |
41 | 79 | { |
42 | 80 | "cell_type": "markdown", |
43 | 81 | "metadata": { |
|
118 | 156 | "source": [ |
119 | 157 | "# Prefer environment variable if already set\n", |
120 | 158 | "\n", |
121 | | - "from huggingface_hub.v1.hf_api import whoami\n", |
| 159 | + "from huggingface_hub import whoami\n", |
122 | 160 | "HF_TOKEN = os.environ.get(\"HF_TOKEN\")\n", |
123 | 161 | "\n", |
124 | 162 | "if HF_TOKEN:\n", |
|
160 | 198 | "- **Optax**: Gradient processing and optimization library for JAX\n", |
161 | 199 | "- **Transformers**: Hugging Face library for tokenizers and model configurations\n", |
162 | 200 | "- **Qwix**: Quantization and LoRA utilities for JAX models\n", |
163 | | - "- **Tunix**: Training utilities including `PeftTrainer` and `AutoModel` for streamlined fine-tuning\n", |
164 | | - "\n", |
165 | | - "The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with all dependencies preinstalled. To install the dependencies manually:\n", |
166 | | - "\n", |
167 | | - "```bash\n", |
168 | | - "pip install 'jax[cuda13]' flax optax transformers datasets qwix\n", |
169 | | - "```\n", |
170 | | - "\n", |
171 | | - "On top of the installation (either container or manual), you will need Tunix:\n", |
172 | | - "\n", |
173 | | - "```bash\n", |
174 | | - "pip install tunix\n", |
175 | | - "```" |
| 201 | + "- **Tunix**: Training utilities including `PeftTrainer` and `AutoModel` for streamlined fine-tuning" |
176 | 202 | ] |
177 | 203 | }, |
178 | 204 | { |
|
197 | 223 | "import qwix\n", |
198 | 224 | "from tunix.models.automodel import AutoModel\n", |
199 | 225 | "from tunix.sft import peft_trainer, metrics_logger\n", |
| 226 | + "import wandb\n", |
200 | 227 | "\n", |
201 | 228 | "print(f\"JAX {jax.__version__} | Devices: {jax.devices()}\")" |
202 | 229 | ] |
|
618 | 645 | }, |
619 | 646 | "outputs": [], |
620 | 647 | "source": [ |
| 648 | + "#Initialize wandb\n", |
| 649 | + "wandb.init()\n", |
| 650 | + "\n", |
621 | 651 | "# Quick inference test with the fine-tuned LoRA model\n", |
622 | 652 | "prompt = \"What is the capital of France?\"\n", |
623 | 653 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n", |
|
0 commit comments