Skip to content

Commit 899a2ed

Browse files
Directly import Pytorch modules
1 parent 42ed90f commit 899a2ed

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

neuralpredictors/layers/readouts/attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch.nn import functional as F
55
from torch.nn import init
6-
from torch.nn import modules as nn_modules
6+
from torch.nn.modules import ELU, BatchNorm2d, Conv2d, Module, Sequential
77
from torch.nn.parameter import Parameter
88

99
from .base import Readout
@@ -31,18 +31,18 @@ def __init__(
3131
c, w, h = in_shape
3232
self.features = Parameter(torch.Tensor(self.outdims, c))
3333

34-
attention = nn_modules.Sequential()
34+
attention = Sequential()
3535
for i in range(attention_layers - 1):
3636
attention.add_module(
3737
f"conv{i}",
38-
nn_modules.Conv2d(c, c, attention_kernel, padding=attention_kernel > 1),
38+
Conv2d(c, c, attention_kernel, padding=attention_kernel > 1),
3939
)
40-
attention.add_module(f"norm{i}", nn_modules.BatchNorm2d(c)) # type: ignore[no-untyped-call]
41-
attention.add_module(f"nonlin{i}", nn_modules.ELU())
40+
attention.add_module(f"norm{i}", BatchNorm2d(c)) # type: ignore[no-untyped-call]
41+
attention.add_module(f"nonlin{i}", ELU())
4242
else:
4343
attention.add_module(
4444
f"conv{attention_layers}",
45-
nn_modules.Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1),
45+
Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1),
4646
)
4747
self.attention = attention
4848

@@ -55,8 +55,8 @@ def __init__(
5555
self.initialize(mean_activity)
5656

5757
@staticmethod
58-
def init_conv(m: nn_modules.Module) -> None:
59-
if isinstance(m, nn_modules.Conv2d):
58+
def init_conv(m: Module) -> None:
59+
if isinstance(m, Conv2d):
6060
init.xavier_normal_(m.weight.data)
6161
if m.bias is not None:
6262
m.bias.data.fill_(0)

0 commit comments

Comments
 (0)