Skip to content

Commit 6947efe

Browse files
committed
pull first
1 parent 4588f28 commit 6947efe

6 files changed

Lines changed: 247 additions & 109 deletions

File tree

molpipeline/experimental/uncertainty/conformal.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@
33
Provides unified and cross-conformal prediction with Mondrian and nonconformity options.
44
"""
55

6+
# pylint: disable=too-many-instance-attributes, attribute-defined-outside-init
7+
68
from typing import Any, cast
79

810
import numpy as np
911
from crepes import WrapClassifier, WrapRegressor
1012
from crepes.extras import MondrianCategorizer
13+
from numpy.typing import NDArray
1114
from scipy.stats import mode
1215
from sklearn.base import BaseEstimator, clone
1316
from sklearn.model_selection import KFold, StratifiedKFold
1417

1518

16-
def bin_targets(y: np.ndarray, n_bins: int = 10) -> np.ndarray:
17-
"""Bin continuous targets for stratified splitting in regression.
19+
def bin_targets(y: NDArray[Any], n_bins: int = 10) -> NDArray[np.int_]:
20+
"""
21+
Bin continuous targets for stratified splitting in regression.
1822
1923
Parameters
2024
----------
@@ -27,7 +31,6 @@ def bin_targets(y: np.ndarray, n_bins: int = 10) -> np.ndarray:
2731
-------
2832
np.ndarray
2933
Binned targets.
30-
3134
"""
3235
y = np.asarray(y)
3336
bins = np.linspace(np.min(y), np.max(y), n_bins + 1)
@@ -99,6 +102,7 @@ def __init__(
99102
Number of parallel jobs (default: 1).
100103
**kwargs : Any
101104
Additional keyword arguments for crepes.
105+
102106
"""
103107
self.estimator = estimator
104108
self.mondrian = mondrian
@@ -110,7 +114,7 @@ def __init__(
110114
self.n_jobs = n_jobs
111115
self.kwargs = kwargs
112116

113-
def fit(self, x: np.ndarray, y: np.ndarray) -> "UnifiedConformalCV":
117+
def fit(self, x: NDArray[Any], y: NDArray[Any]) -> "UnifiedConformalCV":
114118
"""Fit the conformal predictor.
115119
116120
Parameters
@@ -142,7 +146,7 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> "UnifiedConformalCV":
142146
return self
143147

144148
def calibrate(
145-
self, x_calib: np.ndarray, y_calib: np.ndarray, **calib_params: Any,
149+
self, x_calib: NDArray[Any], y_calib: NDArray[Any], **calib_params: Any,
146150
) -> None:
147151
"""Calibrate the conformal predictor.
148152
@@ -180,7 +184,7 @@ def calibrate(
180184
else:
181185
raise ValueError("estimator_type must be 'classifier' or 'regressor'")
182186

183-
def predict(self, x: np.ndarray) -> np.ndarray:
187+
def predict(self, x: NDArray[Any]) -> NDArray[Any]:
184188
"""Predict using the conformal predictor.
185189
186190
Parameters
@@ -196,7 +200,7 @@ def predict(self, x: np.ndarray) -> np.ndarray:
196200
"""
197201
return self._conformal.predict(x)
198202

199-
def predict_proba(self, x: np.ndarray) -> np.ndarray:
203+
def predict_proba(self, x: NDArray[Any]) -> NDArray[Any]:
200204
"""Predict probabilities using the conformal predictor.
201205
202206
Parameters
@@ -221,7 +225,7 @@ def predict_proba(self, x: np.ndarray) -> np.ndarray:
221225
return conformal.predict_proba(x)
222226

223227
def predict_conformal_set(
224-
self, x: np.ndarray, confidence: float | None = None,
228+
self, x: NDArray[Any], confidence: float | None = None,
225229
) -> Any:
226230
"""Predict conformal sets.
227231
@@ -251,7 +255,7 @@ def predict_conformal_set(
251255
conformal = cast("WrapClassifier", self._conformal)
252256
return conformal.predict_set(x, confidence=conf)
253257

254-
def predict_p(self, x: np.ndarray, **kwargs: Any) -> Any:
258+
def predict_p(self, x: NDArray[Any], **kwargs: Any) -> Any:
255259
"""Predict p-values.
256260
257261
Parameters
@@ -276,7 +280,7 @@ def predict_p(self, x: np.ndarray, **kwargs: Any) -> Any:
276280
raise NotImplementedError("predict_p is only for classification.")
277281
return self._conformal.predict_p(x, **kwargs)
278282

279-
def predict_int(self, x: np.ndarray, confidence: float | None = None) -> Any:
283+
def predict_int(self, x: NDArray[Any], confidence: float | None = None) -> Any:
280284
"""Predict intervals.
281285
282286
Parameters
@@ -370,6 +374,7 @@ def __init__(
370374
Number of bins for stratified splitting in regression (default: 10).
371375
**kwargs : Any
372376
Additional keyword arguments for crepes.
377+
373378
"""
374379
self.estimator = estimator
375380
self.n_folds = n_folds
@@ -383,8 +388,8 @@ def __init__(
383388

384389
def fit(
385390
self,
386-
x: np.ndarray,
387-
y: np.ndarray,
391+
x: NDArray[Any],
392+
y: NDArray[Any],
388393
) -> "CrossConformalCV":
389394
"""Fit the cross-conformal predictor.
390395
@@ -453,7 +458,7 @@ def _bin_func(
453458
self.models_.append(model)
454459
return self
455460

456-
def predict(self, x: np.ndarray) -> np.ndarray:
461+
def predict(self, x: NDArray[Any]) -> NDArray[Any]:
457462
"""Predict using the cross-conformal predictor.
458463
459464
Parameters
@@ -476,7 +481,7 @@ def predict(self, x: np.ndarray) -> np.ndarray:
476481
pred_mode = mode(result, axis=0, keepdims=False)
477482
return np.ravel(pred_mode.mode)
478483

479-
def predict_proba(self, x: np.ndarray) -> np.ndarray:
484+
def predict_proba(self, x: NDArray[Any]) -> NDArray[Any]:
480485
"""Predict probabilities using the cross-conformal predictor.
481486
482487
Parameters
@@ -511,7 +516,7 @@ def predict_proba(self, x: np.ndarray) -> np.ndarray:
511516
return proba
512517

513518
def predict_conformal_set(
514-
self, x: np.ndarray, confidence: float | None = None,
519+
self, x: NDArray[Any], confidence: float | None = None,
515520
) -> list[list[Any]]:
516521
"""Predict conformal sets using the cross-conformal predictor.
517522

0 commit comments

Comments
 (0)