|
30 | 30 | }, |
31 | 31 | { |
32 | 32 | "cell_type": "code", |
33 | | - "source": [ |
34 | | - "!pip install google-metrax clu" |
35 | | - ], |
| 33 | + "execution_count": null, |
36 | 34 | "metadata": { |
37 | 35 | "id": "t-eEZKa8qHy7" |
38 | 36 | }, |
39 | | - "execution_count": null, |
40 | | - "outputs": [] |
| 37 | + "outputs": [], |
| 38 | + "source": [ |
| 39 | + "!pip install google-metrax" |
| 40 | + ] |
41 | 41 | }, |
42 | 42 | { |
43 | 43 | "cell_type": "code", |
|
135 | 135 | }, |
136 | 136 | { |
137 | 137 | "cell_type": "markdown", |
| 138 | + "metadata": { |
| 139 | + "id": "aZiutTU3qguP" |
| 140 | + }, |
138 | 141 | "source": [ |
139 | 142 | "## Lifecycle of a Metrax Metric\n", |
140 | 143 | "\n", |
|
144 | 147 | "2. **Iteration/Batch Processing:** As you process data in batches or on different devices, you create a new metric state for the current batch/device using `Metric.from_model_output()` (functional API) or update the existing metric object with `.update()` (object-oriented API), passing the predictions, labels, and any relevant weights for that specific data slice.\n", |
145 | 148 | "3. **Merging/Updating:** For the functional API, you merge the newly created metric state for the current batch/device with the accumulated state in your dictionary using the `.merge()` method. For the object-oriented API, the `.update()` method directly modifies the state within the metric object in your dictionary. This step accumulates the necessary statistics across all processed data.\n", |
146 | 149 | "4. **Final Computation:** After processing all data, you call the `.compute()` method on the final, merged (functional API) or updated (object-oriented API) metric state in your dictionary. This performs the final calculations and returns the metric's value (e.g., a single floating-point number like AUCPR)." |
147 | | - ], |
148 | | - "metadata": { |
149 | | - "id": "aZiutTU3qguP" |
150 | | - } |
| 150 | + ] |
151 | 151 | }, |
152 | 152 | { |
153 | 153 | "cell_type": "markdown", |
|
171 | 171 | }, |
172 | 172 | { |
173 | 173 | "cell_type": "markdown", |
| 174 | + "metadata": { |
| 175 | + "id": "TOT_PXRDv80K" |
| 176 | + }, |
174 | 177 | "source": [ |
175 | 178 | "###Basic Usage: Unweighted Metrics\n", |
176 | 179 | "This first example demonstrates the core functional workflow without sample weights." |
177 | | - ], |
178 | | - "metadata": { |
179 | | - "id": "TOT_PXRDv80K" |
180 | | - } |
| 180 | + ] |
181 | 181 | }, |
182 | 182 | { |
183 | 183 | "cell_type": "code", |
| 184 | + "execution_count": null, |
| 185 | + "metadata": { |
| 186 | + "id": "HFHRQRmpxmys" |
| 187 | + }, |
| 188 | + "outputs": [], |
184 | 189 | "source": [ |
185 | 190 | "import metrax\n", |
186 | 191 | "\n", |
|
191 | 196 | " 'AUCPR': metrax.AUCPR,\n", |
192 | 197 | " 'AUCROC': metrax.AUCROC,\n", |
193 | 198 | "}" |
194 | | - ], |
195 | | - "metadata": { |
196 | | - "id": "HFHRQRmpxmys" |
197 | | - }, |
198 | | - "execution_count": null, |
199 | | - "outputs": [] |
| 199 | + ] |
200 | 200 | }, |
201 | 201 | { |
202 | 202 | "cell_type": "code", |
|
220 | 220 | }, |
221 | 221 | { |
222 | 222 | "cell_type": "code", |
| 223 | + "execution_count": null, |
| 224 | + "metadata": { |
| 225 | + "id": "sYZe0JUQxiNt" |
| 226 | + }, |
| 227 | + "outputs": [], |
223 | 228 | "source": [ |
224 | 229 | "# --- Method 2: Iterative Merging by Batch (Unweighted) ---\n", |
225 | 230 | "print(\"\\n--- Method 2: Iterative Merging (Unweighted) ---\")\n", |
|
239 | 244 | "for name, metric_state in iterative_metrics.items():\n", |
240 | 245 | " iterative_results[name] = metric_state.compute()\n", |
241 | 246 | " print(f\"{name}: {iterative_results[name]}\")" |
242 | | - ], |
243 | | - "metadata": { |
244 | | - "id": "sYZe0JUQxiNt" |
245 | | - }, |
246 | | - "execution_count": null, |
247 | | - "outputs": [] |
| 247 | + ] |
248 | 248 | }, |
249 | 249 | { |
250 | 250 | "cell_type": "markdown", |
| 251 | + "metadata": { |
| 252 | + "id": "B_HJqaJKv_73" |
| 253 | + }, |
251 | 254 | "source": [ |
252 | 255 | "###Advanced Usage: Incorporating Sample Weights\n", |
253 | 256 | "Just like the NNX API, the functional API supports sample_weights in from_model_output for metrics where it is applicable. The following example calculates AUCPR and AUCROC using the weighted data we prepared earlier." |
254 | | - ], |
255 | | - "metadata": { |
256 | | - "id": "B_HJqaJKv_73" |
257 | | - } |
| 257 | + ] |
258 | 258 | }, |
259 | 259 | { |
260 | 260 | "cell_type": "code", |
| 261 | + "execution_count": null, |
| 262 | + "metadata": { |
| 263 | + "id": "czPiI6dhxvmL" |
| 264 | + }, |
| 265 | + "outputs": [], |
261 | 266 | "source": [ |
262 | 267 | "import metrax\n", |
263 | 268 | "\n", |
|
266 | 271 | " 'AUCPR': metrax.AUCPR,\n", |
267 | 272 | " 'AUCROC': metrax.AUCROC,\n", |
268 | 273 | "}" |
269 | | - ], |
270 | | - "metadata": { |
271 | | - "id": "czPiI6dhxvmL" |
272 | | - }, |
273 | | - "execution_count": null, |
274 | | - "outputs": [] |
| 274 | + ] |
275 | 275 | }, |
276 | 276 | { |
277 | 277 | "cell_type": "code", |
| 278 | + "execution_count": null, |
| 279 | + "metadata": { |
| 280 | + "id": "MSnndNonwDKq" |
| 281 | + }, |
| 282 | + "outputs": [], |
278 | 283 | "source": [ |
279 | 284 | "# --- Method 1: Full-Batch Calculation (Weighted) ---\n", |
280 | 285 | "print(\"--- Method 1: Full-Batch Calculation (Weighted) ---\")\n", |
|
287 | 292 | " )\n", |
288 | 293 | " full_batch_results_weighted[name] = metric_state.compute()\n", |
289 | 294 | " print(f\"{name}: {full_batch_results_weighted[name]}\")" |
290 | | - ], |
291 | | - "metadata": { |
292 | | - "id": "MSnndNonwDKq" |
293 | | - }, |
294 | | - "execution_count": null, |
295 | | - "outputs": [] |
| 295 | + ] |
296 | 296 | }, |
297 | 297 | { |
298 | 298 | "cell_type": "code", |
| 299 | + "execution_count": null, |
| 300 | + "metadata": { |
| 301 | + "id": "e8BPLW6XxxVr" |
| 302 | + }, |
| 303 | + "outputs": [], |
299 | 304 | "source": [ |
300 | 305 | "# --- Method 2: Iterative Merging by Batch (Weighted) ---\n", |
301 | 306 | "print(\"\\n--- Method 2: Iterative Merging (Weighted) ---\")\n", |
|
315 | 320 | "for name, metric_state in iterative_metrics_weighted.items():\n", |
316 | 321 | " iterative_results_weighted[name] = metric_state.compute()\n", |
317 | 322 | " print(f\"{name}: {iterative_results_weighted[name]}\")" |
318 | | - ], |
319 | | - "metadata": { |
320 | | - "id": "e8BPLW6XxxVr" |
321 | | - }, |
322 | | - "execution_count": null, |
323 | | - "outputs": [] |
| 323 | + ] |
324 | 324 | }, |
325 | 325 | { |
326 | 326 | "cell_type": "markdown", |
|
341 | 341 | }, |
342 | 342 | { |
343 | 343 | "cell_type": "markdown", |
| 344 | + "metadata": { |
| 345 | + "id": "TokxcQvTvWua" |
| 346 | + }, |
344 | 347 | "source": [ |
345 | 348 | "###Basic Usage: Unweighted Metrics\n", |
346 | 349 | "Let's start with the simplest use case: calculating metrics without any sample weights. This example demonstrates the core object-oriented workflow. Note that for Precision and Recall, metrax uses a default classification threshold of 0.5.\n", |
347 | 350 | "\n" |
348 | | - ], |
349 | | - "metadata": { |
350 | | - "id": "TokxcQvTvWua" |
351 | | - } |
| 351 | + ] |
352 | 352 | }, |
353 | 353 | { |
354 | 354 | "cell_type": "code", |
| 355 | + "execution_count": null, |
| 356 | + "metadata": { |
| 357 | + "id": "g0ULds_Sx06D" |
| 358 | + }, |
| 359 | + "outputs": [], |
355 | 360 | "source": [ |
356 | 361 | "import metrax.nnx\n", |
357 | 362 | "\n", |
|
362 | 367 | " 'AUCPR': metrax.nnx.AUCPR,\n", |
363 | 368 | " 'AUCROC': metrax.nnx.AUCROC,\n", |
364 | 369 | "}" |
365 | | - ], |
366 | | - "metadata": { |
367 | | - "id": "g0ULds_Sx06D" |
368 | | - }, |
369 | | - "execution_count": null, |
370 | | - "outputs": [] |
| 370 | + ] |
371 | 371 | }, |
372 | 372 | { |
373 | 373 | "cell_type": "code", |
|
395 | 395 | }, |
396 | 396 | { |
397 | 397 | "cell_type": "code", |
| 398 | + "execution_count": null, |
| 399 | + "metadata": { |
| 400 | + "id": "2Hz14HRfx5vM" |
| 401 | + }, |
| 402 | + "outputs": [], |
398 | 403 | "source": [ |
399 | 404 | "# --- Method 2: Iterative Updating by Batch (nnx) ---\n", |
400 | 405 | "print(\"\\n--- Method 2: Iterative Updating with nnx (Unweighted) ---\")\n", |
|
410 | 415 | "for name, metric_obj in iterative_metrics_nnx.items():\n", |
411 | 416 | " iterative_results_nnx[name] = metric_obj.compute()\n", |
412 | 417 | " print(f\"{name}: {iterative_results_nnx[name]}\")" |
413 | | - ], |
414 | | - "metadata": { |
415 | | - "id": "2Hz14HRfx5vM" |
416 | | - }, |
417 | | - "execution_count": null, |
418 | | - "outputs": [] |
| 418 | + ] |
419 | 419 | }, |
420 | 420 | { |
421 | 421 | "cell_type": "markdown", |
| 422 | + "metadata": { |
| 423 | + "id": "mhq1dM5Rvq-Q" |
| 424 | + }, |
422 | 425 | "source": [ |
423 | 426 | "###Advanced Usage: Incorporating Sample Weights\n", |
424 | 427 | "In many real-world scenarios, you'll want to assign different importance to different examples. This is often done to handle class imbalance, where you might give more weight to examples from a rare class. metrax.nnx supports this through the sample_weights argument in the .update() method.\n", |
425 | 428 | "\n", |
426 | 429 | "The following example calculates AUCPR and AUCROC, which are metrics that support sample weights, using the weighted data we prepared earlier" |
427 | | - ], |
428 | | - "metadata": { |
429 | | - "id": "mhq1dM5Rvq-Q" |
430 | | - } |
| 430 | + ] |
431 | 431 | }, |
432 | 432 | { |
433 | 433 | "cell_type": "code", |
| 434 | + "execution_count": null, |
| 435 | + "metadata": { |
| 436 | + "id": "mB2z33Jax8md" |
| 437 | + }, |
| 438 | + "outputs": [], |
434 | 439 | "source": [ |
435 | 440 | "import metrax.nnx\n", |
436 | 441 | "\n", |
|
439 | 444 | " 'AUCPR': metrax.nnx.AUCPR,\n", |
440 | 445 | " 'AUCROC': metrax.nnx.AUCROC,\n", |
441 | 446 | "}" |
442 | | - ], |
443 | | - "metadata": { |
444 | | - "id": "mB2z33Jax8md" |
445 | | - }, |
446 | | - "execution_count": null, |
447 | | - "outputs": [] |
| 447 | + ] |
448 | 448 | }, |
449 | 449 | { |
450 | 450 | "cell_type": "code", |
| 451 | + "execution_count": null, |
| 452 | + "metadata": { |
| 453 | + "id": "HERtwSZbvs6-" |
| 454 | + }, |
| 455 | + "outputs": [], |
451 | 456 | "source": [ |
452 | 457 | "# --- Method 1: Full-Batch Calculation with Sample Weights ---\n", |
453 | 458 | "print(\"--- Method 1: Full-Batch Calculation with nnx (Weighted) ---\")\n", |
|
467 | 472 | "for name, metric_obj in full_batch_metrics_weighted.items():\n", |
468 | 473 | " full_batch_results_weighted[name] = metric_obj.compute()\n", |
469 | 474 | " print(f\"{name}: {full_batch_results_weighted[name]}\")" |
470 | | - ], |
471 | | - "metadata": { |
472 | | - "id": "HERtwSZbvs6-" |
473 | | - }, |
474 | | - "execution_count": null, |
475 | | - "outputs": [] |
| 475 | + ] |
476 | 476 | }, |
477 | 477 | { |
478 | 478 | "cell_type": "code", |
| 479 | + "execution_count": null, |
| 480 | + "metadata": { |
| 481 | + "id": "rTAmdmyHx-wi" |
| 482 | + }, |
| 483 | + "outputs": [], |
479 | 484 | "source": [ |
480 | 485 | "# --- Method 2: Iterative Updating with Sample Weights ---\n", |
481 | 486 | "print(\"\\n--- Method 2: Iterative Updating with nnx (Weighted) ---\")\n", |
|
495 | 500 | "for name, metric_obj in iterative_metrics_weighted.items():\n", |
496 | 501 | " iterative_results_weighted[name] = metric_obj.compute()\n", |
497 | 502 | " print(f\"{name}: {iterative_results_weighted[name]}\")" |
498 | | - ], |
499 | | - "metadata": { |
500 | | - "id": "rTAmdmyHx-wi" |
501 | | - }, |
502 | | - "execution_count": null, |
503 | | - "outputs": [] |
| 503 | + ] |
504 | 504 | }, |
505 | 505 | { |
506 | 506 | "cell_type": "markdown", |
| 507 | + "metadata": { |
| 508 | + "id": "5_MHjDaRzwBf" |
| 509 | + }, |
507 | 510 | "source": [ |
508 | 511 | "## Scaling to Multiple Devices\n", |
509 | 512 | "\n", |
|
521 | 524 | "4. **`jit`-Compile the Function**: You write a function that looks like a normal, single-device calculation and decorate it with `@jax.jit`. When JAX's compiler sees that the inputs to this function are sharded arrays, it automatically generates a distributed version of the code, implicitly handling all cross-device communication.\n", |
522 | 525 | "\n", |
523 | 526 | "This method provides the fine-grained control that is essential for all modern JAX parallelism patterns." |
524 | | - ], |
525 | | - "metadata": { |
526 | | - "id": "5_MHjDaRzwBf" |
527 | | - } |
| 527 | + ] |
528 | 528 | }, |
529 | 529 | { |
530 | 530 | "cell_type": "code", |
| 531 | + "execution_count": null, |
| 532 | + "metadata": { |
| 533 | + "id": "QkVcKJRPAbRE" |
| 534 | + }, |
| 535 | + "outputs": [], |
531 | 536 | "source": [ |
532 | 537 | "import jax\n", |
533 | 538 | "import numpy as np\n", |
|
547 | 552 | " labels=labels,\n", |
548 | 553 | " sample_weights=sample_weights\n", |
549 | 554 | " )" |
550 | | - ], |
551 | | - "metadata": { |
552 | | - "id": "QkVcKJRPAbRE" |
553 | | - }, |
554 | | - "execution_count": null, |
555 | | - "outputs": [] |
| 555 | + ] |
556 | 556 | }, |
557 | 557 | { |
558 | 558 | "cell_type": "code", |
| 559 | + "execution_count": null, |
| 560 | + "metadata": { |
| 561 | + "id": "VGsGxzR3Afsv" |
| 562 | + }, |
| 563 | + "outputs": [], |
559 | 564 | "source": [ |
560 | 565 | "# Advanced SPMD Parallelism: jit + Mesh\n", |
561 | 566 | "def calculate_aucpr_mesh(predictions, labels, sample_weights):\n", |
|
588 | 593 | "\n", |
589 | 594 | " # The result is already a globally correct metric state, replicated on all devices.\n", |
590 | 595 | " return jitted_calculate(sharded_predictions, sharded_labels, sharded_weights)" |
591 | | - ], |
592 | | - "metadata": { |
593 | | - "id": "VGsGxzR3Afsv" |
594 | | - }, |
595 | | - "execution_count": null, |
596 | | - "outputs": [] |
| 596 | + ] |
597 | 597 | }, |
598 | 598 | { |
599 | 599 | "cell_type": "code", |
|
0 commit comments