Skip to content

Commit 2e25be6

Browse files
authored
fix(skore-hub-project,skore-local-project): Remove adherence to private skore modules (#1903)
Needed for #1900 to be merged. Knowing that `skore-hub-project` and `skore-local-project` are always used in combination of `skore`, we can compare objects type using directly `EstimatorReport` class instead of diving into the MRO.
1 parent 2f7f1ab commit 2e25be6

File tree

5 files changed

+20
-24
lines changed

5 files changed

+20
-24
lines changed

skore-hub-project/src/skore_hub_project/item/skore_estimator_report_item.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from operator import attrgetter
1414
from typing import TYPE_CHECKING
1515

16-
from .item import ItemTypeError, lazy_is_instance, switch_mpl_backend
16+
from .item import ItemTypeError, switch_mpl_backend
1717
from .matplotlib_figure_item import MatplotlibFigureItem
1818
from .media_item import MediaItem
1919
from .pandas_dataframe_item import PandasDataFrameItem
@@ -23,7 +23,7 @@
2323
from collections.abc import Generator
2424
from typing import Any, Literal, TypedDict
2525

26-
from skore.sklearn import EstimatorReport
26+
from skore import EstimatorReport
2727

2828
class MetadataFunction: # noqa: D101
2929
metadata: Any
@@ -416,7 +416,9 @@ def factory(cls, value: EstimatorReport, /) -> SkoreEstimatorReportItem:
416416
ItemTypeError
417417
If ``value`` is not an instance of ``skore.EstimatorReport``.
418418
"""
419-
if lazy_is_instance(value, "skore.sklearn._estimator.report.EstimatorReport"):
419+
from skore import EstimatorReport
420+
421+
if isinstance(value, EstimatorReport):
420422
return super().factory(value)
421423

422424
raise ItemTypeError(f"Type '{value.__class__}' is not supported.")

skore-hub-project/src/skore_hub_project/project/project.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
from .. import item as item_module
1111
from ..client.client import AuthenticatedClient, HTTPStatusError
12-
from ..item.item import lazy_is_instance
1312

1413
if TYPE_CHECKING:
1514
from typing import TypedDict
1615

17-
from skore.sklearn import EstimatorReport
16+
from skore import EstimatorReport
1817

1918
class Metadata(TypedDict): # noqa: D101
2019
id: str
@@ -129,9 +128,9 @@ def put(self, key: str, report: EstimatorReport):
129128
if not isinstance(key, str):
130129
raise TypeError(f"Key must be a string (found '{type(key)}')")
131130

132-
if not lazy_is_instance(
133-
report, "skore.sklearn._estimator.report.EstimatorReport"
134-
):
131+
from skore import EstimatorReport
132+
133+
if not isinstance(report, EstimatorReport):
135134
raise TypeError(
136135
f"Report must be a `skore.EstimatorReport` (found '{type(report)}')"
137136
)

skore-hub-project/tests/unit/item/test_skore_estimator_report_item.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ def test_metadata(self, monkeypatch, report, report_b64_str):
121121
item2 = SkoreEstimatorReportItem(report_b64_str)
122122

123123
monkeypatch.setattr(
124-
"skore.sklearn._estimator.metrics_accessor._MetricsAccessor.r2",
124+
"skore.EstimatorReport.metrics.r2",
125125
lambda self, data_source: float(data_source == "train"),
126126
)
127127
monkeypatch.setattr(
128-
"skore.sklearn._estimator.metrics_accessor._MetricsAccessor.rmse",
128+
"skore.EstimatorReport.metrics.rmse",
129129
lambda self, data_source: float("nan"),
130130
)
131131
monkeypatch.setattr(
132-
"skore.sklearn._estimator.metrics_accessor._MetricsAccessor.timings",
132+
"skore.EstimatorReport.metrics.timings",
133133
lambda self: {
134134
"fit_time": hash("fit_time"),
135135
"predict_time_test": hash("predict_time_test"),

skore-local-project/src/skore_local_project/project.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from .storage import DiskCacheStorage
1919

2020
if TYPE_CHECKING:
21-
from typing import Any, TypedDict
21+
from typing import TypedDict
2222

23-
from skore.sklearn import EstimatorReport
23+
from skore import EstimatorReport
2424

2525
class PersistedMetadata: # noqa: D101
2626
artifact_id: str
@@ -52,13 +52,6 @@ class Metadata(TypedDict): # noqa: D101
5252
predict_time: float
5353

5454

55-
def lazy_is_instance_skore_estimator_report(value: Any) -> bool:
56-
"""Return True if value is an instance of ``skore.EstimatorReport``."""
57-
return "skore.sklearn._estimator.report.EstimatorReport" in {
58-
f"{cls.__module__}.{cls.__name__}" for cls in value.__class__.__mro__
59-
}
60-
61-
6255
class Project:
6356
r"""
6457
API to manage a collection of key-report pairs persisted in a local storage.
@@ -214,7 +207,9 @@ def put(self, key: str, report: EstimatorReport):
214207
if not isinstance(key, str):
215208
raise TypeError(f"Key must be a string (found '{type(key)}')")
216209

217-
if not lazy_is_instance_skore_estimator_report(report):
210+
from skore import EstimatorReport
211+
212+
if not isinstance(report, EstimatorReport):
218213
raise TypeError(
219214
f"Report must be a `skore.EstimatorReport` (found '{type(report)}')"
220215
)

skore-local-project/tests/unit/test_project.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn.datasets import make_classification, make_regression
77
from sklearn.linear_model import LinearRegression, LogisticRegression
88
from sklearn.model_selection import train_test_split
9-
from skore.sklearn import EstimatorReport
9+
from skore import EstimatorReport
1010
from skore_local_project import Project
1111
from skore_local_project.storage import DiskCacheStorage
1212

@@ -49,11 +49,11 @@ def monkeypatch_datetime(self, monkeypatch, Datetime):
4949
@fixture(autouse=True)
5050
def monkeypatch_metrics(self, monkeypatch, Datetime):
5151
monkeypatch.setattr(
52-
"skore.sklearn._estimator.metrics_accessor._MetricsAccessor.rmse",
52+
"skore.EstimatorReport.metrics.rmse",
5353
lambda _, data_source: float(hash(f"<rmse_{data_source}>")),
5454
)
5555
monkeypatch.setattr(
56-
"skore.sklearn._estimator.metrics_accessor._MetricsAccessor.timings",
56+
"skore.EstimatorReport.metrics.timings",
5757
lambda self: {
5858
"fit_time": float(hash("<fit_time>")),
5959
"predict_time_test": float(hash("<predict_time_test>")),

0 commit comments

Comments
 (0)