-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcvxpy_utils.py
More file actions
477 lines (394 loc) · 16.6 KB
/
cvxpy_utils.py
File metadata and controls
477 lines (394 loc) · 16.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
"""A set of helper functions for using cvxpy."""
from __future__ import annotations
import logging
from itertools import product
import numpy as np
import cvxpy as cp
from cvxpy.expressions.variable import Variable
from cvxpy.expressions.expression import Expression
from .roc_utils import calc_cost_of_point, compute_global_roc_from_groupwise
# Maximum distance from solution to feasibility or optimality
SOLUTION_TOLERANCE = 1e-9
# Set of all fairness constraints with a cvxpy LP implementation
ALL_CONSTRAINTS = {
"equalized_odds", # equal TPR and equal FPR across groups
"true_positive_rate_parity", # TPR parity, same as FNR parity
"false_positive_rate_parity", # FPR parity, same as TNR parity
"true_negative_rate_parity", # TNR parity, same as FPR parity
"false_negative_rate_parity", # FNR parity, same as TPR parity
"demographic_parity", # equal positive prediction rates across groups
}
NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE = (
"Currently only the following constraints are supported: {}.".format(
", ".join(sorted(ALL_CONSTRAINTS))
)
)
def compute_line(p1: np.ndarray, p2: np.ndarray) -> tuple[float, float]:
"""Computes the slope and intercept of the line that passes
through the two given points.
The intercept is the value at x=0!
(or NaN for vertical lines)
For vertical lines just use the x-value of one of the points
to find the intercept at y=0.
Parameters
----------
p1 : np.ndarray
A 2-D point.
p2 : np.ndarray
A 2-D point.
Returns
-------
tuple[float, float]
A tuple pair with (slope, intercept) of the line that goes from p1 to p2.
Raises
------
ValueError
Raised when input is invalid, e.g., when p1 == p2.
"""
p1x, p1y = p1
p2x, p2y = p2
if all(p1 == p2):
raise ValueError("Invalid points: p1==p2;")
# Vertical line
if np.isclose(p2x, p1x):
slope = np.inf
intercept = np.nan
# Diagonal or horizontal line
else:
slope = (p2y - p1y) / (p2x - p1x)
intercept = p1y - slope * p1x
return slope, intercept
def compute_halfspace_inequality( # noqa: C901
p1: np.ndarray,
p2: np.ndarray,
) -> tuple[float, float, float]:
"""Computes the halfspace inequality defined by the vector p1->p2, such that
Ax + b <= 0,
where A and b are extracted from the line that goes through p1->p2.
As such, the inequality enforces that points must lie on the LEFT of the
line defined by the p1->p2 vector.
In other words, input points are assumed to be in COUNTER CLOCK-WISE order
(right-hand rule).
Parameters
----------
p1 : np.ndarray
A point in the halfspace.
p2 : np.ndarray
Another point in the halfspace.
Returns
-------
tuple[float, float, float]
Returns an array of size=(n_dims + 1), with format [A; b],
representing the inequality Ax + b <= 0.
Raises
------
RuntimeError
Thrown in case if inconsistent internal state variables.
"""
slope, intercept = compute_line(p1, p2)
# Unpack the points for ease of use
p1x, p1y = p1
p2x, p2y = p2
# if slope is infinity, the constraint only applies to the values of x;
# > the halfspace's b intercept value will correspond to this value of x;
if np.isinf(slope):
# Sanity check for vertical line
if not np.isclose(p1x, p2x):
raise RuntimeError(
"Got infinite slope for line containing two points with "
"different x-axis coordinates."
)
# Vector pointing downwards? then, x >= b
if p2y < p1y:
return [-1, 0, p1x]
# Vector pointing upwards? then, x <= b
elif p2y > p1y:
return [1, 0, -p1x]
# elif slope is zero, the constraint only applies to the values of y;
# > the halfspace's b intercept value will correspond to this value of y;
elif np.isclose(slope, 0.0):
# Sanity checks for horizontal line
if not np.isclose(p1y, p2y) or not np.isclose(p1y, intercept):
raise RuntimeError(
f"Invalid horizontal line; points p1 and p2 should have same "
f"y-axis value as intercept ({p1y}, {p2y}, {intercept})."
)
# Vector pointing leftwards? then, y <= b
if p2x < p1x:
return [0, 1, -p1y]
# Vector pointing rightwards? then, y >= b
elif p2x > p1x:
return [0, -1, p1y]
# else, we have a standard diagonal line
else:
# Vector points left?
# then, y <= mx + b <=> -mx + y - b <= 0
if p2x < p1x:
return [-slope, 1, -intercept]
# Vector points right?
# then, y >= mx + b <=> mx - y + b <= 0
elif p2x > p1x:
return [slope, -1, intercept]
logging.error(f"No constraint can be concluded from points p1={p1} and p2={p2};")
return [0, 0, 0]
def make_cvxpy_halfspace_inequality(
p1: np.ndarray,
p2: np.ndarray,
cvxpy_point: Variable,
) -> Expression:
"""Creates a single cvxpy inequality constraint that enforces the given
point, `cvxpy_point`, to lie on the left of the vector p1->p2.
Points must be sorted in counter clock-wise order!
Parameters
----------
p1 : np.ndarray
A point p1.
p2 : np.ndarray
Another point p2.
cvxpy_point : Variable
The cvxpy variable over which the constraint will be applied.
Returns
-------
Expression
A linear inequality constraint of type Ax + b <= 0.
"""
x_coeff, y_coeff, b = compute_halfspace_inequality(p1, p2)
return np.array([x_coeff, y_coeff]) @ cvxpy_point + b <= 0
def make_cvxpy_point_in_polygon_constraints(
polygon_vertices: np.ndarray,
cvxpy_point: Variable,
) -> list[Expression]:
"""Creates the set of cvxpy constraints that force the given cvxpy variable
point to lie within the polygon defined by the given vertices.
Parameters
----------
polygon_vertices : np.ndarray
A sequence of points that make up a polygon.
Points must be sorted in COUNTER CLOCK-WISE order! (right-hand rule)
cvxpy_point : cvxpy.Variable
A cvxpy variable representing a point, over which the constraints will
be applied.
Returns
-------
list[Expression]
A list of cvxpy constraints.
"""
return [
make_cvxpy_halfspace_inequality(
polygon_vertices[i],
polygon_vertices[(i + 1) % len(polygon_vertices)],
cvxpy_point,
)
for i in range(len(polygon_vertices))
]
def compute_fair_optimum( # noqa: C901
*,
fairness_constraint: str,
tolerance: float,
groupwise_roc_hulls: dict[int, np.ndarray],
group_sizes_label_pos: np.ndarray,
group_sizes_label_neg: np.ndarray,
groupwise_prevalence: np.ndarray,
global_prevalence: float,
false_positive_cost: float = 1.0,
false_negative_cost: float = 1.0,
l_p_norm: int | str = np.inf,
) -> tuple[np.ndarray, np.ndarray]:
"""Computes the solution to finding the optimal fair (equal odds) classifier.
Can relax the equal odds constraint by some given tolerance.
Parameters
----------
fairness_constraint : str
The name of the fairness constraint under which the LP will be
optimized. Possible inputs are:
'equalized_odds'
match true positive and false positive rates across groups
tolerance : float
A value for the tolerance when enforcing the fairness constraint.
groupwise_roc_hulls : dict[int, np.ndarray]
A dict mapping each group to the convex hull of the group's ROC curve.
The convex hull is an np.array of shape (n_points, 2), containing the
points that form the convex hull of the ROC curve, sorted in COUNTER
CLOCK-WISE order.
group_sizes_label_pos : np.ndarray
The relative or absolute number of positive samples in each group.
group_sizes_label_neg : np.ndarray
The relative or absolute number of negative samples in each group.
global_prevalence : float
The global prevalence of positive samples.
false_positive_cost : float, optional
The cost of a FALSE POSITIVE error, by default 1.
false_negative_cost : float, optional
The cost of a FALSE NEGATIVE error, by default 1.
l_p_norm : int | str, optional
The type of l-p norm to use when computing the distance between two ROC
points. Used only for the "equalized_odds" constraint. By default uses
`np.inf` (l-infinity distance): the maximum between groups' TPR and FPR
differences. Using `l_p_norm=1` will correspond to the
`average_abs_odds_difference`.
See the following link for more information on this parameter:
https://www.cvxpy.org/api_reference/cvxpy.atoms.other_atoms.html#norm
Returns
-------
(groupwise_roc_points, global_roc_point) : tuple[np.ndarray, np.ndarray]
A tuple pair, (<1>, <2>), containing:
1: an array with the group-wise ROC points for the solution.
2: an array with the single global ROC point for the solution.
"""
if fairness_constraint not in ALL_CONSTRAINTS:
raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE)
n_groups = len(groupwise_roc_hulls)
if n_groups != len(group_sizes_label_neg) or n_groups != len(group_sizes_label_pos):
raise ValueError(
"Invalid arguments; all of the following should have the same "
"length: groupwise_roc_hulls, group_sizes_label_neg, group_sizes_label_pos;"
f"got: {len(groupwise_roc_hulls)}, {len(group_sizes_label_neg)}, {len(group_sizes_label_pos)};"
)
# Group-wise ROC points --- in the form (FPR, TPR)
groupwise_roc_points_vars = [
cp.Variable(shape=2, name=f"ROC point for group {i}", nonneg=True)
for i in range(n_groups)
]
# Define global ROC point as a linear combination of the group-wise ROC points
global_roc_point_var = cp.Variable(shape=2, name="Global ROC point", nonneg=True)
constraints = [
# Global FPR is the average of group FPRs weighted by LNs in each group
global_roc_point_var[0]
== group_sizes_label_neg @ np.array([p[0] for p in groupwise_roc_points_vars]),
# Global TPR is the average of group TPRs weighted by LPs in each group
global_roc_point_var[1]
== group_sizes_label_pos @ np.array([p[1] for p in groupwise_roc_points_vars]),
]
# ** APPLY FAIRNESS CONSTRAINTS **
# NOTE: feature request: compatibility with multiple constraints simultaneously
# If "equalized_odds"
# - i.e., l-p distance between any two groups' ROC points must be less than `tolerance`;
# - DEFAULT: l-infinity distance (max distance between any two points in the ROC curve);
if fairness_constraint == "equalized_odds":
constraints += [
cp.norm(
groupwise_roc_points_vars[i] - groupwise_roc_points_vars[j],
p=l_p_norm,
)
<= tolerance
for i, j in product(range(n_groups), range(n_groups))
if i < j
]
# If some rate parity, i.e., parity of one of {TPR, FPR, TNR, FNR}
# i.e., constrain absolute distance between any two groups' rate metric
elif fairness_constraint.endswith("rate_parity"):
roc_idx_of_interest: int
if (
fairness_constraint == "true_positive_rate_parity" # TPR
or fairness_constraint == "false_negative_rate_parity" # FNR
):
roc_idx_of_interest = 1
elif (
fairness_constraint == "false_positive_rate_parity" # FPR
or fairness_constraint == "true_negative_rate_parity" # TNR
):
roc_idx_of_interest = 0
else:
# This point should never be reached as fairness_constraint was previously validated
raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE)
constraints += [
cp.abs(
groupwise_roc_points_vars[i][roc_idx_of_interest]
- groupwise_roc_points_vars[j][roc_idx_of_interest]
)
<= tolerance
for i, j in product(range(n_groups), range(n_groups))
if i < j
]
# If demographic parity, i.e., equal positive prediction rates across groups
# note: this ignores the labels Y and only considers predictions Y_hat
elif fairness_constraint == "demographic_parity":
# NOTE: PPR = TPR * prevalence + FPR * (1 - prevalence)
def group_positive_prediction_rate(group_idx: int):
"""Computes group-wise PPR as a function of the ROC cvxpy vars."""
group_prevalence = groupwise_prevalence[group_idx]
group_tpr = groupwise_roc_points_vars[group_idx][1]
group_fpr = groupwise_roc_points_vars[group_idx][0]
return group_tpr * group_prevalence + group_fpr * (1 - group_prevalence)
# Add constraints on the absolute difference between group-wos
constraints += [
cp.abs(
group_positive_prediction_rate(i) - group_positive_prediction_rate(j)
) <= tolerance
for i, j in product(range(n_groups), range(n_groups))
if i < j
]
# NOTE: implement other constraints here
else:
raise NotImplementedError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE)
# Constraints for points in respective group-wise ROC curves
for idx in range(n_groups):
constraints += make_cvxpy_point_in_polygon_constraints(
polygon_vertices=groupwise_roc_hulls[idx],
cvxpy_point=groupwise_roc_points_vars[idx],
)
# Define cost function
obj = cp.Minimize(
calc_cost_of_point(
fpr=global_roc_point_var[0],
fnr=1 - global_roc_point_var[1],
prevalence=global_prevalence,
false_pos_cost=false_positive_cost,
false_neg_cost=false_negative_cost,
)
)
# Define cvxpy problem
prob = cp.Problem(obj, constraints)
# Run solver
# prob.solve(solver=cp.ECOS, abstol=SOLUTION_TOLERANCE, feastol=SOLUTION_TOLERANCE)
# NOTE: ECOS solver has been deprecated in favor of CLARABEL in cvxpy 1.3.2+
# https://www.cvxpy.org/updates/index.html?h=ecos#ecos-deprecation
prob.solve(
solver=cp.CLARABEL,
tol_gap_abs=SOLUTION_TOLERANCE,
tol_feas=SOLUTION_TOLERANCE,
)
# NOTE: these tolerances are supposed to be smaller than the default np.isclose tolerances
# (useful when comparing if two points are the same, within the cvxpy accuracy tolerance)
# Log solution
logging.info(
f"cvxpy solver took {prob.solver_stats.solve_time}s; status is {prob.status}."
)
if prob.status not in ["infeasible", "unbounded"]:
# Otherwise, problem.value is inf or -inf, respectively.
logging.info(f"Optimal solution value: {prob.value}")
for variable in prob.variables():
logging.info(f"Variable {variable.name()}: value {variable.value}")
else:
# This line should never be reached (there are always trivial fair
# solutions in the ROC diagonal)
raise ValueError(f"cvxpy problem has no solution; status={prob.status}")
groupwise_roc_points = np.vstack([p.value for p in groupwise_roc_points_vars])
global_roc_point = global_roc_point_var.value
# Validating solution cost
solution_cost = calc_cost_of_point(
fpr=global_roc_point[0],
fnr=1 - global_roc_point[1],
prevalence=global_prevalence,
false_pos_cost=false_positive_cost,
false_neg_cost=false_negative_cost,
)
if not np.isclose(solution_cost, prob.value):
logging.error(
f"Solution was found but cost did not pass validation! "
f"Found solution ROC point {global_roc_point} with theoretical cost "
f"{prob.value}, but actual cost is {solution_cost};"
)
# Validating congruency between group-wise ROC points and global ROC point
global_roc_from_groupwise = compute_global_roc_from_groupwise(
groupwise_roc_points=groupwise_roc_points,
groupwise_label_pos_weight=group_sizes_label_pos,
groupwise_label_neg_weight=group_sizes_label_neg,
)
if not all(np.isclose(global_roc_from_groupwise, global_roc_point)):
logging.error(
f"Solution: global ROC point ({global_roc_point}) does not seem to "
f"match group-wise ROC points; global should be "
f"({global_roc_from_groupwise}) to be consistent with group-wise;"
)
return groupwise_roc_points, global_roc_point