Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions methods/catalog/probe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .model import Probe
3 changes: 3 additions & 0 deletions methods/catalog/probe/library/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .wachter_rip import probe_recourse
261 changes: 261 additions & 0 deletions methods/catalog/probe/library/wachter_rip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import datetime
from typing import List, Optional

import numpy as np
import math
import torch
import torch.optim as optim
import torch.distributions.normal as normal_distribution
from torch.distributions.multivariate_normal import MultivariateNormal
from torch import nn
from torch.autograd import Variable

from methods.processing import reconstruct_encoding_constraints

DECISION_THRESHOLD = 0.5

# Mean and variance for rectified normal distribution:
# see in here : http://journal-sfds.fr/article/view/669


def compute_jacobian(inputs, output, num_classes=1):
"""
:param inputs: Batch X Size (e.g. Depth X Width X Height)
:param output: Batch X Classes
:return: jacobian: Batch X Classes X Size
"""
assert inputs.requires_grad
grad = gradient(output, inputs)
return grad


def gradient(y, x, grad_outputs=None):
"""Compute dy/dx @ grad_outputs"""
if grad_outputs is None:
grad_outputs = torch.tensor(1)
grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
return grad


def compute_invalidation_rate_closed(torch_model, x, sigma2):
# Compute input into CDF
prob = torch_model(x)
logit_x = torch.log(prob[0][1] / prob[0][0])
Sigma2 = sigma2 * torch.eye(x.shape[0])
jacobian_x = compute_jacobian(x, logit_x, num_classes=1).reshape(-1)
denom = torch.sqrt(sigma2) * torch.norm(jacobian_x, 2)
arg = logit_x / denom

# Evaluate Gaussian cdf
normal = normal_distribution.Normal(loc=0.0, scale=1.0)
normal_cdf = normal.cdf(arg)

# Get invalidation rate
ir = 1 - normal_cdf

return ir


def perturb_sample(x, n_samples, sigma2):
# stack copies of this sample, i.e. n rows of x.
X = x.repeat(n_samples, 1)
# sample normal distributed values
Sigma = torch.eye(x.shape[1]) * sigma2
eps = MultivariateNormal(
loc=torch.zeros(x.shape[1]), covariance_matrix=Sigma
).sample((n_samples,))

return X + eps

def reparametrization_trick(mu, sigma2, n_samples):

#var = torch.eye(mu.shape[1]) * sigma2
std = torch.sqrt(sigma2)
epsilon = MultivariateNormal(loc=torch.zeros(mu.shape[1]), covariance_matrix=torch.eye(mu.shape[1]))
epsilon = epsilon.sample((n_samples,)) # standard Gaussian random noise
ones = torch.ones_like(epsilon)
random_samples = mu.reshape(-1) * ones + std * epsilon

return random_samples


def compute_invalidation_rate(torch_model, random_samples):
yhat = torch_model(random_samples)[:, 1]
hat = (yhat > 0.5).float()
ir = 1 - torch.mean(hat, 0)
return ir


def probe_recourse(
torch_model,
x: np.ndarray,
cat_feature_indices: List[int],
binary_cat_features: bool = True,
feature_costs: Optional[List[float]] = None,
lr: float = 0.07,
lambda_param: float = 5,
y_target: List[int] = [0.45, 0.55],
n_iter: int = 500,
t_max_min: float = 1.0,
norm: int = 1,
clamp: bool = False,
loss_type: str = "MSE",
invalidation_target: float = 0.45,
inval_target_eps: float = 0.005,
noise_variance: float = 0.01
) -> np.ndarray:
"""
Generates counterfactual example according to Wachter et.al for input instance x

Parameters
----------
torch_model: black-box-model to discover
x: factual to explain
cat_feature_indices: list of positions of categorical features in x
binary_cat_features: If true, the encoding of x is done by drop_if_binary
feature_costs: List with costs per feature
lr: learning rate for gradient descent
lambda_param: weight factor for feature_cost
y_target: List of one-hot-encoded target class
n_iter: maximum number of iteration
t_max_min: maximum time of search
norm: L-norm to calculate cost
clamp: If true, feature values will be clamped to (0, 1)
loss_type: String for loss function (MSE or BCE)

Returns
-------
Counterfactual example as np.ndarray
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# returns counterfactual instance
torch.manual_seed(0)
noise_variance = torch.tensor(noise_variance)

if feature_costs is not None:
feature_costs = torch.from_numpy(feature_costs).float().to(device)

#print("x:", x)

x = torch.from_numpy(x).float().to(device)
y_target = torch.tensor(y_target).float().to(device)
lamb = torch.tensor(lambda_param).float().to(device)
# x_new is used for gradient search in optimizing process
x_new = Variable(x.clone(), requires_grad=True)
# x_new_enc is a copy of x_new with reconstructed encoding constraints of x_new
# such that categorical data is either 0 or 1
x_new_enc = reconstruct_encoding_constraints(
x_new, cat_feature_indices, binary_cat_features
)

optimizer = optim.Adam([x_new], lr, amsgrad=True)
softmax = nn.Softmax()

if loss_type == "MSE":
loss_fn = torch.nn.MSELoss()
f_x_new = softmax(torch_model(x_new))[1]
else:
loss_fn = torch.nn.BCELoss()
f_x_new = torch_model(x_new)[:, 1]

t0 = datetime.datetime.now()
t_max = datetime.timedelta(minutes=t_max_min)

costs = []
ces = []

random_samples = reparametrization_trick(x_new, noise_variance, n_samples=1000)
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)

while (f_x_new <= DECISION_THRESHOLD) or (invalidation_rate > invalidation_target + inval_target_eps):
# it = 0
for it in range(n_iter):
# while invalidation_target >= 0.5 and it < n_iter:

optimizer.zero_grad()
# x_new_enc = reconstruct_encoding_constraints(
# x_new, cat_feature_indices, binary_cat_features
# )
# use x_new_enc for prediction results to ensure constraints
# f_x_new = softmax(torch_model(x_new))[:, 1]
f_x_new_binary = torch_model(x_new).squeeze(axis=0)

Comment on lines +175 to +182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Replace unsupported axis argument to squeeze

Inside the optimization loop the code calls torch_model(x_new).squeeze(axis=0) before computing the loss. PyTorch tensors only accept the keyword dim (or a positional argument) for squeeze; passing axis raises a TypeError, so Probe.get_counterfactuals will fail as soon as it is executed. Use squeeze(0) or squeeze(dim=0) instead so the loop can run.

Useful? React with 👍 / 👎.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codex just to clarify, even with Python 3.7 and PyTorch 1.7.0, this syntax would still be invalid?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codex even with Python 3.7 and PyTorch 1.7.0, this syntax would still be invalid?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

Testing

  • ⚠️ No automated tests were run (not requested).

View task →

cost = (
torch.dist(x_new, x, norm)
if feature_costs is None
else torch.norm(feature_costs * (x_new - x), norm)
)

# Compute Invalidation loss
# output_mean, output_std = compute_output_dist_suff_statistics(torch_model, x_new,
# noise_variance=noise_variance)

# normal = normal_distribution.Normal(loc=0.0, scale=1.0)
# ratio = torch.divide(output_mean, output_std)
# normal_cdf = normal.cdf(ratio)
# invalidation_rate = 1 - normal_cdf

# invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
invalidation_rate_c = compute_invalidation_rate_closed(torch_model, x_new, noise_variance)

# Compute & update losses
loss_invalidation = invalidation_rate_c - invalidation_target
# Hinge loss
loss_invalidation[loss_invalidation < 0] = 0

loss = 3 * loss_invalidation + loss_fn(f_x_new_binary, y_target) + lamb * cost
loss.backward()
optimizer.step()

random_samples = reparametrization_trick(x_new, noise_variance, n_samples=10000)
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)

# x_pertub = perturb_sample(x_new, sigma2=noise_variance, n_samples=10000)
# pred = 1 - torch_model(x_pertub)[:, 1]
# invalidation_rate_empirical = torch.mean(pred)

# print('-----------------------------------------')
# print('IR empirical', invalidation_rate_empirical)
# print('IR from loss', invalidation_rate)
# print('IR loss', loss_invalidation)

# clamp potential CF
if clamp:
x_new.clone().clamp_(0, 1)
# it += 1

x_new_enc = reconstruct_encoding_constraints(
x_new, cat_feature_indices, binary_cat_features
)
# f_x_new = torch_model(x_new_enc)[:, 1]
f_x_new = torch_model(x_new)[:, 1]

if (f_x_new > DECISION_THRESHOLD) and (invalidation_rate < invalidation_target + inval_target_eps):
#print('--------------------------------------')
#print('invalidation rate:', invalidation_rate)
#print('emp invalidation rate', invalidation_rate_empirical)
#print('cost:', cost)
#print('classifier output:', f_x_new_binary)

costs.append(cost)
ces.append(x_new)

break

lamb -= 0.10

if datetime.datetime.now() - t0 > t_max:
print("Timeout")
break

if not ces:
print("No Counterfactual Explanation Found at that Target Rate - Try Different Target")
else:
print("Counterfactual Explanation Found")
costs = torch.tensor(costs)
min_idx = int(torch.argmin(costs).numpy())
x_new_enc = ces[min_idx]

#print("x_prime ", x_new_enc.cpu().detach().numpy().squeeze(axis=0))

return x_new_enc.cpu().detach().numpy().squeeze(axis=0)
105 changes: 105 additions & 0 deletions methods/catalog/probe/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import List
import pandas as pd
from sklearn.base import BaseEstimator

from ...api import RecourseMethod
from methods.catalog.probe.library import probe_recourse
from methods.processing import (
check_counterfactuals,
merge_default_parameters,
)

class Probe(RecourseMethod):
"""
Implementation of Probe framework using Wachter recourse generation from Pawelczyk et.al. [1]_.

Parameters
----------
mlmodel : model.MLModel
Black-Box-Model
data: data.Data
Dataset to perform on
hyperparams : dict
Dictionary containing hyperparameters. See notes below for its contents.

Methods
-------
get_counterfactuals:
Generate counterfactual examples for given factuals.

.. [1] Martin Pawelczyk,Teresa Datta, Johan Van den Heuvel, Gjergji Kasneci, Himabindu Lakkaraju.2023
Probabilistically Robust Recourse: Navigating the Trade-offs between Costs and Robustness in Algorithmic Recourse
https://openreview.net/pdf?id=sC-PmTsiTB(2023).
"""
_DEFAULT_HYPERPARAMS = {
"feature_cost": "_optional_",
"lr": 0.001,
"lambda_": 0.01,
"n_iter": 1000,
"t_max_min": 1.0,
"norm": 1,
"clamp": True,
"loss_type": "MSE",
"y_target": [0, 1],
"binary_cat_features": True,
"noise_variance": 0.01,
"invalidation_target": 0.45,
"inval_target_eps": 0.005,
}

def __init__(self, mlmodel, hyperparams):
super().__init__(mlmodel)

checked_hyperparams = merge_default_parameters(
hyperparams, self._DEFAULT_HYPERPARAMS
)
self._feature_costs = checked_hyperparams["feature_cost"]
self._lr = checked_hyperparams["lr"]
self._lambda_param = checked_hyperparams["lambda_"]
self._n_iter = checked_hyperparams["n_iter"]
self._t_max_min = checked_hyperparams["t_max_min"]
self._norm = checked_hyperparams["norm"]
self._clamp = checked_hyperparams["clamp"]
self._loss_type = checked_hyperparams["loss_type"]
self._y_target = checked_hyperparams["y_target"]
self._binary_cat_features = checked_hyperparams["binary_cat_features"]
self._noise_variance = checked_hyperparams["noise_variance"]
self._invalidation_target = checked_hyperparams["invalidation_target"]
self._inval_target_eps = checked_hyperparams["inval_target_eps"]

def get_counterfactuals(self, factuals: pd.DataFrame) -> pd.DataFrame:
# Normalize and encode data
# df_enc_norm_fact = self.encode_normalize_order_factuals(factuals)

factuals = self._mlmodel.get_ordered_features(factuals)

encoded_feature_names = self._mlmodel.data.categorical
cat_features_indices = [
factuals.columns.get_loc(feature) for feature in encoded_feature_names
]

df_cfs = factuals.apply(
lambda x: probe_recourse(
self._mlmodel.raw_model,
x.reshape((1, -1)),
cat_features_indices,
binary_cat_features=self._binary_cat_features,
feature_costs=self._feature_costs,
lr=self._lr,
lambda_param=self._lambda_param,
n_iter=self._n_iter,
t_max_min=self._t_max_min,
norm=self._norm,
clamp=self._clamp,
loss_type=self._loss_type,
invalidation_target=self._invalidation_target,
inval_target_eps=self._inval_target_eps,
noise_variance=self._noise_variance
),
raw=True,
axis=1,
)

df_cfs = check_counterfactuals(self._mlmodel, df_cfs, factuals.index)
df_cfs = self._mlmodel.get_ordered_features(df_cfs)
return df_cfs
Loading
Loading