diff --git a/skore-hub-project/src/skore_hub_project/report/cross_validation_report.py b/skore-hub-project/src/skore_hub_project/report/cross_validation_report.py index b774ecd96..f28fdf4e9 100644 --- a/skore-hub-project/src/skore_hub_project/report/cross_validation_report.py +++ b/skore-hub-project/src/skore_hub_project/report/cross_validation_report.py @@ -72,8 +72,6 @@ class CrossValidationReportPayload(ReportPayload): The project to which the report payload should be sent. report : CrossValidationReport The report on which to calculate the payload to be sent. - upload : bool, optional - Upload the report to the artefacts storage, default True. key : str The key to associate to the report. """ @@ -224,7 +222,6 @@ def estimators(self) -> list[EstimatorReportPayload]: EstimatorReportPayload( project=self.project, report=report, - upload=False, key=f"{self.key}:estimator-report", ) for report in self.report.estimator_reports_ @@ -236,7 +233,7 @@ def parameters(self) -> CrossValidationReportArtefact | dict[()]: """ The checksum of the instance. - The checksum of the instance that was assigned after being uploaded to the + The checksum of the instance that was assigned before being uploaded to the artefact storage. It is based on its ``joblib`` serialization and mainly used to retrieve it from the artefacts storage. @@ -244,9 +241,4 @@ def parameters(self) -> CrossValidationReportArtefact | dict[()]: The ``parameters`` property will be removed in favor of a new ``checksum`` property in a near future. """ - if self.upload: - return CrossValidationReportArtefact( - project=self.project, - report=self.report, - ) - return {} + return CrossValidationReportArtefact(project=self.project, report=self.report) diff --git a/skore-hub-project/src/skore_hub_project/report/estimator_report.py b/skore-hub-project/src/skore_hub_project/report/estimator_report.py index 623267060..ea2f4d30f 100644 --- a/skore-hub-project/src/skore_hub_project/report/estimator_report.py +++ b/skore-hub-project/src/skore_hub_project/report/estimator_report.py @@ -62,8 +62,6 @@ class EstimatorReportPayload(ReportPayload): The project to which the report payload should be sent. report : EstimatorReport The report on which to calculate the payload to be sent. - upload : bool, optional - Upload the report to the artefacts storage, default True. key : str The key to associate to the report. """ @@ -120,7 +118,7 @@ def parameters(self) -> EstimatorReportArtefact | dict: """ The checksum of the instance. - The checksum of the instance that was assigned after being uploaded to the + The checksum of the instance that was assigned before being uploaded to the artefact storage. It is based on its ``joblib`` serialization and mainly used to retrieve it from the artefacts storage. @@ -128,6 +126,4 @@ def parameters(self) -> EstimatorReportArtefact | dict: The ``parameters`` property will be removed in favor of a new ``checksum`` property in a near future. """ - if self.upload: - return EstimatorReportArtefact(project=self.project, report=self.report) - return {} + return EstimatorReportArtefact(project=self.project, report=self.report) diff --git a/skore-hub-project/src/skore_hub_project/report/report.py b/skore-hub-project/src/skore_hub_project/report/report.py index 968d81e78..dbdc2e305 100644 --- a/skore-hub-project/src/skore_hub_project/report/report.py +++ b/skore-hub-project/src/skore_hub_project/report/report.py @@ -28,8 +28,6 @@ class ReportPayload(ABC, BaseModel): The project to which the report payload should be sent. report : EstimatorReport | CrossValidationReport The report on which to calculate the payload to be sent. - upload : bool, optional - Upload the report to the artefacts storage, default True. key : str The key to associate to the report. """ @@ -41,7 +39,6 @@ class ReportPayload(ABC, BaseModel): project: Project = Field(repr=False, exclude=True) report: EstimatorReport | CrossValidationReport = Field(repr=False, exclude=True) - upload: bool = Field(default=True, repr=False, exclude=True) key: str @computed_field # type: ignore[prop-decorator] @@ -79,7 +76,7 @@ def parameters(self) -> Artefact | dict[()]: """ The checksum of the instance. - The checksum of the instance that was assigned after being uploaded to the + The checksum of the instance that was assigned before being uploaded to the artefact storage. It is based on its ``joblib`` serialization and mainly used to retrieve it from the artefacts storage. @@ -106,14 +103,12 @@ def metrics(self) -> list[Metric]: - int [0, inf[, to be displayed at the position, - None, not to be displayed. """ - payloads = [ + return [ payload for metric in self.METRICS if (payload := metric(report=self.report)).value is not None ] - return payloads - @computed_field # type: ignore[prop-decorator] @cached_property def related_items(self) -> list[Media]: @@ -126,10 +121,8 @@ def related_items(self) -> list[Media]: ----- Unavailable medias have been filtered out. """ - payloads = [ + return [ payload for media in cast(list[Callable], self.MEDIAS) if (payload := media(report=self.report)).representation is not None ] - - return payloads diff --git a/skore-hub-project/tests/unit/report/test_cross_validation_report.py b/skore-hub-project/tests/unit/report/test_cross_validation_report.py index cbf077ded..1d24f5299 100644 --- a/skore-hub-project/tests/unit/report/test_cross_validation_report.py +++ b/skore-hub-project/tests/unit/report/test_cross_validation_report.py @@ -8,8 +8,9 @@ from sklearn.datasets import make_classification, make_regression from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.model_selection import ShuffleSplit -from skore import CrossValidationReport +from skore import CrossValidationReport, EstimatorReport from skore_hub_project import Project +from skore_hub_project.artefact import EstimatorReportArtefact from skore_hub_project.artefact.serializer import Serializer from skore_hub_project.media import EstimatorHtmlRepr from skore_hub_project.media.data import TableReport @@ -168,18 +169,59 @@ def test_classes(self, payload): def test_classes_many_rows(self, payload): assert payload.classes == [0, 0, 1, 1, 1, 0, 0, 1, 0, 1] + @mark.usefixtures("monkeypatch_routes") def test_estimators(self, payload, respx_mock): - respx_mock.post("projects///runs").mock( - Response(200, json={"id": 0}) - ) + estimators = [estimator.model_dump() for estimator in payload.estimators] - assert len(payload.estimators) == len(payload.report.estimator_reports_) + # Ensure payload dict is well constructed + assert len(estimators) == len(payload.report.estimator_reports_) - for i, er_payload in enumerate(payload.estimators): - assert isinstance(er_payload, EstimatorReportPayload) - assert er_payload.report == payload.report.estimator_reports_[i] - assert er_payload.upload is False - assert er_payload.parameters == {} + for i, estimator in enumerate(payload.estimators): + assert isinstance(estimator, EstimatorReportPayload) + assert isinstance(estimator.parameters, EstimatorReportArtefact) + assert estimator.report == payload.report.estimator_reports_[i] + + # Ensure upload is well done + requests = [call.request for call in respx_mock.calls][1:] + + assert len(requests) == (len(payload.report.estimator_reports_) * 3) + + def serialize(object: EstimatorReport) -> tuple[bytes, str]: + cache = object._cache + object._cache = {} + + try: + with Serializer(object) as serializer: + pickle = serializer.filepath.read_bytes() + checksum = serializer.checksum + finally: + object._cache = cache + + return pickle, checksum + + for i in range(len(payload.report.estimator_reports_)): + pickle, checksum = serialize(payload.report.estimator_reports_[i]) + r0 = requests[(i * 3)] + r1 = requests[(i * 3) + 1] + r2 = requests[(i * 3) + 2] + + assert r0.url.path == "/projects///artefacts" + assert loads(r0.content.decode()) == [ + { + "checksum": checksum, + "chunk_number": 1, + "content_type": "estimator-report", + } + ] + assert r1.url == "http://chunk1.com/" + assert r1.content == pickle + assert r2.url.path == "/projects///artefacts/complete" + assert loads(r2.content.decode()) == [ + { + "checksum": checksum, + "etags": {"1": '""'}, + } + ] @mark.usefixtures("monkeypatch_routes") def test_parameters(self, small_cv_binary_classification, payload, respx_mock):