|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 |
| -# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) |
| 3 | +# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) |
4 | 4 | # Apache 2.0
|
5 | 5 |
|
6 | 6 | import logging
|
|
10 | 10 | import torch.nn.functional as F
|
11 | 11 |
|
12 | 12 | from common import load_lda_mat
|
13 |
| -''' |
14 |
| - input dim=$feat_dim name=input |
15 |
| -
|
16 |
| - # please note that it is important to have input layer with the name=input |
17 |
| - # as the layer immediately preceding the fixed-affine-layer to enable |
18 |
| - # the use of short notation for the descriptor |
19 |
| - fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=$dir/configs/lda.mat |
20 |
| -
|
21 |
| - # the first splicing is moved before the lda layer, so no splicing here |
22 |
| - relu-batchnorm-layer name=tdnn1 dim=625 |
23 |
| - relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 |
24 |
| - relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 |
25 |
| - relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 |
26 |
| - relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 |
27 |
| - relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 |
28 |
| -
|
29 |
| - ## adding the layers for chain branch |
30 |
| - relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5 |
31 |
| - output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 |
32 |
| -
|
33 |
| - # adding the layers for xent branch |
34 |
| - # This block prints the configs for a separate output that will be |
35 |
| - # trained with a cross-entropy objective in the 'chain' models... this |
36 |
| - # has the effect of regularizing the hidden parts of the model. we use |
37 |
| - # 0.5 / args.xent_regularize as the learning rate factor- the factor of |
38 |
| - # 0.5 / args.xent_regularize is suitable as it means the xent |
39 |
| - # final-layer learns at a rate independent of the regularization |
40 |
| - # constant; and the 0.5 was tuned so as to make the relative progress |
41 |
| - # similar in the xent and regular final layers. |
42 |
| - relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5 |
43 |
| - output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 |
44 |
| -''' |
| 13 | +from tdnnf_layer import FactorizedTDNN |
| 14 | +from tdnnf_layer import OrthonormalLinear |
| 15 | +from tdnnf_layer import PrefinalLayer |
45 | 16 |
|
46 | 17 |
|
47 | 18 | def get_chain_model(feat_dim,
|
48 | 19 | output_dim,
|
49 | 20 | hidden_dim,
|
50 |
| - kernel_size_list, |
51 |
| - stride_list, |
| 21 | + bottleneck_dim, |
| 22 | + time_stride_list, |
| 23 | + conv_stride_list, |
52 | 24 | lda_mat_filename=None):
|
53 | 25 | model = ChainModel(feat_dim=feat_dim,
|
54 | 26 | output_dim=output_dim,
|
55 | 27 | lda_mat_filename=lda_mat_filename,
|
56 | 28 | hidden_dim=hidden_dim,
|
57 |
| - kernel_size_list=kernel_size_list, |
58 |
| - stride_list=stride_list) |
| 29 | + time_stride_list=time_stride_list, |
| 30 | + conv_stride_list=conv_stride_list) |
59 | 31 | return model
|
60 | 32 |
|
61 | 33 |
|
| 34 | +''' |
| 35 | +input dim=43 name=input |
| 36 | +
|
| 37 | +# please note that it is important to have input layer with the name=input |
| 38 | +# as the layer immediately preceding the fixed-affine-layer to enable |
| 39 | +# the use of short notation for the descriptor |
| 40 | +fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=exp/chain_cleaned_1c/tdnn1c_sp/configs/lda.mat |
| 41 | +
|
| 42 | +# the first splicing is moved before the lda layer, so no splicing here |
| 43 | +relu-batchnorm-dropout-layer name=tdnn1 l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim-continuous=true dim=1024 |
| 44 | +tdnnf-layer name=tdnnf2 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1 |
| 45 | +tdnnf-layer name=tdnnf3 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1 |
| 46 | +tdnnf-layer name=tdnnf4 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1 |
| 47 | +tdnnf-layer name=tdnnf5 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=0 |
| 48 | +tdnnf-layer name=tdnnf6 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 49 | +tdnnf-layer name=tdnnf7 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 50 | +tdnnf-layer name=tdnnf8 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 51 | +tdnnf-layer name=tdnnf9 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 52 | +tdnnf-layer name=tdnnf10 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 53 | +tdnnf-layer name=tdnnf11 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 54 | +tdnnf-layer name=tdnnf12 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 55 | +tdnnf-layer name=tdnnf13 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 |
| 56 | +linear-component name=prefinal-l dim=256 l2-regularize=0.008 orthonormal-constraint=-1.0 |
| 57 | +
|
| 58 | +prefinal-layer name=prefinal-chain input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256 |
| 59 | +output-layer name=output include-log-softmax=false dim=3456 l2-regularize=0.002 |
| 60 | +
|
| 61 | +prefinal-layer name=prefinal-xent input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256 |
| 62 | +output-layer name=output-xent dim=3456 learning-rate-factor=5.0 l2-regularize=0.002 |
| 63 | +''' |
| 64 | + |
| 65 | + |
62 | 66 | # Create a network like the above one
|
63 | 67 | class ChainModel(nn.Module):
|
64 | 68 |
|
65 | 69 | def __init__(self,
|
66 | 70 | feat_dim,
|
67 | 71 | output_dim,
|
68 |
| - lda_mat_filename, |
69 |
| - hidden_dim=625, |
70 |
| - kernel_size_list=[1, 3, 3, 3, 3, 3], |
71 |
| - stride_list=[1, 1, 3, 1, 1, 1], |
| 72 | + lda_mat_filename=None, |
| 73 | + hidden_dim=1024, |
| 74 | + bottleneck_dim=128, |
| 75 | + time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1], |
| 76 | + conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1], |
72 | 77 | frame_subsampling_factor=3):
|
73 | 78 | super().__init__()
|
74 | 79 |
|
75 | 80 | # at present, we support only frame_subsampling_factor to be 3
|
76 | 81 | assert frame_subsampling_factor == 3
|
77 | 82 |
|
78 |
| - assert len(kernel_size_list) == len(stride_list) |
79 |
| - num_layers = len(kernel_size_list) |
| 83 | + assert len(time_stride_list) == len(conv_stride_list) |
| 84 | + num_layers = len(time_stride_list) |
| 85 | + |
| 86 | + # tdnn1_affine requires [N, T, C] |
| 87 | + self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3, |
| 88 | + out_features=hidden_dim) |
80 | 89 |
|
81 |
| - tdnns = [] |
| 90 | + # tdnn1_batchnorm requires [N, C, T] |
| 91 | + self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim) |
| 92 | + |
| 93 | + tdnnfs = [] |
82 | 94 | for i in range(num_layers):
|
83 |
| - in_channels = hidden_dim |
84 |
| - if i == 0: |
85 |
| - in_channels = feat_dim * 3 |
86 |
| - |
87 |
| - kernel_size = kernel_size_list[i] |
88 |
| - stride = stride_list[i] |
89 |
| - |
90 |
| - # we do not need to perform padding in Conv1d because it |
91 |
| - # has been included in left/right context while generating egs |
92 |
| - layer = nn.Conv1d(in_channels=in_channels, |
93 |
| - out_channels=hidden_dim, |
94 |
| - kernel_size=kernel_size, |
95 |
| - stride=stride) |
96 |
| - tdnns.append(layer) |
97 |
| - |
98 |
| - self.tdnns = nn.ModuleList(tdnns) |
99 |
| - self.batch_norms = nn.ModuleList([ |
100 |
| - nn.BatchNorm1d(num_features=hidden_dim) for i in range(num_layers) |
101 |
| - ]) |
102 |
| - |
103 |
| - self.prefinal_chain_tdnn = nn.Conv1d(in_channels=hidden_dim, |
104 |
| - out_channels=hidden_dim, |
105 |
| - kernel_size=1) |
106 |
| - self.prefinal_chain_batch_norm = nn.BatchNorm1d(num_features=hidden_dim) |
107 |
| - self.output_fc = nn.Linear(in_features=hidden_dim, |
108 |
| - out_features=output_dim) |
109 |
| - |
110 |
| - self.prefinal_xent_tdnn = nn.Conv1d(in_channels=hidden_dim, |
111 |
| - out_channels=hidden_dim, |
112 |
| - kernel_size=1) |
113 |
| - self.prefinal_xent_batch_norm = nn.BatchNorm1d(num_features=hidden_dim) |
114 |
| - self.output_xent_fc = nn.Linear(in_features=hidden_dim, |
115 |
| - out_features=output_dim) |
| 95 | + time_stride = time_stride_list[i] |
| 96 | + conv_stride = conv_stride_list[i] |
| 97 | + layer = FactorizedTDNN(dim=hidden_dim, |
| 98 | + bottleneck_dim=bottleneck_dim, |
| 99 | + time_stride=time_stride, |
| 100 | + conv_stride=conv_stride) |
| 101 | + tdnnfs.append(layer) |
| 102 | + |
| 103 | + # tdnnfs requires [N, C, T] |
| 104 | + self.tdnnfs = nn.ModuleList(tdnnfs) |
| 105 | + |
| 106 | + # prefinal_l affine requires [N, C, T] |
| 107 | + self.prefinal_l = OrthonormalLinear(dim=hidden_dim, |
| 108 | + bottleneck_dim=bottleneck_dim * 2, |
| 109 | + time_stride=0) |
| 110 | + |
| 111 | + # prefinal_chain requires [N, C, T] |
| 112 | + self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim, |
| 113 | + small_dim=bottleneck_dim * 2) |
| 114 | + |
| 115 | + # output_affine requires [N, T, C] |
| 116 | + self.output_affine = nn.Linear(in_features=bottleneck_dim * 2, |
| 117 | + out_features=output_dim) |
| 118 | + |
| 119 | + # prefinal_xent requires [N, C, T] |
| 120 | + self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim, |
| 121 | + small_dim=bottleneck_dim * 2) |
| 122 | + |
| 123 | + self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2, |
| 124 | + out_features=output_dim) |
116 | 125 |
|
117 | 126 | if lda_mat_filename:
|
118 | 127 | logging.info('Use LDA from {}'.format(lda_mat_filename))
|
@@ -146,32 +155,69 @@ def forward(self, x):
|
146 | 155 |
|
147 | 156 | # at this point, x is [N, C, T]
|
148 | 157 |
|
149 |
| - # Conv1d requires input of shape [N, C, T] |
150 |
| - for i in range(len(self.tdnns)): |
151 |
| - x = self.tdnns[i](x) |
152 |
| - x = F.relu(x) |
153 |
| - x = self.batch_norms[i](x) |
| 158 | + x = x.permute(0, 2, 1) |
| 159 | + |
| 160 | + # at this point, x is [N, T, C] |
| 161 | + |
| 162 | + x = self.tdnn1_affine(x) |
| 163 | + |
| 164 | + # at this point, x is [N, T, C] |
| 165 | + |
| 166 | + x = F.relu(x) |
| 167 | + |
| 168 | + x = x.permute(0, 2, 1) |
| 169 | + |
| 170 | + # at this point, x is [N, C, T] |
| 171 | + |
| 172 | + x = self.tdnn1_batchnorm(x) |
| 173 | + |
| 174 | + # tdnnf requires input of shape [N, C, T] |
| 175 | + for i in range(len(self.tdnnfs)): |
| 176 | + x = self.tdnnfs[i](x) |
154 | 177 |
|
155 | 178 | # at this point, x is [N, C, T]
|
156 | 179 |
|
157 |
| - # we have two branches from this point on |
| 180 | + x = self.prefinal_l(x) |
| 181 | + |
| 182 | + # at this point, x is [N, C, T] |
158 | 183 |
|
159 |
| - # first, for the chain branch |
160 |
| - x_chain = self.prefinal_chain_tdnn(x) |
161 |
| - x_chain = F.relu(x_chain) |
162 |
| - x_chain = self.prefinal_chain_batch_norm(x_chain) |
163 |
| - x_chain = x_chain.permute(0, 2, 1) |
164 |
| - # at this point, x_chain is [N, T, C] |
165 |
| - nnet_output = self.output_fc(x_chain) |
| 184 | + # for the output node |
| 185 | + nnet_output = self.prefinal_chain(x) |
166 | 186 |
|
167 |
| - # now for the xent branch |
168 |
| - x_xent = self.prefinal_xent_tdnn(x) |
169 |
| - x_xent = F.relu(x_xent) |
170 |
| - x_xent = self.prefinal_xent_batch_norm(x_xent) |
171 |
| - x_xent = x_xent.permute(0, 2, 1) |
| 187 | + # at this point, nnet_output is [N, C, T] |
| 188 | + nnet_output = nnet_output.permute(0, 2, 1) |
| 189 | + # at this point, nnet_output is [N, T, C] |
| 190 | + nnet_output = self.output_affine(nnet_output) |
| 191 | + |
| 192 | + # for the xent node |
| 193 | + xent_output = self.prefinal_xent(x) |
| 194 | + |
| 195 | + # at this point, xent_output is [N, C, T] |
| 196 | + xent_output = xent_output.permute(0, 2, 1) |
| 197 | + # at this point, xent_output is [N, T, C] |
| 198 | + xent_output = self.output_xent_affine(xent_output) |
172 | 199 |
|
173 |
| - # at this point x_xent is [N, T, C] |
174 |
| - xent_output = self.output_xent_fc(x_xent) |
175 | 200 | xent_output = F.log_softmax(xent_output, dim=-1)
|
176 | 201 |
|
177 | 202 | return nnet_output, xent_output
|
| 203 | + |
| 204 | + def constrain_orthonormal(self): |
| 205 | + for i in range(len(self.tdnnfs)): |
| 206 | + self.tdnnfs[i].constrain_orthonormal() |
| 207 | + |
| 208 | + self.prefinal_l.constrain_orthonormal() |
| 209 | + self.prefinal_chain.constrain_orthonormal() |
| 210 | + self.prefinal_xent.constrain_orthonormal() |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == '__main__': |
| 214 | + feat_dim = 43 |
| 215 | + output_dim = 4344 |
| 216 | + model = ChainModel(feat_dim=feat_dim, output_dim=output_dim) |
| 217 | + N = 1 |
| 218 | + T = 150 + 27 + 27 |
| 219 | + C = feat_dim * 3 |
| 220 | + x = torch.arange(N * T * C).reshape(N, T, C).float() |
| 221 | + nnet_output, xent_output = model(x) |
| 222 | + print(x.shape, nnet_output.shape, xent_output.shape) |
| 223 | + model.constrain_orthonormal() |
0 commit comments