From f6ced6518828513b889de3530fd617dd2df1a234 Mon Sep 17 00:00:00 2001 From: "Pyre Bot Jr." Date: Fri, 1 Nov 2024 09:39:02 -0700 Subject: [PATCH] Add type error suppressions for upcoming upgrade Differential Revision: D65341979 --- captum/_utils/common.py | 6 +++++- captum/metrics/_core/infidelity.py | 2 ++ captum/metrics/_core/sensitivity.py | 2 ++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 6459cd8aa..4804a29b5 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -680,6 +680,7 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor: raise AssertionError(f"Target type {type(target)} is not valid.") +# pyre-fixme[24]: Generic type `slice` expects 3 type parameters. def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool: if isinstance(target, tuple): for index in target: @@ -690,7 +691,10 @@ def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool: def _verify_select_column( - output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]] + # pyre-fixme[24]: Generic type `slice` expects 3 type parameters. + output: Tensor, + # pyre-fixme[24]: Generic type `slice` expects 3 type parameters. + target: Union[int, Tuple[Union[int, slice], ...]], ) -> Tensor: target = (target,) if isinstance(target, int) else target assert ( diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index c4c4bd061..0609d4fae 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -499,6 +499,8 @@ def _generate_perturbations( repeated instances per example. """ + # pyre-fixme[53]: Captured variable `baselines_expanded` is not annotated. + # pyre-fixme[53]: Captured variable `inputs_expanded` is not annotated. def call_perturb_func() -> ( Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] ): diff --git a/captum/metrics/_core/sensitivity.py b/captum/metrics/_core/sensitivity.py index 381dbbc44..b4b0190ea 100644 --- a/captum/metrics/_core/sensitivity.py +++ b/captum/metrics/_core/sensitivity.py @@ -232,6 +232,8 @@ def max_values(input_tnsr: Tensor) -> Tensor: # pyre-fixme[33]: Given annotation cannot be `Any`. kwargs_copy: Any = None + # pyre-fixme[53]: Captured variable `bsz` is not annotated. + # pyre-fixme[53]: Captured variable `expl_inputs` is not annotated. def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: inputs_perturbed = _generate_perturbations(current_n_perturb_samples)