Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions skore/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,16 @@ convention = "numpy"
"tests/*" = ["D"]

[tool.mypy]
ignore_missing_imports = true
exclude = ["src/skore/_externals/.*", "hatch/*", "tests/*"]
exclude = ["src/skore/_externals/", "hatch/", "tests/"]

[[tool.mypy.overrides]]
module = ["sklearn.*"]
ignore_missing_imports = true
module = [
"ipywidgets.*",
"joblib.*",
"pandas.*",
"plotly.*",
"seaborn.*",
"sklearn.*",
"skrub.*",
]
4 changes: 2 additions & 2 deletions skore/src/skore/_sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def get_predictions(
] = "predict",
X: ArrayLike | None = None,
pos_label: PositiveLabel | None = _DEFAULT,
) -> list[ArrayLike]:
) -> list[ArrayLike] | list[list[ArrayLike]]:
"""Get predictions from the underlying reports.

This method has the advantage to reload from the cache if the predictions
Expand Down Expand Up @@ -406,7 +406,7 @@ def get_predictions(
>>> print([split_predictions.shape for split_predictions in predictions])
[(25,), (25,)]
"""
return [
return [ # type: ignore
report.get_predictions(
data_source=data_source,
response_method=response_method,
Expand Down
6 changes: 3 additions & 3 deletions skore/src/skore/_sklearn/_cross_validation/data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _retrieve_data_as_frame(
y = self._parent.y

if not sbd.is_dataframe(X):
X = pd.DataFrame(X, columns=[f"Feature {i}" for i in range(X.shape[1])])
X = pd.DataFrame(X, columns=[f"Feature {i}" for i in range(X.shape[1])]) # type: ignore

if with_y:
if y is None:
Expand All @@ -52,10 +52,10 @@ def _retrieve_data_as_frame(
name = y.name if y.name is not None else "Target"
y = y.to_frame(name=name)
elif not sbd.is_dataframe(y):
if y.ndim == 1:
if y.ndim == 1: # type: ignore
columns = ["Target"]
else:
columns = [f"Target {i}" for i in range(y.shape[1])]
columns = [f"Target {i}" for i in range(y.shape[1])] # type: ignore
y = pd.DataFrame(y, columns=columns)

return X, y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1143,16 +1143,17 @@ def _get_display(
X, y, _ = report.metrics._get_X_y_and_data_source_hash(
data_source=data_source
)

y_true.append(
YPlotData(
estimator_name=self._parent.estimator_name_,
split=report_idx,
y=y,
y=cast(ArrayLike, y),
)
)
results = _get_cached_response_values(
cache=report._cache,
estimator_hash=report._hash,
estimator_hash=int(report._hash),
estimator=report._estimator,
X=X,
response_method=response_method,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _feature_permutation(
feature_names = (
self._parent.estimator_.feature_names_in_
if hasattr(self._parent.estimator_, "feature_names_in_")
else [f"Feature #{i}" for i in range(X_.shape[1])]
else [f"Feature #{i}" for i in range(X_.shape[1])] # type: ignore
)

# If there is more than one metric
Expand Down
4 changes: 2 additions & 2 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def _compute_metric_scores(

results = _get_cached_response_values(
cache=self._parent._cache,
estimator_hash=self._parent._hash,
estimator_hash=int(self._parent._hash),
estimator=self._parent.estimator_,
X=X,
response_method=response_method,
Expand Down Expand Up @@ -1674,7 +1674,7 @@ def _get_display(
else:
results = _get_cached_response_values(
cache=self._parent._cache,
estimator_hash=self._parent._hash,
estimator_hash=int(self._parent._hash),
estimator=self._parent.estimator_,
X=X,
response_method=response_method,
Expand Down
2 changes: 1 addition & 1 deletion skore/src/skore/_sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def get_predictions(

results = _get_cached_response_values(
cache=self._cache,
estimator_hash=self._hash,
estimator_hash=int(self._hash),
estimator=self._estimator,
X=X_,
response_method=response_method,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,17 +865,19 @@ def _compute_data_for_display(
):
label_binarizer = LabelBinarizer().fit(est.classes_)
y_true_onehot_i: NDArray = label_binarizer.transform(y_true_i.y)
y_pred_i_y = cast(NDArray, y_pred_i.y)

for class_idx, class_ in enumerate(est.classes_):
precision_class_i, recall_class_i, thresholds_class_i = (
precision_recall_curve(
y_true_onehot_i[:, class_idx],
y_pred_i.y[:, class_idx],
y_pred_i_y[:, class_idx],
pos_label=None,
drop_intermediate=drop_intermediate,
)
)
average_precision_class_i = average_precision_score(
y_true_onehot_i[:, class_idx], y_pred_i.y[:, class_idx]
y_true_onehot_i[:, class_idx], y_pred_i_y[:, class_idx]
)

for precision, recall, threshold in zip(
Expand Down
16 changes: 8 additions & 8 deletions skore/src/skore/_sklearn/_plot/metrics/prediction_error.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numbers
from collections import namedtuple
from typing import Any, Literal
from typing import Any, Literal, cast

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -265,7 +265,7 @@ def _plot_single_estimator(
self.ax_.legend(handles, labels, loc="lower right")
self.ax_.set_title(f"Prediction Error for {estimator_name}")

return scatter
return cast(list[Artist], scatter)

def _plot_cross_validated_estimator(
self,
Expand Down Expand Up @@ -352,7 +352,7 @@ def _plot_cross_validated_estimator(
self.ax_.legend(handles, labels, loc="lower right", title=legend_title)
self.ax_.set_title(f"Prediction Error for {estimator_name}")

return scatter
return cast(list[Artist], scatter)

def _plot_comparison_estimator(
self,
Expand Down Expand Up @@ -435,7 +435,7 @@ def _plot_comparison_estimator(
self.ax_.legend(handles, labels, loc="lower right", title=legend_title)
self.ax_.set_title("Prediction Error")

return scatter
return cast(list[Artist], scatter)

def _plot_comparison_cross_validation(
self,
Expand Down Expand Up @@ -518,7 +518,7 @@ def _plot_comparison_cross_validation(
self.ax_.legend(handles, labels, loc="lower right", title=legend_title)
self.ax_.set_title("Prediction Error")

return scatter
return cast(list[Artist], scatter)

@DisplayMixin.style_plot
def plot(
Expand Down Expand Up @@ -824,9 +824,9 @@ def _compute_data_for_display(
}
)
else:
y_true_sample = y_true_i.y
y_pred_sample = y_pred_i.y
residuals_sample = y_true_i.y - y_pred_i.y
y_true_sample = cast(np.typing.NDArray, y_true_i.y)
y_pred_sample = cast(np.typing.NDArray, y_pred_i.y)
residuals_sample = y_true_sample - y_pred_sample

for y_true_sample_i, y_pred_sample_i, residuals_sample_i in zip(
y_true_sample, y_pred_sample, residuals_sample, strict=False
Expand Down
6 changes: 5 additions & 1 deletion skore/src/skore/_sklearn/_plot/metrics/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(
self.ml_task = ml_task
self.report_type = report_type

self.chance_level_: Line2D | list[Line2D] | None

def _plot_single_estimator(
self,
*,
Expand Down Expand Up @@ -947,10 +949,12 @@ def _compute_data_for_display(
):
label_binarizer = LabelBinarizer().fit(est.classes_)
y_true_onehot_i: NDArray = label_binarizer.transform(y_true_i.y)
y_pred_i_y = cast(NDArray, y_pred_i.y)

for class_idx, class_ in enumerate(est.classes_):
fpr_class_i, tpr_class_i, thresholds_class_i = roc_curve(
y_true_onehot_i[:, class_idx],
y_pred_i.y[:, class_idx],
y_pred_i_y[:, class_idx],
pos_label=None,
drop_intermediate=drop_intermediate,
)
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/_sklearn/train_test_split/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class labels.
if y is None and len(arrays) >= 2:
y = arrays[-1]

y_labels: np.ndarray | None

if y is not None:
y_labels = np.unique(y)
y_test = (
Expand Down
4 changes: 2 additions & 2 deletions skore/src/skore/project/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _filter_dataframe(self, ml_task: str, report_type: str) -> pd.DataFrame:
df.columns = [col.removesuffix("_mean") for col in df.columns]
return df

def _get_datasets(self, ml_task: str, report_type: str) -> np.ndarray:
def _get_datasets(self, ml_task: str, report_type: str) -> list[str]:
"""Get the unique datasets from the filtered dataframe.

Parameters
Expand All @@ -219,7 +219,7 @@ def _get_datasets(self, ml_task: str, report_type: str) -> np.ndarray:

Returns
-------
np.ndarray
list[str]
The unique datasets.
"""
return self._filter_dataframe(ml_task, report_type)["dataset"].unique()
Expand Down