33Provides unified and cross-conformal prediction with Mondrian and nonconformity options.
44"""
55
6+ # pylint: disable=too-many-instance-attributes, attribute-defined-outside-init
7+
68from typing import Any , cast
79
810import numpy as np
911from crepes import WrapClassifier , WrapRegressor
1012from crepes .extras import MondrianCategorizer
13+ from numpy .typing import NDArray
1114from scipy .stats import mode
1215from sklearn .base import BaseEstimator , clone
1316from sklearn .model_selection import KFold , StratifiedKFold
1417
1518
16- def bin_targets (y : np .ndarray , n_bins : int = 10 ) -> np .ndarray :
17- """Bin continuous targets for stratified splitting in regression.
19+ def bin_targets (y : NDArray [Any ], n_bins : int = 10 ) -> NDArray [np .int_ ]:
20+ """
21+ Bin continuous targets for stratified splitting in regression.
1822
1923 Parameters
2024 ----------
@@ -27,7 +31,6 @@ def bin_targets(y: np.ndarray, n_bins: int = 10) -> np.ndarray:
2731 -------
2832 np.ndarray
2933 Binned targets.
30-
3134 """
3235 y = np .asarray (y )
3336 bins = np .linspace (np .min (y ), np .max (y ), n_bins + 1 )
@@ -99,6 +102,7 @@ def __init__(
99102 Number of parallel jobs (default: 1).
100103 **kwargs : Any
101104 Additional keyword arguments for crepes.
105+
102106 """
103107 self .estimator = estimator
104108 self .mondrian = mondrian
@@ -110,7 +114,7 @@ def __init__(
110114 self .n_jobs = n_jobs
111115 self .kwargs = kwargs
112116
113- def fit (self , x : np . ndarray , y : np . ndarray ) -> "UnifiedConformalCV" :
117+ def fit (self , x : NDArray [ Any ] , y : NDArray [ Any ] ) -> "UnifiedConformalCV" :
114118 """Fit the conformal predictor.
115119
116120 Parameters
@@ -142,7 +146,7 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> "UnifiedConformalCV":
142146 return self
143147
144148 def calibrate (
145- self , x_calib : np . ndarray , y_calib : np . ndarray , ** calib_params : Any ,
149+ self , x_calib : NDArray [ Any ] , y_calib : NDArray [ Any ] , ** calib_params : Any ,
146150 ) -> None :
147151 """Calibrate the conformal predictor.
148152
@@ -180,7 +184,7 @@ def calibrate(
180184 else :
181185 raise ValueError ("estimator_type must be 'classifier' or 'regressor'" )
182186
183- def predict (self , x : np . ndarray ) -> np . ndarray :
187+ def predict (self , x : NDArray [ Any ] ) -> NDArray [ Any ] :
184188 """Predict using the conformal predictor.
185189
186190 Parameters
@@ -196,7 +200,7 @@ def predict(self, x: np.ndarray) -> np.ndarray:
196200 """
197201 return self ._conformal .predict (x )
198202
199- def predict_proba (self , x : np . ndarray ) -> np . ndarray :
203+ def predict_proba (self , x : NDArray [ Any ] ) -> NDArray [ Any ] :
200204 """Predict probabilities using the conformal predictor.
201205
202206 Parameters
@@ -221,7 +225,7 @@ def predict_proba(self, x: np.ndarray) -> np.ndarray:
221225 return conformal .predict_proba (x )
222226
223227 def predict_conformal_set (
224- self , x : np . ndarray , confidence : float | None = None ,
228+ self , x : NDArray [ Any ] , confidence : float | None = None ,
225229 ) -> Any :
226230 """Predict conformal sets.
227231
@@ -251,7 +255,7 @@ def predict_conformal_set(
251255 conformal = cast ("WrapClassifier" , self ._conformal )
252256 return conformal .predict_set (x , confidence = conf )
253257
254- def predict_p (self , x : np . ndarray , ** kwargs : Any ) -> Any :
258+ def predict_p (self , x : NDArray [ Any ] , ** kwargs : Any ) -> Any :
255259 """Predict p-values.
256260
257261 Parameters
@@ -276,7 +280,7 @@ def predict_p(self, x: np.ndarray, **kwargs: Any) -> Any:
276280 raise NotImplementedError ("predict_p is only for classification." )
277281 return self ._conformal .predict_p (x , ** kwargs )
278282
279- def predict_int (self , x : np . ndarray , confidence : float | None = None ) -> Any :
283+ def predict_int (self , x : NDArray [ Any ] , confidence : float | None = None ) -> Any :
280284 """Predict intervals.
281285
282286 Parameters
@@ -370,6 +374,7 @@ def __init__(
370374 Number of bins for stratified splitting in regression (default: 10).
371375 **kwargs : Any
372376 Additional keyword arguments for crepes.
377+
373378 """
374379 self .estimator = estimator
375380 self .n_folds = n_folds
@@ -383,8 +388,8 @@ def __init__(
383388
384389 def fit (
385390 self ,
386- x : np . ndarray ,
387- y : np . ndarray ,
391+ x : NDArray [ Any ] ,
392+ y : NDArray [ Any ] ,
388393 ) -> "CrossConformalCV" :
389394 """Fit the cross-conformal predictor.
390395
@@ -453,7 +458,7 @@ def _bin_func(
453458 self .models_ .append (model )
454459 return self
455460
456- def predict (self , x : np . ndarray ) -> np . ndarray :
461+ def predict (self , x : NDArray [ Any ] ) -> NDArray [ Any ] :
457462 """Predict using the cross-conformal predictor.
458463
459464 Parameters
@@ -476,7 +481,7 @@ def predict(self, x: np.ndarray) -> np.ndarray:
476481 pred_mode = mode (result , axis = 0 , keepdims = False )
477482 return np .ravel (pred_mode .mode )
478483
479- def predict_proba (self , x : np . ndarray ) -> np . ndarray :
484+ def predict_proba (self , x : NDArray [ Any ] ) -> NDArray [ Any ] :
480485 """Predict probabilities using the cross-conformal predictor.
481486
482487 Parameters
@@ -511,7 +516,7 @@ def predict_proba(self, x: np.ndarray) -> np.ndarray:
511516 return proba
512517
513518 def predict_conformal_set (
514- self , x : np . ndarray , confidence : float | None = None ,
519+ self , x : NDArray [ Any ] , confidence : float | None = None ,
515520 ) -> list [list [Any ]]:
516521 """Predict conformal sets using the cross-conformal predictor.
517522
0 commit comments