Skip to content

Commit 5c1d0f6

Browse files
João Lucas de Sousa AlmeidaJoão Lucas de Sousa Almeida
João Lucas de Sousa Almeida
authored and
João Lucas de Sousa Almeida
committed
Docstrings for simulai.models.Transformer
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent a33fe17 commit 5c1d0f6

File tree

1 file changed

+130
-10
lines changed

1 file changed

+130
-10
lines changed

simulai/models/_pytorch_models/_transformer.py

+130-10
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,30 @@
99
class BaseTemplate(NetworkTemplate):
1010

1111
def __init__(self):
12+
"""
13+
Template used for sharing fundamental methods with the
14+
children transformer-like encoders and decoders.
15+
"""
1216

1317
super(BaseTemplate, self).__init__()
1418

1519
def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.nn.Module:
20+
"""
21+
It configures the activation functions for the transformer layers.
22+
23+
Parameters
24+
----------
25+
activation : Union[str, torch.nn.Module]
26+
Activation function to be used in all the network layers
27+
Returns
28+
A Module object for this activation function.
29+
-------
30+
31+
Raises
32+
------
33+
Exception :
34+
When the activation function is not supported.
35+
"""
1636

1737
if isinstance(activation, torch.nn.Module):
1838
return encoder_activation
@@ -23,11 +43,25 @@ def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.n
2343

2444
class BasicEncoder(BaseTemplate):
2545

26-
def __init__(self, num_heads=1,
46+
def __init__(self, num_heads:int=1,
2747
activation:Union[str, torch.nn.Module]='relu',
2848
mlp_layer:torch.nn.Module=None,
2949
embed_dim:Union[int, Tuple]=None,
30-
):
50+
) -> None:
51+
"""
52+
Generic transformer encoder.
53+
54+
Parameters
55+
----------
56+
num_heads : int
57+
Number of attention heads for the self-attention layers.
58+
activation : Union[str, torch.nn.Module]=
59+
Activation function to be used in all the network layers
60+
mlp_layer : torch.nn.Module
61+
A Module object representing the MLP (Dense) operation.
62+
embed_dim : Union[int, Tuple]
63+
Dimension used for the transfoirmer embedding.
64+
"""
3165

3266
super(BasicEncoder, self).__init__()
3367

@@ -53,6 +87,18 @@ def __init__(self, num_heads=1,
5387

5488
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None
5589
) -> torch.Tensor:
90+
"""
91+
92+
Parameters
93+
----------
94+
input_data : Union[torch.Tensor, np.ndarray]
95+
The input dataset.
96+
97+
Returns
98+
-------
99+
torch.Tensor
100+
The output generated by the encoder.
101+
"""
56102

57103
h = input_data
58104
h1 = self.activation_1(h)
@@ -68,6 +114,20 @@ def __init__(self, num_heads:int=1,
68114
activation:Union[str, torch.nn.Module]='relu',
69115
mlp_layer:torch.nn.Module=None,
70116
embed_dim:Union[int, Tuple]=None):
117+
"""
118+
Generic transformer decoder.
119+
120+
Parameters
121+
----------
122+
num_heads : int
123+
Number of attention heads for the self-attention layers.
124+
activation : Union[str, torch.nn.Module]=
125+
Activation function to be used in all the network layers
126+
mlp_layer : torch.nn.Module
127+
A Module object representing the MLP (Dense) operation.
128+
embed_dim : Union[int, Tuple]
129+
Dimension used for the transfoirmer embedding.
130+
"""
71131

72132
super(BasicDecoder, self).__init__()
73133

@@ -94,6 +154,20 @@ def __init__(self, num_heads:int=1,
94154
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None,
95155
encoder_output:torch.Tensor=None,
96156
) -> torch.Tensor:
157+
"""
158+
159+
Parameters
160+
----------
161+
input_data : Union[torch.Tensor, np.ndarray]
162+
The input dataset (in principle, the same input used for the encoder).
163+
encoder_output : torch.Tensor
164+
The output provided by the encoder stage.
165+
166+
Returns
167+
-------
168+
torch.Tensor
169+
The decoder output.
170+
"""
97171

98172
h = input_data
99173
h1 = self.activation_1(h)
@@ -115,6 +189,37 @@ def __init__(self, num_heads_encoder:int=1,
115189
decoder_mlp_layer_config:dict=None,
116190
number_of_encoders:int=1,
117191
number_of_decoders:int=1) -> None:
192+
"""
193+
A classical encoder-decoder transformer:
194+
195+
U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e
196+
197+
(u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V
198+
199+
Parameters
200+
----------
201+
num_heads_encoder : int
202+
The number of heads for the self-attention layer of the encoder.
203+
num_heads_decoder :int
204+
The number of heads for the self-attention layer of the decoder.
205+
embed_dim_encoder : int
206+
The dimension of the embedding for the encoder.
207+
embed_dim_decoder : int
208+
The dimension of the embedding for the decoder.
209+
encoder_activation : Union[str, torch.nn.Module]
210+
The activation to be used in all the encoder layers.
211+
decoder_activation : Union[str, torch.nn.Module]
212+
The activation to be used in all the decoder layers.
213+
encoder_mlp_layer_config : dict
214+
A configuration dictionary to instantiate the encoder MLP layer.weights
215+
decoder_mlp_layer_config : dict
216+
A configuration dictionary to instantiate the encoder MLP layer.weights
217+
number_of_encoders : int
218+
The number of encoders to be used.
219+
number_of_decoders : int
220+
The number of decoders to be used.
221+
222+
"""
118223

119224
super(Transformer, self).__init__()
120225

@@ -165,7 +270,6 @@ def __init__(self, num_heads_encoder:int=1,
165270
]
166271

167272

168-
169273
self.weights = list()
170274

171275
for e, encoder_e in enumerate(self.EncoderStage):
@@ -179,15 +283,31 @@ def __init__(self, num_heads_encoder:int=1,
179283
@as_tensor
180284
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None) -> torch.Tensor:
181285

182-
encoder_output = self.EncoderStage(input_data)
286+
"""
287+
288+
Parameters
289+
----------
290+
input_data : Union[torch.Tensor, np.ndarray]
291+
The input dataset.
292+
293+
Returns
294+
-------
295+
torch.Tensor
296+
The transformer output.
297+
"""
298+
299+
encoder_output = self.EncoderStage(input_data)
183300

184-
current_input = input_data
185-
for decoder in self.DecoderStage:
186-
output = decoder(input_data=current_input, encoder_output=encoder_output)
187-
current_input = output
301+
current_input = input_data
302+
for decoder in self.DecoderStage:
303+
output = decoder(input_data=current_input, encoder_output=encoder_output)
304+
current_input = output
188305

189-
return output
306+
return output
190307

191308
def summary(self):
309+
"""
310+
It prints a general view of the architecture.
311+
"""
192312

193-
print(self)
313+
print(self)

0 commit comments

Comments
 (0)