Skip to content

Commit 7f1b303

Browse files
Merge pull request #144 from cblessing24/mypy
Annotate Attention Readout
2 parents 80134f5 + 28e564e commit 7f1b303

File tree

11 files changed

+119
-37
lines changed

11 files changed

+119
-37
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.mypy_cache/

.github/workflows/mypy.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
name: Mypy
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
mypy:
7+
runs-on: ubuntu-18.04
8+
steps:
9+
- uses: actions/checkout@v2
10+
- uses: actions/setup-python@v2
11+
- name: Check code with mypy
12+
run: touch .env && docker-compose run mypy

.pre-commit-config.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
repos:
2+
- repo: local
3+
hooks:
4+
- id: mypy
5+
name: mypy
6+
language: system
7+
entry: docker-compose run --rm mypy
8+
files: ^neuralpredictors/
9+
types: [python]
10+
pass_filenames: false

Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
FROM sinzlab/pytorch:v3.8-torch1.7.0-cuda11.0-dj0.12.7
1+
FROM sinzlab/pytorch:v3.9-torch1.9.0-cuda11.1-dj0.12.7
22

33
COPY . /src/neuralpredictors
4+
WORKDIR /src/neuralpredictors
45

56
RUN python3 -m pip install --upgrade pip &&\
7+
python3 -m pip install mypy==$(cat mypy_version.txt) &&\
68
python3 -m pip install -e /src/neuralpredictors
79

8-
WORKDIR /src/neuralpredictors
910

1011
ENTRYPOINT ["python3"]

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
![Test](https://github.com/sinzlab/neuralpredictors/workflows/Test/badge.svg)
44
[![codecov](https://codecov.io/gh/sinzlab/neuralpredictors/branch/main/graph/badge.svg)](https://codecov.io/gh/sinzlab/neuralpredictors)
55
![Black](https://github.com/sinzlab/neuralpredictors/workflows/Black/badge.svg)
6+
[![Mypy](https://github.com/sinzlab/neuralpredictors/actions/workflows/mypy.yml/badge.svg)](https://github.com/sinzlab/neuralpredictors/actions/workflows/mypy.yml)
67
[![PyPI version](https://badge.fury.io/py/neuralpredictors.svg)](https://badge.fury.io/py/neuralpredictors)
78

89
[Sinz Lab](https://sinzlab.org/) Neural System Identification Utilities for [PyTorch](https://pytorch.org/).

docker-compose.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ services:
1818
build: .
1919
volumes:
2020
- .:/src/neuralpredictors
21+
mypy:
22+
build: .
23+
volumes:
24+
- .:/src/neuralpredictors
25+
- mypy-cache:/src/neuralpredictors/.mypy_cache
26+
entrypoint: ["/src/neuralpredictors/run_mypy.sh"]
27+
28+
volumes:
29+
mypy-cache:

mypy_files.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
neuralpredictors/layers/readouts/attention.py

mypy_version.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.910
Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,96 @@
1+
from typing import Any, Literal, Mapping, Optional, Tuple
2+
13
import torch
2-
from torch import nn
3-
from torch.nn import Parameter
44
from torch.nn import functional as F
5+
from torch.nn import init
6+
from torch.nn.modules import ELU, BatchNorm2d, Conv2d, Module, Sequential
7+
from torch.nn.parameter import Parameter
8+
59
from .base import Readout
610

711

812
class AttentionReadout(Readout):
913
def __init__(
1014
self,
11-
in_shape,
12-
outdims,
13-
bias,
14-
init_noise=1e-3,
15-
attention_kernel=1,
16-
attention_layers=1,
17-
mean_activity=None,
18-
feature_reg_weight=1.0,
19-
gamma_readout=None, # depricated, use feature_reg_weight instead
20-
**kwargs,
21-
):
15+
in_shape: Tuple[int, int, int],
16+
outdims: int,
17+
bias: bool,
18+
init_noise: float = 1e-3,
19+
attention_kernel: int = 1,
20+
attention_layers: int = 1,
21+
mean_activity: Optional[Mapping[str, float]] = None,
22+
feature_reg_weight: float = 1.0,
23+
gamma_readout: Optional[float] = None, # deprecated, use feature_reg_weight instead
24+
**kwargs: Any,
25+
) -> None:
2226
super().__init__()
2327
self.in_shape = in_shape
2428
self.outdims = outdims
25-
self.feature_reg_weight = self.resolve_deprecated_gamma_readout(feature_reg_weight, gamma_readout)
29+
self.feature_reg_weight = self.resolve_deprecated_gamma_readout(feature_reg_weight, gamma_readout) # type: ignore[no-untyped-call]
2630
self.mean_activity = mean_activity
2731
c, w, h = in_shape
2832
self.features = Parameter(torch.Tensor(self.outdims, c))
2933

30-
attention = nn.Sequential()
34+
attention = Sequential()
3135
for i in range(attention_layers - 1):
3236
attention.add_module(
3337
f"conv{i}",
34-
nn.Conv2d(c, c, attention_kernel, padding=attention_kernel > 1),
38+
Conv2d(c, c, attention_kernel, padding=attention_kernel > 1),
3539
)
36-
attention.add_module(f"norm{i}", nn.BatchNorm2d(c))
37-
attention.add_module(f"nonlin{i}", nn.ELU())
40+
attention.add_module(f"norm{i}", BatchNorm2d(c)) # type: ignore[no-untyped-call]
41+
attention.add_module(f"nonlin{i}", ELU())
3842
else:
3943
attention.add_module(
4044
f"conv{attention_layers}",
41-
nn.Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1),
45+
Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1),
4246
)
4347
self.attention = attention
4448

4549
self.init_noise = init_noise
4650
if bias:
47-
bias = Parameter(torch.Tensor(self.outdims))
48-
self.register_parameter("bias", bias)
51+
bias_param = Parameter(torch.Tensor(self.outdims))
52+
self.register_parameter("bias", bias_param)
4953
else:
5054
self.register_parameter("bias", None)
5155
self.initialize(mean_activity)
5256

5357
@staticmethod
54-
def init_conv(m):
55-
if isinstance(m, nn.Conv2d):
56-
nn.init.xavier_normal_(m.weight.data)
58+
def init_conv(m: Module) -> None:
59+
if isinstance(m, Conv2d):
60+
init.xavier_normal_(m.weight.data)
5761
if m.bias is not None:
5862
m.bias.data.fill_(0)
5963

60-
def initialize_attention(self):
64+
def initialize_attention(self) -> None:
6165
self.apply(self.init_conv)
6266

63-
def initialize(self, mean_activity=None):
67+
def initialize(self, mean_activity: Optional[Mapping[str, float]] = None) -> None: # type: ignore[override]
6468
if mean_activity is None:
6569
mean_activity = self.mean_activity
6670
self.features.data.normal_(0, self.init_noise)
6771
if self.bias is not None:
68-
self.initialize_bias(mean_activity=mean_activity)
72+
self.initialize_bias(mean_activity=mean_activity) # type: ignore[no-untyped-call]
6973
self.initialize_attention()
7074

71-
def feature_l1(self, reduction="sum", average=None):
72-
return self.apply_reduction(self.features.abs(), reduction=reduction, average=average)
75+
def feature_l1(
76+
self, reduction: Literal["sum", "mean", None] = "sum", average: Optional[bool] = None
77+
) -> torch.Tensor:
78+
return self.apply_reduction(self.features.abs(), reduction=reduction, average=average) # type: ignore[no-untyped-call,no-any-return]
7379

74-
def regularizer(self, reduction="sum", average=None):
75-
return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight
80+
def regularizer(
81+
self, reduction: Literal["sum", "mean", None] = "sum", average: Optional[bool] = None
82+
) -> torch.Tensor:
83+
return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight # type: ignore[no-any-return]
7684

77-
def forward(self, x, shift=None):
85+
def forward(self, x: torch.Tensor, shift: Optional[Any] = None) -> torch.Tensor:
7886
attention = self.attention(x)
7987
b, c, w, h = attention.shape
8088
attention = F.softmax(attention.view(b, c, -1), dim=-1).view(b, c, w, h)
81-
y = torch.einsum("bnwh,bcwh->bcn", attention, x)
82-
y = torch.einsum("bcn,nc->bn", y, self.features)
89+
y: torch.Tensor = torch.einsum("bnwh,bcwh->bcn", attention, x) # type: ignore[attr-defined]
90+
y = torch.einsum("bcn,nc->bn", y, self.features) # type: ignore[attr-defined]
8391
if self.bias is not None:
8492
y = y + self.bias
8593
return y
8694

87-
def __repr__(self):
95+
def __repr__(self) -> str:
8896
return self.__class__.__name__ + " (" + "{} x {} x {}".format(*self.in_shape) + " -> " + str(self.outdims) + ")"

pyproject.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,24 @@ line-length = 120
44
[tool.coverage.run]
55
branch = true
66
source = ["neuralpredictors"]
7+
8+
[tool.mypy]
9+
python_version = "3.8"
10+
files = "neuralpredictors"
11+
exclude = "old_\\w+\\.py$"
12+
strict = true
13+
disallow_untyped_calls = true
14+
disallow_untyped_defs = true
15+
disallow_incomplete_defs = true
16+
disallow_untyped_decorators = true
17+
18+
[[tool.mypy.overrides]]
19+
module = [
20+
"h5py",
21+
"scipy.signal",
22+
"scipy.special",
23+
"skimage.transform",
24+
"torchvision",
25+
"tqdm"
26+
]
27+
ignore_missing_imports = true

0 commit comments

Comments
 (0)