Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 53 additions & 32 deletions skore-hub-project/src/skore_hub_project/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import re
from functools import cached_property, wraps
from operator import itemgetter
from tempfile import TemporaryFile
Expand Down Expand Up @@ -84,6 +85,10 @@ class Project:
The current run identifier of the project.
"""

__REPORT_URN_PATTERN = re.compile(
r"skore:report:(?P<type>(estimator|cross-validation)):(?P<id>.+)"
)

def __init__(self, tenant: str, name: str):
"""
Initialize a hub project.
Expand Down Expand Up @@ -168,46 +173,62 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport):
with HUBClient() as client:
client.post(url=url, json=payload_dict)

@property
@ensure_project_is_created
def reports(self):
"""Accessor for interaction with the persisted reports."""
def get(self, urn: str) -> EstimatorReport | CrossValidationReport:
"""Get a persisted report by its URN."""
if m := re.match(Project.__REPORT_URN_PATTERN, urn):
url = f"projects/{self.tenant}/{self.name}/{m['type']}-reports/{m['id']}"
else:
raise ValueError(
f"URN '{urn}' format does not match '{Project.__REPORT_URN_PATTERN}'"
)

def get(id: str) -> EstimatorReport:
"""Get a persisted report by its id."""
# Retrieve report metadata.
with HUBClient() as client:
response = client.get(
url=f"projects/{self.tenant}/{self.name}/experiments/estimator-reports/{id}"
)
# Retrieve report metadata.
with HUBClient() as client:
response = client.get(url=url)

metadata = response.json()
checksum = metadata["raw"]["checksum"]
metadata = response.json()
checksum = metadata["raw"]["checksum"]

# Ask for read url.
with HUBClient() as client:
response = client.get(
url=f"projects/{self.tenant}/{self.name}/artefacts/read",
params={"artefact_checksum": [checksum]},
)
# Ask for read url.
with HUBClient() as client:
response = client.get(
url=f"projects/{self.tenant}/{self.name}/artefacts/read",
params={"artefact_checksum": [checksum]},
)

url = response.json()[0]["url"]

# Download pickled report before unpickling it.
#
# It uses streaming responses that do not load the entire response body into
# memory at once.
with (
TemporaryFile(mode="w+b") as tmpfile,
Client() as client,
client.stream(method="GET", url=url, timeout=30) as response,
):
for data in response.iter_bytes():
tmpfile.write(data)

url = response.json()[0]["url"]
tmpfile.seek(0)

# Download pickled report before unpickling it.
#
# It uses streaming responses that do not load the entire response body into
# memory at once.
with (
TemporaryFile(mode="w+b") as tmpfile,
Client() as client,
client.stream(method="GET", url=url, timeout=30) as response,
):
for data in response.iter_bytes():
tmpfile.write(data)
return joblib.load(tmpfile)

@property
@ensure_project_is_created
def reports(self):
"""Accessor for interaction with the persisted reports."""

tmpfile.seek(0)
def get(urn: str) -> EstimatorReport | CrossValidationReport:
"""
Get a persisted report by its URN.

return joblib.load(tmpfile)
.. deprecated
The ``Project.reports.get`` function will be removed in favor of
``Project.get`` in a near future.
"""
return self.get(urn)

def metadata() -> list[Metadata]:
"""Obtain metadata for all persisted reports regardless of their run."""
Expand Down
39 changes: 34 additions & 5 deletions skore-hub-project/tests/unit/project/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import joblib
from httpx import Client, Response
from pytest import fixture, mark, raises
from skore import EstimatorReport
from skore import CrossValidationReport, EstimatorReport
from skore_hub_project import Project
from skore_hub_project.report import (
CrossValidationReportPayload,
Expand Down Expand Up @@ -172,13 +172,13 @@ def test_reports(self, respx_mock):
assert hasattr(project.reports, "get")
assert hasattr(project.reports, "metadata")

def test_reports_get(self, respx_mock, regression):
def test_reports_get_estimator_report(self, respx_mock, regression):
# Mock hub routes that will be called
url = "projects/<tenant>/<name>/runs"
response = Response(200, json={"id": 0})
respx_mock.post(url).mock(response)

url = "projects/<tenant>/<name>/experiments/estimator-reports/<report_id>"
url = "projects/<tenant>/<name>/estimator-reports/<report_id>"
response = Response(200, json={"raw": {"checksum": "<checksum>"}})
respx_mock.get(url).mock(response)

Expand All @@ -195,11 +195,40 @@ def test_reports_get(self, respx_mock, regression):

# Test
project = Project("<tenant>", "<name>")
report = project.reports.get("<report_id>")
report = project.reports.get("skore:report:estimator:<report_id>")

assert isinstance(report, EstimatorReport)
assert report.estimator_name_ == regression.estimator_name_
assert report._ml_task == regression._ml_task
assert report.ml_task == regression.ml_task

def test_reports_get_cross_validation_report(self, respx_mock, cv_regression):
# Mock hub routes that will be called
url = "projects/<tenant>/<name>/runs"
response = Response(200, json={"id": 0})
respx_mock.post(url).mock(response)

url = "projects/<tenant>/<name>/cross-validation-reports/<report_id>"
response = Response(200, json={"raw": {"checksum": "<checksum>"}})
respx_mock.get(url).mock(response)

url = "projects/<tenant>/<name>/artefacts/read"
response = Response(200, json=[{"url": "http://url.com"}])
respx_mock.get(url).mock(response)

with BytesIO() as stream:
joblib.dump(cv_regression, stream)

url = "http://url.com"
response = Response(200, content=stream.getvalue())
respx_mock.get(url).mock(response)

# Test
project = Project("<tenant>", "<name>")
report = project.reports.get("skore:report:cross-validation:<report_id>")

assert isinstance(report, CrossValidationReport)
assert report.estimator_name_ == cv_regression.estimator_name_
assert report.ml_task == cv_regression.ml_task

def test_reports_metadata(self, nowstr, respx_mock):
url = "projects/<tenant>/<name>/runs"
Expand Down
Loading