Skip to content

Commit 28e564e

Browse files
Remove ignore in forward method
1 parent 899a2ed commit 28e564e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

neuralpredictors/layers/readouts/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def forward(self, x: torch.Tensor, shift: Optional[Any] = None) -> torch.Tensor:
8686
attention = self.attention(x)
8787
b, c, w, h = attention.shape
8888
attention = F.softmax(attention.view(b, c, -1), dim=-1).view(b, c, w, h)
89-
y = torch.einsum("bnwh,bcwh->bcn", attention, x) # type: ignore[attr-defined]
89+
y: torch.Tensor = torch.einsum("bnwh,bcwh->bcn", attention, x) # type: ignore[attr-defined]
9090
y = torch.einsum("bcn,nc->bn", y, self.features) # type: ignore[attr-defined]
9191
if self.bias is not None:
9292
y = y + self.bias
93-
return y # type: ignore[no-any-return]
93+
return y
9494

9595
def __repr__(self) -> str:
9696
return self.__class__.__name__ + " (" + "{} x {} x {}".format(*self.in_shape) + " -> " + str(self.outdims) + ")"

0 commit comments

Comments
 (0)