Skip to content

Commit b2d6933

Browse files
authored
feat(skore/project)!: Put/get/summarize CrossValidationReport (#2025)
Closes #1944 and #1923.
1 parent c87f9c4 commit b2d6933

File tree

5 files changed

+60
-53
lines changed

5 files changed

+60
-53
lines changed

skore/pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies = [
3737
"plotly",
3838
"rich",
3939
"scikit-learn",
40-
"skore-local-project",
40+
"skore[local]",
4141
"skrub",
4242
"seaborn",
4343
]
@@ -79,7 +79,8 @@ dev = [
7979
"nbformat",
8080
"ipykernel",
8181
]
82-
hub = ["skore-hub-project"]
82+
local = ["skore-local-project>=0.0.4"]
83+
hub = ["skore-hub-project>=0.0.10"]
8384

8485
[project.urls]
8586
Homepage = "https://probabl.ai"

skore/src/skore/project/project.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from importlib.metadata import entry_points
77
from typing import Any, Literal
88

9-
from skore import EstimatorReport
9+
from skore import CrossValidationReport, EstimatorReport
1010
from skore.project.summary import Summary
1111

1212

@@ -95,7 +95,7 @@ class Project:
9595
>>> from sklearn.datasets import make_classification, make_regression
9696
>>> from sklearn.linear_model import LinearRegression, LogisticRegression
9797
>>> from sklearn.model_selection import train_test_split
98-
>>> from skore import EstimatorReport
98+
>>> from skore import CrossValidationReport, EstimatorReport
9999
>>>
100100
>>> X, y = make_classification(random_state=42)
101101
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
@@ -118,6 +118,8 @@ class Project:
118118
>>> X_test=X_test,
119119
>>> y_test=y_test,
120120
>>> )
121+
>>>
122+
>>> cv_regressor_report = CrossValidationReport(regressor, X, y)
121123
122124
Construct the project in local mode, persisted in a temporary directory.
123125
@@ -132,6 +134,7 @@ class Project:
132134
133135
>>> local_project.put("my-simple-classification", classifier_report)
134136
>>> local_project.put("my-simple-regression", regressor_report)
137+
>>> local_project.put("my-simple-cv_regression", cv_regressor_report)
135138
136139
Investigate metadata/metrics to filter the best reports.
137140
@@ -153,6 +156,7 @@ def __setup_plugin(name: str) -> tuple[Literal["local", "hub"], str, Any, dict]:
153156
raise SystemError("No project plugin found, please install at least one.")
154157

155158
mode: Literal["local", "hub"]
159+
156160
if match := re.match(Project.__HUB_NAME_PATTERN, name):
157161
mode = "hub"
158162
name = match["name"]
@@ -162,10 +166,7 @@ def __setup_plugin(name: str) -> tuple[Literal["local", "hub"], str, Any, dict]:
162166
parameters = {"name": name}
163167

164168
if mode not in PLUGINS.names:
165-
raise ValueError(
166-
f"Unknown mode `{mode}`. "
167-
f"Please install the `skore-{mode}-project` python package."
168-
)
169+
raise ValueError(f"Unknown mode `{mode}`. Please install `skore[{mode}]`.")
169170

170171
return mode, name, PLUGINS[mode].load(), parameters
171172

@@ -214,7 +215,7 @@ def name(self):
214215
"""The name of the project."""
215216
return self.__name
216217

217-
def put(self, key: str, report: EstimatorReport):
218+
def put(self, key: str, report: EstimatorReport | CrossValidationReport):
218219
"""
219220
Put a key-report pair to the project.
220221
@@ -225,7 +226,7 @@ def put(self, key: str, report: EstimatorReport):
225226
----------
226227
key : str
227228
The key to associate with ``report`` in the project.
228-
report : skore.EstimatorReport
229+
report : EstimatorReport | CrossValidationReport
229230
The report to associate with ``key`` in the project.
230231
231232
Raises
@@ -236,19 +237,20 @@ def put(self, key: str, report: EstimatorReport):
236237
if not isinstance(key, str):
237238
raise TypeError(f"Key must be a string (found '{type(key)}')")
238239

239-
if not isinstance(report, EstimatorReport):
240+
if not isinstance(report, EstimatorReport | CrossValidationReport):
240241
raise TypeError(
241-
f"Report must be a `skore.EstimatorReport` (found '{type(report)}')"
242+
f"Report must be `EstimatorReport` or `CrossValidationReport` "
243+
f"(found '{type(report)}')"
242244
)
243245

244246
return self.__project.put(key=key, report=report)
245247

246-
def get(self, id: str) -> EstimatorReport:
248+
def get(self, id: str) -> EstimatorReport | CrossValidationReport:
247249
"""
248250
Get a persisted report by its id.
249251
250-
Report IDs can be found via :meth:`skore.Project.summarize`, which is
251-
also the preferred method of interacting with a ``skore.Project``.
252+
Report IDs can be found via :meth:`skore.Project.summarize`, which is also the
253+
preferred method of interacting with a ``skore.Project``.
252254
253255
Parameters
254256
----------
@@ -260,7 +262,7 @@ def get(self, id: str) -> EstimatorReport:
260262
KeyError
261263
If a non-existent ID is passed.
262264
"""
263-
return self.__project.reports.get(id)
265+
return self.__project.get(id)
264266

265267
def summarize(self) -> Summary:
266268
"""Obtain metadata/metrics for all persisted reports."""

skore/src/skore/project/summary.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
from pandas import Categorical, DataFrame, Index, MultiIndex, RangeIndex
88

9-
from skore._sklearn import ComparisonReport
9+
from skore import ComparisonReport
1010
from skore.project.widget import ModelExplorerWidget
1111

1212
if TYPE_CHECKING:
1313
from typing import Literal
1414

15-
from skore._sklearn import EstimatorReport
15+
from skore import CrossValidationReport, EstimatorReport
1616

1717

1818
class Summary(DataFrame):
@@ -45,15 +45,15 @@ def factory(project, /):
4545
4646
Parameters
4747
----------
48-
project : Union[skore_local_project.Project, skore_hub_project.Project]
48+
project : ``skore_local_project.Project`` | ``skore_hub_project.Project``
4949
The project from which the summary object is to be constructed.
5050
5151
Notes
5252
-----
5353
This function is not intended for direct use. Instead simply use the accessor
5454
:meth:`skore.Project.summarize`.
5555
"""
56-
summary = DataFrame(project.reports.metadata(), copy=False)
56+
summary = DataFrame(project.summarize(), copy=False)
5757

5858
if not summary.empty:
5959
summary["learner"] = Categorical(summary["learner"])
@@ -79,7 +79,7 @@ def reports(
7979
*,
8080
filter: bool = True,
8181
return_as: Literal["list", "comparison"] = "list",
82-
) -> list[EstimatorReport] | ComparisonReport:
82+
) -> list[EstimatorReport | CrossValidationReport] | ComparisonReport:
8383
"""
8484
Return the reports referenced by the summary object from the project.
8585
@@ -100,9 +100,7 @@ def reports(
100100
if filter and (querystr := self._query_string_selection()):
101101
self = self.query(querystr)
102102

103-
reports = [
104-
self.project.reports.get(id) for id in self.index.get_level_values("id")
105-
]
103+
reports = [self.project.get(id) for id in self.index.get_level_values("id")]
106104

107105
if return_as == "comparison":
108106
try:

skore/tests/unit/project/test_project.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from importlib.metadata import EntryPoint, EntryPoints
2+
from re import escape
23
from unittest.mock import Mock
34

45
from pandas import DataFrame, MultiIndex, Series
5-
from pytest import fixture, raises
6+
from pytest import fixture, mark, param, raises
67
from sklearn.datasets import make_regression
78
from sklearn.linear_model import LinearRegression
89
from sklearn.model_selection import train_test_split
9-
from skore import EstimatorReport, Project
10+
from skore import CrossValidationReport, EstimatorReport, Project
1011
from skore.project.summary import Summary
1112

1213

@@ -47,7 +48,7 @@ def monkeypatch_entrypoints(monkeypatch, request, FakeLocalProject, FakeHubProje
4748

4849

4950
@fixture(scope="module")
50-
def regression():
51+
def regression() -> EstimatorReport:
5152
X, y = make_regression(random_state=42)
5253
X_train, X_test, y_train, y_test = train_test_split(
5354
X, y, test_size=0.2, random_state=42
@@ -62,6 +63,11 @@ def regression():
6263
)
6364

6465

66+
@fixture(scope="module")
67+
def cv_regression() -> CrossValidationReport:
68+
return CrossValidationReport(LinearRegression(), *make_regression(random_state=42))
69+
70+
6571
class TestProject:
6672
def test_init_local(self, FakeLocalProject):
6773
project = Project("<name>", workspace="<workspace>")
@@ -115,10 +121,7 @@ def test_init_exception_unknown_plugin(self, monkeypatch, tmp_path):
115121

116122
with raises(
117123
ValueError,
118-
match=(
119-
"Unknown mode `local`. "
120-
"Please install the `skore-local-project` python package."
121-
),
124+
match=escape("Unknown mode `local`. Please install `skore[local]`."),
122125
):
123126
Project("<name>")
124127

@@ -130,24 +133,32 @@ def test_name(self):
130133
assert Project("<name>").name == "<name>"
131134
assert Project("hub://<tenant>/<name>").name == "<name>"
132135

133-
def test_put(self, regression, FakeLocalProject):
136+
@mark.parametrize(
137+
"report",
138+
(
139+
param("regression", id="EstimatorReport - regression"),
140+
param("cv_regression", id="CrossValidationReport - regression"),
141+
),
142+
)
143+
def test_put(self, report, FakeLocalProject, request):
144+
report = request.getfixturevalue(report)
134145
project = Project("<name>")
135146

136-
project.put("<key>", regression)
147+
project.put("<key>", report)
137148

138149
assert FakeLocalProject.called
139150
assert project._Project__project.put.called
140151
assert not project._Project__project.put.call_args.args
141152
assert project._Project__project.put.call_args.kwargs == {
142153
"key": "<key>",
143-
"report": regression,
154+
"report": report,
144155
}
145156

146157
def test_put_exception(self):
147158
with raises(TypeError, match="Key must be a string"):
148159
Project("<name>").put(None, "<value>")
149160

150-
with raises(TypeError, match="Report must be a `skore.EstimatorReport`"):
161+
with raises(TypeError, match="Report must be `EstimatorReport` or"):
151162
Project("<name>").put("<key>", "<value>")
152163

153164
def test_get(self, FakeLocalProject):
@@ -156,13 +167,13 @@ def test_get(self, FakeLocalProject):
156167
project.get("<id>")
157168

158169
assert FakeLocalProject.called
159-
assert project._Project__project.reports.get.called
160-
assert project._Project__project.reports.get.call_args.args == ("<id>",)
161-
assert not project._Project__project.reports.get.call_args.kwargs
170+
assert project._Project__project.get.called
171+
assert project._Project__project.get.call_args.args == ("<id>",)
172+
assert not project._Project__project.get.call_args.kwargs
162173

163174
def test_summarize(self):
164175
project = Project("<name>")
165-
project._Project__project.reports.metadata.return_value = [
176+
project._Project__project.summarize.return_value = [
166177
{
167178
"learner": "<learner>",
168179
"accuracy": 1.0,
@@ -172,7 +183,7 @@ def test_summarize(self):
172183

173184
summary = project.summarize()
174185

175-
assert project._Project__project.reports.metadata.called
186+
assert project._Project__project.summarize.called
176187
assert isinstance(summary, DataFrame)
177188
assert isinstance(summary, Summary)
178189
assert DataFrame.equals(

skore/tests/unit/project/test_summary.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from copy import deepcopy
2-
from types import SimpleNamespace
32

43
from joblib import hash as joblib_hash
54
from pandas import DataFrame, Index, MultiIndex, RangeIndex
@@ -58,7 +57,7 @@ class FakeProject:
5857
def __init__(self, *reports):
5958
self.__reports = reports
6059

61-
def _make_report_metadata(self, index, report):
60+
def make_report_metadata(self, index, report):
6261
return {
6362
"id": index,
6463
"run_id": None,
@@ -93,18 +92,14 @@ def _make_report_metadata(self, index, report):
9392
"predict_time_std": None,
9493
}
9594

96-
@property
97-
def reports(self):
98-
def get(id: str):
99-
return self.__reports[int(id)]
95+
def get(self, id: str):
96+
return self.__reports[int(id)]
10097

101-
def metadata():
102-
return [
103-
self._make_report_metadata(index, report)
104-
for index, report in enumerate(self.__reports)
105-
]
106-
107-
return SimpleNamespace(metadata=metadata, get=get)
98+
def summarize(self):
99+
return [
100+
self.make_report_metadata(index, report)
101+
for index, report in enumerate(self.__reports)
102+
]
108103

109104

110105
class TestSummary:

0 commit comments

Comments
 (0)