|
| 1 | +# This file is part of NPFL139 <http://github.com/ufal/npfl139/>. |
| 2 | +# |
| 3 | +# This Source Code Form is subject to the terms of the Mozilla Public |
| 4 | +# License, v. 2.0. If a copy of the MPL was not distributed with this |
| 5 | +# file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 6 | +from collections.abc import Callable |
| 7 | +import math |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | + |
| 13 | +class KerasParameterInitialization: |
| 14 | + def reset_parameters_linear(self) -> None: |
| 15 | + torch.nn.init.xavier_uniform_(self.weight) |
| 16 | + if self.bias is not None: |
| 17 | + torch.nn.init.zeros_(self.bias) |
| 18 | + |
| 19 | + def reset_parameters_bilinear(self) -> None: |
| 20 | + # Keras does not have a Bilinear layer. But we analogously use |
| 21 | + # the Xavier uniform initialization, where |
| 22 | + # - the fan_out for each out_feature is in_feature1 * in_feature2 |
| 23 | + # - the fan_in for each in_feature1 is out_feature * in_feature2 |
| 24 | + # - the fan_in for each in_feature2 is out_feature * in_feature1 |
| 25 | + # - the overall fan_in is computed as a weighted average of the above two as |
| 26 | + # (2 * out_feature * in_feature1 * in_feature2) / (in_feature1 + in_feature2) |
| 27 | + out, in1, in2 = self.weight.shape |
| 28 | + fan_in = (2 * out * in1 * in2) / (in1 + in2) |
| 29 | + fan_out = in1 * in2 |
| 30 | + bound = math.sqrt(6 / (fan_in + fan_out)) |
| 31 | + torch.nn.init.uniform_(self.weight, -bound, bound) |
| 32 | + if self.bias is not None: |
| 33 | + torch.nn.init.zeros_(self.bias) |
| 34 | + |
| 35 | + def reset_parameters_rnn(self) -> None: |
| 36 | + for name, parameter in self.named_parameters(): |
| 37 | + if "weight_ih" in name: |
| 38 | + torch.nn.init.xavier_uniform_(parameter) |
| 39 | + elif "weight_hh" in name: |
| 40 | + torch.nn.init.orthogonal_(parameter) |
| 41 | + elif "bias" in name: |
| 42 | + torch.nn.init.zeros_(parameter) |
| 43 | + if isinstance(self, (torch.nn.LSTM, torch.nn.LSTMCell)): # Set LSTM forget gate bias to 1 |
| 44 | + parameter.data[self.hidden_size:self.hidden_size * 2] = 1 |
| 45 | + |
| 46 | + def reset_parameters_embedding(self) -> None: |
| 47 | + torch.nn.init.uniform_(self.weight, -0.05, 0.05) |
| 48 | + self._fill_padding_idx_with_zero() |
| 49 | + |
| 50 | + overrides: dict[torch.nn.Module, Callable] = { |
| 51 | + torch.nn.Linear: reset_parameters_linear, |
| 52 | + torch.nn.Conv1d: reset_parameters_linear, |
| 53 | + torch.nn.Conv2d: reset_parameters_linear, |
| 54 | + torch.nn.Conv3d: reset_parameters_linear, |
| 55 | + torch.nn.ConvTranspose1d: reset_parameters_linear, |
| 56 | + torch.nn.ConvTranspose2d: reset_parameters_linear, |
| 57 | + torch.nn.ConvTranspose3d: reset_parameters_linear, |
| 58 | + torch.nn.Bilinear: reset_parameters_bilinear, |
| 59 | + torch.nn.RNN: reset_parameters_rnn, |
| 60 | + torch.nn.RNNCell: reset_parameters_rnn, |
| 61 | + torch.nn.LSTM: reset_parameters_rnn, |
| 62 | + torch.nn.LSTMCell: reset_parameters_rnn, |
| 63 | + torch.nn.GRU: reset_parameters_rnn, |
| 64 | + torch.nn.GRUCell: reset_parameters_rnn, |
| 65 | + torch.nn.Embedding: reset_parameters_embedding, |
| 66 | + torch.nn.EmbeddingBag: reset_parameters_embedding, |
| 67 | + } |
| 68 | + |
| 69 | + |
| 70 | +class KerasNormalizationLayers: |
| 71 | + @staticmethod |
| 72 | + def override_default_argument_value(func: Callable, name: str, default: Any) -> None: |
| 73 | + default_names = func.__code__.co_varnames[:func.__code__.co_argcount][-len(func.__defaults__):] |
| 74 | + assert name in default_names, f"Argument {name} not found in {func.__name__} arguments" |
| 75 | + func.__defaults__ = tuple( |
| 76 | + default if arg_name == name else arg_value |
| 77 | + for arg_name, arg_value in zip(default_names, func.__defaults__) |
| 78 | + ) |
| 79 | + |
| 80 | + batch_norms = [ |
| 81 | + torch.nn.BatchNorm1d, |
| 82 | + torch.nn.BatchNorm2d, |
| 83 | + torch.nn.BatchNorm3d, |
| 84 | + torch.nn.LazyBatchNorm1d, |
| 85 | + torch.nn.LazyBatchNorm2d, |
| 86 | + torch.nn.LazyBatchNorm3d, |
| 87 | + torch.nn.SyncBatchNorm, |
| 88 | + ] |
| 89 | + |
| 90 | + all_norms = batch_norms + [ |
| 91 | + torch.nn.LayerNorm, |
| 92 | + torch.nn.GroupNorm, |
| 93 | + ] |
| 94 | + |
| 95 | + |
| 96 | +def global_keras_initializers( |
| 97 | + parameter_initialization: bool = True, |
| 98 | + batchnorm_momentum_override: float | None = 0.01, |
| 99 | + norm_layer_epsilon_override: float | None = 0.001, |
| 100 | +) -> None: |
| 101 | + """Change default PyTorch initializers to Keras defaults. |
| 102 | +
|
| 103 | + The following initializers are used: |
| 104 | +
|
| 105 | + - `Linear`, `Conv1d`, `Conv2d`, `Conv3d`, `ConvTranspose1d`, `ConvTranspose2d`, `ConvTranspose3d`, `Bilinear`: |
| 106 | + Xavier uniform for weights, zeros for biases. |
| 107 | + - `Embedding`, `EmbeddingBag`: Uniform [-0.05, 0.05] for weights. |
| 108 | + - `RNN`, `RNNCell`, `LSTM`, `LSTMCell`, `GRU`, `GRUCell`: Xavier uniform for input weights, |
| 109 | + orthogonal for recurrent weights, zeros for biases (with LSTM forget gate bias set to 1). |
| 110 | +
|
| 111 | + Furthermore, for batch normalization layers, the default momentum value is changed |
| 112 | + from 0.1 to the Keras default of 0.01 (or any other value specified). |
| 113 | +
|
| 114 | + Finally, for batch normalization, layer normalization, and group normalization layers, |
| 115 | + the default epsilon value is changed from 1e-5 to the Keras default of 1e-3 |
| 116 | + (or any other value specified). |
| 117 | +
|
| 118 | + Parameters: |
| 119 | + parameter_initialization: If True, override the default PyTorch initializers with Keras defaults. |
| 120 | + batchnorm_momentum_override: If not None, override the default value of batch normalization |
| 121 | + momentum from 0.1 to this value. |
| 122 | + norm_layer_epsilon_override: If not None, override the default value of epsilon |
| 123 | + for batch normalization, layer normalization, and group normalization layers from |
| 124 | + 1e-5 to this value. |
| 125 | + """ |
| 126 | + if parameter_initialization: |
| 127 | + for class_, reset_parameters_method in KerasParameterInitialization.overrides.items(): |
| 128 | + class_.reset_parameters = reset_parameters_method |
| 129 | + |
| 130 | + if batchnorm_momentum_override is not None: |
| 131 | + for batch_norm_super in KerasNormalizationLayers.batch_norms: |
| 132 | + for batch_norm in [batch_norm_super] + batch_norm_super.__subclasses__(): |
| 133 | + KerasNormalizationLayers.override_default_argument_value( |
| 134 | + batch_norm.__init__, "momentum", batchnorm_momentum_override |
| 135 | + ) |
| 136 | + |
| 137 | + if norm_layer_epsilon_override is not None: |
| 138 | + for norm_layer_super in KerasNormalizationLayers.all_norms: |
| 139 | + for norm_layer in [norm_layer_super] + norm_layer_super.__subclasses__(): |
| 140 | + KerasNormalizationLayers.override_default_argument_value( |
| 141 | + norm_layer.__init__, "eps", norm_layer_epsilon_override |
| 142 | + ) |
0 commit comments