-
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathguided_backprop.py
More file actions
22 lines (17 loc) · 823 Bytes
/
guided_backprop.py
File metadata and controls
22 lines (17 loc) · 823 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from vanilla_gradient import VanillaGradient
class GuidedBackprop(VanillaGradient):
def __init__(self, model):
super(GuidedBackprop, self).__init__(model)
self.relu_inputs = list()
self.update_relus()
def update_relus(self):
def clip_gradient(module, grad_input, grad_output):
relu_input = self.relu_inputs.pop()
return (grad_output[0] * (grad_output[0] > 0.).float() * (relu_input > 0.).float(),)
def save_input(module, input, output):
self.relu_inputs.append(input[0])
for module in self.model.modules():
if isinstance(module, torch.nn.ReLU):
self.hooks.append(module.register_forward_hook(save_input))
self.hooks.append(module.register_backward_hook(clip_gradient))