33import torch
44from torch .nn import functional as F
55from torch .nn import init
6- from torch .nn import modules as nn_modules
6+ from torch .nn . modules import ELU , BatchNorm2d , Conv2d , Module , Sequential
77from torch .nn .parameter import Parameter
88
99from .base import Readout
@@ -31,18 +31,18 @@ def __init__(
3131 c , w , h = in_shape
3232 self .features = Parameter (torch .Tensor (self .outdims , c ))
3333
34- attention = nn_modules . Sequential ()
34+ attention = Sequential ()
3535 for i in range (attention_layers - 1 ):
3636 attention .add_module (
3737 f"conv{ i } " ,
38- nn_modules . Conv2d (c , c , attention_kernel , padding = attention_kernel > 1 ),
38+ Conv2d (c , c , attention_kernel , padding = attention_kernel > 1 ),
3939 )
40- attention .add_module (f"norm{ i } " , nn_modules . BatchNorm2d (c )) # type: ignore[no-untyped-call]
41- attention .add_module (f"nonlin{ i } " , nn_modules . ELU ())
40+ attention .add_module (f"norm{ i } " , BatchNorm2d (c )) # type: ignore[no-untyped-call]
41+ attention .add_module (f"nonlin{ i } " , ELU ())
4242 else :
4343 attention .add_module (
4444 f"conv{ attention_layers } " ,
45- nn_modules . Conv2d (c , outdims , attention_kernel , padding = attention_kernel > 1 ),
45+ Conv2d (c , outdims , attention_kernel , padding = attention_kernel > 1 ),
4646 )
4747 self .attention = attention
4848
@@ -55,8 +55,8 @@ def __init__(
5555 self .initialize (mean_activity )
5656
5757 @staticmethod
58- def init_conv (m : nn_modules . Module ) -> None :
59- if isinstance (m , nn_modules . Conv2d ):
58+ def init_conv (m : Module ) -> None :
59+ if isinstance (m , Conv2d ):
6060 init .xavier_normal_ (m .weight .data )
6161 if m .bias is not None :
6262 m .bias .data .fill_ (0 )
0 commit comments