1515import  torch 
1616import  torch .nn  as  nn 
1717from  torch  import  Tensor 
18+ from  typing  import  Tuple 
1819
19- from  conformer .decoder  import  DecoderRNNT 
2020from  conformer .encoder  import  ConformerEncoder 
2121from  conformer .modules  import  Linear 
2222
@@ -31,17 +31,13 @@ class Conformer(nn.Module):
3131        num_classes (int): Number of classification classes 
3232        input_dim (int, optional): Dimension of input vector 
3333        encoder_dim (int, optional): Dimension of conformer encoder 
34-         decoder_dim (int, optional): Dimension of conformer decoder 
3534        num_encoder_layers (int, optional): Number of conformer blocks 
36-         num_decoder_layers (int, optional): Number of decoder layers 
37-         decoder_rnn_type (str, optional): type of RNN cell 
3835        num_attention_heads (int, optional): Number of attention heads 
3936        feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 
4037        conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 
4138        feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 
4239        attention_dropout_p (float, optional): Probability of attention module dropout 
4340        conv_dropout_p (float, optional): Probability of conformer convolution module dropout 
44-         decoder_dropout_p (float, optional): Probability of conformer decoder dropout 
4541        conv_kernel_size (int or tuple, optional): Size of the convolving kernel 
4642        half_step_residual (bool): Flag indication whether to use half step residual or not 
4743
@@ -58,20 +54,16 @@ def __init__(
5854            num_classes : int ,
5955            input_dim : int  =  80 ,
6056            encoder_dim : int  =  512 ,
61-             decoder_dim : int  =  640 ,
6257            num_encoder_layers : int  =  17 ,
63-             num_decoder_layers : int  =  1 ,
6458            num_attention_heads : int  =  8 ,
6559            feed_forward_expansion_factor : int  =  4 ,
6660            conv_expansion_factor : int  =  2 ,
6761            input_dropout_p : float  =  0.1 ,
6862            feed_forward_dropout_p : float  =  0.1 ,
6963            attention_dropout_p : float  =  0.1 ,
7064            conv_dropout_p : float  =  0.1 ,
71-             decoder_dropout_p : float  =  0.1 ,
7265            conv_kernel_size : int  =  31 ,
7366            half_step_residual : bool  =  True ,
74-             decoder_rnn_type : str  =  "lstm" ,
7567    ) ->  None :
7668        super (Conformer , self ).__init__ ()
7769        self .encoder  =  ConformerEncoder (
@@ -88,137 +80,27 @@ def __init__(
8880            conv_kernel_size = conv_kernel_size ,
8981            half_step_residual = half_step_residual ,
9082        )
91-         self .decoder  =  DecoderRNNT (
92-             num_classes = num_classes ,
93-             hidden_state_dim = decoder_dim ,
94-             output_dim = encoder_dim ,
95-             num_layers = num_decoder_layers ,
96-             rnn_type = decoder_rnn_type ,
97-             dropout_p = decoder_dropout_p ,
98-         )
9983        self .fc  =  Linear (encoder_dim  <<  1 , num_classes , bias = False )
10084
101-     def  set_encoder (self , encoder ):
102-         """ Setter for encoder """ 
103-         self .encoder  =  encoder 
104- 
105-     def  set_decoder (self , decoder ):
106-         """ Setter for decoder """ 
107-         self .decoder  =  decoder 
108- 
10985    def  count_parameters (self ) ->  int :
11086        """ Count parameters of encoder """ 
111-         num_encoder_parameters  =  self .encoder .count_parameters ()
112-         num_decoder_parameters  =  self .decoder .count_parameters ()
113-         return  num_encoder_parameters  +  num_decoder_parameters 
87+         return  self .encoder .count_parameters ()
11488
11589    def  update_dropout (self , dropout_p ) ->  None :
11690        """ Update dropout probability of model """ 
11791        self .encoder .update_dropout (dropout_p )
118-         self .decoder .update_dropout (dropout_p )
119- 
120-     def  joint (self , encoder_outputs : Tensor , decoder_outputs : Tensor ) ->  Tensor :
121-         """ 
122-         Joint `encoder_outputs` and `decoder_outputs`. 
123- 
124-         Args: 
125-             encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size 
126-                 ``(batch, seq_length, dimension)`` 
127-             decoder_outputs (torch.FloatTensor): A output sequence of decoder. `FloatTensor` of size 
128-                 ``(batch, seq_length, dimension)`` 
129- 
130-         Returns: 
131-             * outputs (torch.FloatTensor): outputs of joint `encoder_outputs` and `decoder_outputs`.. 
132-         """ 
133-         if  encoder_outputs .dim () ==  3  and  decoder_outputs .dim () ==  3 :
134-             input_length  =  encoder_outputs .size (1 )
135-             target_length  =  decoder_outputs .size (1 )
13692
137-             encoder_outputs  =  encoder_outputs .unsqueeze (2 )
138-             decoder_outputs  =  decoder_outputs .unsqueeze (1 )
139- 
140-             encoder_outputs  =  encoder_outputs .repeat ([1 , 1 , target_length , 1 ])
141-             decoder_outputs  =  decoder_outputs .repeat ([1 , input_length , 1 , 1 ])
142- 
143-         outputs  =  torch .cat ((encoder_outputs , decoder_outputs ), dim = - 1 )
144-         outputs  =  self .fc (outputs )
145- 
146-         return  outputs 
147- 
148-     def  forward (
149-             self ,
150-             inputs : Tensor ,
151-             input_lengths : Tensor ,
152-             targets : Tensor ,
153-             target_lengths : Tensor 
154-     ) ->  Tensor :
93+     def  forward (self , inputs : Tensor , input_lengths : Tensor ) ->  Tuple [Tensor , Tensor ]:
15594        """ 
15695        Forward propagate a `inputs` and `targets` pair for training. 
15796
15897        Args: 
15998            inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 
16099                `FloatTensor` of size ``(batch, seq_length, dimension)``. 
161100            input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 
162-             targets (torch.LongTensr): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)`` 
163-             target_lengths (torch.LongTensor): The length of target tensor. ``(batch)`` 
164101
165102        Returns: 
166103            * predictions (torch.FloatTensor): Result of model predictions. 
167104        """ 
168-         encoder_outputs , _  =  self .encoder (inputs , input_lengths )
169-         decoder_outputs , _  =  self .decoder (targets , target_lengths )
170-         outputs  =  self .joint (encoder_outputs , decoder_outputs )
171-         return  outputs 
172- 
173-     @torch .no_grad () 
174-     def  decode (self , encoder_output : Tensor , max_length : int ) ->  Tensor :
175-         """ 
176-         Decode `encoder_outputs`. 
177- 
178-         Args: 
179-             encoder_output (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size 
180-                 ``(seq_length, dimension)`` 
181-             max_length (int): max decoding time step 
182- 
183-         Returns: 
184-             * predicted_log_probs (torch.FloatTensor): Log probability of model predictions. 
185-         """ 
186-         pred_tokens , hidden_state  =  list (), None 
187-         decoder_input  =  encoder_output .new_tensor ([[self .decoder .sos_id ]], dtype = torch .long )
188- 
189-         for  t  in  range (max_length ):
190-             decoder_output , hidden_state  =  self .decoder (decoder_input , hidden_states = hidden_state )
191-             step_output  =  self .joint (encoder_output [t ].view (- 1 ), decoder_output .view (- 1 ))
192-             step_output  =  step_output .softmax (dim = 0 )
193-             pred_token  =  step_output .argmax (dim = 0 )
194-             pred_token  =  int (pred_token .item ())
195-             pred_tokens .append (pred_token )
196-             decoder_input  =  step_output .new_tensor ([[pred_token ]], dtype = torch .long )
197- 
198-         return  torch .LongTensor (pred_tokens )
199- 
200-     @torch .no_grad () 
201-     def  recognize (self , inputs : Tensor , input_lengths : Tensor ):
202-         """ 
203-         Recognize input speech. This method consists of the forward of the encoder and the decode() of the decoder. 
204- 
205-         Args: 
206-             inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 
207-                 `FloatTensor` of size ``(batch, seq_length, dimension)``. 
208-             input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 
209- 
210-         Returns: 
211-             * predictions (torch.FloatTensor): Result of model predictions. 
212-         """ 
213-         outputs  =  list ()
214- 
215-         encoder_outputs , output_lengths  =  self .encoder (inputs , input_lengths )
216-         max_length  =  encoder_outputs .size (1 )
217- 
218-         for  encoder_output  in  encoder_outputs :
219-             decoded_seq  =  self .decode (encoder_output , max_length )
220-             outputs .append (decoded_seq )
221- 
222-         outputs  =  torch .stack (outputs , dim = 1 ).transpose (0 , 1 )
223- 
224-         return  outputs 
105+         encoder_outputs , encoder_output_lengths  =  self .encoder (inputs , input_lengths )
106+         return  encoder_outputs , encoder_output_lengths 
0 commit comments