Skip to content

Commit 98a8aa2

Browse files
authored
fix: mlflow run id changes in SDK (eval-hub#85)
* fix: mlflow run id changes in SDK * fix: lint errors
1 parent babf6fa commit 98a8aa2

5 files changed

Lines changed: 224 additions & 17 deletions

File tree

src/evalhub/adapter/callbacks.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,23 @@ class _MlflowOps:
3232
3333
from evalhub.adapter.mlflow import MlflowArtifact
3434
35-
callbacks.mlflow.save(
35+
rid = callbacks.mlflow.save(
3636
results,
3737
job_spec,
3838
artifacts=[
3939
MlflowArtifact("results.json", json_bytes, "application/json"),
4040
MlflowArtifact("report.html", html_bytes, "text/html"),
4141
],
4242
)
43+
if rid:
44+
results.mlflow_run_id = rid
4345
4446
Metrics, params, and all artifacts are saved in a single MLflow run.
45-
Does nothing if ``job_spec.experiment_name`` is not set.
47+
Does nothing if ``job_spec.experiment_name`` is not set (returns ``None``).
48+
49+
Returns the MLflow run id when a run is created. Assign it to
50+
``results.mlflow_run_id`` before ``callbacks.report_results(results)`` so
51+
Eval Hub stores the link.
4652
4753
The backend is controlled by the ``backend`` constructor argument or the
4854
``EVALHUB_MLFLOW_BACKEND`` environment variable:
@@ -60,16 +66,15 @@ def save(
6066
results: JobResults,
6167
job_spec: JobSpec,
6268
artifacts: list[MlflowArtifact] | None = None,
63-
) -> None:
69+
) -> str | None:
6470
if not job_spec.experiment_name:
6571
logger.debug("No MLflow experiment configured, skipping")
66-
return
72+
return None
6773

6874
try:
6975
if self._backend == MlflowBackend.UPSTREAM:
70-
self._save_upstream(results, job_spec, artifacts)
71-
else:
72-
self._save_odh(results, job_spec, artifacts)
76+
return self._save_upstream(results, job_spec, artifacts)
77+
return self._save_odh(results, job_spec, artifacts)
7378
except Exception as e:
7479
logger.error("Failed to save to MLflow: %s", e)
7580
raise RuntimeError(f"MLflow save failed: {e}") from e
@@ -82,16 +87,17 @@ def save(
8287
def _build_params_metrics(
8388
results: JobResults,
8489
) -> tuple[list, list]:
85-
from .mlflow import Metric, Param
90+
from .mlflow import Metric, Param, sanitize_metric_key_for_api
8691

8792
params = [
8893
Param("benchmark_id", results.benchmark_id),
8994
Param("model_name", results.model_name),
9095
Param("num_examples_evaluated", str(results.num_examples_evaluated)),
9196
Param("duration_seconds", str(results.duration_seconds)),
9297
]
98+
# MLflow rejects commas etc. in metric keys; Eval Hub keeps r.metric_name as-is.
9399
metrics: list[Metric] = [
94-
Metric(r.metric_name, float(r.metric_value))
100+
Metric(sanitize_metric_key_for_api(r.metric_name), float(r.metric_value))
95101
for r in results.results
96102
if isinstance(r.metric_value, int | float)
97103
]
@@ -104,21 +110,23 @@ def _save_odh(
104110
results: JobResults,
105111
job_spec: JobSpec,
106112
artifacts: list[MlflowArtifact] | None,
107-
) -> None:
113+
) -> str:
108114
from .mlflow import MlflowClient
109115

110116
params, metrics = self._build_params_metrics(results)
111117
run_tags: dict[str, str] = {
112118
tag["key"]: tag["value"] for tag in (job_spec.tags or [])
113119
}
114120

121+
run_id: str = ""
115122
with MlflowClient() as client:
116123
experiment_id = client.get_or_create_experiment(
117124
job_spec.experiment_name or ""
118125
)
119126
with client.start_run(
120127
experiment_id, run_name=job_spec.id, tags=run_tags
121-
) as run_id:
128+
) as rid:
129+
run_id = rid
122130
client.log_batch(run_id, metrics=metrics, params=params)
123131
for artifact in artifacts or []:
124132
client.upload_artifact(
@@ -129,20 +137,21 @@ def _save_odh(
129137
)
130138

131139
logger.info(
132-
"Saved to MLflow (odh) experiment '%s' (run: %s) — "
140+
"Saved to MLflow (odh) experiment '%s' (run_id: %s) — "
133141
"%d metric(s), %d artifact(s)",
134142
job_spec.experiment_name,
135-
job_spec.id,
143+
run_id,
136144
len(metrics),
137145
len(artifacts or []),
138146
)
147+
return run_id
139148

140149
def _save_upstream(
141150
self,
142151
results: JobResults,
143152
job_spec: JobSpec,
144153
artifacts: list[MlflowArtifact] | None,
145-
) -> None:
154+
) -> str:
146155
import tempfile
147156
from pathlib import Path as _Path
148157

@@ -160,7 +169,9 @@ def _save_upstream(
160169
}
161170

162171
mlflow.set_experiment(job_spec.experiment_name)
163-
with mlflow.start_run(run_name=job_spec.id, tags=run_tags):
172+
run_id = ""
173+
with mlflow.start_run(run_name=job_spec.id, tags=run_tags) as active_run:
174+
run_id = active_run.info.run_id
164175
mlflow.log_params({p.key: p.value for p in params})
165176
mlflow.log_metrics({m.key: m.value for m in metrics})
166177

@@ -177,13 +188,14 @@ def _save_upstream(
177188
mlflow.log_artifact(str(tmp_file), artifact_path=artifact_dir)
178189

179190
logger.info(
180-
"Saved to MLflow (upstream) experiment '%s' (run: %s) — "
191+
"Saved to MLflow (upstream) experiment '%s' (run_id: %s) — "
181192
"%d metric(s), %d artifact(s)",
182193
job_spec.experiment_name,
183-
job_spec.id,
194+
run_id,
184195
len(metrics),
185196
len(artifacts or []),
186197
)
198+
return run_id
187199

188200

189201
class DefaultCallbacks(JobCallbacks):
@@ -192,6 +204,16 @@ class DefaultCallbacks(JobCallbacks):
192204
This implementation:
193205
- Reports status updates to sidecar (if available) or logs them
194206
- Pushes OCI artifacts directly using OCIArtifactPersister
207+
- ``report_results(results)``: POSTs final results to Eval Hub; if
208+
``results.mlflow_run_id`` is set (for example from ``save()``), that id
209+
is included (if unset, the field is left out).
210+
211+
Example::
212+
213+
rid = callbacks.mlflow.save(results, job_spec)
214+
if rid:
215+
results.mlflow_run_id = rid
216+
callbacks.report_results(results)
195217
196218
This is the recommended callback implementation for both production and development.
197219
@@ -612,6 +634,9 @@ def report_results(self, results: JobResults) -> None:
612634
if self.provider_id:
613635
status_event["provider_id"] = self.provider_id
614636

637+
if results.mlflow_run_id:
638+
status_event["mlflow_run_id"] = results.mlflow_run_id
639+
615640
# Include OCI artifact reference if available
616641
if results.oci_artifact:
617642
status_event["artifacts"] = {

src/evalhub/adapter/mlflow.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import mimetypes
1414
import os
15+
import re
1516
import time
1617
from collections.abc import Iterable, Iterator
1718
from contextlib import contextmanager
@@ -42,6 +43,19 @@ class Metric:
4243
step: int = 0
4344

4445

46+
# Tracking REST /runs/log-batch rejects keys outside [A-Za-z0-9_\-.\s:/]
47+
_BAD_MLFLOW_METRIC_KEY_CHARS = re.compile(r"[^a-zA-Z0-9_\-.\s:/]+")
48+
49+
50+
def sanitize_metric_key_for_api(name: str) -> str:
51+
"""Map metric names to MLflow-safe keys (e.g. lm-eval ``acc,none`` → ``acc_none``).
52+
53+
Used only when logging to MLflow; ``JobResults`` metric names are unchanged.
54+
"""
55+
s = _BAD_MLFLOW_METRIC_KEY_CHARS.sub("_", name).strip().strip("_")
56+
return s or "metric"
57+
58+
4559
@dataclass
4660
class Param:
4761
key: str

src/evalhub/adapter/models/job.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ class JobResults(BaseModel):
260260
default=None, description="OCI artifact info if persisted"
261261
)
262262

263+
mlflow_run_id: str | None = Field(
264+
default=None,
265+
description="Optional MLflow run id included on the terminal results event when set",
266+
)
267+
263268

264269
class JobCallbacks(ABC):
265270
"""Abstract interface for job callbacks.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Tests for DefaultCallbacks POST /events payload (mlflow_run_id)."""
2+
3+
from __future__ import annotations
4+
5+
from datetime import UTC, datetime
6+
from unittest.mock import MagicMock, patch
7+
8+
from evalhub.adapter.callbacks import DefaultCallbacks
9+
from evalhub.adapter.models.job import JobResults
10+
from evalhub.models.api import EvaluationResult
11+
12+
13+
def _results(mlflow_run_id: str | None = None) -> JobResults:
14+
return JobResults(
15+
id="job-1",
16+
benchmark_id="arc_easy",
17+
benchmark_index=0,
18+
model_name="m",
19+
results=[
20+
EvaluationResult(metric_name="acc", metric_value=0.9, metric_type="float")
21+
],
22+
num_examples_evaluated=1,
23+
duration_seconds=1.0,
24+
completed_at=datetime.now(UTC),
25+
mlflow_run_id=mlflow_run_id,
26+
)
27+
28+
29+
def test_report_results_sends_mlflow_run_id_when_set_on_job_results() -> None:
30+
mock_http = MagicMock()
31+
resp = MagicMock()
32+
resp.raise_for_status = MagicMock()
33+
mock_http.post.return_value = resp
34+
35+
with patch.object(DefaultCallbacks, "_create_http_client", return_value=mock_http):
36+
callbacks = DefaultCallbacks(
37+
job_id="job-1",
38+
benchmark_id="arc_easy",
39+
provider_id="lm_evaluation_harness",
40+
benchmark_index=0,
41+
sidecar_url="http://evalhub:8080",
42+
insecure=True,
43+
)
44+
45+
callbacks.report_results(_results(mlflow_run_id="mlflow-run-abc"))
46+
47+
mock_http.post.assert_called_once()
48+
body = mock_http.post.call_args.kwargs["json"]
49+
assert body["benchmark_status_event"]["mlflow_run_id"] == "mlflow-run-abc"
50+
51+
52+
def test_report_results_omits_mlflow_run_id_when_not_set() -> None:
53+
mock_http = MagicMock()
54+
resp = MagicMock()
55+
resp.raise_for_status = MagicMock()
56+
mock_http.post.return_value = resp
57+
58+
with patch.object(DefaultCallbacks, "_create_http_client", return_value=mock_http):
59+
callbacks = DefaultCallbacks(
60+
job_id="job-1",
61+
benchmark_id="arc_easy",
62+
benchmark_index=0,
63+
sidecar_url="http://evalhub:8080",
64+
insecure=True,
65+
)
66+
67+
callbacks.report_results(_results())
68+
69+
body = mock_http.post.call_args.kwargs["json"]
70+
assert "mlflow_run_id" not in body["benchmark_status_event"]
71+
72+
73+
def test_mlflow_save_returns_run_id_from_odh_path() -> None:
74+
"""Regression: save() must return _save_odh/_save_upstream result (not None)."""
75+
from evalhub.adapter.callbacks import _MlflowOps
76+
from evalhub.adapter.config import MlflowBackend
77+
from evalhub.adapter.models.job import JobResults, JobSpec
78+
from evalhub.models.api import EvaluationResult, ModelConfig
79+
80+
spec = JobSpec(
81+
id="j1",
82+
provider_id="p",
83+
benchmark_id="b",
84+
benchmark_index=0,
85+
model=ModelConfig(url="http://localhost/v1", name="m"),
86+
parameters={},
87+
callback_url="http://localhost/",
88+
experiment_name="exp",
89+
)
90+
results = JobResults(
91+
id="j1",
92+
benchmark_id="b",
93+
benchmark_index=0,
94+
model_name="m",
95+
results=[
96+
EvaluationResult(metric_name="acc", metric_value=1.0, metric_type="float")
97+
],
98+
num_examples_evaluated=1,
99+
duration_seconds=1.0,
100+
completed_at=datetime.now(UTC),
101+
)
102+
ops = _MlflowOps(backend=MlflowBackend.ODH)
103+
with patch.object(_MlflowOps, "_save_odh", return_value="run-from-odh") as m:
104+
rid = ops.save(results, spec)
105+
assert rid == "run-from-odh"
106+
m.assert_called_once()
107+
108+
109+
def test_mlflow_save_returns_run_id_from_upstream_path() -> None:
110+
from evalhub.adapter.callbacks import _MlflowOps
111+
from evalhub.adapter.config import MlflowBackend
112+
from evalhub.adapter.models.job import JobResults, JobSpec
113+
from evalhub.models.api import EvaluationResult, ModelConfig
114+
115+
spec = JobSpec(
116+
id="j1",
117+
provider_id="p",
118+
benchmark_id="b",
119+
benchmark_index=0,
120+
model=ModelConfig(url="http://localhost/v1", name="m"),
121+
parameters={},
122+
callback_url="http://localhost/",
123+
experiment_name="exp",
124+
)
125+
results = JobResults(
126+
id="j1",
127+
benchmark_id="b",
128+
benchmark_index=0,
129+
model_name="m",
130+
results=[
131+
EvaluationResult(metric_name="acc", metric_value=1.0, metric_type="float")
132+
],
133+
num_examples_evaluated=1,
134+
duration_seconds=1.0,
135+
completed_at=datetime.now(UTC),
136+
)
137+
ops = _MlflowOps(backend=MlflowBackend.UPSTREAM)
138+
with patch.object(_MlflowOps, "_save_upstream", return_value="run-upstream") as m:
139+
rid = ops.save(results, spec)
140+
assert rid == "run-upstream"
141+
m.assert_called_once()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""MLflow metric key sanitization for REST API rules."""
2+
3+
from evalhub.adapter.mlflow import sanitize_metric_key_for_api
4+
5+
6+
def test_sanitize_lm_eval_style_comma() -> None:
7+
assert sanitize_metric_key_for_api("acc,none") == "acc_none"
8+
9+
10+
def test_sanitize_preserves_allowed_chars() -> None:
11+
assert (
12+
sanitize_metric_key_for_api("exact_match,strict-match")
13+
== "exact_match_strict-match"
14+
)
15+
16+
17+
def test_sanitize_empty_fallback() -> None:
18+
assert sanitize_metric_key_for_api(",,,") == "metric"
19+
20+
21+
def test_sanitize_simple_name_unchanged() -> None:
22+
assert sanitize_metric_key_for_api("accuracy") == "accuracy"

0 commit comments

Comments
 (0)