File tree 1 file changed +5
-5
lines changed
1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -17,18 +17,18 @@ def __init__(
17
17
def forward (self , input ):
18
18
return torch .cat ([input .pow (m ) for m in torch .arange (1 , self .n_moments + 1 )], 1 )
19
19
20
-
20
+
21
21
# rescaling by a specific element of a given input
22
22
class RescaleOutputsByInput (nn .Module ):
23
23
def __init__ (self , rescale_index : int = 0 , bias = False ):
24
24
super ().__init__ ()
25
25
self .rescale_index = rescale_index
26
26
if bias :
27
- self .bias = torch .nn .Parameter (torch .Tensor (1 )) # only a scalar here
28
- torch .nn .init .zeros_ (self .bias )
27
+ self .bias = torch .nn .Parameter (torch .Tensor (1 )) # only a scalar here
28
+ torch .nn .init .ones_ (self .bias )
29
29
else :
30
- self .bias = 0.0 # register_parameter('bias', None) # necessary?
31
-
30
+ self .bias = 0.0 # register_parameter('bias', None) # necessary?
31
+
32
32
def forward (self , x , y ):
33
33
if x .dim () == 1 :
34
34
return x [self .rescale_index ] * y + self .bias
You can’t perform that action at this time.
0 commit comments