|
| 1 | +from typing import Any, Literal, Mapping, Optional, Tuple |
| 2 | + |
1 | 3 | import torch |
2 | | -from torch import nn |
3 | | -from torch.nn import Parameter |
4 | 4 | from torch.nn import functional as F |
| 5 | +from torch.nn import init |
| 6 | +from torch.nn.modules import ELU, BatchNorm2d, Conv2d, Module, Sequential |
| 7 | +from torch.nn.parameter import Parameter |
| 8 | + |
5 | 9 | from .base import Readout |
6 | 10 |
|
7 | 11 |
|
8 | 12 | class AttentionReadout(Readout): |
9 | 13 | def __init__( |
10 | 14 | self, |
11 | | - in_shape, |
12 | | - outdims, |
13 | | - bias, |
14 | | - init_noise=1e-3, |
15 | | - attention_kernel=1, |
16 | | - attention_layers=1, |
17 | | - mean_activity=None, |
18 | | - feature_reg_weight=1.0, |
19 | | - gamma_readout=None, # depricated, use feature_reg_weight instead |
20 | | - **kwargs, |
21 | | - ): |
| 15 | + in_shape: Tuple[int, int, int], |
| 16 | + outdims: int, |
| 17 | + bias: bool, |
| 18 | + init_noise: float = 1e-3, |
| 19 | + attention_kernel: int = 1, |
| 20 | + attention_layers: int = 1, |
| 21 | + mean_activity: Optional[Mapping[str, float]] = None, |
| 22 | + feature_reg_weight: float = 1.0, |
| 23 | + gamma_readout: Optional[float] = None, # deprecated, use feature_reg_weight instead |
| 24 | + **kwargs: Any, |
| 25 | + ) -> None: |
22 | 26 | super().__init__() |
23 | 27 | self.in_shape = in_shape |
24 | 28 | self.outdims = outdims |
25 | | - self.feature_reg_weight = self.resolve_deprecated_gamma_readout(feature_reg_weight, gamma_readout) |
| 29 | + self.feature_reg_weight = self.resolve_deprecated_gamma_readout(feature_reg_weight, gamma_readout) # type: ignore[no-untyped-call] |
26 | 30 | self.mean_activity = mean_activity |
27 | 31 | c, w, h = in_shape |
28 | 32 | self.features = Parameter(torch.Tensor(self.outdims, c)) |
29 | 33 |
|
30 | | - attention = nn.Sequential() |
| 34 | + attention = Sequential() |
31 | 35 | for i in range(attention_layers - 1): |
32 | 36 | attention.add_module( |
33 | 37 | f"conv{i}", |
34 | | - nn.Conv2d(c, c, attention_kernel, padding=attention_kernel > 1), |
| 38 | + Conv2d(c, c, attention_kernel, padding=attention_kernel > 1), |
35 | 39 | ) |
36 | | - attention.add_module(f"norm{i}", nn.BatchNorm2d(c)) |
37 | | - attention.add_module(f"nonlin{i}", nn.ELU()) |
| 40 | + attention.add_module(f"norm{i}", BatchNorm2d(c)) # type: ignore[no-untyped-call] |
| 41 | + attention.add_module(f"nonlin{i}", ELU()) |
38 | 42 | else: |
39 | 43 | attention.add_module( |
40 | 44 | f"conv{attention_layers}", |
41 | | - nn.Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1), |
| 45 | + Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1), |
42 | 46 | ) |
43 | 47 | self.attention = attention |
44 | 48 |
|
45 | 49 | self.init_noise = init_noise |
46 | 50 | if bias: |
47 | | - bias = Parameter(torch.Tensor(self.outdims)) |
48 | | - self.register_parameter("bias", bias) |
| 51 | + bias_param = Parameter(torch.Tensor(self.outdims)) |
| 52 | + self.register_parameter("bias", bias_param) |
49 | 53 | else: |
50 | 54 | self.register_parameter("bias", None) |
51 | 55 | self.initialize(mean_activity) |
52 | 56 |
|
53 | 57 | @staticmethod |
54 | | - def init_conv(m): |
55 | | - if isinstance(m, nn.Conv2d): |
56 | | - nn.init.xavier_normal_(m.weight.data) |
| 58 | + def init_conv(m: Module) -> None: |
| 59 | + if isinstance(m, Conv2d): |
| 60 | + init.xavier_normal_(m.weight.data) |
57 | 61 | if m.bias is not None: |
58 | 62 | m.bias.data.fill_(0) |
59 | 63 |
|
60 | | - def initialize_attention(self): |
| 64 | + def initialize_attention(self) -> None: |
61 | 65 | self.apply(self.init_conv) |
62 | 66 |
|
63 | | - def initialize(self, mean_activity=None): |
| 67 | + def initialize(self, mean_activity: Optional[Mapping[str, float]] = None) -> None: # type: ignore[override] |
64 | 68 | if mean_activity is None: |
65 | 69 | mean_activity = self.mean_activity |
66 | 70 | self.features.data.normal_(0, self.init_noise) |
67 | 71 | if self.bias is not None: |
68 | | - self.initialize_bias(mean_activity=mean_activity) |
| 72 | + self.initialize_bias(mean_activity=mean_activity) # type: ignore[no-untyped-call] |
69 | 73 | self.initialize_attention() |
70 | 74 |
|
71 | | - def feature_l1(self, reduction="sum", average=None): |
72 | | - return self.apply_reduction(self.features.abs(), reduction=reduction, average=average) |
| 75 | + def feature_l1( |
| 76 | + self, reduction: Literal["sum", "mean", None] = "sum", average: Optional[bool] = None |
| 77 | + ) -> torch.Tensor: |
| 78 | + return self.apply_reduction(self.features.abs(), reduction=reduction, average=average) # type: ignore[no-untyped-call,no-any-return] |
73 | 79 |
|
74 | | - def regularizer(self, reduction="sum", average=None): |
75 | | - return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight |
| 80 | + def regularizer( |
| 81 | + self, reduction: Literal["sum", "mean", None] = "sum", average: Optional[bool] = None |
| 82 | + ) -> torch.Tensor: |
| 83 | + return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight # type: ignore[no-any-return] |
76 | 84 |
|
77 | | - def forward(self, x, shift=None): |
| 85 | + def forward(self, x: torch.Tensor, shift: Optional[Any] = None) -> torch.Tensor: |
78 | 86 | attention = self.attention(x) |
79 | 87 | b, c, w, h = attention.shape |
80 | 88 | attention = F.softmax(attention.view(b, c, -1), dim=-1).view(b, c, w, h) |
81 | | - y = torch.einsum("bnwh,bcwh->bcn", attention, x) |
82 | | - y = torch.einsum("bcn,nc->bn", y, self.features) |
| 89 | + y: torch.Tensor = torch.einsum("bnwh,bcwh->bcn", attention, x) # type: ignore[attr-defined] |
| 90 | + y = torch.einsum("bcn,nc->bn", y, self.features) # type: ignore[attr-defined] |
83 | 91 | if self.bias is not None: |
84 | 92 | y = y + self.bias |
85 | 93 | return y |
86 | 94 |
|
87 | | - def __repr__(self): |
| 95 | + def __repr__(self) -> str: |
88 | 96 | return self.__class__.__name__ + " (" + "{} x {} x {}".format(*self.in_shape) + " -> " + str(self.outdims) + ")" |
0 commit comments