Skip to content

Commit 3b7104c

Browse files
Enable use of 'brier_score' and 'average_precision' for Tuner and UQEnsemble. Co-authored-by: vaifai <vaifaipandey1996@gmail.com>
* Initialising new metrics * Updating docstrings, examples and adding the new metrics to the validate_tuning_inputs method * Replacing brier_score with average_precision in threshold objectives * Removing brier_score and average_precision from threshold_objective * Removing instances of average_precision from demo files * Removing average_precision from thresh_objective in ensemble.py * Improving upon the comments * fix notebook description * add clip step to ensemble score computation * fix docstring * fix unit test --------- Co-authored-by: vaifai <vaifaipandey1996@gmail.com>
1 parent 6b8083d commit 3b7104c

4 files changed

Lines changed: 24 additions & 24 deletions

File tree

examples/ensemble_tuning_demo.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@
501501
" <li><code>ground_truth_answers</code> - (<strong>List[str]</strong>) A list of ideal (correct) responses.</li>\n",
502502
" <li><code>grader_function</code> - (<strong>callable, default=None</strong>) A user-defined function that takes a response and a ground truth 'answer' and returns a boolean indicator of whether the response is correct. If not provided, vectara's HHEM is used: https://huggingface.co/vectara/hallucination_evaluation_model</li>\n",
503503
" <li><code>num_responses</code> - (<strong>int, default=5</strong>) The number of sampled responses used to compute consistency.</li>\n",
504-
" <li><code>weights_objective</code> - (<strong>str, default='roc_auc'</strong>) Objective function for weight optimization. Must match thresh_objective if one of {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}. If same as thresh_objective, joint optimization will be done.</li>\n",
505-
" <li><code>thresh_objective</code> - (<strong>str, default='fbeta_score'</strong>) Objective function for threshold optimization via grid search. One of {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}.</li>\n",
504+
" <li><code>weights_objective</code> - (<strong>str, default='roc_auc'</strong>) Objective function for weight optimization. One of {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss', 'average_precision', 'brier_score'}. Must match thresh_objective if one of {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}. If same as thresh_objective, joint optimization will be done.</li>\n",
505+
" <li><code>thresh_objective</code> - (<strong>str, default='fbeta_score'</strong>) Objective function for threshold optimization via grid search. One of {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}.</li>\n",
506506
" <li><code>thresh_bounds</code> - (<strong>tuple of floats, default=(0,1)</strong>) Bounds to search for threshold.</li>\n",
507507
" <li><code>n_trials</code> - (<strong>int, default=100</strong>) Indicates how many trials to search over with optuna optimizer</li>\n",
508508
" <li><code>step_size</code> - (<strong>float, default=0.01</strong>) Indicates step size in grid search, if used.</li>\n",
@@ -1229,15 +1229,15 @@
12291229
],
12301230
"metadata": {
12311231
"environment": {
1232-
"kernel": "uqlm",
1233-
"name": "workbench-notebooks.m125",
1232+
"kernel": "uqlm_my_test",
1233+
"name": "workbench-notebooks.m126",
12341234
"type": "gcloud",
1235-
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125"
1235+
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m126"
12361236
},
12371237
"kernelspec": {
1238-
"display_name": "uqlm",
1238+
"display_name": "uqlm_my_test",
12391239
"language": "python",
1240-
"name": "uqlm"
1240+
"name": "uqlm_my_test"
12411241
},
12421242
"language_info": {
12431243
"codemirror_mode": {
@@ -1249,7 +1249,7 @@
12491249
"name": "python",
12501250
"nbconvert_exporter": "python",
12511251
"pygments_lexer": "ipython3",
1252-
"version": "3.9.21"
1252+
"version": "3.11.12"
12531253
}
12541254
},
12551255
"nbformat": 4,

tests/test_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def setup_method(self):
2727
def test_initialization(self):
2828
# Test default initialization
2929
tuner = Tuner()
30-
assert list(tuner.objective_to_func.keys()) == ["fbeta_score", "accuracy_score", "balanced_accuracy_score", "log_loss", "roc_auc"]
30+
assert list(tuner.objective_to_func.keys()) == ["fbeta_score", "accuracy_score", "balanced_accuracy_score", "log_loss", "roc_auc", "average_precision", "brier_score"]
3131

3232
def test_tune_threshold(self):
3333
tuner = Tuner()
@@ -61,7 +61,7 @@ def test_validation_errors_and_optimization_paths(self):
6161
# test unsupported weights_objective
6262
with pytest.raises(ValueError) as e:
6363
Tuner().tune_params(score_lists=self.score_lists, correct_indicators=self.correct_indicators, weights_objective="invalid")
64-
assert "Only 'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc_score', and 'log_loss' are supported for tuning objectives." in str(e.value)
64+
assert "Only 'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc_score', 'log_loss', 'average_precision', and 'brier_score' are supported for tuning objectives." in str(e.value)
6565

6666
# test unsupported thresh_objective
6767
with pytest.raises(ValueError) as e:

uqlm/scorers/ensemble.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,14 @@ def tune_from_graded(self, correct_indicators: List[bool], weights_objective: st
218218
correct_indicators : list of bool
219219
A list of boolean indicators of whether self.responses are correct.
220220
221-
weights_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='roc_auc'
221+
weights_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss', 'average_precision', 'brier_score'}, default='roc_auc'
222222
Objective function for weight optimization. Must match thresh_objective if one of 'fbeta_score',
223223
'accuracy_score', 'balanced_accuracy_score'. If same as thresh_objective, joint optimization will be done.
224224
225225
thresh_bounds : tuple of floats, default=(0,1)
226226
Bounds to search for threshold
227227
228-
thresh_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='fbeta_score'
228+
thresh_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}, default='fbeta_score'
229229
Objective function for threshold optimization via grid search.
230230
231231
n_trials : int, default=100
@@ -269,14 +269,14 @@ async def tune(self, prompts: List[str], ground_truth_answers: List[str], grader
269269
num_responses : int, default=5
270270
The number of sampled responses used to compute consistency.
271271
272-
weights_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='roc_auc'
272+
weights_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss', 'average_precision', 'brier_score'}, default='roc_auc'
273273
Objective function for weight optimization. Must match thresh_objective if one of 'fbeta_score',
274274
'accuracy_score', 'balanced_accuracy_score'. If same as thresh_objective, joint optimization will be done.
275275
276276
thresh_bounds : tuple of floats, default=(0,1)
277277
Bounds to search for threshold
278278
279-
thresh_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='fbeta_score'
279+
thresh_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}, default='fbeta_score'
280280
Objective function for threshold optimization via grid search.
281281
282282
n_trials : int, default=100

uqlm/utils/tuner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
import optuna
1919
from typing import Any, Dict, List, Tuple
2020

21-
from sklearn.metrics import fbeta_score, balanced_accuracy_score, accuracy_score, roc_auc_score, log_loss
21+
from sklearn.metrics import fbeta_score, balanced_accuracy_score, accuracy_score, roc_auc_score, log_loss, average_precision_score, brier_score_loss
2222

2323
optuna.logging.set_verbosity(optuna.logging.WARNING)
2424

2525

2626
class Tuner:
2727
def __init__(self) -> None:
2828
"""
29-
Class for tuning weights and threshold for UQEnsemble class.
29+
Class for tuning weights and threshold for UQEnsemble
3030
"""
31-
self.objective_to_func = {"fbeta_score": self._f_score, "accuracy_score": accuracy_score, "balanced_accuracy_score": balanced_accuracy_score, "log_loss": log_loss, "roc_auc": roc_auc_score}
31+
self.objective_to_func = {"fbeta_score": self._f_score, "accuracy_score": accuracy_score, "balanced_accuracy_score": balanced_accuracy_score, "log_loss": log_loss, "roc_auc": roc_auc_score, "average_precision": average_precision_score, "brier_score": brier_score_loss}
3232

3333
def tune_threshold(self, y_scores: List[float], correct_indicators: List[bool], thresh_objective: str = "fbeta_score", fscore_beta: float = 1, bounds: Tuple[float, float] = (0, 1), step_size: int = 0.01) -> float:
3434
"""
@@ -42,7 +42,7 @@ def tune_threshold(self, y_scores: List[float], correct_indicators: List[bool],
4242
correct_indicators : list of bool
4343
A list of boolean indicators of whether self.original_responses are correct.
4444
45-
thresh_objective: {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='fbeta_score'
45+
thresh_objective: {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}, default='fbeta_score'
4646
Objective function for threshold optimization via grid search.
4747
4848
fscore_beta : float, default=1
@@ -84,11 +84,11 @@ def tune_params(self, score_lists: List[List[float]], correct_indicators: List[b
8484
correct_indicators : list of bool
8585
A list of boolean indicators of whether self.original_responses are correct.
8686
87-
weights_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='roc_auc'
87+
weights_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss', 'average_precision', 'brier_score'}, default='roc_auc'
8888
Objective function for optimization of weights. Must match thresh_objective if one of 'fbeta_score',
8989
'accuracy_score', 'balanced_accuracy_score'. If same as thresh_objective, joint optimization will be done.
9090
91-
thresh_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc', 'log_loss'}, default='fbeta_score'
91+
thresh_objective : {'fbeta_score', 'accuracy_score', 'balanced_accuracy_score'}, default='fbeta_score'
9292
Objective function for threshold optimization via grid search.
9393
9494
thresh_bounds : tuple of floats, default=(0,1)
@@ -118,9 +118,8 @@ def tune_params(self, score_lists: List[List[float]], correct_indicators: List[b
118118
self.step_size = step_size
119119
self.fscore_beta = fscore_beta
120120
self.optimize_jointly = weights_objective == thresh_objective
121-
self.obj_multiplier = 1 if weights_objective == "logloss" else -1
121+
self.obj_multiplier = 1 if weights_objective in ["logloss", "brier_score"] else -1
122122

123-
# Validate inputs are correct
124123
self._validate_tuning_inputs()
125124
self.weights_tuning_objective = self.objective_to_func[self.weights_objective]
126125
self.threshold_tuning_objective = self.objective_to_func[self.thresh_objective]
@@ -171,7 +170,7 @@ def _validate_tuning_inputs(self):
171170
if self.weights_objective not in self.objective_to_func:
172171
raise ValueError(
173172
"""
174-
Only 'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc_score', and 'log_loss' are supported for tuning objectives.
173+
Only 'fbeta_score', 'accuracy_score', 'balanced_accuracy_score', 'roc_auc_score', 'log_loss', 'average_precision', and 'brier_score' are supported for tuning objectives.
175174
"""
176175
)
177176
if self.thresh_objective not in ["fbeta_score", "accuracy_score", "balanced_accuracy_score"]:
@@ -214,7 +213,8 @@ def _compute_ensemble_scores(self, weights: List[float], score_lists: List[List[
214213
adjusted_weights = weights[:, None] * valid_mask
215214
normalized_weights = adjusted_weights / np.sum(adjusted_weights, axis=0, keepdims=True)
216215
stacked_nonan = np.nan_to_num(score_lists, nan=0.0)
217-
return np.sum(stacked_nonan * normalized_weights, axis=0)
216+
ensemble_scores = np.sum(stacked_nonan * normalized_weights, axis=0)
217+
return np.clip(ensemble_scores, 0, 1)
218218

219219
def _grid_search_weights_thresh(self):
220220
"""

0 commit comments

Comments
 (0)