Skip to content

Implement trainable diagonal exponential rescaling functions #1

Open
@jlperla

Description

@jlperla

Might see something here: https://towardsdatascience.com/how-to-build-your-own-pytorch-neural-network-layer-from-scratch-842144d623f6

The key is that it needs to have its internal weights as trainniable parameters for pytorch, rather than as fixed values.

If the input is x : R^N then for all of these, we want to map to a N x N matrix (which is diagonal in this case). We can then take that matrix and multiply it by the input to rescale. For the diagonal, the multiplication is then just a pointwise multiplication.

To start with, implement the pointwise f(x) = exp( D x) for a diagonal D and input x. i.e.

f(x_1) = exp(D_1 x_1)
f(x_2) = exp(D_2 x_2)
...
f(x_N) = exp(D_N x_N)

which is a N-parameter learnable function (i.e. self.D = torch.nn.Parameter(torch.zeros(N)) etc.)

Code that might not be that far off is

# Given an input y this calculates:
# exp(D x) * y for the pointwise multiple and exponential
class DiagonalExponentialRescaling(nn.Module):
    def __init__(self, n_in):
        super().__init__()
        self.n_in = n_in
        self.weights = torch.nn.Parameter(torch.Tensor(n_in))
        self.reset_parameters()
        
    def reset_parameters(self):
        # Lets start at zero, but later could have option
        torch.nn.init.zeros_(self.weights) # maybe this?  Not entirely sure.        
    
    def forward(self, x, y):
        exp_x = torch.exp(torch.mul(x, self.weights))  # exponential of input
        return torch.mul(exp_x, y)

Because this will be relatively low dimensional in parameters, we will need to make sure we start at the right place. Maybe even D_n = 0 is a better initial condition then something totally random.

Note that we would call this with two inputs

model = DiagonalExponentialRescaling(5)
x = torch.tensor([...])
y = torch.tensor([...]) # maybe coming out of another network
out = model(x, y)

For the unit tests:

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions