9
9
class BaseTemplate (NetworkTemplate ):
10
10
11
11
def __init__ (self ):
12
+ """
13
+ Template used for sharing fundamental methods with the
14
+ children transformer-like encoders and decoders.
15
+ """
12
16
13
17
super (BaseTemplate , self ).__init__ ()
14
18
15
19
def _activation_getter (self , activation : Union [str , torch .nn .Module ]) -> torch .nn .Module :
20
+ """
21
+ It configures the activation functions for the transformer layers.
22
+
23
+ Parameters
24
+ ----------
25
+ activation : Union[str, torch.nn.Module]
26
+ Activation function to be used in all the network layers
27
+ Returns
28
+ A Module object for this activation function.
29
+ -------
30
+
31
+ Raises
32
+ ------
33
+ Exception :
34
+ When the activation function is not supported.
35
+ """
16
36
17
37
if isinstance (activation , torch .nn .Module ):
18
38
return encoder_activation
@@ -23,11 +43,25 @@ def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.n
23
43
24
44
class BasicEncoder (BaseTemplate ):
25
45
26
- def __init__ (self , num_heads = 1 ,
46
+ def __init__ (self , num_heads : int = 1 ,
27
47
activation :Union [str , torch .nn .Module ]= 'relu' ,
28
48
mlp_layer :torch .nn .Module = None ,
29
49
embed_dim :Union [int , Tuple ]= None ,
30
- ):
50
+ ) -> None :
51
+ """
52
+ Generic transformer encoder.
53
+
54
+ Parameters
55
+ ----------
56
+ num_heads : int
57
+ Number of attention heads for the self-attention layers.
58
+ activation : Union[str, torch.nn.Module]=
59
+ Activation function to be used in all the network layers
60
+ mlp_layer : torch.nn.Module
61
+ A Module object representing the MLP (Dense) operation.
62
+ embed_dim : Union[int, Tuple]
63
+ Dimension used for the transfoirmer embedding.
64
+ """
31
65
32
66
super (BasicEncoder , self ).__init__ ()
33
67
@@ -53,6 +87,18 @@ def __init__(self, num_heads=1,
53
87
54
88
def forward (self , input_data : Union [torch .Tensor , np .ndarray ] = None
55
89
) -> torch .Tensor :
90
+ """
91
+
92
+ Parameters
93
+ ----------
94
+ input_data : Union[torch.Tensor, np.ndarray]
95
+ The input dataset.
96
+
97
+ Returns
98
+ -------
99
+ torch.Tensor
100
+ The output generated by the encoder.
101
+ """
56
102
57
103
h = input_data
58
104
h1 = self .activation_1 (h )
@@ -68,6 +114,20 @@ def __init__(self, num_heads:int=1,
68
114
activation :Union [str , torch .nn .Module ]= 'relu' ,
69
115
mlp_layer :torch .nn .Module = None ,
70
116
embed_dim :Union [int , Tuple ]= None ):
117
+ """
118
+ Generic transformer decoder.
119
+
120
+ Parameters
121
+ ----------
122
+ num_heads : int
123
+ Number of attention heads for the self-attention layers.
124
+ activation : Union[str, torch.nn.Module]=
125
+ Activation function to be used in all the network layers
126
+ mlp_layer : torch.nn.Module
127
+ A Module object representing the MLP (Dense) operation.
128
+ embed_dim : Union[int, Tuple]
129
+ Dimension used for the transfoirmer embedding.
130
+ """
71
131
72
132
super (BasicDecoder , self ).__init__ ()
73
133
@@ -94,6 +154,20 @@ def __init__(self, num_heads:int=1,
94
154
def forward (self , input_data : Union [torch .Tensor , np .ndarray ] = None ,
95
155
encoder_output :torch .Tensor = None ,
96
156
) -> torch .Tensor :
157
+ """
158
+
159
+ Parameters
160
+ ----------
161
+ input_data : Union[torch.Tensor, np.ndarray]
162
+ The input dataset (in principle, the same input used for the encoder).
163
+ encoder_output : torch.Tensor
164
+ The output provided by the encoder stage.
165
+
166
+ Returns
167
+ -------
168
+ torch.Tensor
169
+ The decoder output.
170
+ """
97
171
98
172
h = input_data
99
173
h1 = self .activation_1 (h )
@@ -115,6 +189,37 @@ def __init__(self, num_heads_encoder:int=1,
115
189
decoder_mlp_layer_config :dict = None ,
116
190
number_of_encoders :int = 1 ,
117
191
number_of_decoders :int = 1 ) -> None :
192
+ """
193
+ A classical encoder-decoder transformer:
194
+
195
+ U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e
196
+
197
+ (u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V
198
+
199
+ Parameters
200
+ ----------
201
+ num_heads_encoder : int
202
+ The number of heads for the self-attention layer of the encoder.
203
+ num_heads_decoder :int
204
+ The number of heads for the self-attention layer of the decoder.
205
+ embed_dim_encoder : int
206
+ The dimension of the embedding for the encoder.
207
+ embed_dim_decoder : int
208
+ The dimension of the embedding for the decoder.
209
+ encoder_activation : Union[str, torch.nn.Module]
210
+ The activation to be used in all the encoder layers.
211
+ decoder_activation : Union[str, torch.nn.Module]
212
+ The activation to be used in all the decoder layers.
213
+ encoder_mlp_layer_config : dict
214
+ A configuration dictionary to instantiate the encoder MLP layer.weights
215
+ decoder_mlp_layer_config : dict
216
+ A configuration dictionary to instantiate the encoder MLP layer.weights
217
+ number_of_encoders : int
218
+ The number of encoders to be used.
219
+ number_of_decoders : int
220
+ The number of decoders to be used.
221
+
222
+ """
118
223
119
224
super (Transformer , self ).__init__ ()
120
225
@@ -165,7 +270,6 @@ def __init__(self, num_heads_encoder:int=1,
165
270
]
166
271
167
272
168
-
169
273
self .weights = list ()
170
274
171
275
for e , encoder_e in enumerate (self .EncoderStage ):
@@ -179,15 +283,31 @@ def __init__(self, num_heads_encoder:int=1,
179
283
@as_tensor
180
284
def forward (self , input_data : Union [torch .Tensor , np .ndarray ] = None ) -> torch .Tensor :
181
285
182
- encoder_output = self .EncoderStage (input_data )
286
+ """
287
+
288
+ Parameters
289
+ ----------
290
+ input_data : Union[torch.Tensor, np.ndarray]
291
+ The input dataset.
292
+
293
+ Returns
294
+ -------
295
+ torch.Tensor
296
+ The transformer output.
297
+ """
298
+
299
+ encoder_output = self .EncoderStage (input_data )
183
300
184
- current_input = input_data
185
- for decoder in self .DecoderStage :
186
- output = decoder (input_data = current_input , encoder_output = encoder_output )
187
- current_input = output
301
+ current_input = input_data
302
+ for decoder in self .DecoderStage :
303
+ output = decoder (input_data = current_input , encoder_output = encoder_output )
304
+ current_input = output
188
305
189
- return output
306
+ return output
190
307
191
308
def summary (self ):
309
+ """
310
+ It prints a general view of the architecture.
311
+ """
192
312
193
- print (self )
313
+ print (self )
0 commit comments