Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.2
skore==0.10.3
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
2 changes: 1 addition & 1 deletion skore-hub-project/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ test-base = [
"pytest-randomly",
"pytest-xdist",
"respx",
"skore",
"skore>=0.10.3",
"skrub",
"xdoctest",
]
Expand Down
10 changes: 7 additions & 3 deletions skore-hub-project/src/skore_hub_project/media/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import Field, computed_field

from skore_hub_project import switch_mpl_backend
from skore_hub_project.protocol import EstimatorReport
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport

from .media import Media, Representation

Expand All @@ -33,7 +33,7 @@ def _to_native(obj):


class TableReport(Media): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
report: EstimatorReport | CrossValidationReport = Field(repr=False, exclude=True)
key: str = "table_report"
verbose_name: str = "Table report"
category: Literal["data"] = "data"
Expand All @@ -51,10 +51,14 @@ def representation(self) -> Representation: # noqa: D102
table_report_display = function(**function_kwargs)
table_report = table_report_display.summary

table_report["extract"] = (
table_report["extract_head"] = (
table_report["dataframe"].head(3).to_dict(orient="split")
)

table_report["extract_tail"] = (
table_report["dataframe"].tail(3).to_dict(orient="split")
)

del table_report["dataframe"]
del table_report["sample_table"]

Expand Down
2 changes: 2 additions & 0 deletions skore-hub-project/src/skore_hub_project/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport):
if not isinstance(key, str):
raise TypeError(f"Key must be a string (found '{type(key)}')")

Payload: type

if isinstance(report, EstimatorReport):
Payload = EstimatorReportPayload
url = f"projects/{self.tenant}/{self.name}/estimator-reports"
Expand Down
1 change: 1 addition & 0 deletions skore-hub-project/src/skore_hub_project/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CrossValidationReport(Protocol):

_cache: Any
metrics: Any
data: Any
estimator_reports_: Any
ml_task: Any
estimator: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from sklearn.model_selection._split import _CVIterableWrapper

from skore_hub_project.artefact import CrossValidationReportArtefact
from skore_hub_project.media import (
EstimatorHtmlRepr,
)
from skore_hub_project.media import EstimatorHtmlRepr
from skore_hub_project.media.data import TableReport
from skore_hub_project.media.media import Media
from skore_hub_project.metric import (
AccuracyTestMean,
Expand Down Expand Up @@ -124,7 +123,13 @@ class CrossValidationReportPayload(ReportPayload):
PredictTimeTrainStd,
),
)
MEDIAS: ClassVar[tuple[Media, ...]] = cast(tuple[Media, ...], (EstimatorHtmlRepr,))
MEDIAS: ClassVar[tuple[Media, ...]] = cast(
tuple[Media, ...],
(
EstimatorHtmlRepr,
TableReport,
),
)

report: CrossValidationReport = Field(repr=False, exclude=True)

Expand Down Expand Up @@ -224,7 +229,6 @@ def estimators(self) -> list[EstimatorReportPayload]:
report=report,
upload=False,
key=f"{self.key}:estimator-report",
run_id=self.run_id,
)
for report in self.report.estimator_reports_
]
Expand Down
5 changes: 3 additions & 2 deletions skore-hub-project/src/skore_hub_project/report/report.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Class definition of the payload used to send a report to ``hub``."""

from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import cached_property
from typing import ClassVar
from typing import ClassVar, cast

from pydantic import BaseModel, ConfigDict, Field, computed_field

Expand Down Expand Up @@ -127,7 +128,7 @@ def related_items(self) -> list[Media]:
"""
payloads = [
payload
for media in self.MEDIAS
for media in cast(list[Callable], self.MEDIAS)
if (payload := media(report=self.report)).representation is not None
]

Expand Down
3 changes: 2 additions & 1 deletion skore-hub-project/tests/unit/media/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def test_table_report(binary_classification, Media, data_source):
"n_rows",
"n_columns",
"n_constant_columns",
"extract",
"extract_head",
"extract_tail",
"columns",
"top_associations",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from skore import CrossValidationReport
from skore_hub_project import Project
from skore_hub_project.artefact.serializer import Serializer
from skore_hub_project.media import (
EstimatorHtmlRepr,
)
from skore_hub_project.media import EstimatorHtmlRepr
from skore_hub_project.media.data import TableReport
from skore_hub_project.metric import (
AccuracyTestMean,
AccuracyTestStd,
Expand Down Expand Up @@ -246,7 +245,10 @@ def test_metrics(self, payload):
]

def test_related_items(self, payload):
assert list(map(type, payload.related_items)) == [EstimatorHtmlRepr]
assert list(map(type, payload.related_items)) == [
EstimatorHtmlRepr,
TableReport,
]

@mark.usefixtures("monkeypatch_routes")
def test_model_dump(self, small_cv_binary_classification, payload):
Expand Down
Loading