Skip to content

Commit 4157779

Browse files
authored
feat(skore-hub-project)!: Add CrossValidationReport metadata to summarize (#2020)
Following #2019. Part of #1923.
1 parent ada3a18 commit 4157779

File tree

2 files changed

+132
-39
lines changed

2 files changed

+132
-39
lines changed

skore-hub-project/src/skore_hub_project/project/project.py

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import itertools
56
import re
67
from functools import cached_property, wraps
78
from operator import itemgetter
@@ -23,13 +24,19 @@ class Metadata(TypedDict): # noqa: D101
2324
key: str
2425
date: str
2526
learner: str
26-
dataset: str
2727
ml_task: str
28+
report_type: str
29+
dataset: str
2830
rmse: float | None
2931
log_loss: float | None
3032
roc_auc: float | None
31-
fit_time: float
32-
predict_time: float
33+
fit_time: float | None
34+
predict_time: float | None
35+
rmse_mean: float | None
36+
log_loss_mean: float | None
37+
roc_auc_mean: float | None
38+
fit_time_mean: float | None
39+
predict_time_mean: float | None
3340

3441

3542
def ensure_project_is_created(method):
@@ -215,6 +222,57 @@ def get(self, urn: str) -> EstimatorReport | CrossValidationReport:
215222

216223
return joblib.load(tmpfile)
217224

225+
@ensure_project_is_created
226+
def summarize(self) -> list[Metadata]:
227+
"""Obtain metadata/metrics for all persisted reports in insertion order."""
228+
229+
def dto(response):
230+
report_type, summary = response
231+
metrics = {
232+
metric["name"]: metric["value"]
233+
for metric in summary["metrics"]
234+
if metric["data_source"] in (None, "test")
235+
}
236+
237+
return {
238+
"id": summary["urn"],
239+
"run_id": summary["run_id"],
240+
"key": summary["key"],
241+
"date": summary["created_at"],
242+
"learner": summary["estimator_class_name"],
243+
"ml_task": summary["ml_task"],
244+
"report_type": report_type,
245+
"dataset": summary["dataset_fingerprint"],
246+
"rmse": metrics.get("rmse"),
247+
"log_loss": metrics.get("log_loss"),
248+
"roc_auc": metrics.get("roc_auc"),
249+
"fit_time": metrics.get("fit_time"),
250+
"predict_time": metrics.get("predict_time"),
251+
"rmse_mean": metrics.get("rmse_mean"),
252+
"log_loss_mean": metrics.get("log_loss_mean"),
253+
"roc_auc_mean": metrics.get("roc_auc_mean"),
254+
"fit_time_mean": metrics.get("fit_time_mean"),
255+
"predict_time_mean": metrics.get("predict_time_mean"),
256+
}
257+
258+
with HUBClient() as client:
259+
responses = itertools.chain(
260+
zip(
261+
itertools.repeat("estimator"),
262+
client.get(
263+
f"projects/{self.tenant}/{self.name}/estimator-reports/"
264+
).json(),
265+
),
266+
zip(
267+
itertools.repeat("cross-validation"),
268+
client.get(
269+
f"projects/{self.tenant}/{self.name}/cross-validation-reports/"
270+
).json(),
271+
),
272+
)
273+
274+
return sorted(map(dto, responses), key=itemgetter("date"))
275+
218276
@property
219277
@ensure_project_is_created
220278
def reports(self):
@@ -231,36 +289,14 @@ def get(urn: str) -> EstimatorReport | CrossValidationReport:
231289
return self.get(urn)
232290

233291
def metadata() -> list[Metadata]:
234-
"""Obtain metadata for all persisted reports regardless of their run."""
235-
236-
def dto(summary):
237-
metrics = {
238-
metric["name"]: metric["value"]
239-
for metric in summary["metrics"]
240-
if metric["data_source"] in (None, "test")
241-
}
242-
243-
return {
244-
"id": summary["id"],
245-
"run_id": summary["run_id"],
246-
"key": summary["key"],
247-
"date": summary["created_at"],
248-
"learner": summary["estimator_class_name"],
249-
"dataset": summary["dataset_fingerprint"],
250-
"ml_task": summary["ml_task"],
251-
"rmse": metrics.get("rmse"),
252-
"log_loss": metrics.get("log_loss"),
253-
"roc_auc": metrics.get("roc_auc"),
254-
"fit_time": metrics.get("fit_time"),
255-
"predict_time": metrics.get("predict_time"),
256-
}
257-
258-
with HUBClient() as client:
259-
response = client.get(
260-
f"projects/{self.tenant}/{self.name}/experiments/estimator-reports"
261-
)
262-
263-
return sorted(map(dto, response.json()), key=itemgetter("date"))
292+
"""
293+
Obtain metadata/metrics for all persisted reports in insertion order.
294+
295+
.. deprecated
296+
The ``Project.reports.metadata`` function will be removed in favor of
297+
``Project.summarize`` in a near future.
298+
"""
299+
return self.summarize()
264300

265301
return SimpleNamespace(get=get, metadata=metadata)
266302

skore-hub-project/tests/unit/project/test_project.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,13 @@ def test_reports_metadata(self, nowstr, respx_mock):
234234
url = "projects/<tenant>/<name>/runs"
235235
respx_mock.post(url).mock(Response(200, json={"id": 2}))
236236

237-
url = "projects/<tenant>/<name>/experiments/estimator-reports"
237+
url = "projects/<tenant>/<name>/estimator-reports/"
238238
respx_mock.get(url).mock(
239239
Response(
240240
200,
241241
json=[
242242
{
243+
"urn": "skore:report:estimator:<report_id_0>",
243244
"id": "<report_id_0>",
244245
"run_id": 0,
245246
"key": "<key>",
@@ -253,6 +254,7 @@ def test_reports_metadata(self, nowstr, respx_mock):
253254
],
254255
},
255256
{
257+
"urn": "skore:report:estimator:<report_id_1>",
256258
"id": "<report_id_1>",
257259
"run_id": 1,
258260
"key": "<key>",
@@ -262,7 +264,30 @@ def test_reports_metadata(self, nowstr, respx_mock):
262264
"created_at": nowstr,
263265
"metrics": [
264266
{"name": "log_loss", "value": 0, "data_source": "train"},
265-
{"name": "log_loss", "value": 1, "data_source": "test"},
267+
{"name": "log_loss", "value": 2, "data_source": "test"},
268+
],
269+
},
270+
],
271+
)
272+
)
273+
274+
url = "projects/<tenant>/<name>/cross-validation-reports/"
275+
respx_mock.get(url).mock(
276+
Response(
277+
200,
278+
json=[
279+
{
280+
"urn": "skore:report:cross-validation:<report_id_2>",
281+
"id": "<report_id_2>",
282+
"run_id": 3,
283+
"key": "<key>",
284+
"ml_task": "<ml_task>",
285+
"estimator_class_name": "<estimator_class_name>",
286+
"dataset_fingerprint": "<dataset_fingerprint>",
287+
"created_at": nowstr,
288+
"metrics": [
289+
{"name": "rmse_mean", "value": 0, "data_source": "train"},
290+
{"name": "rmse_mean", "value": 3, "data_source": "test"},
266291
],
267292
},
268293
],
@@ -274,32 +299,64 @@ def test_reports_metadata(self, nowstr, respx_mock):
274299

275300
assert metadata == [
276301
{
277-
"id": "<report_id_0>",
302+
"id": "skore:report:estimator:<report_id_0>",
278303
"run_id": 0,
279304
"key": "<key>",
280305
"date": nowstr,
281306
"learner": "<estimator_class_name>",
282-
"dataset": "<dataset_fingerprint>",
283307
"ml_task": "<ml_task>",
308+
"report_type": "estimator",
309+
"dataset": "<dataset_fingerprint>",
284310
"rmse": 1,
285311
"log_loss": None,
286312
"roc_auc": None,
287313
"fit_time": None,
288314
"predict_time": None,
315+
"rmse_mean": None,
316+
"log_loss_mean": None,
317+
"roc_auc_mean": None,
318+
"fit_time_mean": None,
319+
"predict_time_mean": None,
289320
},
290321
{
291-
"id": "<report_id_1>",
322+
"id": "skore:report:estimator:<report_id_1>",
292323
"run_id": 1,
293324
"key": "<key>",
294325
"date": nowstr,
295326
"learner": "<estimator_class_name>",
327+
"ml_task": "<ml_task>",
328+
"report_type": "estimator",
296329
"dataset": "<dataset_fingerprint>",
330+
"rmse": None,
331+
"log_loss": 2,
332+
"roc_auc": None,
333+
"fit_time": None,
334+
"predict_time": None,
335+
"rmse_mean": None,
336+
"log_loss_mean": None,
337+
"roc_auc_mean": None,
338+
"fit_time_mean": None,
339+
"predict_time_mean": None,
340+
},
341+
{
342+
"id": "skore:report:cross-validation:<report_id_2>",
343+
"run_id": 3,
344+
"key": "<key>",
345+
"date": nowstr,
346+
"learner": "<estimator_class_name>",
297347
"ml_task": "<ml_task>",
348+
"report_type": "cross-validation",
349+
"dataset": "<dataset_fingerprint>",
298350
"rmse": None,
299-
"log_loss": 1,
351+
"log_loss": None,
300352
"roc_auc": None,
301353
"fit_time": None,
302354
"predict_time": None,
355+
"rmse_mean": 3,
356+
"log_loss_mean": None,
357+
"roc_auc_mean": None,
358+
"fit_time_mean": None,
359+
"predict_time_mean": None,
303360
},
304361
]
305362

0 commit comments

Comments
 (0)