Skip to content

Commit dc69a23

Browse files
authored
Merge pull request #46 from HighDimensionalEconLab/modification_bias_jan
different bias initialization
2 parents 518a190 + 1c9b25a commit dc69a23

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

econ_layers/layers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ def __init__(
1717
def forward(self, input):
1818
return torch.cat([input.pow(m) for m in torch.arange(1, self.n_moments + 1)], 1)
1919

20-
20+
2121
# rescaling by a specific element of a given input
2222
class RescaleOutputsByInput(nn.Module):
2323
def __init__(self, rescale_index: int = 0, bias=False):
2424
super().__init__()
2525
self.rescale_index = rescale_index
2626
if bias:
27-
self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here
28-
torch.nn.init.zeros_(self.bias)
27+
self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here
28+
torch.nn.init.ones_(self.bias)
2929
else:
30-
self.bias = 0.0 # register_parameter('bias', None) # necessary?
31-
30+
self.bias = 0.0 # register_parameter('bias', None) # necessary?
31+
3232
def forward(self, x, y):
3333
if x.dim() == 1:
3434
return x[self.rescale_index] * y + self.bias

0 commit comments

Comments
 (0)