Open
Description
In the implementation of DeepLIFT, there is a normalization step that gets applied to contributions from the Softmax module. I am pretty sure this step comes from a misreading of the DeepLIFT paper. Section 3.6 of the paper (link) recommends that:
- (a) In the case of Softmax outputs, we may prefer to compute contributions to the logits rather than contributions to the Softmax outputs.
- (b) If we compute contributions to the logits, then we can normalize the contributions.
So we actually shouldn't be normalizing the contributions to the Softmax outputs at all.
Expected behaviour
- We should remove the normalization step from the Softmax code.
- We should include a warning somewhere that the user may prefer to use logits rather than Softmax output.
- Maybe in the docstring?
- Or potentially a warning in the code when we spot a softmax, which could also include a warning about this change of behaviour (1).
- In an ideal world there would be an additional parameter for (a) which specifies "use the logits instead of the softmax outputs".
- But I'm not sure we have a concept of "this is the penultimate layer". Maybe we just ask people to pass in a forward func that outputs logits.
- In an ideal world there would be an additional parameter for (b) which specifies "we are calculating contributions to the logits, please normalize the contributions in the (linear) logit layer".
- But to do this we would need a hook around the penultimate linear layer (or around the final linear layer if we expect the forward func to produce logits) so this also depends on identifying the penultimate / final layer.
I am happy to implement these myself but I would appreciate feedback, for (3/4) in particular.
Current code with normalization
def softmax(
module: Module,
inputs: Tensor,
outputs: Tensor,
grad_input: Tensor,
grad_output: Tensor,
eps: float = 1e-10,
) -> Tensor:
delta_in, delta_out = _compute_diffs(inputs, outputs)
grad_input_unnorm = torch.where(
abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
)
# normalizing
n = grad_input.numel()
# updating only the first half
new_grad_inp = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
return new_grad_inp
Metadata
Assignees
Labels
No labels