From bdf6a554dfa52776a0d1820aff75c3191258a71a Mon Sep 17 00:00:00 2001 From: cssastry Date: Mon, 19 Feb 2024 04:00:29 -0400 Subject: [PATCH] fix gram-matrix detector --- openood/postprocessors/gram_postprocessor.py | 218 ++++++++++--------- 1 file changed, 114 insertions(+), 104 deletions(-) diff --git a/openood/postprocessors/gram_postprocessor.py b/openood/postprocessors/gram_postprocessor.py index 1d214a10..c3c22d78 100644 --- a/openood/postprocessors/gram_postprocessor.py +++ b/openood/postprocessors/gram_postprocessor.py @@ -10,7 +10,8 @@ from .base_postprocessor import BasePostprocessor from .info import num_classes_dict - +from collections import defaultdict, Counter +import random class GRAMPostprocessor(BasePostprocessor): def __init__(self, config): @@ -24,17 +25,21 @@ def __init__(self, config): self.setup_flag = False def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): + net = FeatureExtractor(net) if not self.setup_flag: - self.feature_min, self.feature_max = sample_estimator( + self.feature_min, self.feature_max, self.normalize_factors = sample_estimator( net, id_loader_dict['train'], self.num_classes, self.powers) self.setup_flag = True else: pass + net.destroy_hooks() def postprocess(self, net: nn.Module, data: Any): + net = FeatureExtractor(net) preds, deviations = get_deviations(net, data, self.feature_min, - self.feature_max, self.num_classes, + self.feature_max, self.normalize_factors, self.powers) + net.destroy_hooks() return preds, deviations def set_hyperparam(self, hyperparam: list): @@ -47,121 +52,126 @@ def get_hyperparam(self): def tensor2list(x): return x.data.cuda().tolist() +def G_p(ob, p): + temp = ob.detach() + + temp = temp**p + temp = temp.reshape(temp.shape[0],temp.shape[1],-1) + temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) + temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1) + + return temp + +def delta(mins, maxs, x): + dev = (F.relu(mins-x)/torch.abs(mins+10**-6)).sum(dim=1,keepdim=True) + dev += (F.relu(x-maxs)/torch.abs(maxs+10**-6)).sum(dim=1,keepdim=True) + return dev + +class FeatureExtractor(torch.nn.Module): + # Inspired from https://github.com/paaatcha/gram-ood + def __init__(self, torch_model): + super().__init__() + self.torch_model = torch_model + self.feat_list = list() + def _hook_fn(_, input, output): + self.feat_list.append(output) + + # To set a different layer, you must use this function: + def hook_layers(torch_model): + hooked_layers = list() + for layer in torch_model.modules(): + if isinstance(layer, nn.ReLU) or isinstance(layer, nn.Conv2d): + hooked_layers.append(layer) + return hooked_layers + + def register_layers(layers): + regs_layers = list() + for lay in layers: + regs_layers.append(lay.register_forward_hook(_hook_fn)) + return regs_layers + + ## Setting the hook + hl = hook_layers (torch_model) + self.rgl = register_layers (hl) + # print(f"{len(self.rgl)} Features") + + def forward(self, x, return_feature_list=True): + preds = self.torch_model(x) + list = self.feat_list.copy() + self.feat_list.clear() + return preds, list + + def destroy_hooks(self): + for lay in self.rgl: + lay.remove() @torch.no_grad() def sample_estimator(model, train_loader, num_classes, powers): model.eval() - - num_layer = 5 # 4 for lenet - num_poles_list = powers - num_poles = len(num_poles_list) - feature_class = [[[None for x in range(num_poles)] - for y in range(num_layer)] for z in range(num_classes)] - label_list = [] - mins = [[[None for x in range(num_poles)] for y in range(num_layer)] - for z in range(num_classes)] - maxs = [[[None for x in range(num_poles)] for y in range(num_layer)] - for z in range(num_classes)] - + gram_features = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda : None))) + mins = dict() + maxs = dict() + class_counts = Counter() # collect features and compute gram metrix for batch in tqdm(train_loader, desc='Compute min/max'): data = batch['data'].cuda() label = batch['label'] _, feature_list = model(data, return_feature_list=True) - label_list = tensor2list(label) - for layer_idx in range(num_layer): - - for pole_idx, p in enumerate(num_poles_list): - temp = feature_list[layer_idx].detach() - - temp = temp**p - temp = temp.reshape(temp.shape[0], temp.shape[1], -1) - temp = ((torch.matmul(temp, - temp.transpose(dim0=2, - dim1=1)))).sum(dim=2) - temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape( - temp.shape[0], -1) - - temp = tensor2list(temp) - for feature, label in zip(temp, label_list): - if isinstance(feature_class[label][layer_idx][pole_idx], - type(None)): - feature_class[label][layer_idx][pole_idx] = feature + class_counts.update(Counter(label.cpu().numpy())) + for layer_idx, feature in enumerate(feature_list): + for power in powers: + gram_feature = G_p(feature, power).cpu() + for class_ in range(num_classes): + if gram_features[layer_idx][power][class_] is None: + gram_features[layer_idx][power][class_] = gram_feature[label==class_] else: - feature_class[label][layer_idx][pole_idx].extend( - feature) + gram_features[layer_idx][power][class_] = torch.cat([gram_features[layer_idx][power][class_],gram_feature[label==class_]],dim=0) + + val_idxs = {} + train_idxs = {} + for c in class_counts: + L = class_counts[c] + val_idxs[c] = random.sample(range(L),int(0.1*L)) + train_idxs[c] = list(set(range(L)) - set(val_idxs[c])) + normalize_factors = [] # compute mins/maxs - for label in range(num_classes): - for layer_idx in range(num_layer): - for poles_idx in range(num_poles): - feature = torch.tensor( - np.array(feature_class[label][layer_idx][poles_idx])) - current_min = feature.min(dim=0, keepdim=True)[0] - current_max = feature.max(dim=0, keepdim=True)[0] - - if mins[label][layer_idx][poles_idx] is None: - mins[label][layer_idx][poles_idx] = current_min - maxs[label][layer_idx][poles_idx] = current_max - else: - mins[label][layer_idx][poles_idx] = torch.min( - current_min, mins[label][layer_idx][poles_idx]) - maxs[label][layer_idx][poles_idx] = torch.max( - current_min, maxs[label][layer_idx][poles_idx]) - - return mins, maxs - - -def get_deviations(model, data, mins, maxs, num_classes, powers): + for layer_idx in gram_features: + total_delta = None + for class_ in class_counts: + trn = train_idxs[class_] + val = val_idxs[class_] + class_deltas = 0 + for power in powers: + mins[layer_idx,power,class_] = gram_features[layer_idx][power][class_][trn].min(dim=0,keepdim=True)[0] + maxs[layer_idx,power,class_] = gram_features[layer_idx][power][class_][trn].max(dim=0,keepdim=True)[0] + class_deltas += delta(mins[layer_idx,power,class_], + maxs[layer_idx,power,class_], + gram_features[layer_idx][power][class_][val]) + if total_delta is None: + total_delta = class_deltas + else: + total_delta = torch.cat([total_delta,class_deltas],dim=0) + normalize_factors.append(total_delta.mean(dim=0,keepdim=True)) + normalize_factors = torch.cat(normalize_factors,dim=1) + return mins, maxs, normalize_factors + +def get_deviations(model, data, mins, maxs, normalize_factors, powers): model.eval() - num_layer = 5 # 4 for lenet - num_poles_list = powers - exist = 1 - pred_list = [] - dev = [0 for x in range(data.shape[0])] + deviations = torch.zeros(data.shape[0],1) # get predictions logits, feature_list = model(data, return_feature_list=True) - confs = F.softmax(logits, dim=1).cpu().detach().numpy() - preds = np.argmax(confs, axis=1) - predsList = preds.tolist() - preds = torch.tensor(preds) - - for pred in predsList: - exist = 1 - if len(pred_list) == 0: - pred_list.extend([pred]) - else: - for pred_now in pred_list: - if pred_now == pred: - exist = 0 - if exist == 1: - pred_list.extend([pred]) - - # compute sample level deviation - for layer_idx in range(num_layer): - for pole_idx, p in enumerate(num_poles_list): - # get gram metirx - temp = feature_list[layer_idx].detach() - temp = temp**p - temp = temp.reshape(temp.shape[0], temp.shape[1], -1) - temp = ((torch.matmul(temp, temp.transpose(dim0=2, - dim1=1)))).sum(dim=2) - temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape( - temp.shape[0], -1) - temp = tensor2list(temp) - - # compute the deviations with train data - for idx in range(len(temp)): - dev[idx] += (F.relu(mins[preds[idx]][layer_idx][pole_idx] - - sum(temp[idx])) / - torch.abs(mins[preds[idx]][layer_idx][pole_idx] + - 10**-6)).sum() - dev[idx] += (F.relu( - sum(temp[idx]) - maxs[preds[idx]][layer_idx][pole_idx]) / - torch.abs(maxs[preds[idx]][layer_idx][pole_idx] + - 10**-6)).sum() - conf = [i / 50 for i in dev] - - return preds, torch.tensor(conf) + confs = F.softmax(logits, dim=1).cpu().detach() + confs, preds = confs.max(dim=1) + for layer_idx, feature in enumerate(feature_list): + n = normalize_factors[:,layer_idx].item() + for power in powers: + gram_feature = G_p(feature, power).cpu() + for class_ in range(logits.shape[1]): + deviations[preds==class_] += delta(mins[layer_idx,power,class_], + maxs[layer_idx,power,class_], + gram_feature[preds==class_])/n + + return preds, -deviations/confs[:,None] \ No newline at end of file