Open
Description
If we find that we want more control over what inputs are rescaled, then something like
class RescaleInputsbyInput(nn.Module):
def __init__(self, rescale_index, inputs_to_rescale = None):
super().__init__()
self.rescale_index = rescale_index
# if inputs_to_rescale is None, then it assume all except the current one.
if inputs_to_rescale is None:
self.inputs_to_rescale = # generate all indices except the rescale_index one.
else:
self.inputs_to_rescale = inputs_to_rescale
def forward(self, x):
rescale_scalar = 1 / x[self.rescale_index]
new_x = # multiple all indices in "inputs_to_rescale" by "rescale_scalar".
return new_x