Skip to content

Commit e965d15

Browse files
committed
release google-metrax 0.1.4 to pypi
1 parent 9628089 commit e965d15

File tree

2 files changed

+102
-102
lines changed

2 files changed

+102
-102
lines changed

metrax_example.ipynb

Lines changed: 101 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
},
3131
{
3232
"cell_type": "code",
33-
"source": [
34-
"!pip install google-metrax clu"
35-
],
33+
"execution_count": null,
3634
"metadata": {
3735
"id": "t-eEZKa8qHy7"
3836
},
39-
"execution_count": null,
40-
"outputs": []
37+
"outputs": [],
38+
"source": [
39+
"!pip install google-metrax"
40+
]
4141
},
4242
{
4343
"cell_type": "code",
@@ -135,6 +135,9 @@
135135
},
136136
{
137137
"cell_type": "markdown",
138+
"metadata": {
139+
"id": "aZiutTU3qguP"
140+
},
138141
"source": [
139142
"## Lifecycle of a Metrax Metric\n",
140143
"\n",
@@ -144,10 +147,7 @@
144147
"2. **Iteration/Batch Processing:** As you process data in batches or on different devices, you create a new metric state for the current batch/device using `Metric.from_model_output()` (functional API) or update the existing metric object with `.update()` (object-oriented API), passing the predictions, labels, and any relevant weights for that specific data slice.\n",
145148
"3. **Merging/Updating:** For the functional API, you merge the newly created metric state for the current batch/device with the accumulated state in your dictionary using the `.merge()` method. For the object-oriented API, the `.update()` method directly modifies the state within the metric object in your dictionary. This step accumulates the necessary statistics across all processed data.\n",
146149
"4. **Final Computation:** After processing all data, you call the `.compute()` method on the final, merged (functional API) or updated (object-oriented API) metric state in your dictionary. This performs the final calculations and returns the metric's value (e.g., a single floating-point number like AUCPR)."
147-
],
148-
"metadata": {
149-
"id": "aZiutTU3qguP"
150-
}
150+
]
151151
},
152152
{
153153
"cell_type": "markdown",
@@ -171,16 +171,21 @@
171171
},
172172
{
173173
"cell_type": "markdown",
174+
"metadata": {
175+
"id": "TOT_PXRDv80K"
176+
},
174177
"source": [
175178
"###Basic Usage: Unweighted Metrics\n",
176179
"This first example demonstrates the core functional workflow without sample weights."
177-
],
178-
"metadata": {
179-
"id": "TOT_PXRDv80K"
180-
}
180+
]
181181
},
182182
{
183183
"cell_type": "code",
184+
"execution_count": null,
185+
"metadata": {
186+
"id": "HFHRQRmpxmys"
187+
},
188+
"outputs": [],
184189
"source": [
185190
"import metrax\n",
186191
"\n",
@@ -191,12 +196,7 @@
191196
" 'AUCPR': metrax.AUCPR,\n",
192197
" 'AUCROC': metrax.AUCROC,\n",
193198
"}"
194-
],
195-
"metadata": {
196-
"id": "HFHRQRmpxmys"
197-
},
198-
"execution_count": null,
199-
"outputs": []
199+
]
200200
},
201201
{
202202
"cell_type": "code",
@@ -220,6 +220,11 @@
220220
},
221221
{
222222
"cell_type": "code",
223+
"execution_count": null,
224+
"metadata": {
225+
"id": "sYZe0JUQxiNt"
226+
},
227+
"outputs": [],
223228
"source": [
224229
"# --- Method 2: Iterative Merging by Batch (Unweighted) ---\n",
225230
"print(\"\\n--- Method 2: Iterative Merging (Unweighted) ---\")\n",
@@ -239,25 +244,25 @@
239244
"for name, metric_state in iterative_metrics.items():\n",
240245
" iterative_results[name] = metric_state.compute()\n",
241246
" print(f\"{name}: {iterative_results[name]}\")"
242-
],
243-
"metadata": {
244-
"id": "sYZe0JUQxiNt"
245-
},
246-
"execution_count": null,
247-
"outputs": []
247+
]
248248
},
249249
{
250250
"cell_type": "markdown",
251+
"metadata": {
252+
"id": "B_HJqaJKv_73"
253+
},
251254
"source": [
252255
"###Advanced Usage: Incorporating Sample Weights\n",
253256
"Just like the NNX API, the functional API supports sample_weights in from_model_output for metrics where it is applicable. The following example calculates AUCPR and AUCROC using the weighted data we prepared earlier."
254-
],
255-
"metadata": {
256-
"id": "B_HJqaJKv_73"
257-
}
257+
]
258258
},
259259
{
260260
"cell_type": "code",
261+
"execution_count": null,
262+
"metadata": {
263+
"id": "czPiI6dhxvmL"
264+
},
265+
"outputs": [],
261266
"source": [
262267
"import metrax\n",
263268
"\n",
@@ -266,15 +271,15 @@
266271
" 'AUCPR': metrax.AUCPR,\n",
267272
" 'AUCROC': metrax.AUCROC,\n",
268273
"}"
269-
],
270-
"metadata": {
271-
"id": "czPiI6dhxvmL"
272-
},
273-
"execution_count": null,
274-
"outputs": []
274+
]
275275
},
276276
{
277277
"cell_type": "code",
278+
"execution_count": null,
279+
"metadata": {
280+
"id": "MSnndNonwDKq"
281+
},
282+
"outputs": [],
278283
"source": [
279284
"# --- Method 1: Full-Batch Calculation (Weighted) ---\n",
280285
"print(\"--- Method 1: Full-Batch Calculation (Weighted) ---\")\n",
@@ -287,15 +292,15 @@
287292
" )\n",
288293
" full_batch_results_weighted[name] = metric_state.compute()\n",
289294
" print(f\"{name}: {full_batch_results_weighted[name]}\")"
290-
],
291-
"metadata": {
292-
"id": "MSnndNonwDKq"
293-
},
294-
"execution_count": null,
295-
"outputs": []
295+
]
296296
},
297297
{
298298
"cell_type": "code",
299+
"execution_count": null,
300+
"metadata": {
301+
"id": "e8BPLW6XxxVr"
302+
},
303+
"outputs": [],
299304
"source": [
300305
"# --- Method 2: Iterative Merging by Batch (Weighted) ---\n",
301306
"print(\"\\n--- Method 2: Iterative Merging (Weighted) ---\")\n",
@@ -315,12 +320,7 @@
315320
"for name, metric_state in iterative_metrics_weighted.items():\n",
316321
" iterative_results_weighted[name] = metric_state.compute()\n",
317322
" print(f\"{name}: {iterative_results_weighted[name]}\")"
318-
],
319-
"metadata": {
320-
"id": "e8BPLW6XxxVr"
321-
},
322-
"execution_count": null,
323-
"outputs": []
323+
]
324324
},
325325
{
326326
"cell_type": "markdown",
@@ -341,17 +341,22 @@
341341
},
342342
{
343343
"cell_type": "markdown",
344+
"metadata": {
345+
"id": "TokxcQvTvWua"
346+
},
344347
"source": [
345348
"###Basic Usage: Unweighted Metrics\n",
346349
"Let's start with the simplest use case: calculating metrics without any sample weights. This example demonstrates the core object-oriented workflow. Note that for Precision and Recall, metrax uses a default classification threshold of 0.5.\n",
347350
"\n"
348-
],
349-
"metadata": {
350-
"id": "TokxcQvTvWua"
351-
}
351+
]
352352
},
353353
{
354354
"cell_type": "code",
355+
"execution_count": null,
356+
"metadata": {
357+
"id": "g0ULds_Sx06D"
358+
},
359+
"outputs": [],
355360
"source": [
356361
"import metrax.nnx\n",
357362
"\n",
@@ -362,12 +367,7 @@
362367
" 'AUCPR': metrax.nnx.AUCPR,\n",
363368
" 'AUCROC': metrax.nnx.AUCROC,\n",
364369
"}"
365-
],
366-
"metadata": {
367-
"id": "g0ULds_Sx06D"
368-
},
369-
"execution_count": null,
370-
"outputs": []
370+
]
371371
},
372372
{
373373
"cell_type": "code",
@@ -395,6 +395,11 @@
395395
},
396396
{
397397
"cell_type": "code",
398+
"execution_count": null,
399+
"metadata": {
400+
"id": "2Hz14HRfx5vM"
401+
},
402+
"outputs": [],
398403
"source": [
399404
"# --- Method 2: Iterative Updating by Batch (nnx) ---\n",
400405
"print(\"\\n--- Method 2: Iterative Updating with nnx (Unweighted) ---\")\n",
@@ -410,27 +415,27 @@
410415
"for name, metric_obj in iterative_metrics_nnx.items():\n",
411416
" iterative_results_nnx[name] = metric_obj.compute()\n",
412417
" print(f\"{name}: {iterative_results_nnx[name]}\")"
413-
],
414-
"metadata": {
415-
"id": "2Hz14HRfx5vM"
416-
},
417-
"execution_count": null,
418-
"outputs": []
418+
]
419419
},
420420
{
421421
"cell_type": "markdown",
422+
"metadata": {
423+
"id": "mhq1dM5Rvq-Q"
424+
},
422425
"source": [
423426
"###Advanced Usage: Incorporating Sample Weights\n",
424427
"In many real-world scenarios, you'll want to assign different importance to different examples. This is often done to handle class imbalance, where you might give more weight to examples from a rare class. metrax.nnx supports this through the sample_weights argument in the .update() method.\n",
425428
"\n",
426429
"The following example calculates AUCPR and AUCROC, which are metrics that support sample weights, using the weighted data we prepared earlier"
427-
],
428-
"metadata": {
429-
"id": "mhq1dM5Rvq-Q"
430-
}
430+
]
431431
},
432432
{
433433
"cell_type": "code",
434+
"execution_count": null,
435+
"metadata": {
436+
"id": "mB2z33Jax8md"
437+
},
438+
"outputs": [],
434439
"source": [
435440
"import metrax.nnx\n",
436441
"\n",
@@ -439,15 +444,15 @@
439444
" 'AUCPR': metrax.nnx.AUCPR,\n",
440445
" 'AUCROC': metrax.nnx.AUCROC,\n",
441446
"}"
442-
],
443-
"metadata": {
444-
"id": "mB2z33Jax8md"
445-
},
446-
"execution_count": null,
447-
"outputs": []
447+
]
448448
},
449449
{
450450
"cell_type": "code",
451+
"execution_count": null,
452+
"metadata": {
453+
"id": "HERtwSZbvs6-"
454+
},
455+
"outputs": [],
451456
"source": [
452457
"# --- Method 1: Full-Batch Calculation with Sample Weights ---\n",
453458
"print(\"--- Method 1: Full-Batch Calculation with nnx (Weighted) ---\")\n",
@@ -467,15 +472,15 @@
467472
"for name, metric_obj in full_batch_metrics_weighted.items():\n",
468473
" full_batch_results_weighted[name] = metric_obj.compute()\n",
469474
" print(f\"{name}: {full_batch_results_weighted[name]}\")"
470-
],
471-
"metadata": {
472-
"id": "HERtwSZbvs6-"
473-
},
474-
"execution_count": null,
475-
"outputs": []
475+
]
476476
},
477477
{
478478
"cell_type": "code",
479+
"execution_count": null,
480+
"metadata": {
481+
"id": "rTAmdmyHx-wi"
482+
},
483+
"outputs": [],
479484
"source": [
480485
"# --- Method 2: Iterative Updating with Sample Weights ---\n",
481486
"print(\"\\n--- Method 2: Iterative Updating with nnx (Weighted) ---\")\n",
@@ -495,15 +500,13 @@
495500
"for name, metric_obj in iterative_metrics_weighted.items():\n",
496501
" iterative_results_weighted[name] = metric_obj.compute()\n",
497502
" print(f\"{name}: {iterative_results_weighted[name]}\")"
498-
],
499-
"metadata": {
500-
"id": "rTAmdmyHx-wi"
501-
},
502-
"execution_count": null,
503-
"outputs": []
503+
]
504504
},
505505
{
506506
"cell_type": "markdown",
507+
"metadata": {
508+
"id": "5_MHjDaRzwBf"
509+
},
507510
"source": [
508511
"## Scaling to Multiple Devices\n",
509512
"\n",
@@ -521,13 +524,15 @@
521524
"4. **`jit`-Compile the Function**: You write a function that looks like a normal, single-device calculation and decorate it with `@jax.jit`. When JAX's compiler sees that the inputs to this function are sharded arrays, it automatically generates a distributed version of the code, implicitly handling all cross-device communication.\n",
522525
"\n",
523526
"This method provides the fine-grained control that is essential for all modern JAX parallelism patterns."
524-
],
525-
"metadata": {
526-
"id": "5_MHjDaRzwBf"
527-
}
527+
]
528528
},
529529
{
530530
"cell_type": "code",
531+
"execution_count": null,
532+
"metadata": {
533+
"id": "QkVcKJRPAbRE"
534+
},
535+
"outputs": [],
531536
"source": [
532537
"import jax\n",
533538
"import numpy as np\n",
@@ -547,15 +552,15 @@
547552
" labels=labels,\n",
548553
" sample_weights=sample_weights\n",
549554
" )"
550-
],
551-
"metadata": {
552-
"id": "QkVcKJRPAbRE"
553-
},
554-
"execution_count": null,
555-
"outputs": []
555+
]
556556
},
557557
{
558558
"cell_type": "code",
559+
"execution_count": null,
560+
"metadata": {
561+
"id": "VGsGxzR3Afsv"
562+
},
563+
"outputs": [],
559564
"source": [
560565
"# Advanced SPMD Parallelism: jit + Mesh\n",
561566
"def calculate_aucpr_mesh(predictions, labels, sample_weights):\n",
@@ -588,12 +593,7 @@
588593
"\n",
589594
" # The result is already a globally correct metric state, replicated on all devices.\n",
590595
" return jitted_calculate(sharded_predictions, sharded_labels, sharded_weights)"
591-
],
592-
"metadata": {
593-
"id": "VGsGxzR3Afsv"
594-
},
595-
"execution_count": null,
596-
"outputs": []
596+
]
597597
},
598598
{
599599
"cell_type": "code",

0 commit comments

Comments
 (0)