Skip to content

Commit 381550c

Browse files
authored
chore: Use dedicated sklearn directory for wrappers of sklearn functions (#693)
Closes #696
1 parent 0b33dae commit 381550c

File tree

5 files changed

+22
-15
lines changed

5 files changed

+22
-15
lines changed

examples/plot_03_cross_validate.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
ML/DS projects.
1111
"""
1212

13-
1413
# %%
1514
import subprocess
1615

@@ -21,8 +20,7 @@
2120
from sklearn import svm
2221
from sklearn.model_selection import cross_validate as sklearn_cross_validate
2322

24-
from skore import load
25-
import skore.cross_validate
23+
import skore
2624

2725

2826
# %%
@@ -39,7 +37,7 @@
3937

4038

4139
# %%
42-
my_project_gs = load("my_project_cv.skore")
40+
my_project_gs = skore.load("my_project_cv.skore")
4341

4442
# %%
4543
# Cross-validation in scikit-learn
@@ -61,25 +59,28 @@
6159

6260
# %%
6361
X, y = datasets.load_iris(return_X_y=True)
64-
clf = svm.SVC(kernel='linear', C=1, random_state=0)
62+
clf = svm.SVC(kernel="linear", C=1, random_state=0)
6563

6664
# %%
6765
# Single metric evaluation using ``cross_validate``:
6866

6967
# %%
7068
cv_results = sklearn_cross_validate(clf, X, y, cv=5)
71-
cv_results['test_score']
69+
cv_results["test_score"]
7270

7371
# %%
7472
# Multiple metric evaluation using ``cross_validate``:
7573

7674
# %%
7775
scores = sklearn_cross_validate(
78-
clf, X, y, cv=5,
79-
scoring=['accuracy', 'precision_macro'],
76+
clf,
77+
X,
78+
y,
79+
cv=5,
80+
scoring=["accuracy", "precision_macro"],
8081
)
81-
print(scores['test_accuracy'])
82-
print(scores['test_precision_macro'])
82+
print(scores["test_accuracy"])
83+
print(scores["test_precision_macro"])
8384

8485
# %%
8586
# In scikit-learn, why do we recommend using ``cross_validate`` over ``cross_val_score``?

skore/src/skore/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44

55
import rich.logging
66

7-
from skore.cross_validate import cross_validate
87
from skore.project import Project, load
9-
10-
from .utils._show_versions import show_versions
8+
from skore.sklearn import cross_validate
9+
from skore.utils._show_versions import show_versions
1110

1211
__all__ = [
12+
"Project",
1313
"cross_validate",
1414
"load",
1515
"show_versions",
16-
"Project",
1716
]
1817

1918

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Enhance `sklearn` functions."""
2+
3+
from skore.sklearn.cross_validate import cross_validate
4+
5+
__all__ = [
6+
"cross_validate",
7+
]
File renamed without changes.

skore/tests/integration/test_cross_validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.ensemble import RandomForestClassifier
99
from sklearn.multiclass import OneVsOneClassifier
1010
from sklearn.svm import SVC
11-
from skore.cross_validate import cross_validate
11+
from skore import cross_validate
1212
from skore.item.cross_validation_item import (
1313
CrossValidationAggregationItem,
1414
CrossValidationItem,

0 commit comments

Comments
 (0)