Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion skore-hub-project/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,15 @@ convention = "numpy"
"tests/*" = ["D"]

[tool.mypy]
exclude = ["hatch/", "tests/"]

[[tool.mypy.overrides]]
follow_untyped_imports = true
module = ["skore.*"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
exclude = ["hatch/*", "tests/*"]
module = [
"joblib.*",
"sklearn.*",
]
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Artifact(BaseModel, ABC):
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)

project: Project = Field(repr=False, exclude=True)
content_type: str
content_type: str = Field(init=False)

@abstractmethod
def content_to_upload(self) -> Content | AbstractContextManager[Content]:
Expand Down
13 changes: 5 additions & 8 deletions skore-hub-project/src/skore_hub_project/artifact/media/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from typing import Literal

from pydantic import Field

from skore_hub_project import switch_mpl_backend
from skore_hub_project.artifact.media.media import Media
from skore_hub_project.artifact.media.media import Media, Report
from skore_hub_project.protocol import EstimatorReport


class TableReport(Media): # noqa: D101
class TableReport(Media[Report]): # noqa: D101
name: Literal["table_report"] = "table_report"
data_source: Literal["train", "test"] | None = None
content_type: Literal["application/vnd.skrub.table-report.v1+json"] = (
"application/vnd.skrub.table-report.v1+json"
)
Expand Down Expand Up @@ -41,11 +40,9 @@ def content_to_upload(self) -> bytes: # noqa: D102
)


class TableReportTrain(TableReport): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
class TableReportTrain(TableReport[EstimatorReport]): # noqa: D101
data_source: Literal["train"] = "train"


class TableReportTest(TableReport): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
class TableReportTest(TableReport[EstimatorReport]): # noqa: D101
data_source: Literal["test"] = "test"
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from functools import reduce
from typing import ClassVar, Literal, cast

from pydantic import Field

from skore_hub_project.artifact.media.media import Media
from skore_hub_project.artifact.media.media import Media, Report
from skore_hub_project.protocol import EstimatorReport


class FeatureImportance(Media, ABC): # noqa: D101
class FeatureImportance(Media[Report], ABC): # noqa: D101
accessor: ClassVar[str]
content_type: Literal["application/vnd.dataframe"] = "application/vnd.dataframe"

Expand Down Expand Up @@ -41,8 +39,7 @@ def content_to_upload(self) -> bytes | None: # noqa: D102
)


class Permutation(FeatureImportance, ABC): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
class Permutation(FeatureImportance[EstimatorReport], ABC): # noqa: D101
accessor: ClassVar[str] = "feature_importance.permutation"
name: Literal["permutation"] = "permutation"

Expand All @@ -55,12 +52,13 @@ class PermutationTest(Permutation): # noqa: D101
data_source: Literal["test"] = "test"


class MeanDecreaseImpurity(FeatureImportance): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
class MeanDecreaseImpurity(FeatureImportance[EstimatorReport]): # noqa: D101
accessor: ClassVar[str] = "feature_importance.mean_decrease_impurity"
name: Literal["mean_decrease_impurity"] = "mean_decrease_impurity"
data_source: None = None


class Coefficients(FeatureImportance): # noqa: D101
accessor: ClassVar[str] = "feature_importance.coefficients"
name: Literal["coefficients"] = "coefficients"
data_source: None = None
15 changes: 10 additions & 5 deletions skore-hub-project/src/skore_hub_project/artifact/media/media.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Class definition of the payload used to associate a media with the report."""

from abc import ABC
from typing import Generic, TypeVar

from pydantic import Field

from skore_hub_project.artifact.artifact import Artifact
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport

Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport))


class Media(Artifact):
class Media(Artifact, Generic[Report], ABC):
"""
Payload used to associate a media with the report.

Expand All @@ -17,13 +22,13 @@ class Media(Artifact):
content_type : str
The content-type of the artifact content.
report : EstimatorReport | CrossValidationReport
The report to pickled.
The report on which compute the media.
name : str
The name of the media.
data_source : str | None
The source of the data used to generate the media.
"""

report: EstimatorReport | CrossValidationReport = Field(repr=False, exclude=True)
name: str
data_source: str | None = None
report: Report = Field(repr=False, exclude=True)
name: str = Field(init=False)
data_source: str | None = Field(init=False)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class EstimatorHtmlRepr(Media): # noqa: D101
name: Literal["estimator_html_repr"] = "estimator_html_repr"
data_source: None = None
content_type: Literal["text/html"] = "text/html"

def content_to_upload(self) -> str: # noqa: D102
Expand Down
4 changes: 2 additions & 2 deletions skore-hub-project/src/skore_hub_project/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from time import sleep

from httpx import HTTPError, TimeoutException
from httpx import HTTPStatusError, TimeoutException


def get_oauth_device_login(success_uri: str | None = None):
Expand Down Expand Up @@ -125,7 +125,7 @@ def get_oauth_device_code_probe(device_code: str, *, timeout=600):
while True:
try:
client.get(url, params=params)
except HTTPError as exc:
except HTTPStatusError as exc:
if exc.response.status_code != 400:
raise

Expand Down
2 changes: 1 addition & 1 deletion skore-hub-project/src/skore_hub_project/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Client(HTTPXClient):
)
)

RETRYABLE_EXCEPTIONS: Final[tuple[HTTPError, ...]] = (
RETRYABLE_EXCEPTIONS: Final[tuple[type[HTTPError], ...]] = (
TimeoutException,
NetworkError,
RemoteProtocolError,
Expand Down
3 changes: 3 additions & 0 deletions skore-hub-project/src/skore_hub_project/metric/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class Accuracy(EstimatorReportMetric): # noqa: D101
name: Literal["accuracy"] = "accuracy"
verbose_name: Literal["Accuracy"] = "Accuracy"
greater_is_better: Literal[True] = True
position: None = None


class AccuracyTrain(Accuracy): # noqa: D101
Expand All @@ -28,6 +29,7 @@ class AccuracyMean(CrossValidationReportMetric): # noqa: D101
name: Literal["accuracy_mean"] = "accuracy_mean"
verbose_name: Literal["Accuracy - MEAN"] = "Accuracy - MEAN"
greater_is_better: Literal[True] = True
position: None = None


class AccuracyTrainMean(AccuracyMean): # noqa: D101
Expand All @@ -44,6 +46,7 @@ class AccuracyStd(CrossValidationReportMetric): # noqa: D101
name: Literal["accuracy_std"] = "accuracy_std"
verbose_name: Literal["Accuracy - STD"] = "Accuracy - STD"
greater_is_better: Literal[False] = False
position: None = None


class AccuracyTrainStd(AccuracyStd): # noqa: D101
Expand Down
3 changes: 3 additions & 0 deletions skore-hub-project/src/skore_hub_project/metric/brier_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class BrierScore(EstimatorReportMetric): # noqa: D101
name: Literal["brier_score"] = "brier_score"
verbose_name: Literal["Brier score"] = "Brier score"
greater_is_better: Literal[False] = False
position: None = None


class BrierScoreTrain(BrierScore): # noqa: D101
Expand All @@ -28,6 +29,7 @@ class BrierScoreMean(CrossValidationReportMetric): # noqa: D101
name: Literal["brier_score_mean"] = "brier_score_mean"
verbose_name: Literal["Brier score - MEAN"] = "Brier score - MEAN"
greater_is_better: Literal[False] = False
position: None = None


class BrierScoreTrainMean(BrierScoreMean): # noqa: D101
Expand All @@ -44,6 +46,7 @@ class BrierScoreStd(CrossValidationReportMetric): # noqa: D101
name: Literal["brier_score_std"] = "brier_score_std"
verbose_name: Literal["Brier score - STD"] = "Brier score - STD"
greater_is_better: Literal[False] = False
position: None = None


class BrierScoreTrainStd(BrierScoreStd): # noqa: D101
Expand Down
1 change: 1 addition & 0 deletions skore-hub-project/src/skore_hub_project/metric/log_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class LogLossStd(CrossValidationReportMetric): # noqa: D101
name: str = "log_loss_std"
verbose_name: str = "Log loss - STD"
greater_is_better: bool = False
position: None = None


class LogLossTrainStd(LogLossStd): # noqa: D101
Expand Down
33 changes: 18 additions & 15 deletions skore-hub-project/src/skore_hub_project/metric/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from contextlib import suppress
from functools import cached_property, reduce
from math import isfinite
from typing import Any, ClassVar, Literal, cast
from typing import Any, ClassVar, Generic, Literal, TypeVar, cast

from pydantic import BaseModel, ConfigDict, Field, computed_field

from skore_hub_project.protocol import CrossValidationReport, EstimatorReport

Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport))


def cast_to_float(value: Any) -> float | None:
"""Cast value to float."""
Expand All @@ -23,12 +25,14 @@ def cast_to_float(value: Any) -> float | None:
return None


class Metric(ABC, BaseModel):
class Metric(BaseModel, Generic[Report], ABC):
"""
Payload used to send a metric to ``hub``.

Attributes
----------
report : EstimatorReport | CrossValidationReport
The report on which compute the metric.
name : str
Name of the metric.
verbose_name : str
Expand All @@ -44,11 +48,12 @@ class Metric(ABC, BaseModel):

model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)

name: str
verbose_name: str
data_source: Literal["train", "test"] | None = None
greater_is_better: bool | None = None
position: int | None = None
report: Report = Field(repr=False, exclude=True)
name: str = Field(init=False)
verbose_name: str = Field(init=False)
data_source: Literal["train", "test"] | None = Field(init=False)
greater_is_better: bool | None = Field(init=False)
position: int | None = Field(init=False)

@computed_field # type: ignore[prop-decorator]
@property
Expand All @@ -57,12 +62,14 @@ def value(self) -> float | None:
"""The value of the metric."""


class EstimatorReportMetric(Metric):
class EstimatorReportMetric(Metric[EstimatorReport]):
"""
Payload used to send an estimator report metric.

Attributes
----------
report: EstimatorReport
The report on which compute the metric.
name : str
Name of the metric.
verbose_name : str
Expand All @@ -74,13 +81,10 @@ class EstimatorReportMetric(Metric):
position: int | None, optional
Indicator of the "position" of the metric in the parallel coordinates plot,
default None to disable its display.
report: EstimatorReport
The report on which compute the metric.
accessor : ClassVar[str]
The "accessor" of the metric i.e., the path to the metric calculation function.
"""

report: EstimatorReport = Field(repr=False, exclude=True)
accessor: ClassVar[str]

@computed_field # type: ignore[prop-decorator]
Expand All @@ -98,12 +102,14 @@ def value(self) -> float | None:
return cast_to_float(function(data_source=self.data_source))


class CrossValidationReportMetric(Metric):
class CrossValidationReportMetric(Metric[CrossValidationReport]):
"""
Payload used to send a cross-validation report metric, usually MEAN or STD.

Attributes
----------
report: CrossValidationReport
The report on which compute the metric.
name : str
Name of the metric.
verbose_name : str
Expand All @@ -115,15 +121,12 @@ class CrossValidationReportMetric(Metric):
position: int | None, optional
Indicator of the "position" of the metric in the parallel coordinates plot,
default None to disable its display.
report: CrossValidationReport
The report on which compute the metric.
accessor : ClassVar[str]
The "accessor" of the metric i.e., the path to the metric calculation function.
aggregate : ClassVar[Literal["mean", "std"]]
The aggregation parameter passed to the ``accessor``.
"""

report: CrossValidationReport = Field(repr=False, exclude=True)
accessor: ClassVar[str]
aggregate: ClassVar[Literal["mean", "std"]]

Expand Down
3 changes: 3 additions & 0 deletions skore-hub-project/src/skore_hub_project/metric/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Precision(EstimatorReportMetric): # noqa: D101
name: str = "precision"
verbose_name: str = "Precision (macro)"
greater_is_better: bool = True
position: None = None

@computed_field # type: ignore[prop-decorator]
@cached_property
Expand All @@ -41,6 +42,7 @@ class PrecisionMean(CrossValidationReportMetric): # noqa: D101
name: str = "precision_mean"
verbose_name: str = "Precision (macro) - MEAN"
greater_is_better: bool = True
position: None = None

@computed_field # type: ignore[prop-decorator]
@cached_property
Expand Down Expand Up @@ -71,6 +73,7 @@ class PrecisionStd(CrossValidationReportMetric): # noqa: D101
name: str = "precision_std"
verbose_name: str = "Precision (macro) - STD"
greater_is_better: bool = False
position: None = None

@computed_field # type: ignore[prop-decorator]
@cached_property
Expand Down
3 changes: 3 additions & 0 deletions skore-hub-project/src/skore_hub_project/metric/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class R2(EstimatorReportMetric): # noqa: D101
name: str = "r2"
verbose_name: str = "R²"
greater_is_better: bool = True
position: None = None


class R2Train(R2): # noqa: D101
Expand All @@ -28,6 +29,7 @@ class R2Mean(CrossValidationReportMetric): # noqa: D101
name: str = "r2_mean"
verbose_name: str = "R² - MEAN"
greater_is_better: bool = True
position: None = None


class R2TrainMean(R2Mean): # noqa: D101
Expand All @@ -44,6 +46,7 @@ class R2Std(CrossValidationReportMetric): # noqa: D101
name: str = "r2_std"
verbose_name: str = "R² - STD"
greater_is_better: bool = False
position: None = None


class R2TrainStd(R2Std): # noqa: D101
Expand Down
Loading
Loading