Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ class CrossValidationReportPayload(ReportPayload):
The project to which the report payload should be sent.
report : CrossValidationReport
The report on which to calculate the payload to be sent.
upload : bool, optional
Upload the report to the artefacts storage, default True.
key : str
The key to associate to the report.
"""
Expand Down Expand Up @@ -224,7 +222,6 @@ def estimators(self) -> list[EstimatorReportPayload]:
EstimatorReportPayload(
project=self.project,
report=report,
upload=False,
key=f"{self.key}:estimator-report",
)
for report in self.report.estimator_reports_
Expand All @@ -236,17 +233,12 @@ def parameters(self) -> CrossValidationReportArtefact | dict[()]:
"""
The checksum of the instance.

The checksum of the instance that was assigned after being uploaded to the
The checksum of the instance that was assigned before being uploaded to the
artefact storage. It is based on its ``joblib`` serialization and mainly used to
retrieve it from the artefacts storage.

.. deprecated
The ``parameters`` property will be removed in favor of a new ``checksum``
property in a near future.
"""
if self.upload:
return CrossValidationReportArtefact(
project=self.project,
report=self.report,
)
return {}
return CrossValidationReportArtefact(project=self.project, report=self.report)
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class EstimatorReportPayload(ReportPayload):
The project to which the report payload should be sent.
report : EstimatorReport
The report on which to calculate the payload to be sent.
upload : bool, optional
Upload the report to the artefacts storage, default True.
key : str
The key to associate to the report.
"""
Expand Down Expand Up @@ -120,14 +118,12 @@ def parameters(self) -> EstimatorReportArtefact | dict:
"""
The checksum of the instance.

The checksum of the instance that was assigned after being uploaded to the
The checksum of the instance that was assigned before being uploaded to the
artefact storage. It is based on its ``joblib`` serialization and mainly used to
retrieve it from the artefacts storage.

.. deprecated
The ``parameters`` property will be removed in favor of a new ``checksum``
property in a near future.
"""
if self.upload:
return EstimatorReportArtefact(project=self.project, report=self.report)
return {}
return EstimatorReportArtefact(project=self.project, report=self.report)
13 changes: 3 additions & 10 deletions skore-hub-project/src/skore_hub_project/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class ReportPayload(ABC, BaseModel):
The project to which the report payload should be sent.
report : EstimatorReport | CrossValidationReport
The report on which to calculate the payload to be sent.
upload : bool, optional
Upload the report to the artefacts storage, default True.
key : str
The key to associate to the report.
"""
Expand All @@ -41,7 +39,6 @@ class ReportPayload(ABC, BaseModel):

project: Project = Field(repr=False, exclude=True)
report: EstimatorReport | CrossValidationReport = Field(repr=False, exclude=True)
upload: bool = Field(default=True, repr=False, exclude=True)
key: str

@computed_field # type: ignore[prop-decorator]
Expand Down Expand Up @@ -79,7 +76,7 @@ def parameters(self) -> Artefact | dict[()]:
"""
The checksum of the instance.

The checksum of the instance that was assigned after being uploaded to the
The checksum of the instance that was assigned before being uploaded to the
artefact storage. It is based on its ``joblib`` serialization and mainly used to
retrieve it from the artefacts storage.

Expand All @@ -106,14 +103,12 @@ def metrics(self) -> list[Metric]:
- int [0, inf[, to be displayed at the position,
- None, not to be displayed.
"""
payloads = [
return [
payload
for metric in self.METRICS
if (payload := metric(report=self.report)).value is not None
]

return payloads

@computed_field # type: ignore[prop-decorator]
@cached_property
def related_items(self) -> list[Media]:
Expand All @@ -126,10 +121,8 @@ def related_items(self) -> list[Media]:
-----
Unavailable medias have been filtered out.
"""
payloads = [
return [
payload
for media in cast(list[Callable], self.MEDIAS)
if (payload := media(report=self.report)).representation is not None
]

return payloads
62 changes: 52 additions & 10 deletions skore-hub-project/tests/unit/report/test_cross_validation_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from sklearn.datasets import make_classification, make_regression
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import ShuffleSplit
from skore import CrossValidationReport
from skore import CrossValidationReport, EstimatorReport
from skore_hub_project import Project
from skore_hub_project.artefact import EstimatorReportArtefact
from skore_hub_project.artefact.serializer import Serializer
from skore_hub_project.media import EstimatorHtmlRepr
from skore_hub_project.media.data import TableReport
Expand Down Expand Up @@ -168,18 +169,59 @@ def test_classes(self, payload):
def test_classes_many_rows(self, payload):
assert payload.classes == [0, 0, 1, 1, 1, 0, 0, 1, 0, 1]

@mark.usefixtures("monkeypatch_routes")
def test_estimators(self, payload, respx_mock):
respx_mock.post("projects/<tenant>/<name>/runs").mock(
Response(200, json={"id": 0})
)
estimators = [estimator.model_dump() for estimator in payload.estimators]

assert len(payload.estimators) == len(payload.report.estimator_reports_)
# Ensure payload dict is well constructed
assert len(estimators) == len(payload.report.estimator_reports_)

for i, er_payload in enumerate(payload.estimators):
assert isinstance(er_payload, EstimatorReportPayload)
assert er_payload.report == payload.report.estimator_reports_[i]
assert er_payload.upload is False
assert er_payload.parameters == {}
for i, estimator in enumerate(payload.estimators):
assert isinstance(estimator, EstimatorReportPayload)
assert isinstance(estimator.parameters, EstimatorReportArtefact)
assert estimator.report == payload.report.estimator_reports_[i]

# Ensure upload is well done
requests = [call.request for call in respx_mock.calls][1:]

assert len(requests) == (len(payload.report.estimator_reports_) * 3)

def serialize(object: EstimatorReport) -> tuple[bytes, str]:
cache = object._cache
object._cache = {}

try:
with Serializer(object) as serializer:
pickle = serializer.filepath.read_bytes()
checksum = serializer.checksum
finally:
object._cache = cache

return pickle, checksum

for i in range(len(payload.report.estimator_reports_)):
pickle, checksum = serialize(payload.report.estimator_reports_[i])
r0 = requests[(i * 3)]
r1 = requests[(i * 3) + 1]
r2 = requests[(i * 3) + 2]

assert r0.url.path == "/projects/<tenant>/<name>/artefacts"
assert loads(r0.content.decode()) == [
{
"checksum": checksum,
"chunk_number": 1,
"content_type": "estimator-report",
}
]
assert r1.url == "http://chunk1.com/"
assert r1.content == pickle
assert r2.url.path == "/projects/<tenant>/<name>/artefacts/complete"
assert loads(r2.content.decode()) == [
{
"checksum": checksum,
"etags": {"1": '"<etag1>"'},
}
]

@mark.usefixtures("monkeypatch_routes")
def test_parameters(self, small_cv_binary_classification, payload, respx_mock):
Expand Down