Skip to content

DeepLIFT with Softmax - normalization error #1367

Open
@AITheorem

Description

In the implementation of DeepLIFT, there is a normalization step that gets applied to contributions from the Softmax module. I am pretty sure this step comes from a misreading of the DeepLIFT paper. Section 3.6 of the paper (link) recommends that:

  • (a) In the case of Softmax outputs, we may prefer to compute contributions to the logits rather than contributions to the Softmax outputs.
  • (b) If we compute contributions to the logits, then we can normalize the contributions.

So we actually shouldn't be normalizing the contributions to the Softmax outputs at all.

Expected behaviour

  1. We should remove the normalization step from the Softmax code.
  2. We should include a warning somewhere that the user may prefer to use logits rather than Softmax output.
    • Maybe in the docstring?
    • Or potentially a warning in the code when we spot a softmax, which could also include a warning about this change of behaviour (1).
  3. In an ideal world there would be an additional parameter for (a) which specifies "use the logits instead of the softmax outputs".
    • But I'm not sure we have a concept of "this is the penultimate layer". Maybe we just ask people to pass in a forward func that outputs logits.
  4. In an ideal world there would be an additional parameter for (b) which specifies "we are calculating contributions to the logits, please normalize the contributions in the (linear) logit layer".
    • But to do this we would need a hook around the penultimate linear layer (or around the final linear layer if we expect the forward func to produce logits) so this also depends on identifying the penultimate / final layer.

I am happy to implement these myself but I would appreciate feedback, for (3/4) in particular.

Current code with normalization

def softmax(
    module: Module,
    inputs: Tensor,
    outputs: Tensor,
    grad_input: Tensor,
    grad_output: Tensor,
    eps: float = 1e-10,
) -> Tensor:
    delta_in, delta_out = _compute_diffs(inputs, outputs)

    grad_input_unnorm = torch.where(
        abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
    )
    # normalizing
    n = grad_input.numel()

    # updating only the first half
    new_grad_inp = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
    return new_grad_inp

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions