The google-cloud-mldiagnostics library is a Python package designed to help
engineers and researchers monitor and diagnose machine learning training runs
with GCP suite of diagnostic toolings.
It provides tools for tracking workload progress, collecting metrics and
profiling performance.
- jax
- any versions
- Other in progress
Install pypi package link
pip install google-cloud-mldiagnosticsThis package does not install libtpu, jax and xprof and expects they are
installed separately.
At the beginning of the training script, create a machine learning run:
from google_cloud_mldiagnostics import machinelearning_run
machinelearning_run(
name=<run-name>,
gcs_path="gs://<bucket>",
)from google_cloud_mldiagnostics import machinelearning_run
machinelearning_run(
name=<run-name>,
gcs_path="gs://<bucket>",
on_demand_xprof=True
)from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
machinelearning_run(
name=<run-name>,
gcs_path="gs://<bucket>",
)
xprof = xprof()
xprof.start()
# some code
xprof.stop()from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import metrics
from google_cloud_mldiagnostics import metric_types
machinelearning_run(
name=<run-name>,
gcs_path="gs://<bucket>",
)
metrics.record(metric_type.MetricType.LOSS, <value>)To pair the metric value with the current step:
metrics.record(metric_type.MetricType.LOSS, <value>, step=<step>)from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import metrics
machinelearning_run(
name=<run-name>,
gcs_path="gs://<bucket>",
)
metrics.record("<my-metric>", <value>)To pair the metric value with the current step:
metrics.record("<my-metric>", <value>, step=<value>)