diff --git a/tests/helpers/__init__.py b/captum/testing/helpers/__init__.py similarity index 66% rename from tests/helpers/__init__.py rename to captum/testing/helpers/__init__.py index 745946af71..16a664407b 100644 --- a/tests/helpers/__init__.py +++ b/captum/testing/helpers/__init__.py @@ -3,7 +3,7 @@ # pyre-strict try: - from tests.helpers.fb.internal_base import FbBaseTest as BaseTest + from captum.testing.helpers.fb.internal_base import FbBaseTest as BaseTest __all__ = [ "BaseTest", @@ -13,4 +13,4 @@ # tests/helpers/__init__.py:13: error: Incompatible import of "BaseTest" # (imported name has type "type[BaseTest]", local name has type # "type[FbBaseTest]") [assignment] - from tests.helpers.basic import BaseTest # type: ignore + from captum.testing.helpers.basic import BaseTest # type: ignore diff --git a/tests/helpers/basic.py b/captum/testing/helpers/basic.py similarity index 100% rename from tests/helpers/basic.py rename to captum/testing/helpers/basic.py diff --git a/tests/helpers/basic_models.py b/captum/testing/helpers/basic_models.py similarity index 100% rename from tests/helpers/basic_models.py rename to captum/testing/helpers/basic_models.py diff --git a/tests/helpers/classification_models.py b/captum/testing/helpers/classification_models.py similarity index 100% rename from tests/helpers/classification_models.py rename to captum/testing/helpers/classification_models.py diff --git a/tests/helpers/evaluate_linear_model.py b/captum/testing/helpers/evaluate_linear_model.py similarity index 100% rename from tests/helpers/evaluate_linear_model.py rename to captum/testing/helpers/evaluate_linear_model.py diff --git a/tests/attr/helpers/attribution_delta_util.py b/tests/attr/helpers/attribution_delta_util.py index 4fcdc09ed9..dd4fc100e8 100644 --- a/tests/attr/helpers/attribution_delta_util.py +++ b/tests/attr/helpers/attribution_delta_util.py @@ -4,7 +4,7 @@ from typing import Tuple, Union import torch -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest from torch import Tensor diff --git a/tests/attr/helpers/get_config_util.py b/tests/attr/helpers/get_config_util.py index d3627527c4..aa66d08a86 100644 --- a/tests/attr/helpers/get_config_util.py +++ b/tests/attr/helpers/get_config_util.py @@ -5,7 +5,7 @@ import torch from captum._utils.gradient import compute_gradients -from tests.helpers.basic_models import BasicModel, BasicModel5_MultiArgs +from captum.testing.helpers.basic_models import BasicModel, BasicModel5_MultiArgs from torch import Tensor from torch.nn import Module diff --git a/tests/attr/helpers/test_config.py b/tests/attr/helpers/test_config.py index 636232b6c0..effe71a69a 100644 --- a/tests/attr/helpers/test_config.py +++ b/tests/attr/helpers/test_config.py @@ -40,8 +40,8 @@ from captum.attr._core.saliency import Saliency from captum.attr._core.shapley_value import ShapleyValueSampling from captum.attr._utils.input_layer_wrapper import ModelInputWrapper -from tests.helpers.basic import set_all_random_seeds -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import set_all_random_seeds +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/layer/test_grad_cam.py b/tests/attr/layer/test_grad_cam.py index 7487b2123b..1f8829d24d 100644 --- a/tests/attr/layer/test_grad_cam.py +++ b/tests/attr/layer/test_grad_cam.py @@ -8,13 +8,13 @@ import torch from captum._utils.typing import TensorLikeList from captum.attr._core.layer.grad_cam import LayerGradCam -from packaging import version -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorTuplesAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, ) +from packaging import version from torch import Tensor from torch.nn import Module diff --git a/tests/attr/layer/test_internal_influence.py b/tests/attr/layer/test_internal_influence.py index 9b57f6c75c..5316e49a38 100644 --- a/tests/attr/layer/test_internal_influence.py +++ b/tests/attr/layer/test_internal_influence.py @@ -7,12 +7,12 @@ import torch from captum._utils.typing import BaselineType from captum.attr._core.layer.internal_influence import InternalInfluence -from packaging import version -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest +from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, ) +from packaging import version from torch import Tensor from torch.nn import Module diff --git a/tests/attr/layer/test_layer_ablation.py b/tests/attr/layer/test_layer_ablation.py index 8d14f5ad88..4d35c9f801 100644 --- a/tests/attr/layer/test_layer_ablation.py +++ b/tests/attr/layer/test_layer_ablation.py @@ -8,8 +8,8 @@ import torch from captum._utils.typing import BaselineType from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/layer/test_layer_activation.py b/tests/attr/layer/test_layer_activation.py index e368483ae1..3da6871bbf 100644 --- a/tests/attr/layer/test_layer_activation.py +++ b/tests/attr/layer/test_layer_activation.py @@ -8,12 +8,12 @@ import torch import torch.nn as nn from captum.attr._core.layer.layer_activation import LayerActivation -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, Conv1dSeqModel, diff --git a/tests/attr/layer/test_layer_conductance.py b/tests/attr/layer/test_layer_conductance.py index 29fb343788..a8b13c538c 100644 --- a/tests/attr/layer/test_layer_conductance.py +++ b/tests/attr/layer/test_layer_conductance.py @@ -8,18 +8,18 @@ import torch from captum._utils.typing import BaselineType from captum.attr._core.layer.layer_conductance import LayerConductance -from packaging import version -from tests.attr.helpers.conductance_reference import ConductanceReference -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, ) +from packaging import version +from tests.attr.helpers.conductance_reference import ConductanceReference from torch import Tensor from torch.nn import Module diff --git a/tests/attr/layer/test_layer_deeplift.py b/tests/attr/layer/test_layer_deeplift.py index ae24fb7171..3dcdb3efc0 100644 --- a/tests/attr/layer/test_layer_deeplift.py +++ b/tests/attr/layer/test_layer_deeplift.py @@ -9,18 +9,13 @@ import torch from captum.attr._core.layer.layer_deep_lift import LayerDeepLift, LayerDeepLiftShap -from packaging import version -from tests.attr.helpers.neuron_layer_testing_util import ( - create_inps_and_base_for_deeplift_neuron_layer_testing, - create_inps_and_base_for_deepliftshap_neuron_layer_testing, -) -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assert_delta, assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_ConvNet_MaxPool3d, BasicModel_MaxPool_ReLU, @@ -28,6 +23,11 @@ LinearMaxPoolLinearModel, ReLULinearModel, ) +from packaging import version +from tests.attr.helpers.neuron_layer_testing_util import ( + create_inps_and_base_for_deeplift_neuron_layer_testing, + create_inps_and_base_for_deepliftshap_neuron_layer_testing, +) from torch import Tensor diff --git a/tests/attr/layer/test_layer_feature_permutation.py b/tests/attr/layer/test_layer_feature_permutation.py index 9796a65248..ab945c3680 100644 --- a/tests/attr/layer/test_layer_feature_permutation.py +++ b/tests/attr/layer/test_layer_feature_permutation.py @@ -4,9 +4,9 @@ import torch from captum.attr._core.layer.layer_feature_permutation import LayerFeaturePermutation -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_MultiLayer +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel_MultiLayer from torch import Tensor diff --git a/tests/attr/layer/test_layer_gradient_shap.py b/tests/attr/layer/test_layer_gradient_shap.py index b50bac751f..04112792f2 100644 --- a/tests/attr/layer/test_layer_gradient_shap.py +++ b/tests/attr/layer/test_layer_gradient_shap.py @@ -13,18 +13,18 @@ LayerGradientShap, LayerInputBaselineXGradient, ) -from packaging import version -from tests.attr.helpers.attribution_delta_util import assert_attribution_delta -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, ) -from tests.helpers.classification_models import SoftmaxModel +from captum.testing.helpers.classification_models import SoftmaxModel +from packaging import version +from tests.attr.helpers.attribution_delta_util import assert_attribution_delta from torch import Tensor from torch.nn import Module diff --git a/tests/attr/layer/test_layer_gradient_x_activation.py b/tests/attr/layer/test_layer_gradient_x_activation.py index e3248f2bcd..1dc1888fc4 100644 --- a/tests/attr/layer/test_layer_gradient_x_activation.py +++ b/tests/attr/layer/test_layer_gradient_x_activation.py @@ -9,13 +9,13 @@ from captum._utils.typing import ModuleOrModuleList from captum.attr._core.layer.layer_activation import LayerActivation from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation -from packaging import version -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest +from captum.testing.helpers.basic_models import ( BasicEmbeddingModel, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, ) +from packaging import version from torch import Tensor from torch.nn import Module diff --git a/tests/attr/layer/test_layer_integrated_gradients.py b/tests/attr/layer/test_layer_integrated_gradients.py index 0ad158ed6d..56f965854f 100644 --- a/tests/attr/layer/test_layer_integrated_gradients.py +++ b/tests/attr/layer/test_layer_integrated_gradients.py @@ -13,17 +13,17 @@ configure_interpretable_embedding_layer, remove_interpretable_embedding_layer, ) -from packaging import version -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicEmbeddingModel, BasicModel_MultiLayer, BasicModel_MultiLayer_TrueMultiInput, ) +from packaging import version from torch import Tensor from torch.nn import Module diff --git a/tests/attr/layer/test_layer_lrp.py b/tests/attr/layer/test_layer_lrp.py index ccc56377e2..8217b4c2ff 100644 --- a/tests/attr/layer/test_layer_lrp.py +++ b/tests/attr/layer/test_layer_lrp.py @@ -9,9 +9,12 @@ from captum.attr import LayerLRP from captum.attr._utils.lrp_rules import Alpha1_Beta0_Rule, EpsilonRule, GammaRule -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( + BasicModel_ConvNet_One_Conv, + SimpleLRPModel, +) from torch import Tensor diff --git a/tests/attr/models/test_base.py b/tests/attr/models/test_base.py index 40ee37f36d..cdfdf66444 100644 --- a/tests/attr/models/test_base.py +++ b/tests/attr/models/test_base.py @@ -12,8 +12,8 @@ InterpretableEmbeddingBase, remove_interpretable_embedding_layer, ) -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicEmbeddingModel, TextModule +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicEmbeddingModel, TextModule from torch.nn import Embedding diff --git a/tests/attr/neuron/test_neuron_ablation.py b/tests/attr/neuron/test_neuron_ablation.py index 37c30daae2..b316400feb 100644 --- a/tests/attr/neuron/test_neuron_ablation.py +++ b/tests/attr/neuron/test_neuron_ablation.py @@ -12,9 +12,9 @@ TensorOrTupleOfTensorsGeneric, ) from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/neuron/test_neuron_conductance.py b/tests/attr/neuron/test_neuron_conductance.py index 73692fec2f..619268aaa3 100644 --- a/tests/attr/neuron/test_neuron_conductance.py +++ b/tests/attr/neuron/test_neuron_conductance.py @@ -9,15 +9,15 @@ from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.layer.layer_conductance import LayerConductance from captum.attr._core.neuron.neuron_conductance import NeuronConductance - -from packaging import version -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, ) + +from packaging import version from torch import Tensor from torch.nn import Module diff --git a/tests/attr/neuron/test_neuron_deeplift.py b/tests/attr/neuron/test_neuron_deeplift.py index 8a522a3c92..b1f05f4845 100644 --- a/tests/attr/neuron/test_neuron_deeplift.py +++ b/tests/attr/neuron/test_neuron_deeplift.py @@ -9,18 +9,18 @@ import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap -from tests.attr.helpers.neuron_layer_testing_util import ( - create_inps_and_base_for_deeplift_neuron_layer_testing, - create_inps_and_base_for_deepliftshap_neuron_layer_testing, -) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_ConvNet_MaxPool3d, LinearMaxPoolLinearModel, ReLULinearModel, ) +from tests.attr.helpers.neuron_layer_testing_util import ( + create_inps_and_base_for_deeplift_neuron_layer_testing, + create_inps_and_base_for_deepliftshap_neuron_layer_testing, +) from torch import Tensor diff --git a/tests/attr/neuron/test_neuron_gradient.py b/tests/attr/neuron/test_neuron_gradient.py index 1439fcae30..466ee5dd8f 100644 --- a/tests/attr/neuron/test_neuron_gradient.py +++ b/tests/attr/neuron/test_neuron_gradient.py @@ -10,12 +10,12 @@ from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._core.neuron.neuron_gradient import NeuronGradient from captum.attr._core.saliency import Saliency -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/neuron/test_neuron_gradient_shap.py b/tests/attr/neuron/test_neuron_gradient_shap.py index 7035847222..2311c8b441 100644 --- a/tests/attr/neuron/test_neuron_gradient_shap.py +++ b/tests/attr/neuron/test_neuron_gradient_shap.py @@ -8,10 +8,10 @@ from captum.attr._core.neuron.neuron_integrated_gradients import ( NeuronIntegratedGradients, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_MultiLayer -from tests.helpers.classification_models import SoftmaxModel +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel_MultiLayer +from captum.testing.helpers.classification_models import SoftmaxModel from torch import Tensor from torch.nn import Module diff --git a/tests/attr/neuron/test_neuron_integrated_gradients.py b/tests/attr/neuron/test_neuron_integrated_gradients.py index fb8c316d63..ce5819ffc7 100644 --- a/tests/attr/neuron/test_neuron_integrated_gradients.py +++ b/tests/attr/neuron/test_neuron_integrated_gradients.py @@ -11,12 +11,12 @@ from captum.attr._core.neuron.neuron_integrated_gradients import ( NeuronIntegratedGradients, ) -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/test_approximation_methods.py b/tests/attr/test_approximation_methods.py index 5c11d36bbd..b2ced6ecf8 100644 --- a/tests/attr/test_approximation_methods.py +++ b/tests/attr/test_approximation_methods.py @@ -7,7 +7,7 @@ import torch from captum.attr._utils.approximation_methods import Riemann, riemann_builders -from tests.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic import assertTensorAlmostEqual class Test(unittest.TestCase): diff --git a/tests/attr/test_baselines.py b/tests/attr/test_baselines.py index 6b7dbfb1be..1aab426f96 100644 --- a/tests/attr/test_baselines.py +++ b/tests/attr/test_baselines.py @@ -4,7 +4,7 @@ from captum.attr._utils.baselines import ProductBaselines # from parameterized import parameterized -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest class TestProductBaselines(BaseTest): diff --git a/tests/attr/test_class_summarizer.py b/tests/attr/test_class_summarizer.py index 86f9a9f1f5..13c4fdeef9 100644 --- a/tests/attr/test_class_summarizer.py +++ b/tests/attr/test_class_summarizer.py @@ -5,7 +5,7 @@ import torch from captum.attr import ClassSummarizer, CommonStats -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest class Test(BaseTest): diff --git a/tests/attr/test_common.py b/tests/attr/test_common.py index 95778bb9c7..5cf48ea309 100644 --- a/tests/attr/test_common.py +++ b/tests/attr/test_common.py @@ -5,7 +5,7 @@ import torch from captum.attr._core.noise_tunnel import SUPPORTED_NOISE_TUNNEL_TYPES from captum.attr._utils.common import _validate_input, _validate_noise_tunnel_type -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest class Test(BaseTest): diff --git a/tests/attr/test_data_parallel.py b/tests/attr/test_data_parallel.py index 2135e9e368..4f03f8f13e 100644 --- a/tests/attr/test_data_parallel.py +++ b/tests/attr/test_data_parallel.py @@ -18,6 +18,11 @@ ) from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.attribution import Attribution, InternalAttribution +from captum.testing.helpers.basic import ( + assertTensorTuplesAlmostEqual, + BaseTest, + deep_copy_args, +) from tests.attr.helpers.gen_test_utils import ( gen_test_name, get_target_layer, @@ -25,7 +30,6 @@ should_create_generated_test, ) from tests.attr.helpers.test_config import config -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest, deep_copy_args from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_dataloader_attr.py b/tests/attr/test_dataloader_attr.py index d7501134e8..8a568a127d 100644 --- a/tests/attr/test_dataloader_attr.py +++ b/tests/attr/test_dataloader_attr.py @@ -9,12 +9,12 @@ from captum.attr._core.dataloader_attr import DataLoaderAttribution, InputRole from captum.attr._core.feature_ablation import FeatureAblation -from parameterized import parameterized -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertAttributionComparision, assertTensorAlmostEqual, BaseTest, ) +from parameterized import parameterized from torch import Tensor from torch.utils.data import DataLoader, TensorDataset diff --git a/tests/attr/test_deconvolution.py b/tests/attr/test_deconvolution.py index 571bf40314..5cde4f3be2 100644 --- a/tests/attr/test_deconvolution.py +++ b/tests/attr/test_deconvolution.py @@ -13,9 +13,9 @@ from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( NeuronDeconvolution, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel_ConvNet_One_Conv from torch.nn import Module diff --git a/tests/attr/test_deeplift_basic.py b/tests/attr/test_deeplift_basic.py index d1d145e732..88e17f9ccd 100644 --- a/tests/attr/test_deeplift_basic.py +++ b/tests/attr/test_deeplift_basic.py @@ -8,12 +8,12 @@ import torch from captum.attr._core.deep_lift import DeepLift, DeepLiftShap from captum.attr._core.integrated_gradients import IntegratedGradients -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertAttributionComparision, assertTensorAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicModelWithReusedModules, Conv1dSeqModel, LinearMaxPoolLinearModel, diff --git a/tests/attr/test_deeplift_classification.py b/tests/attr/test_deeplift_classification.py index fac39040de..17b984605f 100644 --- a/tests/attr/test_deeplift_classification.py +++ b/tests/attr/test_deeplift_classification.py @@ -8,13 +8,13 @@ from captum._utils.typing import TargetType from captum.attr._core.deep_lift import DeepLift, DeepLiftShap from captum.attr._core.integrated_gradients import IntegratedGradients -from tests.helpers.basic import assertAttributionComparision, BaseTest -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import assertAttributionComparision, BaseTest +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_ConvNet_MaxPool1d, BasicModel_ConvNet_MaxPool3d, ) -from tests.helpers.classification_models import ( +from captum.testing.helpers.classification_models import ( SigmoidDeepLiftModel, SoftmaxDeepLiftModel, ) diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 5f01f2ab9d..3646bd6c58 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -15,9 +15,9 @@ from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.attribution import Attribution -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel, BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index ed4db030f5..e1795e0ebe 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -6,9 +6,9 @@ import torch from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModelWithSparseInputs +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModelWithSparseInputs from torch import Tensor diff --git a/tests/attr/test_gradient_shap.py b/tests/attr/test_gradient_shap.py index 93bc16c811..924e137236 100644 --- a/tests/attr/test_gradient_shap.py +++ b/tests/attr/test_gradient_shap.py @@ -10,14 +10,14 @@ from captum._utils.typing import Tensor from captum.attr._core.gradient_shap import GradientShap from captum.attr._core.integrated_gradients import IntegratedGradients +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicLinearModel, BasicModel2 +from captum.testing.helpers.classification_models import SoftmaxModel from tests.attr.helpers.attribution_delta_util import ( assert_attribution_delta, assert_delta, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicLinearModel, BasicModel2 -from tests.helpers.classification_models import SoftmaxModel class Test(BaseTest): diff --git a/tests/attr/test_guided_backprop.py b/tests/attr/test_guided_backprop.py index 8a3d5ac69d..a82273009c 100644 --- a/tests/attr/test_guided_backprop.py +++ b/tests/attr/test_guided_backprop.py @@ -11,9 +11,9 @@ from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( NeuronGuidedBackprop, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel_ConvNet_One_Conv from torch.nn import Module diff --git a/tests/attr/test_guided_grad_cam.py b/tests/attr/test_guided_grad_cam.py index 40c63f48af..fa55209218 100644 --- a/tests/attr/test_guided_grad_cam.py +++ b/tests/attr/test_guided_grad_cam.py @@ -8,9 +8,9 @@ import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._core.guided_grad_cam import GuidedGradCam -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel_ConvNet_One_Conv from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_hook_removal.py b/tests/attr/test_hook_removal.py index 8bac3c9432..730f663a5c 100644 --- a/tests/attr/test_hook_removal.py +++ b/tests/attr/test_hook_removal.py @@ -9,6 +9,8 @@ from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._models.base import _set_deep_layer_value from captum.attr._utils.attribution import Attribution, InternalAttribution +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import deep_copy_args from tests.attr.helpers.gen_test_utils import ( gen_test_name, get_target_layer, @@ -16,8 +18,6 @@ should_create_generated_test, ) from tests.attr.helpers.test_config import config -from tests.helpers import BaseTest -from tests.helpers.basic import deep_copy_args from torch.nn import Module """ diff --git a/tests/attr/test_input_layer_wrapper.py b/tests/attr/test_input_layer_wrapper.py index 053858c23f..e9ed85a956 100644 --- a/tests/attr/test_input_layer_wrapper.py +++ b/tests/attr/test_input_layer_wrapper.py @@ -23,8 +23,8 @@ LayerIntegratedGradients, ) from captum.attr._utils.input_layer_wrapper import ModelInputWrapper -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest +from captum.testing.helpers.basic_models import ( BasicModel, BasicModel_MultiLayer_TrueMultiInput, MixedKwargsAndArgsModule, diff --git a/tests/attr/test_input_x_gradient.py b/tests/attr/test_input_x_gradient.py index c15a05ee43..5c1eac338e 100644 --- a/tests/attr/test_input_x_gradient.py +++ b/tests/attr/test_input_x_gradient.py @@ -7,13 +7,13 @@ from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._core.input_x_gradient import InputXGradient from captum.attr._core.noise_tunnel import NoiseTunnel +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.classification_models import SoftmaxModel from tests.attr.helpers.get_config_util import ( get_basic_config, get_multiargs_basic_config, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.classification_models import SoftmaxModel from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_integrated_gradients_basic.py b/tests/attr/test_integrated_gradients_basic.py index 378f7eb493..b3d4535817 100644 --- a/tests/attr/test_integrated_gradients_basic.py +++ b/tests/attr/test_integrated_gradients_basic.py @@ -11,9 +11,9 @@ from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.common import _tensorize_baseline -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel, BasicModel2, BasicModel3, diff --git a/tests/attr/test_integrated_gradients_classification.py b/tests/attr/test_integrated_gradients_classification.py index f11c066590..b27ae28146 100644 --- a/tests/attr/test_integrated_gradients_classification.py +++ b/tests/attr/test_integrated_gradients_classification.py @@ -8,9 +8,9 @@ from captum._utils.typing import BaselineType, Tensor from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._core.noise_tunnel import NoiseTunnel -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.classification_models import SigmoidModel, SoftmaxModel +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.classification_models import SigmoidModel, SoftmaxModel from torch.nn import Module diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 5b6bb89e0c..4d6be2c334 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -7,9 +7,9 @@ import torch from captum._utils.typing import BatchEncodingType from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor diff --git a/tests/attr/test_jit.py b/tests/attr/test_jit.py index 6eff6482d1..1d00aaef95 100644 --- a/tests/attr/test_jit.py +++ b/tests/attr/test_jit.py @@ -25,13 +25,17 @@ from captum.attr._core.saliency import Saliency from captum.attr._core.shapley_value import ShapleyValueSampling from captum.attr._utils.attribution import Attribution +from captum.testing.helpers.basic import ( + assertTensorTuplesAlmostEqual, + BaseTest, + deep_copy_args, +) from tests.attr.helpers.gen_test_utils import ( gen_test_name, parse_test_config, should_create_generated_test, ) from tests.attr.helpers.test_config import config -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest, deep_copy_args from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_kernel_shap.py b/tests/attr/test_kernel_shap.py index 2cd007550c..fa34d369d9 100644 --- a/tests/attr/test_kernel_shap.py +++ b/tests/attr/test_kernel_shap.py @@ -10,13 +10,13 @@ import torch from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.kernel_shap import KernelShap -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, set_all_random_seeds, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicLinearModel, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 0c5ab7cf54..e3fccb2794 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -19,12 +19,12 @@ _format_input_baseline, _format_tensor_into_tuples, ) -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, ) -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic_models import ( BasicLinearModel, BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 0bbe4f4e73..d6f1a2a4ea 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -32,9 +32,9 @@ from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual, rand_like from parameterized import parameterized, parameterized_class -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual, rand_like from torch import nn, Tensor diff --git a/tests/attr/test_llm_attr_hf_compatibility.py b/tests/attr/test_llm_attr_hf_compatibility.py index f465cb0ef2..51c059eb29 100644 --- a/tests/attr/test_llm_attr_hf_compatibility.py +++ b/tests/attr/test_llm_attr_hf_compatibility.py @@ -13,8 +13,8 @@ from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput +from captum.testing.helpers import BaseTest from parameterized import parameterized, parameterized_class -from tests.helpers import BaseTest from torch import Tensor HAS_HF = True diff --git a/tests/attr/test_lrp.py b/tests/attr/test_lrp.py index 90e504d872..699c2bba99 100644 --- a/tests/attr/test_lrp.py +++ b/tests/attr/test_lrp.py @@ -12,9 +12,9 @@ GammaRule, IdentityRule, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, BasicModelWithReusedLinear, diff --git a/tests/attr/test_occlusion.py b/tests/attr/test_occlusion.py index 72acfc85d5..32705cbdbb 100644 --- a/tests/attr/test_occlusion.py +++ b/tests/attr/test_occlusion.py @@ -14,9 +14,9 @@ TensorOrTupleOfTensorsGeneric, ) from captum.attr._core.occlusion import Occlusion -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel3, BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, diff --git a/tests/attr/test_saliency.py b/tests/attr/test_saliency.py index e6695eaeca..196e3c3d7e 100644 --- a/tests/attr/test_saliency.py +++ b/tests/attr/test_saliency.py @@ -7,18 +7,18 @@ from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._core.saliency import Saliency +from captum.testing.helpers.basic import ( + assertTensorAlmostEqual, + assertTensorTuplesAlmostEqual, + BaseTest, +) +from captum.testing.helpers.classification_models import SoftmaxModel from tests.attr.helpers.get_config_util import ( get_basic_config, get_multiargs_basic_config, get_multiargs_basic_config_large, ) -from tests.helpers.basic import ( - assertTensorAlmostEqual, - assertTensorTuplesAlmostEqual, - BaseTest, -) -from tests.helpers.classification_models import SoftmaxModel from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index 320246f1d4..976adc55f2 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -10,8 +10,8 @@ import torch from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest -from tests.helpers.basic_models import ( +from captum.testing.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest +from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, BasicModelBoolInput, diff --git a/tests/attr/test_stat.py b/tests/attr/test_stat.py index 0b00ad82aa..8289584cce 100644 --- a/tests/attr/test_stat.py +++ b/tests/attr/test_stat.py @@ -6,8 +6,8 @@ import torch from captum.attr import Max, Mean, Min, MSE, StdDev, Sum, Summarizer, Var -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual def get_values(n: int = 100, lo=None, hi=None, integers: bool = False): diff --git a/tests/attr/test_summarizer.py b/tests/attr/test_summarizer.py index 0f71ae4b71..186b339835 100644 --- a/tests/attr/test_summarizer.py +++ b/tests/attr/test_summarizer.py @@ -3,7 +3,7 @@ # pyre-unsafe import torch from captum.attr import CommonStats, Summarizer -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest class Test(BaseTest): diff --git a/tests/attr/test_targets.py b/tests/attr/test_targets.py index d17bc21e1b..d6b69f32ed 100644 --- a/tests/attr/test_targets.py +++ b/tests/attr/test_targets.py @@ -12,6 +12,12 @@ from captum.attr._core.lime import Lime from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.attribution import Attribution, InternalAttribution +from captum.testing.helpers.basic import ( + assertTensorTuplesAlmostEqual, + BaseTest, + deep_copy_args, +) +from captum.testing.helpers.basic_models import BasicModel_MultiLayer from tests.attr.helpers.gen_test_utils import ( gen_test_name, get_target_layer, @@ -19,8 +25,6 @@ should_create_generated_test, ) from tests.attr.helpers.test_config import config -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest, deep_copy_args -from tests.helpers.basic_models import BasicModel_MultiLayer from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_utils_batching.py b/tests/attr/test_utils_batching.py index 9a149f593b..25bb536f0a 100644 --- a/tests/attr/test_utils_batching.py +++ b/tests/attr/test_utils_batching.py @@ -8,8 +8,8 @@ _batched_operator, _tuple_splice_range, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual class Test(BaseTest): diff --git a/tests/concept/test_concept.py b/tests/concept/test_concept.py index dbf28a772d..5b0474a1d0 100644 --- a/tests/concept/test_concept.py +++ b/tests/concept/test_concept.py @@ -7,7 +7,7 @@ import torch from captum.concept._core.concept import Concept from captum.concept._utils.data_iterator import dataset_to_dataloader -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest from torch.utils.data import IterableDataset diff --git a/tests/concept/test_tcav.py b/tests/concept/test_tcav.py index 679dcd10dd..9c2cc6fbe2 100644 --- a/tests/concept/test_tcav.py +++ b/tests/concept/test_tcav.py @@ -28,9 +28,9 @@ from captum.concept._utils.classifier import Classifier from captum.concept._utils.common import concepts_to_str from captum.concept._utils.data_iterator import dataset_to_dataloader -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel_ConvNet +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel_ConvNet from torch import Tensor from torch.utils.data import DataLoader, IterableDataset diff --git a/tests/influence/_core/test_arnoldi_influence.py b/tests/influence/_core/test_arnoldi_influence.py index ad61f9c9cf..1b1c2a8cdb 100644 --- a/tests/influence/_core/test_arnoldi_influence.py +++ b/tests/influence/_core/test_arnoldi_influence.py @@ -17,6 +17,8 @@ _top_eigen, _unflatten_params_factory, ) +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, @@ -29,8 +31,6 @@ UnpackDataset, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_dataloader.py b/tests/influence/_core/test_dataloader.py index 3237270a91..abb646f987 100644 --- a/tests/influence/_core/test_dataloader.py +++ b/tests/influence/_core/test_dataloader.py @@ -9,6 +9,8 @@ TracInCPFast, TracInCPFastRandProj, ) +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, @@ -16,8 +18,6 @@ get_random_model_and_data, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_naive_influence.py b/tests/influence/_core/test_naive_influence.py index bb2bca2bcc..0706408dc4 100644 --- a/tests/influence/_core/test_naive_influence.py +++ b/tests/influence/_core/test_naive_influence.py @@ -12,6 +12,11 @@ _functional_call, _unflatten_params_factory, ) +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import ( + assertTensorAlmostEqual, + assertTensorTuplesAlmostEqual, +) from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, @@ -23,8 +28,6 @@ UnpackDataset, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual from torch.utils.data import DataLoader # TODO: for some unknow reason, this test does not work diff --git a/tests/influence/_core/test_similarity_influence.py b/tests/influence/_core/test_similarity_influence.py index 8fcf4bf72e..2a75d4766a 100644 --- a/tests/influence/_core/test_similarity_influence.py +++ b/tests/influence/_core/test_similarity_influence.py @@ -10,8 +10,8 @@ euclidean_distance, SimilarityInfluence, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from torch import Tensor from torch.utils.data import Dataset diff --git a/tests/influence/_core/test_tracin_aggregate_influence.py b/tests/influence/_core/test_tracin_aggregate_influence.py index 7b293d201b..5c30e7355d 100644 --- a/tests/influence/_core/test_tracin_aggregate_influence.py +++ b/tests/influence/_core/test_tracin_aggregate_influence.py @@ -9,13 +9,13 @@ import torch.nn as nn from captum.influence._core.tracincp import TracInCP +from captum.testing.helpers.basic import assertTensorAlmostEqual, BaseTest from captum.testing.helpers.influence.common import ( build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_intermediate_quantities.py b/tests/influence/_core/test_tracin_intermediate_quantities.py index 6298fba3d4..b0c87ec3c3 100644 --- a/tests/influence/_core/test_tracin_intermediate_quantities.py +++ b/tests/influence/_core/test_tracin_intermediate_quantities.py @@ -12,6 +12,8 @@ TracInCPFast, TracInCPFastRandProj, ) +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, @@ -19,8 +21,6 @@ get_random_model_and_data, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_k_most_influential.py b/tests/influence/_core/test_tracin_k_most_influential.py index 15cf5097d7..0b29b8a499 100644 --- a/tests/influence/_core/test_tracin_k_most_influential.py +++ b/tests/influence/_core/test_tracin_k_most_influential.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn from captum.influence._core.tracincp import TracInCP +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, @@ -16,8 +18,6 @@ ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual class TestTracInGetKMostInfluential(BaseTest): diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 80863494b2..05897b90ae 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -13,6 +13,8 @@ TracInCPFast, TracInCPFastRandProj, ) +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _isSorted, _wrap_model_in_dataparallel, @@ -23,8 +25,6 @@ RangeDataset, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index fba15734b5..dddb1ff75e 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -8,6 +8,8 @@ from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp import TracInCP, TracInCPBase from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, @@ -17,8 +19,6 @@ is_gpu, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_show_progress.py b/tests/influence/_core/test_tracin_show_progress.py index e259038acf..092297d983 100644 --- a/tests/influence/_core/test_tracin_show_progress.py +++ b/tests/influence/_core/test_tracin_show_progress.py @@ -8,13 +8,13 @@ import torch.nn as nn from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast +from captum.testing.helpers import BaseTest from captum.testing.helpers.influence.common import ( build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) from parameterized import parameterized -from tests.helpers import BaseTest from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_validation.py b/tests/influence/_core/test_tracin_validation.py index 57ad1ce0b6..431a8ea0c0 100644 --- a/tests/influence/_core/test_tracin_validation.py +++ b/tests/influence/_core/test_tracin_validation.py @@ -5,6 +5,7 @@ import torch.nn as nn from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast +from captum.testing.helpers import BaseTest from captum.testing.helpers.influence.common import ( build_test_name_func, DataInfluenceConstructor, @@ -12,7 +13,6 @@ ) from parameterized import parameterized -from tests.helpers import BaseTest class TestTracinValidator(BaseTest): diff --git a/tests/influence/_core/test_tracin_xor.py b/tests/influence/_core/test_tracin_xor.py index 307f19df93..6ae7059a7b 100644 --- a/tests/influence/_core/test_tracin_xor.py +++ b/tests/influence/_core/test_tracin_xor.py @@ -9,6 +9,8 @@ import torch.nn as nn import torch.nn.functional as F from captum.influence._core.tracincp import TracInCP +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from captum.testing.helpers.influence.common import ( _wrap_model_in_dataparallel, BasicLinearNet, @@ -17,8 +19,6 @@ DataInfluenceConstructor, ) from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual class TestTracInXOR(BaseTest): diff --git a/tests/influence/_utils/test_common.py b/tests/influence/_utils/test_common.py index 1ed074d120..9f927c559b 100644 --- a/tests/influence/_utils/test_common.py +++ b/tests/influence/_utils/test_common.py @@ -7,8 +7,8 @@ import torch from captum.influence._utils.common import _jacobian_loss_wrt_inputs -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual class TestCommon(BaseTest): diff --git a/tests/insights/test_contribution.py b/tests/insights/test_contribution.py index 8c60898458..6978f3a275 100644 --- a/tests/insights/test_contribution.py +++ b/tests/insights/test_contribution.py @@ -10,8 +10,8 @@ from captum.insights import AttributionVisualizer, Batch from captum.insights.attr_vis.app import FilterConfig from captum.insights.attr_vis.features import BaseFeature, FeatureOutput, ImageFeature +from captum.testing.helpers import BaseTest from packaging import version -from tests.helpers import BaseTest from torch import Tensor from torch.utils.data import DataLoader diff --git a/tests/insights/test_features.py b/tests/insights/test_features.py index 9ec15d43d1..a9d129d5c9 100644 --- a/tests/insights/test_features.py +++ b/tests/insights/test_features.py @@ -10,8 +10,8 @@ ImageFeature, TextFeature, ) +from captum.testing.helpers import BaseTest from matplotlib.figure import Figure -from tests.helpers import BaseTest class TestTextFeature(BaseTest): diff --git a/tests/metrics/test_infidelity.py b/tests/metrics/test_infidelity.py index 6516ace124..60b4a796ca 100644 --- a/tests/metrics/test_infidelity.py +++ b/tests/metrics/test_infidelity.py @@ -15,9 +15,9 @@ Saliency, ) from captum.metrics import infidelity, infidelity_perturb_func_decorator -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel2, BasicModel4_MultiArgs, BasicModel_ConvNet_One_Conv, diff --git a/tests/metrics/test_sensitivity.py b/tests/metrics/test_sensitivity.py index 9dbb4e4749..16c01b3934 100644 --- a/tests/metrics/test_sensitivity.py +++ b/tests/metrics/test_sensitivity.py @@ -16,9 +16,9 @@ ) from captum.metrics import sensitivity_max from captum.metrics._core.sensitivity import default_perturb_func -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel2, BasicModel4_MultiArgs, BasicModel_ConvNet_One_Conv, diff --git a/tests/module/test_binary_concrete_stochastic_gates.py b/tests/module/test_binary_concrete_stochastic_gates.py index f4ada7b9ef..a50d3a273f 100644 --- a/tests/module/test_binary_concrete_stochastic_gates.py +++ b/tests/module/test_binary_concrete_stochastic_gates.py @@ -6,9 +6,9 @@ import torch from captum.module.binary_concrete_stochastic_gates import BinaryConcreteStochasticGates +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from parameterized import parameterized_class -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual @parameterized_class( diff --git a/tests/module/test_gaussian_stochastic_gates.py b/tests/module/test_gaussian_stochastic_gates.py index 58b90d6673..9d2d926b71 100644 --- a/tests/module/test_gaussian_stochastic_gates.py +++ b/tests/module/test_gaussian_stochastic_gates.py @@ -7,9 +7,9 @@ import torch from captum.module.gaussian_stochastic_gates import GaussianStochasticGates +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual from parameterized import parameterized_class -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual @parameterized_class( diff --git a/tests/robust/test_FGSM.py b/tests/robust/test_FGSM.py index 907d074820..4202ef83c4 100644 --- a/tests/robust/test_FGSM.py +++ b/tests/robust/test_FGSM.py @@ -6,9 +6,13 @@ import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.robust import FGSM -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel, BasicModel2, BasicModel_MultiLayer +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( + BasicModel, + BasicModel2, + BasicModel_MultiLayer, +) from torch import Tensor from torch.nn import CrossEntropyLoss diff --git a/tests/robust/test_PGD.py b/tests/robust/test_PGD.py index 8f8f28e77d..6d7fd5c5fe 100644 --- a/tests/robust/test_PGD.py +++ b/tests/robust/test_PGD.py @@ -3,9 +3,13 @@ # pyre-unsafe import torch from captum.robust import PGD -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel, BasicModel2, BasicModel_MultiLayer +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( + BasicModel, + BasicModel2, + BasicModel_MultiLayer, +) from torch.nn import CrossEntropyLoss diff --git a/tests/robust/test_attack_comparator.py b/tests/robust/test_attack_comparator.py index ebd41fb685..6f0e74d44f 100644 --- a/tests/robust/test_attack_comparator.py +++ b/tests/robust/test_attack_comparator.py @@ -6,9 +6,9 @@ import torch from captum.robust import AttackComparator, FGSM -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel, BasicModel_MultiLayer +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel, BasicModel_MultiLayer from torch import Tensor diff --git a/tests/robust/test_min_param_perturbation.py b/tests/robust/test_min_param_perturbation.py index 0c0694c04d..e9513d0983 100644 --- a/tests/robust/test_min_param_perturbation.py +++ b/tests/robust/test_min_param_perturbation.py @@ -5,9 +5,9 @@ import torch from captum.robust import MinParamPerturbation -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicModel, BasicModel_MultiLayer +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicModel, BasicModel_MultiLayer from torch import Tensor diff --git a/tests/utils/models/linear_models/_test_linear_classifier.py b/tests/utils/models/linear_models/_test_linear_classifier.py index c144a394df..b1f458e922 100644 --- a/tests/utils/models/linear_models/_test_linear_classifier.py +++ b/tests/utils/models/linear_models/_test_linear_classifier.py @@ -8,7 +8,7 @@ import numpy.typing as npt import sklearn.datasets as datasets import torch -from tests.helpers.evaluate_linear_model import evaluate +from captum.testing.helpers.evaluate_linear_model import evaluate from torch.utils.data import DataLoader, TensorDataset diff --git a/tests/utils/test_av.py b/tests/utils/test_av.py index 0e4a4befa9..3dd639b485 100644 --- a/tests/utils/test_av.py +++ b/tests/utils/test_av.py @@ -6,9 +6,9 @@ import torch from captum._utils.av import AV -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicLinearReLULinear +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import BasicLinearReLULinear from torch.utils.data import DataLoader, Dataset DEFAULT_IDENTIFIER = "default_identifier" diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index 20814349f0..0c4d5d232c 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -14,7 +14,7 @@ parse_version, safe_div, ) -from tests.helpers.basic import ( +from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, BaseTest, diff --git a/tests/utils/test_gradient.py b/tests/utils/test_gradient.py index da0efae505..59ec021832 100644 --- a/tests/utils/test_gradient.py +++ b/tests/utils/test_gradient.py @@ -12,10 +12,9 @@ compute_layer_gradients_and_eval, undo_gradient_requirements, ) -from packaging import version -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel, BasicModel2, BasicModel4_MultiArgs, @@ -23,6 +22,7 @@ BasicModel6_MultiTensor, BasicModel_MultiLayer, ) +from packaging import version class Test(BaseTest): diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 9b176641f4..d989868ff5 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -3,8 +3,8 @@ # pyre-unsafe import torch -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual class HelpersTest(BaseTest): diff --git a/tests/utils/test_jacobian.py b/tests/utils/test_jacobian.py index 8a8532901c..972ebbfd33 100644 --- a/tests/utils/test_jacobian.py +++ b/tests/utils/test_jacobian.py @@ -8,9 +8,12 @@ _compute_jacobian_wrt_params, _compute_jacobian_wrt_params_with_sample_wise_trick, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import BasicLinearModel2, BasicLinearModel_Multilayer +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( + BasicLinearModel2, + BasicLinearModel_Multilayer, +) class Test(BaseTest): diff --git a/tests/utils/test_linear_model.py b/tests/utils/test_linear_model.py index acfc8b4774..7f6c789af7 100644 --- a/tests/utils/test_linear_model.py +++ b/tests/utils/test_linear_model.py @@ -10,9 +10,9 @@ SGDLinearRegression, SGDRidge, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.evaluate_linear_model import evaluate +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.evaluate_linear_model import evaluate from torch import Tensor diff --git a/tests/utils/test_progress.py b/tests/utils/test_progress.py index 0a4acda00f..a87b997f0f 100644 --- a/tests/utils/test_progress.py +++ b/tests/utils/test_progress.py @@ -7,7 +7,7 @@ import unittest.mock from captum._utils.progress import NullProgress, progress -from tests.helpers import BaseTest +from captum.testing.helpers import BaseTest class Test(BaseTest): diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index f59d11b4cc..854d3eb7c4 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -11,9 +11,9 @@ SampleGradientWrapper, SUPPORTED_MODULES, ) -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.basic_models import ( +from captum.testing.helpers import BaseTest +from captum.testing.helpers.basic import assertTensorAlmostEqual +from captum.testing.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_ConvNetWithPaddingDilation, BasicModel_MultiLayer,