Skip to content

Commit b487891

Browse files
styusuffacebook-github-bot
authored andcommitted
Moving the Base Test Helper (#1485)
Summary: Pull Request resolved: #1485 moving the base test helper to the captum library. This has a lot of file dependency so apologies for the large Diff Reviewed By: cyrjano, lurunming Differential Revision: D68174196 fbshipit-source-id: 84f180748bd6c70cba1b7079a14e97e520228feb
1 parent 4b7879c commit b487891

File tree

93 files changed

+246
-217
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+246
-217
lines changed

tests/helpers/__init__.py renamed to captum/testing/helpers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44

55
try:
6-
from tests.helpers.fb.internal_base import FbBaseTest as BaseTest
6+
from captum.testing.helpers.fb.internal_base import FbBaseTest as BaseTest
77

88
__all__ = [
99
"BaseTest",
@@ -13,4 +13,4 @@
1313
# tests/helpers/__init__.py:13: error: Incompatible import of "BaseTest"
1414
# (imported name has type "type[BaseTest]", local name has type
1515
# "type[FbBaseTest]") [assignment]
16-
from tests.helpers.basic import BaseTest # type: ignore
16+
from captum.testing.helpers.basic import BaseTest # type: ignore
File renamed without changes.

tests/attr/helpers/attribution_delta_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Tuple, Union
55

66
import torch
7-
from tests.helpers import BaseTest
7+
from captum.testing.helpers import BaseTest
88
from torch import Tensor
99

1010

tests/attr/helpers/get_config_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from captum._utils.gradient import compute_gradients
8-
from tests.helpers.basic_models import BasicModel, BasicModel5_MultiArgs
8+
from captum.testing.helpers.basic_models import BasicModel, BasicModel5_MultiArgs
99
from torch import Tensor
1010
from torch.nn import Module
1111

tests/attr/helpers/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from captum.attr._core.saliency import Saliency
4141
from captum.attr._core.shapley_value import ShapleyValueSampling
4242
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
43-
from tests.helpers.basic import set_all_random_seeds
44-
from tests.helpers.basic_models import (
43+
from captum.testing.helpers.basic import set_all_random_seeds
44+
from captum.testing.helpers.basic_models import (
4545
BasicModel_ConvNet,
4646
BasicModel_MultiLayer,
4747
BasicModel_MultiLayer_MultiInput,

tests/attr/layer/test_grad_cam.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import torch
99
from captum._utils.typing import TensorLikeList
1010
from captum.attr._core.layer.grad_cam import LayerGradCam
11-
from packaging import version
12-
from tests.helpers import BaseTest
13-
from tests.helpers.basic import assertTensorTuplesAlmostEqual
14-
from tests.helpers.basic_models import (
11+
from captum.testing.helpers import BaseTest
12+
from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual
13+
from captum.testing.helpers.basic_models import (
1514
BasicModel_ConvNet_One_Conv,
1615
BasicModel_MultiLayer,
1716
)
17+
from packaging import version
1818
from torch import Tensor
1919
from torch.nn import Module
2020

tests/attr/layer/test_internal_influence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import torch
88
from captum._utils.typing import BaselineType
99
from captum.attr._core.layer.internal_influence import InternalInfluence
10-
from packaging import version
11-
from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
12-
from tests.helpers.basic_models import (
10+
from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
11+
from captum.testing.helpers.basic_models import (
1312
BasicModel_MultiLayer,
1413
BasicModel_MultiLayer_MultiInput,
1514
)
15+
from packaging import version
1616
from torch import Tensor
1717
from torch.nn import Module
1818

tests/attr/layer/test_layer_ablation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import torch
99
from captum._utils.typing import BaselineType
1010
from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation
11-
from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
12-
from tests.helpers.basic_models import (
11+
from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
12+
from captum.testing.helpers.basic_models import (
1313
BasicModel_ConvNet_One_Conv,
1414
BasicModel_MultiLayer,
1515
BasicModel_MultiLayer_MultiInput,

tests/attr/layer/test_layer_activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import torch
99
import torch.nn as nn
1010
from captum.attr._core.layer.layer_activation import LayerActivation
11-
from tests.helpers.basic import (
11+
from captum.testing.helpers.basic import (
1212
assertTensorAlmostEqual,
1313
assertTensorTuplesAlmostEqual,
1414
BaseTest,
1515
)
16-
from tests.helpers.basic_models import (
16+
from captum.testing.helpers.basic_models import (
1717
BasicModel_MultiLayer,
1818
BasicModel_MultiLayer_MultiInput,
1919
Conv1dSeqModel,

tests/attr/layer/test_layer_conductance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
import torch
99
from captum._utils.typing import BaselineType
1010
from captum.attr._core.layer.layer_conductance import LayerConductance
11-
from packaging import version
12-
from tests.attr.helpers.conductance_reference import ConductanceReference
13-
from tests.helpers.basic import (
11+
from captum.testing.helpers.basic import (
1412
assertTensorAlmostEqual,
1513
assertTensorTuplesAlmostEqual,
1614
BaseTest,
1715
)
18-
from tests.helpers.basic_models import (
16+
from captum.testing.helpers.basic_models import (
1917
BasicModel_ConvNet,
2018
BasicModel_MultiLayer,
2119
BasicModel_MultiLayer_MultiInput,
2220
)
21+
from packaging import version
22+
from tests.attr.helpers.conductance_reference import ConductanceReference
2323
from torch import Tensor
2424
from torch.nn import Module
2525

tests/attr/layer/test_layer_deeplift.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,25 @@
99

1010
import torch
1111
from captum.attr._core.layer.layer_deep_lift import LayerDeepLift, LayerDeepLiftShap
12-
from packaging import version
13-
from tests.attr.helpers.neuron_layer_testing_util import (
14-
create_inps_and_base_for_deeplift_neuron_layer_testing,
15-
create_inps_and_base_for_deepliftshap_neuron_layer_testing,
16-
)
17-
from tests.helpers.basic import (
12+
from captum.testing.helpers.basic import (
1813
assert_delta,
1914
assertTensorAlmostEqual,
2015
assertTensorTuplesAlmostEqual,
2116
BaseTest,
2217
)
23-
from tests.helpers.basic_models import (
18+
from captum.testing.helpers.basic_models import (
2419
BasicModel_ConvNet,
2520
BasicModel_ConvNet_MaxPool3d,
2621
BasicModel_MaxPool_ReLU,
2722
BasicModel_MultiLayer,
2823
LinearMaxPoolLinearModel,
2924
ReLULinearModel,
3025
)
26+
from packaging import version
27+
from tests.attr.helpers.neuron_layer_testing_util import (
28+
create_inps_and_base_for_deeplift_neuron_layer_testing,
29+
create_inps_and_base_for_deepliftshap_neuron_layer_testing,
30+
)
3131
from torch import Tensor
3232

3333

tests/attr/layer/test_layer_feature_permutation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import torch
66
from captum.attr._core.layer.layer_feature_permutation import LayerFeaturePermutation
7-
from tests.helpers import BaseTest
8-
from tests.helpers.basic import assertTensorAlmostEqual
9-
from tests.helpers.basic_models import BasicModel_MultiLayer
7+
from captum.testing.helpers import BaseTest
8+
from captum.testing.helpers.basic import assertTensorAlmostEqual
9+
from captum.testing.helpers.basic_models import BasicModel_MultiLayer
1010
from torch import Tensor
1111

1212

tests/attr/layer/test_layer_gradient_shap.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
LayerGradientShap,
1414
LayerInputBaselineXGradient,
1515
)
16-
from packaging import version
17-
from tests.attr.helpers.attribution_delta_util import assert_attribution_delta
18-
from tests.helpers.basic import (
16+
from captum.testing.helpers.basic import (
1917
assertTensorAlmostEqual,
2018
assertTensorTuplesAlmostEqual,
2119
BaseTest,
2220
)
23-
from tests.helpers.basic_models import (
21+
from captum.testing.helpers.basic_models import (
2422
BasicModel_MultiLayer,
2523
BasicModel_MultiLayer_MultiInput,
2624
)
27-
from tests.helpers.classification_models import SoftmaxModel
25+
from captum.testing.helpers.classification_models import SoftmaxModel
26+
from packaging import version
27+
from tests.attr.helpers.attribution_delta_util import assert_attribution_delta
2828
from torch import Tensor
2929
from torch.nn import Module
3030

tests/attr/layer/test_layer_gradient_x_activation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from captum._utils.typing import ModuleOrModuleList
1010
from captum.attr._core.layer.layer_activation import LayerActivation
1111
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
12-
from packaging import version
13-
from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
14-
from tests.helpers.basic_models import (
12+
from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest
13+
from captum.testing.helpers.basic_models import (
1514
BasicEmbeddingModel,
1615
BasicModel_MultiLayer,
1716
BasicModel_MultiLayer_MultiInput,
1817
)
18+
from packaging import version
1919
from torch import Tensor
2020
from torch.nn import Module
2121

tests/attr/layer/test_layer_integrated_gradients.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
configure_interpretable_embedding_layer,
1414
remove_interpretable_embedding_layer,
1515
)
16-
from packaging import version
17-
from tests.helpers.basic import (
16+
from captum.testing.helpers.basic import (
1817
assertTensorAlmostEqual,
1918
assertTensorTuplesAlmostEqual,
2019
BaseTest,
2120
)
22-
from tests.helpers.basic_models import (
21+
from captum.testing.helpers.basic_models import (
2322
BasicEmbeddingModel,
2423
BasicModel_MultiLayer,
2524
BasicModel_MultiLayer_TrueMultiInput,
2625
)
26+
from packaging import version
2727
from torch import Tensor
2828
from torch.nn import Module
2929

tests/attr/layer/test_layer_lrp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from captum.attr import LayerLRP
1010
from captum.attr._utils.lrp_rules import Alpha1_Beta0_Rule, EpsilonRule, GammaRule
1111

12-
from tests.helpers import BaseTest
13-
from tests.helpers.basic import assertTensorAlmostEqual
14-
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel
12+
from captum.testing.helpers import BaseTest
13+
from captum.testing.helpers.basic import assertTensorAlmostEqual
14+
from captum.testing.helpers.basic_models import (
15+
BasicModel_ConvNet_One_Conv,
16+
SimpleLRPModel,
17+
)
1518
from torch import Tensor
1619

1720

tests/attr/models/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
InterpretableEmbeddingBase,
1313
remove_interpretable_embedding_layer,
1414
)
15-
from tests.helpers.basic import assertTensorAlmostEqual
16-
from tests.helpers.basic_models import BasicEmbeddingModel, TextModule
15+
from captum.testing.helpers.basic import assertTensorAlmostEqual
16+
from captum.testing.helpers.basic_models import BasicEmbeddingModel, TextModule
1717
from torch.nn import Embedding
1818

1919

tests/attr/neuron/test_neuron_ablation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
TensorOrTupleOfTensorsGeneric,
1313
)
1414
from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation
15-
from tests.helpers import BaseTest
16-
from tests.helpers.basic import assertTensorAlmostEqual
17-
from tests.helpers.basic_models import (
15+
from captum.testing.helpers import BaseTest
16+
from captum.testing.helpers.basic import assertTensorAlmostEqual
17+
from captum.testing.helpers.basic_models import (
1818
BasicModel_ConvNet_One_Conv,
1919
BasicModel_MultiLayer,
2020
BasicModel_MultiLayer_MultiInput,

tests/attr/neuron/test_neuron_conductance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
1010
from captum.attr._core.layer.layer_conductance import LayerConductance
1111
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
12-
13-
from packaging import version
14-
from tests.helpers import BaseTest
15-
from tests.helpers.basic import assertTensorAlmostEqual
16-
from tests.helpers.basic_models import (
12+
from captum.testing.helpers import BaseTest
13+
from captum.testing.helpers.basic import assertTensorAlmostEqual
14+
from captum.testing.helpers.basic_models import (
1715
BasicModel_ConvNet,
1816
BasicModel_MultiLayer,
1917
BasicModel_MultiLayer_MultiInput,
2018
)
19+
20+
from packaging import version
2121
from torch import Tensor
2222
from torch.nn import Module
2323

tests/attr/neuron/test_neuron_deeplift.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99
import torch
1010
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
1111
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap
12-
from tests.attr.helpers.neuron_layer_testing_util import (
13-
create_inps_and_base_for_deeplift_neuron_layer_testing,
14-
create_inps_and_base_for_deepliftshap_neuron_layer_testing,
15-
)
16-
from tests.helpers import BaseTest
17-
from tests.helpers.basic import assertTensorAlmostEqual
18-
from tests.helpers.basic_models import (
12+
from captum.testing.helpers import BaseTest
13+
from captum.testing.helpers.basic import assertTensorAlmostEqual
14+
from captum.testing.helpers.basic_models import (
1915
BasicModel_ConvNet,
2016
BasicModel_ConvNet_MaxPool3d,
2117
LinearMaxPoolLinearModel,
2218
ReLULinearModel,
2319
)
20+
from tests.attr.helpers.neuron_layer_testing_util import (
21+
create_inps_and_base_for_deeplift_neuron_layer_testing,
22+
create_inps_and_base_for_deepliftshap_neuron_layer_testing,
23+
)
2424
from torch import Tensor
2525

2626

tests/attr/neuron/test_neuron_gradient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
1111
from captum.attr._core.neuron.neuron_gradient import NeuronGradient
1212
from captum.attr._core.saliency import Saliency
13-
from tests.helpers.basic import (
13+
from captum.testing.helpers.basic import (
1414
assertTensorAlmostEqual,
1515
assertTensorTuplesAlmostEqual,
1616
BaseTest,
1717
)
18-
from tests.helpers.basic_models import (
18+
from captum.testing.helpers.basic_models import (
1919
BasicModel_ConvNet,
2020
BasicModel_MultiLayer,
2121
BasicModel_MultiLayer_MultiInput,

tests/attr/neuron/test_neuron_gradient_shap.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from captum.attr._core.neuron.neuron_integrated_gradients import (
99
NeuronIntegratedGradients,
1010
)
11-
from tests.helpers import BaseTest
12-
from tests.helpers.basic import assertTensorAlmostEqual
13-
from tests.helpers.basic_models import BasicModel_MultiLayer
14-
from tests.helpers.classification_models import SoftmaxModel
11+
from captum.testing.helpers import BaseTest
12+
from captum.testing.helpers.basic import assertTensorAlmostEqual
13+
from captum.testing.helpers.basic_models import BasicModel_MultiLayer
14+
from captum.testing.helpers.classification_models import SoftmaxModel
1515
from torch import Tensor
1616
from torch.nn import Module
1717

tests/attr/neuron/test_neuron_integrated_gradients.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from captum.attr._core.neuron.neuron_integrated_gradients import (
1212
NeuronIntegratedGradients,
1313
)
14-
from tests.helpers.basic import (
14+
from captum.testing.helpers.basic import (
1515
assertTensorAlmostEqual,
1616
assertTensorTuplesAlmostEqual,
1717
BaseTest,
1818
)
19-
from tests.helpers.basic_models import (
19+
from captum.testing.helpers.basic_models import (
2020
BasicModel_ConvNet,
2121
BasicModel_MultiLayer,
2222
BasicModel_MultiLayer_MultiInput,

tests/attr/test_approximation_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch
99
from captum.attr._utils.approximation_methods import Riemann, riemann_builders
10-
from tests.helpers.basic import assertTensorAlmostEqual
10+
from captum.testing.helpers.basic import assertTensorAlmostEqual
1111

1212

1313
class Test(unittest.TestCase):

tests/attr/test_baselines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from captum.attr._utils.baselines import ProductBaselines
55

66
# from parameterized import parameterized
7-
from tests.helpers import BaseTest
7+
from captum.testing.helpers import BaseTest
88

99

1010
class TestProductBaselines(BaseTest):

tests/attr/test_class_summarizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from captum.attr import ClassSummarizer, CommonStats
8-
from tests.helpers import BaseTest
8+
from captum.testing.helpers import BaseTest
99

1010

1111
class Test(BaseTest):

0 commit comments

Comments
 (0)