Skip to content

NoiseTunnel on LayerIntegratedGradients Perturbs Input Instead of Attribution Target (RunetimeError) #1281

@EldadTalShir

Description

@EldadTalShir

NoiseTunnel on LayerIntegratedGradients Perturbs Input Instead of Attribution Target (RunetimeError)

When applying NoiseTunnel to a LayerIntegratedGradients instance initialized on a specific model layer, the noise application happens on the input regardless of the chosen layer and attribution target.

For example, using a BERT-based model and setting the layer as bert.embeddings with attribute_to_layer_input=False, the noise application happens at the token level instead of at the embedding level. This is the same result when setting the bert.encoder as the layer with attribute_to_layer_input=True.

Results in the following error when the perturbed tokens are processed by the BERT module:

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

To Reproduce

Steps to reproduce the behavior:

  1. Set-up a BERT-based classifier/regressor
  2. Load the classifier and init the attribution methods (lig and smoothgrad)
  3. Tokenize input and baseline
  4. Run attribution (smoothgrad on lig).
import torch
from captum.attr import NoiseTunnel, LayerIntegratedGradients, TokenReferenceBase
from transformers import AutoTokenizer, AutoModel
from pytorch_lightning import LightningModule

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Transformer(LightningModule):
  def __init__(self,model_name='mixedbread-ai/mxbai-embed-large-v1'):
    super().__init__()
    self.bert = AutoModel.from_pretrained(model_name)
    self.fc = torch.nn.Linear(1024,1)

  def forward(self, x):
    '''
    x: inputs['input_ids'] from AutoTokenizer.from_pretrained(model_name).tokenizer(...)
    '''
    outputs = self.bert(x).last_hidden_state[:,0,:]
    prob = self.fc(outputs)
    return prob

# Load SentenceTransformer model
model = Transformer()
model.to(device)
model.eval()
model.zero_grad()

# Init attribution methods
lig = LayerIntegratedGradients(model, model.bert.embeddings)
sg = NoiseTunnel(lig)

# Get token indices for text
sentence = "Captum is great, but I found a bug today."
tokenizer = AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-embed-large-v1')
tokenized_text = tokenizer(sentence, return_tensors="pt", padding=False, truncation=False)
inputs = tokenized_text['input_ids'].to(device)

# Reference token for IG
ref_token_id = tokenizer.pad_token_id  # Use padding token as reference token
token_reference = TokenReferenceBase(reference_token_idx=ref_token_id)
ref = token_reference.generate_reference(len(inputs[0]),device=device)
# Set [CLS] and [SEP] token ids in baseline
ref[0] = 101
ref[-1] = 102

# Run attribution
attributions, delta = sg.attribute(inputs, nt_type='smoothgrad', nt_samples=100, nt_samples_batch_size=50, stdevs=1.0, draw_baseline_from_distrib=False, baselines=ref.unsqueeze(0), internal_batch_size=50, n_steps=50, method='riemann_trapezoid', attribute_to_layer_input=False, return_convergence_delta=True)

Expected behavior

NoiseTunnel applying noise to the attribution target (in the example, the embeddings instead of the tokens).

Environment

 - Captum / PyTorch Version: (0.7.0 / 2.2.1+cu121)
 - OS: Linux
 - How you installed Captum / PyTorch: pip
 - Build command you used (if compiling from source): N/A
 - Python version: 3.10.12
 - CUDA/cuDNN version: 12.2.r12.2
 - GPU models and configuration: T4 (Google Colab)
 - Any other relevant information: N/A

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions