Skip to content

Commit 3cabb19

Browse files
authored
fix(skore-hub-project): Remove adherence to skore for type checking (#2018)
Fixes conda-forge/skore-hub-project-feedstock#4.
1 parent 57e3653 commit 3cabb19

File tree

13 files changed

+89
-12
lines changed

13 files changed

+89
-12
lines changed

skore-hub-project/src/skore_hub_project/artefact/artefact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from typing import Any
88

99
from pydantic import BaseModel, ConfigDict, Field, computed_field
10-
from skore import CrossValidationReport, EstimatorReport
1110

1211
from skore_hub_project import Project
1312
from skore_hub_project.artefact.upload import upload
13+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
1414

1515

1616
class Artefact(ABC, BaseModel):

skore-hub-project/src/skore_hub_project/media/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from typing import Literal
77

88
from pydantic import Field, computed_field
9-
from skore import EstimatorReport
9+
10+
from skore_hub_project.protocol import EstimatorReport
1011

1112
from .media import Media, Representation
1213

skore-hub-project/src/skore_hub_project/media/feature_importance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from pandas import DataFrame
99
from pydantic import Field, computed_field
10-
from skore import CrossValidationReport, EstimatorReport
10+
11+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
1112

1213
from .media import Media, Representation
1314

skore-hub-project/src/skore_hub_project/media/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from typing import Literal
55

66
from pydantic import Field, computed_field
7-
from skore import CrossValidationReport, EstimatorReport
7+
8+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
89

910
from .media import Media, Representation
1011

@@ -22,5 +23,5 @@ def representation(self) -> Representation: # noqa: D102
2223

2324
return Representation(
2425
media_type="text/html",
25-
value=sklearn.utils.estimator_html_repr(self.report.estimator_),
26+
value=sklearn.utils.estimator_html_repr(self.report.estimator),
2627
)

skore-hub-project/src/skore_hub_project/media/performance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from matplotlib import pyplot as plt
1010
from pydantic import Field, computed_field
11-
from skore import CrossValidationReport, EstimatorReport
11+
12+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
1213

1314
from .media import Media, Representation
1415

skore-hub-project/src/skore_hub_project/metric/metric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from typing import Any, ClassVar, Literal, cast
1111

1212
from pydantic import BaseModel, ConfigDict, Field, computed_field
13-
from skore import CrossValidationReport, EstimatorReport
13+
14+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
1415

1516

1617
def cast_to_float(value: Any) -> float | None:

skore-hub-project/src/skore_hub_project/metric/timing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from pandas import DataFrame, Series
99
from pydantic import Field, computed_field
10-
from skore import CrossValidationReport, EstimatorReport
10+
11+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
1112

1213
from .metric import Metric, cast_to_float
1314

skore-hub-project/src/skore_hub_project/project/project.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from typing import TYPE_CHECKING
1010

1111
import joblib
12-
from skore import CrossValidationReport, EstimatorReport
1312

1413
from skore_hub_project.client.client import Client, HTTPStatusError, HUBClient
14+
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
1515

1616
if TYPE_CHECKING:
1717
from typing import TypedDict
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Protocols definition used to remove adherence to ``skore`` for type checking."""
2+
3+
from typing import Any, Protocol, runtime_checkable
4+
5+
6+
@runtime_checkable
7+
class EstimatorReport(Protocol):
8+
"""Protocol equivalent to ``skore.EstimatorReport``."""
9+
10+
_cache: Any
11+
metrics: Any
12+
data: Any
13+
ml_task: Any
14+
estimator: Any
15+
estimator_name_: Any
16+
X_train: Any
17+
y_train: Any
18+
X_test: Any
19+
y_test: Any
20+
fit: Any
21+
22+
23+
@runtime_checkable
24+
class CrossValidationReport(Protocol):
25+
"""Protocol equivalent to ``skore.CrossValidationReport``."""
26+
27+
_cache: Any
28+
metrics: Any
29+
estimator_reports_: Any
30+
ml_task: Any
31+
estimator: Any
32+
estimator_name_: Any
33+
X: Any
34+
y: Any
35+
splitter: Any
36+
split_indices: Any

skore-hub-project/src/skore_hub_project/report/cross_validation_report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pydantic import Field, computed_field
88
from sklearn.model_selection import BaseCrossValidator
99
from sklearn.model_selection._split import _CVIterableWrapper
10-
from skore import CrossValidationReport
1110

1211
from skore_hub_project.artefact import CrossValidationReportArtefact
1312
from skore_hub_project.media import (
@@ -55,6 +54,7 @@
5554
RocAucTrainStd,
5655
)
5756
from skore_hub_project.metric.metric import Metric
57+
from skore_hub_project.protocol import CrossValidationReport
5858
from skore_hub_project.report.estimator_report import EstimatorReportPayload
5959
from skore_hub_project.report.report import ReportPayload
6060

0 commit comments

Comments
 (0)