Skip to content

Commit ec974af

Browse files
committed
chore(skore-hub-project): Fix typings
1 parent 3aa7942 commit ec974af

File tree

25 files changed

+114
-64
lines changed

25 files changed

+114
-64
lines changed

skore-hub-project/pyproject.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,15 @@ convention = "numpy"
8888
"tests/*" = ["D"]
8989

9090
[tool.mypy]
91+
exclude = ["hatch/", "tests/"]
92+
93+
[[tool.mypy.overrides]]
94+
follow_untyped_imports = true
95+
module = ["skore.*"]
96+
97+
[[tool.mypy.overrides]]
9198
ignore_missing_imports = true
92-
exclude = ["hatch/*", "tests/*"]
99+
module = [
100+
"joblib.*",
101+
"sklearn.*",
102+
]

skore-hub-project/src/skore_hub_project/artifact/artifact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Artifact(BaseModel, ABC):
3232
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
3333

3434
project: Project = Field(repr=False, exclude=True)
35-
content_type: str
35+
content_type: str = Field(init=False)
3636

3737
@abstractmethod
3838
def content_to_upload(self) -> Content | AbstractContextManager[Content]:

skore-hub-project/src/skore_hub_project/artifact/media/data.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
from typing import Literal
44

5-
from pydantic import Field
6-
75
from skore_hub_project import switch_mpl_backend
8-
from skore_hub_project.artifact.media.media import Media
6+
from skore_hub_project.artifact.media.media import Media, Report
97
from skore_hub_project.protocol import EstimatorReport
108

119

12-
class TableReport(Media): # noqa: D101
10+
class TableReport(Media[Report]): # noqa: D101
1311
name: Literal["table_report"] = "table_report"
12+
data_source: Literal["train", "test"] | None = None
1413
content_type: Literal["application/vnd.skrub.table-report.v1+json"] = (
1514
"application/vnd.skrub.table-report.v1+json"
1615
)
@@ -41,11 +40,9 @@ def content_to_upload(self) -> bytes: # noqa: D102
4140
)
4241

4342

44-
class TableReportTrain(TableReport): # noqa: D101
45-
report: EstimatorReport = Field(repr=False, exclude=True)
43+
class TableReportTrain(TableReport[EstimatorReport]): # noqa: D101
4644
data_source: Literal["train"] = "train"
4745

4846

49-
class TableReportTest(TableReport): # noqa: D101
50-
report: EstimatorReport = Field(repr=False, exclude=True)
47+
class TableReportTest(TableReport[EstimatorReport]): # noqa: D101
5148
data_source: Literal["test"] = "test"

skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
from functools import reduce
66
from typing import ClassVar, Literal, cast
77

8-
from pydantic import Field
9-
10-
from skore_hub_project.artifact.media.media import Media
8+
from skore_hub_project.artifact.media.media import Media, Report
119
from skore_hub_project.protocol import EstimatorReport
1210

1311

14-
class FeatureImportance(Media, ABC): # noqa: D101
12+
class FeatureImportance(Media[Report], ABC): # noqa: D101
1513
accessor: ClassVar[str]
1614
content_type: Literal["application/vnd.dataframe"] = "application/vnd.dataframe"
1715

@@ -41,8 +39,7 @@ def content_to_upload(self) -> bytes | None: # noqa: D102
4139
)
4240

4341

44-
class Permutation(FeatureImportance, ABC): # noqa: D101
45-
report: EstimatorReport = Field(repr=False, exclude=True)
42+
class Permutation(FeatureImportance[EstimatorReport], ABC): # noqa: D101
4643
accessor: ClassVar[str] = "feature_importance.permutation"
4744
name: Literal["permutation"] = "permutation"
4845

@@ -55,12 +52,13 @@ class PermutationTest(Permutation): # noqa: D101
5552
data_source: Literal["test"] = "test"
5653

5754

58-
class MeanDecreaseImpurity(FeatureImportance): # noqa: D101
59-
report: EstimatorReport = Field(repr=False, exclude=True)
55+
class MeanDecreaseImpurity(FeatureImportance[EstimatorReport]): # noqa: D101
6056
accessor: ClassVar[str] = "feature_importance.mean_decrease_impurity"
6157
name: Literal["mean_decrease_impurity"] = "mean_decrease_impurity"
58+
data_source: None = None
6259

6360

6461
class Coefficients(FeatureImportance): # noqa: D101
6562
accessor: ClassVar[str] = "feature_importance.coefficients"
6663
name: Literal["coefficients"] = "coefficients"
64+
data_source: None = None
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
"""Class definition of the payload used to associate a media with the report."""
22

3+
from abc import ABC
4+
from typing import Generic, TypeVar
5+
36
from pydantic import Field
47

58
from skore_hub_project.artifact.artifact import Artifact
69
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
710

11+
Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport))
12+
813

9-
class Media(Artifact):
14+
class Media(Artifact, Generic[Report], ABC):
1015
"""
1116
Payload used to associate a media with the report.
1217
@@ -17,13 +22,13 @@ class Media(Artifact):
1722
content_type : str
1823
The content-type of the artifact content.
1924
report : EstimatorReport | CrossValidationReport
20-
The report to pickled.
25+
The report on which compute the media.
2126
name : str
2227
The name of the media.
2328
data_source : str | None
2429
The source of the data used to generate the media.
2530
"""
2631

27-
report: EstimatorReport | CrossValidationReport = Field(repr=False, exclude=True)
28-
name: str
29-
data_source: str | None = None
32+
report: Report = Field(repr=False, exclude=True)
33+
name: str = Field(init=False)
34+
data_source: str | None = Field(init=False)

skore-hub-project/src/skore_hub_project/artifact/media/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class EstimatorHtmlRepr(Media): # noqa: D101
99
name: Literal["estimator_html_repr"] = "estimator_html_repr"
10+
data_source: None = None
1011
content_type: Literal["text/html"] = "text/html"
1112

1213
def content_to_upload(self) -> str: # noqa: D102

skore-hub-project/src/skore_hub_project/client/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime
44
from time import sleep
55

6-
from httpx import HTTPError, TimeoutException
6+
from httpx import HTTPStatusError, TimeoutException
77

88

99
def get_oauth_device_login(success_uri: str | None = None):
@@ -125,7 +125,7 @@ def get_oauth_device_code_probe(device_code: str, *, timeout=600):
125125
while True:
126126
try:
127127
client.get(url, params=params)
128-
except HTTPError as exc:
128+
except HTTPStatusError as exc:
129129
if exc.response.status_code != 400:
130130
raise
131131

skore-hub-project/src/skore_hub_project/client/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Client(HTTPXClient):
5858
)
5959
)
6060

61-
RETRYABLE_EXCEPTIONS: Final[tuple[HTTPError, ...]] = (
61+
RETRYABLE_EXCEPTIONS: Final[tuple[type[HTTPError], ...]] = (
6262
TimeoutException,
6363
NetworkError,
6464
RemoteProtocolError,

skore-hub-project/src/skore_hub_project/metric/accuracy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class Accuracy(EstimatorReportMetric): # noqa: D101
1212
name: Literal["accuracy"] = "accuracy"
1313
verbose_name: Literal["Accuracy"] = "Accuracy"
1414
greater_is_better: Literal[True] = True
15+
position: None = None
1516

1617

1718
class AccuracyTrain(Accuracy): # noqa: D101
@@ -28,6 +29,7 @@ class AccuracyMean(CrossValidationReportMetric): # noqa: D101
2829
name: Literal["accuracy_mean"] = "accuracy_mean"
2930
verbose_name: Literal["Accuracy - MEAN"] = "Accuracy - MEAN"
3031
greater_is_better: Literal[True] = True
32+
position: None = None
3133

3234

3335
class AccuracyTrainMean(AccuracyMean): # noqa: D101
@@ -44,6 +46,7 @@ class AccuracyStd(CrossValidationReportMetric): # noqa: D101
4446
name: Literal["accuracy_std"] = "accuracy_std"
4547
verbose_name: Literal["Accuracy - STD"] = "Accuracy - STD"
4648
greater_is_better: Literal[False] = False
49+
position: None = None
4750

4851

4952
class AccuracyTrainStd(AccuracyStd): # noqa: D101

skore-hub-project/src/skore_hub_project/metric/brier_score.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class BrierScore(EstimatorReportMetric): # noqa: D101
1212
name: Literal["brier_score"] = "brier_score"
1313
verbose_name: Literal["Brier score"] = "Brier score"
1414
greater_is_better: Literal[False] = False
15+
position: None = None
1516

1617

1718
class BrierScoreTrain(BrierScore): # noqa: D101
@@ -28,6 +29,7 @@ class BrierScoreMean(CrossValidationReportMetric): # noqa: D101
2829
name: Literal["brier_score_mean"] = "brier_score_mean"
2930
verbose_name: Literal["Brier score - MEAN"] = "Brier score - MEAN"
3031
greater_is_better: Literal[False] = False
32+
position: None = None
3133

3234

3335
class BrierScoreTrainMean(BrierScoreMean): # noqa: D101
@@ -44,6 +46,7 @@ class BrierScoreStd(CrossValidationReportMetric): # noqa: D101
4446
name: Literal["brier_score_std"] = "brier_score_std"
4547
verbose_name: Literal["Brier score - STD"] = "Brier score - STD"
4648
greater_is_better: Literal[False] = False
49+
position: None = None
4750

4851

4952
class BrierScoreTrainStd(BrierScoreStd): # noqa: D101

0 commit comments

Comments
 (0)