Skip to content

Commit c3cd2f4

Browse files
authored
Merge branch 'master' into master
2 parents b6f2c90 + 36434b2 commit c3cd2f4

File tree

10 files changed

+521
-15
lines changed

10 files changed

+521
-15
lines changed

captum/attr/_utils/lrp_rules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ..._utils.common import _format_tensor_into_tuples
7+
from captum._utils.common import _format_tensor_into_tuples
88

99

1010
class PropagationRule(ABC):

captum/influence/_core/arnoldi_influence_function.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66

77
from captum._utils.gradient import _extract_parameters_from_layers
88

9+
from captum.influence._core.influence_function import (
10+
_get_dataset_embeddings_intermediate_quantities_influence_function,
11+
InfluenceFunctionBase,
12+
IntermediateQuantitiesInfluenceFunction,
13+
)
14+
915
from captum.influence._utils.common import (
1016
_compute_batch_loss_influence_function_base,
1117
_compute_jacobian_sample_wise_grads_per_batch,
@@ -35,12 +41,6 @@
3541
from torch.utils.data import DataLoader, Dataset
3642
from tqdm import tqdm
3743

38-
from .influence_function import (
39-
_get_dataset_embeddings_intermediate_quantities_influence_function,
40-
InfluenceFunctionBase,
41-
IntermediateQuantitiesInfluenceFunction,
42-
)
43-
4444

4545
def _parameter_arnoldi(
4646
hvp: Callable,

captum/insights/attr_vis/app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _update_config(self, settings):
228228

229229
@log_usage()
230230
def render(self, debug=True):
231-
from captum.insights.attr_vis.widget import CaptumInsights
231+
from captum.insights.attr_vis.widget.widget import CaptumInsights
232232
from IPython.display import display
233233

234234
widget = CaptumInsights(visualizer=self)

tests/attr/layer/test_layer_lrp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from captum.attr import LayerLRP
66
from captum.attr._utils.lrp_rules import Alpha1_Beta0_Rule, EpsilonRule, GammaRule
77

8-
from ...helpers.basic import assertTensorAlmostEqual, BaseTest
9-
from ...helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel
8+
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
9+
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel
1010

1111

1212
def _get_basic_config():

tests/attr/models/test_pytext.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from pytext.config.doc_classification import ModelInputConfig, TargetConfig
2121
from pytext.config.field_config import FeatureConfig, WordFeatConfig
2222
from pytext.data import CommonMetadata
23-
from pytext.data.doc_classification_data_handler import DocClassificationDataHandler
23+
from pytext.data.doc_classification_data_handler import ( # @manual=//pytext:main_lib
24+
DocClassificationDataHandler,
25+
)
2426
from pytext.data.featurizer import SimpleFeaturizer
2527
from pytext.fields import FieldMeta
2628
from pytext.models.decoders.mlp_decoder import MLPDecoder
27-
from pytext.models.doc_model import DocModel_Deprecated
29+
from pytext.models.doc_model import DocModel_Deprecated # @manual=//pytext:main_lib
2830
from pytext.models.embeddings.word_embedding import WordEmbedding
2931
from pytext.models.representations.bilstm_doc_attention import BiLSTMDocAttention
3032
except ImportError:

tests/concept/test_tcav.py

-2
Original file line numberDiff line numberDiff line change
@@ -1193,8 +1193,6 @@ def test_TCAV_x_1_0_1_w_flipped_class_id(self) -> None:
11931193
# Testing TCAV with default classifier and experimental sets of varying lengths
11941194
def test_exp_sets_with_diffent_lengths(self) -> None:
11951195
try:
1196-
import sklearn
1197-
import sklearn.linear_model
11981196
import sklearn.svm # noqa: F401
11991197
except ImportError:
12001198
raise unittest.SkipTest("sklearn is not available.")

tests/influence/_core/test_tracin_show_progress.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import io
22
import tempfile
3-
import unittest
43
import unittest.mock
54
from typing import Callable
65

tutorials/Llama2_LLM_Attribution.ipynb

+498
Large diffs are not rendered by default.

website/pages/tutorials/index.js

+5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ class TutorialHome extends React.Component {
8585
Using Captum and Integrated Gradients we interpret the output of several test questions and analyze the attribution scores
8686
of the text and visual parts of the model. Find the tutorial <a href="Multimodal_VQA_Interpret">here</a>.
8787

88+
<h4>Understanding Llama2 with Captum LLM Attribution:</h4>
89+
This tutorial demonstrates how to easily use the LLM attribution functionality to interpret the large langague models (LLM) in text generation.
90+
It takes Llama2 as the example and shows the step-by-step improvements from the basic attribution setting to more advanced techniques.
91+
Find the tutorial <a href="Llama2_LLM_Attribution">here</a>.
92+
8893
<h4>Interpreting question answering with BERT Part 1:</h4>
8994
This tutorial demonstrates how to use Captum to interpret a BERT model for question answering.
9095
We use a pre-trained model from Hugging Face fine-tuned on the SQUAD dataset and show how to use hooks to

website/tutorials.json

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646
"id": "Image_and_Text_Classification_LIME",
4747
"title": "Interpreting vision and text models with LIME"
4848
},
49+
{
50+
"id": "Llama2_LLM_Attribution",
51+
"title": "Understanding Llama2 with Captum LLM Attribution"
52+
},
4953
{
5054
"title": "Interpreting BERT",
5155
"children": [

0 commit comments

Comments
 (0)