Skip to content

Commit 6a1f1f0

Browse files
committed
ruffed and ready
1 parent 6947efe commit 6a1f1f0

5 files changed

Lines changed: 245 additions & 365 deletions

File tree

molpipeline/experimental/uncertainty/conformal.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818

1919
def bin_targets(y: NDArray[Any], n_bins: int = 10) -> NDArray[np.int_]:
20-
"""
21-
Bin continuous targets for stratified splitting in regression.
20+
"""Bin continuous targets for stratified splitting in regression.
2221
2322
Parameters
2423
----------
@@ -31,6 +30,7 @@ def bin_targets(y: NDArray[Any], n_bins: int = 10) -> NDArray[np.int_]:
3130
-------
3231
np.ndarray
3332
Binned targets.
33+
3434
"""
3535
y = np.asarray(y)
3636
bins = np.linspace(np.min(y), np.max(y), n_bins + 1)
@@ -40,9 +40,9 @@ def bin_targets(y: NDArray[Any], n_bins: int = 10) -> NDArray[np.int_]:
4040

4141

4242
class UnifiedConformalCV(BaseEstimator):
43-
"""One wrapper to rule them all: conformal prediction for both classifiers and regressors.
43+
"""Conformal prediction wrapper for both classifiers and regressors.
4444
45-
Uses crepes under the hood, so you know it's sweet.
45+
Uses crepes under the hood.
4646
4747
Parameters
4848
----------
@@ -146,7 +146,10 @@ def fit(self, x: NDArray[Any], y: NDArray[Any]) -> "UnifiedConformalCV":
146146
return self
147147

148148
def calibrate(
149-
self, x_calib: NDArray[Any], y_calib: NDArray[Any], **calib_params: Any,
149+
self,
150+
x_calib: NDArray[Any],
151+
y_calib: NDArray[Any],
152+
**calib_params: Any,
150153
) -> None:
151154
"""Calibrate the conformal predictor.
152155
@@ -225,7 +228,9 @@ def predict_proba(self, x: NDArray[Any]) -> NDArray[Any]:
225228
return conformal.predict_proba(x)
226229

227230
def predict_conformal_set(
228-
self, x: NDArray[Any], confidence: float | None = None,
231+
self,
232+
x: NDArray[Any],
233+
confidence: float | None = None,
229234
) -> Any:
230235
"""Predict conformal sets.
231236
@@ -309,7 +314,7 @@ def predict_int(self, x: NDArray[Any], confidence: float | None = None) -> Any:
309314

310315

311316
class CrossConformalCV(BaseEstimator):
312-
"""Cross-conformal prediction for both classifiers and regressors using WrapClassifier/WrapRegressor.
317+
"""Cross-conformal prediction using WrapClassifier/WrapRegressor.
313318
314319
Handles Mondrian (class_cond) logic as described.
315320
@@ -416,7 +421,9 @@ def fit(
416421
self.models_ = []
417422
if self.estimator_type == "classifier":
418423
splitter = StratifiedKFold(
419-
n_splits=self.n_folds, shuffle=True, random_state=42,
424+
n_splits=self.n_folds,
425+
shuffle=True,
426+
random_state=42,
420427
)
421428
y_split = y
422429
elif self.estimator_type == "regressor":
@@ -448,7 +455,8 @@ def fit(
448455
calib_idx_val = calib_idx
449456

450457
def _bin_func(
451-
_: Any, calib_idx_val: Any = calib_idx_val,
458+
_: Any,
459+
calib_idx_val: Any = calib_idx_val,
452460
) -> Any:
453461
return y[calib_idx_val]
454462

@@ -516,7 +524,9 @@ def predict_proba(self, x: NDArray[Any]) -> NDArray[Any]:
516524
return proba
517525

518526
def predict_conformal_set(
519-
self, x: NDArray[Any], confidence: float | None = None,
527+
self,
528+
x: NDArray[Any],
529+
confidence: float | None = None,
520530
) -> list[list[Any]]:
521531
"""Predict conformal sets using the cross-conformal predictor.
522532

0 commit comments

Comments
 (0)