Description
Might see something here: https://towardsdatascience.com/how-to-build-your-own-pytorch-neural-network-layer-from-scratch-842144d623f6
The key is that it needs to have its internal weights as trainniable parameters for pytorch, rather than as fixed values.
If the input is x : R^N
then for all of these, we want to map to a N x N
matrix (which is diagonal in this case). We can then take that matrix and multiply it by the input to rescale. For the diagonal, the multiplication is then just a pointwise multiplication.
To start with, implement the pointwise f(x) = exp( D x)
for a diagonal D
and input x
. i.e.
f(x_1) = exp(D_1 x_1)
f(x_2) = exp(D_2 x_2)
...
f(x_N) = exp(D_N x_N)
which is a N-parameter learnable function (i.e. self.D = torch.nn.Parameter(torch.zeros(N))
etc.)
Code that might not be that far off is
# Given an input y this calculates:
# exp(D x) * y for the pointwise multiple and exponential
class DiagonalExponentialRescaling(nn.Module):
def __init__(self, n_in):
super().__init__()
self.n_in = n_in
self.weights = torch.nn.Parameter(torch.Tensor(n_in))
self.reset_parameters()
def reset_parameters(self):
# Lets start at zero, but later could have option
torch.nn.init.zeros_(self.weights) # maybe this? Not entirely sure.
def forward(self, x, y):
exp_x = torch.exp(torch.mul(x, self.weights)) # exponential of input
return torch.mul(exp_x, y)
Because this will be relatively low dimensional in parameters, we will need to make sure we start at the right place. Maybe even D_n = 0 is a better initial condition then something totally random.
Note that we would call this with two inputs
model = DiagonalExponentialRescaling(5)
x = torch.tensor([...])
y = torch.tensor([...]) # maybe coming out of another network
out = model(x, y)
For the unit tests:
- make sure to use the
gradcheck
with double precision to make sure it doesn't have any issues. https://github.com/HighDimensionalEconLab/econ_layers/blob/main/tests/test_flexible_sequential.py#L20-L21 - But somehow you need to ensure that it also is able to take the graidents with the parameters themselves. For that, I would setup a simple pytorch training loop to fit it to exponential data that you generate.