-
Notifications
You must be signed in to change notification settings - Fork 101
Closed
Copy link
Labels
bug 🐛Something isn't workingSomething isn't workingneeds-triage ⚠️This has been recently submitted and needs attentionThis has been recently submitted and needs attention
Description
Describe the bug
I can't put a report where the estimator doesn't have a predict_proba.
Steps/Code to Reproduce
# %%
import pandas as pd
from skore import EstimatorReport
from sklearn.datasets import make_classification
from sklearn import svm
from skore import train_test_split
X, y = make_classification(n_samples=1000, n_features=9, random_state=1)
X = pd.DataFrame(data=X, columns=["1", "2", "3", "4", "5", "6", "7", "8", "9"])
X_train, X_test, y_train, y_test = train_test_split(X,y)
clf = svm.SVC(kernel='rbf', C=2, random_state=0)
est_rep = EstimatorReport(estimator=clf, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
# %%
import skore
project = skore.Project("project")
# %%
project.put("est_report", est_rep)Expected Behavior
It should put, the user shouldn't to worry about the fact that some models don't have predict_proba.
Actual Behavior
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[29], line 1
----> 1 project.put("est_report", est_rep)
File ~/Documents/skore/skore/src/skore/project/project.py:209, in Project.put(self, key, report)
204 if not isinstance(report, EstimatorReport):
205 raise TypeError(
206 f"Report must be a `skore.EstimatorReport` (found '{type(report)}')"
207 )
--> 209 return self.__project.put(key=key, report=report)
File ~/anaconda3/envs/skore_test/lib/python3.12/site-packages/skore_local_project/project.py:215, in Project.put(self, key, report)
202 return float(getattr(report.metrics, name)(data_source="test"))
203 return None
205 self.__metadata_storage[uuid4().hex] = {
206 "project_name": self.name,
207 "run_id": self.run_id,
208 "key": key,
209 "artifact_id": pickle_hash,
210 "date": datetime.now(timezone.utc).isoformat(),
211 "learner": report.estimator_name_,
212 "dataset": joblib.hash(report.y_test),
213 "ml_task": report._ml_task,
214 "rmse": metric("rmse"),
--> 215 "log_loss": metric("log_loss"),
216 "roc_auc": metric("roc_auc"),
217 # timings must be calculated last
218 "fit_time": report.metrics.timings().get("fit_time"),
219 "predict_time": report.metrics.timings().get("predict_time_test"),
220 }
File ~/anaconda3/envs/skore_test/lib/python3.12/site-packages/skore_local_project/project.py:202, in Project.put.<locals>.metric(name)
200 if hasattr(report.metrics, name):
201 with suppress(TypeError):
--> 202 return float(getattr(report.metrics, name)(data_source="test"))
203 return None
File ~/Documents/skore/skore/src/skore/sklearn/_estimator/metrics_accessor.py:1184, in _MetricsAccessor.log_loss(self, data_source, X, y)
1146 @available_if(attrgetter("_log_loss"))
1147 def log_loss(
1148 self,
(...)
1152 y: Optional[ArrayLike] = None,
1153 ) -> float:
1154 """Compute the log loss.
1155
1156 Parameters
(...)
1182 0.10...
1183 """
-> 1184 return self._log_loss(
1185 data_source=data_source,
1186 data_source_hash=None,
1187 X=X,
1188 y=y,
1189 )
File ~/Documents/skore/skore/src/skore/sklearn/_estimator/metrics_accessor.py:1210, in _MetricsAccessor._log_loss(self, data_source, data_source_hash, X, y)
1191 @available_if(
1192 _check_supported_ml_task(
1193 supported_ml_tasks=["binary-classification", "multiclass-classification"]
(...)
1202 y: Optional[ArrayLike] = None,
1203 ) -> float:
1204 """Private interface of `log_loss` to be able to pass `data_source_hash`.
1205
1206 `data_source_hash` is either an `int` when we already computed the hash
1207 and are able to pass it around or `None` and thus trigger its computation
1208 in the underlying process.
1209 """
-> 1210 result = self._compute_metric_scores(
1211 metrics.log_loss,
1212 X=X,
1213 y_true=y,
1214 data_source=data_source,
1215 data_source_hash=data_source_hash,
1216 response_method="predict_proba",
1217 )
1218 return cast(float, result)
File ~/Documents/skore/skore/src/skore/sklearn/_estimator/metrics_accessor.py:464, in _MetricsAccessor._compute_metric_scores(self, metric_fn, X, y_true, response_method, data_source, data_source_hash, pos_label, **metric_kwargs)
461 if "pos_label" in metric_params:
462 kwargs.update(pos_label=pos_label)
--> 464 y_pred = _get_cached_response_values(
465 cache=self._parent._cache,
466 estimator_hash=self._parent._hash,
467 estimator=self._parent.estimator_,
468 X=X,
469 response_method=response_method,
470 pos_label=pos_label,
471 data_source=data_source,
472 data_source_hash=data_source_hash,
473 )
475 score = metric_fn(y_true, y_pred, **kwargs)
477 if isinstance(score, np.ndarray):
File ~/Documents/skore/skore/src/skore/sklearn/_base.py:372, in _get_cached_response_values(cache, estimator_hash, estimator, X, response_method, pos_label, data_source, data_source_hash)
323 def _get_cached_response_values(
324 *,
325 cache: dict[tuple[Any, ...], Any],
(...)
332 data_source_hash: Optional[int] = None,
333 ) -> NDArray:
334 """Compute or load from local cache the response values.
335
336 Parameters
(...)
370 The response values.
371 """
--> 372 prediction_method = _check_response_method(estimator, response_method).__name__
374 if data_source == "X_y" and data_source_hash is None:
375 # Only trigger hash computation if it was not previously done.
376 # If data_source_hash is not None, we internally computed ourself the hash
377 # and it is trustful
378 data_source_hash = joblib.hash(X)
File ~/anaconda3/envs/skore_test/lib/python3.12/site-packages/sklearn/utils/validation.py:2283, in _check_response_method(estimator, response_method)
2281 prediction_method = reduce(lambda x, y: x or y, prediction_method)
2282 if prediction_method is None:
-> 2283 raise AttributeError(
2284 f"{estimator.__class__.__name__} has none of the following attributes: "
2285 f"{', '.join(list_methods)}."
2286 )
2288 return prediction_method
AttributeError: SVC has none of the following attributes: predict_proba.Environment
System:
python: 3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:31:09) [GCC 11.2.0]
executable: /home/marie/anaconda3/envs/skore_test/bin/python
machine: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Python dependencies:
skore: 0.0.0+unknown
pip: 24.2
anywidget: 0.9.18
ipython: 8.30.0
ipywidgets: 8.1.5
joblib: 1.4.2
matplotlib: 3.9.3
numpy: 2.2.0
pandas: 2.2.3
plotly: 5.24.1
rich: 14.0.0
scikit-learn: 1.6.1
skore-local-project: 0.0.1Metadata
Metadata
Assignees
Labels
bug 🐛Something isn't workingSomething isn't workingneeds-triage ⚠️This has been recently submitted and needs attentionThis has been recently submitted and needs attention