Why do captum's perturbation and IG treat input & target differently? #1456
Description
❓ Questions and Help
I've been successfull using captum's LayerIntegratedGradients
class,
but none of my attempts trying the same sorts of inputs and targets using
LLMAttribution
seem to work.
I'm working with a BertForMultipleChoice
model, and the input is a list of
the repeated prompt followed by the choices:
for i,c in enumerate(tst['input_ids'][0]):
indices = c.detach().tolist()
sepIdx = indices.index(SEP_IDX)
nearSep = indices[sepIdx-prefix:]
preTokens = tokenizer.convert_ids_to_tokens(indices[sepIdx-prefix:sepIdx-1])
choiceTokens = tokenizer.convert_ids_to_tokens(indices[sepIdx+1:])
print(f"{i} {sepIdx} {' '.join(preTokens):>55}\t[SEP] {' '.join(choiceTokens)}")
0 120 behalf of fake charities . as webster sees it [SEP] recognizing the guidelines commentary is authoritative [SEP]
1 109 when he sol ##ici ##ted personal information from the [SEP] holding that a sentencing guide ##line pre ##va ##ils over its commentary if the two are inconsistent [SEP]
2 98 l ( b ) ( 9 ) ( a [SEP] holding that sentencing guidelines commentary must be given controlling weight unless it violate ##s the constitution or a federal statute or is plainly inconsistent with the guidelines itself [SEP]
3 99 ( b ) ( 9 ) ( a ) [SEP] holding that commentary is not authoritative if it is inconsistent with or a plainly er ##rone ##ous reading of the guide ##line it interpret ##s or explains [SEP]
4 119 on behalf of fake charities . as webster sees [SEP] holding that guidelines commentary is generally authoritative [SEP]
I'm using LayerIntegratedGradients
with a test example and a target scalar reprsenting the index
of the correct (multiple choice) like this:
tstEGTuple = (tst['input_ids'],
tst['attention_mask'],
tst['token_type_ids'])
targetIdx = 3 # for this particular test example
lig = LayerIntegratedGradients(custForwardModel, model.bert.embeddings)
attributions_ig = lig.attribute(tstEGTuple, n_steps=5,target=targetIdx)
and that works, eg allowing calculations like summarize_attributions(attributions_ig), viz.VisualizationDataRecord()
etc.
For LLMAttribution I am following the Llama2 tutorial The closest I can get with LLMAttribution seems to require use of TextTokenInput
for input, but raw text for the target?
in0 = tst['input_ids'][0][0]
in0_tokens = tokenizer.convert_ids_to_tokens(in0)
in0Txt = ' '.join(in0_tokens)
in4captum = TextTokenInput(in0Txt, tokenizer,skip_tokens=skip_tokens)
target = targetList[egIdx]
targetIn = tst['input_ids'][0][target]
targ_tokens = tokenizer.convert_ids_to_tokens(targetIn)
targTxt = ' '.join(targ_tokens)
# targ4captum = TextTokenInput(targTxt, tokenizer,skip_tokens=skip_tokens)
llm_attr = LLMAttribution(fa, tokenizer)
attributions_fa = llm_attr.attribute(in4captum, target=targTxt)
but this raises an exception, that prepare_inputs_for_generation
isn't
available for this BertForMultipleChoice
model:
Traceback (most recent call last):
File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 874, in <module>
main()
File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 854, in main
captumPerturb(model,tokenizer,tstEGTensorDict,tstEGtarget,OutDir)
File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 479, in captumPerturb
attributions_fa = llm_attr.attribute(in4captum, target=targTxt)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py", line 667, in attribute
cur_attr = self.attr_method.attribute(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/log/dummy_log.py", line 39, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
^^^^^^^^^^^^^
File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/_utils/common.py", line 588, in _run_forward
output = forward_func(
^^^^^^^^^^^^^
File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py", line 567, in _forward_func
model_inputs = self.model.prepare_inputs_for_generation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation
raise NotImplementedError(
Thanks for any suggestions!
I also posted this question here Discussion Forum