Skip to content

Commit 5c1d0f6

Browse files
João Lucas de Sousa AlmeidaJoão Lucas de Sousa Almeida
authored andcommitted
Docstrings for simulai.models.Transformer
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent a33fe17 commit 5c1d0f6

File tree

1 file changed

+130
-10
lines changed

1 file changed

+130
-10
lines changed

simulai/models/_pytorch_models/_transformer.py

Lines changed: 130 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,30 @@
99
class BaseTemplate(NetworkTemplate):
1010

1111
def __init__(self):
12+
"""
13+
Template used for sharing fundamental methods with the
14+
children transformer-like encoders and decoders.
15+
"""
1216

1317
super(BaseTemplate, self).__init__()
1418

1519
def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.nn.Module:
20+
"""
21+
It configures the activation functions for the transformer layers.
22+
23+
Parameters
24+
----------
25+
activation : Union[str, torch.nn.Module]
26+
Activation function to be used in all the network layers
27+
Returns
28+
A Module object for this activation function.
29+
-------
30+
31+
Raises
32+
------
33+
Exception :
34+
When the activation function is not supported.
35+
"""
1636

1737
if isinstance(activation, torch.nn.Module):
1838
return encoder_activation
@@ -23,11 +43,25 @@ def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.n
2343

2444
class BasicEncoder(BaseTemplate):
2545

26-
def __init__(self, num_heads=1,
46+
def __init__(self, num_heads:int=1,
2747
activation:Union[str, torch.nn.Module]='relu',
2848
mlp_layer:torch.nn.Module=None,
2949
embed_dim:Union[int, Tuple]=None,
30-
):
50+
) -> None:
51+
"""
52+
Generic transformer encoder.
53+
54+
Parameters
55+
----------
56+
num_heads : int
57+
Number of attention heads for the self-attention layers.
58+
activation : Union[str, torch.nn.Module]=
59+
Activation function to be used in all the network layers
60+
mlp_layer : torch.nn.Module
61+
A Module object representing the MLP (Dense) operation.
62+
embed_dim : Union[int, Tuple]
63+
Dimension used for the transfoirmer embedding.
64+
"""
3165

3266
super(BasicEncoder, self).__init__()
3367

@@ -53,6 +87,18 @@ def __init__(self, num_heads=1,
5387

5488
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None
5589
) -> torch.Tensor:
90+
"""
91+
92+
Parameters
93+
----------
94+
input_data : Union[torch.Tensor, np.ndarray]
95+
The input dataset.
96+
97+
Returns
98+
-------
99+
torch.Tensor
100+
The output generated by the encoder.
101+
"""
56102

57103
h = input_data
58104
h1 = self.activation_1(h)
@@ -68,6 +114,20 @@ def __init__(self, num_heads:int=1,
68114
activation:Union[str, torch.nn.Module]='relu',
69115
mlp_layer:torch.nn.Module=None,
70116
embed_dim:Union[int, Tuple]=None):
117+
"""
118+
Generic transformer decoder.
119+
120+
Parameters
121+
----------
122+
num_heads : int
123+
Number of attention heads for the self-attention layers.
124+
activation : Union[str, torch.nn.Module]=
125+
Activation function to be used in all the network layers
126+
mlp_layer : torch.nn.Module
127+
A Module object representing the MLP (Dense) operation.
128+
embed_dim : Union[int, Tuple]
129+
Dimension used for the transfoirmer embedding.
130+
"""
71131

72132
super(BasicDecoder, self).__init__()
73133

@@ -94,6 +154,20 @@ def __init__(self, num_heads:int=1,
94154
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None,
95155
encoder_output:torch.Tensor=None,
96156
) -> torch.Tensor:
157+
"""
158+
159+
Parameters
160+
----------
161+
input_data : Union[torch.Tensor, np.ndarray]
162+
The input dataset (in principle, the same input used for the encoder).
163+
encoder_output : torch.Tensor
164+
The output provided by the encoder stage.
165+
166+
Returns
167+
-------
168+
torch.Tensor
169+
The decoder output.
170+
"""
97171

98172
h = input_data
99173
h1 = self.activation_1(h)
@@ -115,6 +189,37 @@ def __init__(self, num_heads_encoder:int=1,
115189
decoder_mlp_layer_config:dict=None,
116190
number_of_encoders:int=1,
117191
number_of_decoders:int=1) -> None:
192+
"""
193+
A classical encoder-decoder transformer:
194+
195+
U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e
196+
197+
(u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V
198+
199+
Parameters
200+
----------
201+
num_heads_encoder : int
202+
The number of heads for the self-attention layer of the encoder.
203+
num_heads_decoder :int
204+
The number of heads for the self-attention layer of the decoder.
205+
embed_dim_encoder : int
206+
The dimension of the embedding for the encoder.
207+
embed_dim_decoder : int
208+
The dimension of the embedding for the decoder.
209+
encoder_activation : Union[str, torch.nn.Module]
210+
The activation to be used in all the encoder layers.
211+
decoder_activation : Union[str, torch.nn.Module]
212+
The activation to be used in all the decoder layers.
213+
encoder_mlp_layer_config : dict
214+
A configuration dictionary to instantiate the encoder MLP layer.weights
215+
decoder_mlp_layer_config : dict
216+
A configuration dictionary to instantiate the encoder MLP layer.weights
217+
number_of_encoders : int
218+
The number of encoders to be used.
219+
number_of_decoders : int
220+
The number of decoders to be used.
221+
222+
"""
118223

119224
super(Transformer, self).__init__()
120225

@@ -165,7 +270,6 @@ def __init__(self, num_heads_encoder:int=1,
165270
]
166271

167272

168-
169273
self.weights = list()
170274

171275
for e, encoder_e in enumerate(self.EncoderStage):
@@ -179,15 +283,31 @@ def __init__(self, num_heads_encoder:int=1,
179283
@as_tensor
180284
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None) -> torch.Tensor:
181285

182-
encoder_output = self.EncoderStage(input_data)
286+
"""
287+
288+
Parameters
289+
----------
290+
input_data : Union[torch.Tensor, np.ndarray]
291+
The input dataset.
292+
293+
Returns
294+
-------
295+
torch.Tensor
296+
The transformer output.
297+
"""
298+
299+
encoder_output = self.EncoderStage(input_data)
183300

184-
current_input = input_data
185-
for decoder in self.DecoderStage:
186-
output = decoder(input_data=current_input, encoder_output=encoder_output)
187-
current_input = output
301+
current_input = input_data
302+
for decoder in self.DecoderStage:
303+
output = decoder(input_data=current_input, encoder_output=encoder_output)
304+
current_input = output
188305

189-
return output
306+
return output
190307

191308
def summary(self):
309+
"""
310+
It prints a general view of the architecture.
311+
"""
192312

193-
print(self)
313+
print(self)

0 commit comments

Comments
 (0)