Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test-pip-cpu-with-type-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ on:

jobs:
tests:
if: false # Workflow disabled
strategy:
matrix:
pytorch_args: ["", "-n"]
Expand All @@ -24,6 +23,7 @@ jobs:
sudo chmod -R 777 .
./scripts/install_via_pip.sh ${{ matrix.pytorch_args }}
./scripts/run_mypy.sh
pyre check
# Disabling pyre check until numpy issues are resolved
# pyre check
# Run Tests
python3 -m pytest -ra --cov=. --cov-report term-missing
1 change: 1 addition & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def _format_outputs(
# pyre-fixme[24] Callable requires 2 arguments
def _construct_future_forward(original_forward: Callable) -> Callable:
def future_forward(*args: Any, **kwargs: Any) -> torch.futures.Future[Tensor]:
# pyre-ignore[29]: `Future` is callable at runtime
fut: torch.futures.Future[Tensor] = torch.futures.Future()
fut.set_result(original_forward(*args, **kwargs))
return fut
Expand Down
12 changes: 5 additions & 7 deletions captum/attr/_utils/approximation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
from enum import Enum
from typing import Callable, cast, List, Tuple
from typing import Callable, List, Tuple

import torch

Expand Down Expand Up @@ -126,15 +126,13 @@ def gauss_legendre_builders() -> (
def step_sizes(n: int) -> List[float]:
assert n > 0, "The number of steps has to be larger than zero"
# Scaling from 2 to 1
return cast(
NDArray[np.float64], 0.5 * np.polynomial.legendre.leggauss(n)[1]
).tolist()
result: NDArray[np.float64] = 0.5 * np.polynomial.legendre.leggauss(n)[1]
return result.tolist()

def alphas(n: int) -> List[float]:
assert n > 0, "The number of steps has to be larger than zero"
# Scaling from [-1, 1] to [0, 1]
return cast(
NDArray[np.float64], 0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])
).tolist()
result: NDArray[np.float64] = 0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])
return result.tolist()

return step_sizes, alphas
2 changes: 0 additions & 2 deletions captum/concept/_core/cav.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ def load(
ctx: AbstractContextManager[None, None]
if hasattr(torch.serialization, "safe_globals"):
safe_globals = [
# pyre-ignore[16]: Module `numpy.core.multiarray` has no attribute
# `_reconstruct`
np.core.multiarray._reconstruct, # type: ignore[attr-defined]
np.ndarray,
np.dtype,
Expand Down
1 change: 1 addition & 0 deletions captum/testing/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def forward(
self.relu(lin1_out)
else:
relu_out = self.relu(lin1_out)
# pyre-ignore[29]: `Future` is callable at runtime
result = Future()
lin2_out = self.linear2(relu_out)
if multidim_output:
Expand Down
Loading