Skip to content

Commit be0842f

Browse files
authored
[scripts,egs] Add TDNNF to pytorch. (#3892)
1 parent ee517cd commit be0842f

File tree

10 files changed

+561
-162
lines changed

10 files changed

+561
-162
lines changed

egs/aishell/s10/chain/egs_dataset.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ def __call__(self, batch):
170170

171171

172172
def _test_nnet_chain_example_dataset():
173-
egs_dir = '/cache/fangjun/chain/aishell_kaldi_pybind/test'
173+
egs_dir = 'exp/chain/merged_egs'
174174
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
175-
egs_left_context = 23
176-
egs_right_context = 23
175+
egs_left_context = 29
176+
egs_right_context = 29
177177
frame_subsampling_factor = 3
178178

179179
collate_fn = NnetChainExampleDatasetCollateFunc(
@@ -200,7 +200,9 @@ def _test_nnet_chain_example_dataset():
200200
collate_fn=collate_fn)
201201
for b in dataloader:
202202
key_list, feature_list, supervision_list = b
203-
assert feature_list[0].shape == (128, 192, 120)
203+
assert feature_list[0].shape == (128, 204, 129) \
204+
or feature_list[0].shape == (128, 144, 129) \
205+
or feature_list[0].shape == (128, 165, 129)
204206
assert supervision_list[0].weight == 1
205207
supervision_list[0].num_sequences == 128 # minibach size is 128
206208

egs/aishell/s10/chain/inference.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def main():
3434
output_dim=args.output_dim,
3535
lda_mat_filename=args.lda_mat_filename,
3636
hidden_dim=args.hidden_dim,
37-
kernel_size_list=args.kernel_size_list,
38-
stride_list=args.stride_list)
37+
bottleneck_dim=args.bottleneck_dim,
38+
time_stride_list=args.time_stride_list,
39+
conv_stride_list=args.conv_stride_list)
3940

4041
load_checkpoint(args.checkpoint, model)
4142

egs/aishell/s10/chain/model.py

+143-97
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
3+
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
44
# Apache 2.0
55

66
import logging
@@ -10,109 +10,118 @@
1010
import torch.nn.functional as F
1111

1212
from common import load_lda_mat
13-
'''
14-
input dim=$feat_dim name=input
15-
16-
# please note that it is important to have input layer with the name=input
17-
# as the layer immediately preceding the fixed-affine-layer to enable
18-
# the use of short notation for the descriptor
19-
fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=$dir/configs/lda.mat
20-
21-
# the first splicing is moved before the lda layer, so no splicing here
22-
relu-batchnorm-layer name=tdnn1 dim=625
23-
relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625
24-
relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625
25-
relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625
26-
relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625
27-
relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625
28-
29-
## adding the layers for chain branch
30-
relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5
31-
output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5
32-
33-
# adding the layers for xent branch
34-
# This block prints the configs for a separate output that will be
35-
# trained with a cross-entropy objective in the 'chain' models... this
36-
# has the effect of regularizing the hidden parts of the model. we use
37-
# 0.5 / args.xent_regularize as the learning rate factor- the factor of
38-
# 0.5 / args.xent_regularize is suitable as it means the xent
39-
# final-layer learns at a rate independent of the regularization
40-
# constant; and the 0.5 was tuned so as to make the relative progress
41-
# similar in the xent and regular final layers.
42-
relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5
43-
output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5
44-
'''
13+
from tdnnf_layer import FactorizedTDNN
14+
from tdnnf_layer import OrthonormalLinear
15+
from tdnnf_layer import PrefinalLayer
4516

4617

4718
def get_chain_model(feat_dim,
4819
output_dim,
4920
hidden_dim,
50-
kernel_size_list,
51-
stride_list,
21+
bottleneck_dim,
22+
time_stride_list,
23+
conv_stride_list,
5224
lda_mat_filename=None):
5325
model = ChainModel(feat_dim=feat_dim,
5426
output_dim=output_dim,
5527
lda_mat_filename=lda_mat_filename,
5628
hidden_dim=hidden_dim,
57-
kernel_size_list=kernel_size_list,
58-
stride_list=stride_list)
29+
time_stride_list=time_stride_list,
30+
conv_stride_list=conv_stride_list)
5931
return model
6032

6133

34+
'''
35+
input dim=43 name=input
36+
37+
# please note that it is important to have input layer with the name=input
38+
# as the layer immediately preceding the fixed-affine-layer to enable
39+
# the use of short notation for the descriptor
40+
fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=exp/chain_cleaned_1c/tdnn1c_sp/configs/lda.mat
41+
42+
# the first splicing is moved before the lda layer, so no splicing here
43+
relu-batchnorm-dropout-layer name=tdnn1 l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim-continuous=true dim=1024
44+
tdnnf-layer name=tdnnf2 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
45+
tdnnf-layer name=tdnnf3 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
46+
tdnnf-layer name=tdnnf4 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
47+
tdnnf-layer name=tdnnf5 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=0
48+
tdnnf-layer name=tdnnf6 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
49+
tdnnf-layer name=tdnnf7 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
50+
tdnnf-layer name=tdnnf8 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
51+
tdnnf-layer name=tdnnf9 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
52+
tdnnf-layer name=tdnnf10 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
53+
tdnnf-layer name=tdnnf11 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
54+
tdnnf-layer name=tdnnf12 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
55+
tdnnf-layer name=tdnnf13 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
56+
linear-component name=prefinal-l dim=256 l2-regularize=0.008 orthonormal-constraint=-1.0
57+
58+
prefinal-layer name=prefinal-chain input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256
59+
output-layer name=output include-log-softmax=false dim=3456 l2-regularize=0.002
60+
61+
prefinal-layer name=prefinal-xent input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256
62+
output-layer name=output-xent dim=3456 learning-rate-factor=5.0 l2-regularize=0.002
63+
'''
64+
65+
6266
# Create a network like the above one
6367
class ChainModel(nn.Module):
6468

6569
def __init__(self,
6670
feat_dim,
6771
output_dim,
68-
lda_mat_filename,
69-
hidden_dim=625,
70-
kernel_size_list=[1, 3, 3, 3, 3, 3],
71-
stride_list=[1, 1, 3, 1, 1, 1],
72+
lda_mat_filename=None,
73+
hidden_dim=1024,
74+
bottleneck_dim=128,
75+
time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1],
76+
conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
7277
frame_subsampling_factor=3):
7378
super().__init__()
7479

7580
# at present, we support only frame_subsampling_factor to be 3
7681
assert frame_subsampling_factor == 3
7782

78-
assert len(kernel_size_list) == len(stride_list)
79-
num_layers = len(kernel_size_list)
83+
assert len(time_stride_list) == len(conv_stride_list)
84+
num_layers = len(time_stride_list)
85+
86+
# tdnn1_affine requires [N, T, C]
87+
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
88+
out_features=hidden_dim)
8089

81-
tdnns = []
90+
# tdnn1_batchnorm requires [N, C, T]
91+
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim)
92+
93+
tdnnfs = []
8294
for i in range(num_layers):
83-
in_channels = hidden_dim
84-
if i == 0:
85-
in_channels = feat_dim * 3
86-
87-
kernel_size = kernel_size_list[i]
88-
stride = stride_list[i]
89-
90-
# we do not need to perform padding in Conv1d because it
91-
# has been included in left/right context while generating egs
92-
layer = nn.Conv1d(in_channels=in_channels,
93-
out_channels=hidden_dim,
94-
kernel_size=kernel_size,
95-
stride=stride)
96-
tdnns.append(layer)
97-
98-
self.tdnns = nn.ModuleList(tdnns)
99-
self.batch_norms = nn.ModuleList([
100-
nn.BatchNorm1d(num_features=hidden_dim) for i in range(num_layers)
101-
])
102-
103-
self.prefinal_chain_tdnn = nn.Conv1d(in_channels=hidden_dim,
104-
out_channels=hidden_dim,
105-
kernel_size=1)
106-
self.prefinal_chain_batch_norm = nn.BatchNorm1d(num_features=hidden_dim)
107-
self.output_fc = nn.Linear(in_features=hidden_dim,
108-
out_features=output_dim)
109-
110-
self.prefinal_xent_tdnn = nn.Conv1d(in_channels=hidden_dim,
111-
out_channels=hidden_dim,
112-
kernel_size=1)
113-
self.prefinal_xent_batch_norm = nn.BatchNorm1d(num_features=hidden_dim)
114-
self.output_xent_fc = nn.Linear(in_features=hidden_dim,
115-
out_features=output_dim)
95+
time_stride = time_stride_list[i]
96+
conv_stride = conv_stride_list[i]
97+
layer = FactorizedTDNN(dim=hidden_dim,
98+
bottleneck_dim=bottleneck_dim,
99+
time_stride=time_stride,
100+
conv_stride=conv_stride)
101+
tdnnfs.append(layer)
102+
103+
# tdnnfs requires [N, C, T]
104+
self.tdnnfs = nn.ModuleList(tdnnfs)
105+
106+
# prefinal_l affine requires [N, C, T]
107+
self.prefinal_l = OrthonormalLinear(dim=hidden_dim,
108+
bottleneck_dim=bottleneck_dim * 2,
109+
time_stride=0)
110+
111+
# prefinal_chain requires [N, C, T]
112+
self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim,
113+
small_dim=bottleneck_dim * 2)
114+
115+
# output_affine requires [N, T, C]
116+
self.output_affine = nn.Linear(in_features=bottleneck_dim * 2,
117+
out_features=output_dim)
118+
119+
# prefinal_xent requires [N, C, T]
120+
self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim,
121+
small_dim=bottleneck_dim * 2)
122+
123+
self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2,
124+
out_features=output_dim)
116125

117126
if lda_mat_filename:
118127
logging.info('Use LDA from {}'.format(lda_mat_filename))
@@ -146,32 +155,69 @@ def forward(self, x):
146155

147156
# at this point, x is [N, C, T]
148157

149-
# Conv1d requires input of shape [N, C, T]
150-
for i in range(len(self.tdnns)):
151-
x = self.tdnns[i](x)
152-
x = F.relu(x)
153-
x = self.batch_norms[i](x)
158+
x = x.permute(0, 2, 1)
159+
160+
# at this point, x is [N, T, C]
161+
162+
x = self.tdnn1_affine(x)
163+
164+
# at this point, x is [N, T, C]
165+
166+
x = F.relu(x)
167+
168+
x = x.permute(0, 2, 1)
169+
170+
# at this point, x is [N, C, T]
171+
172+
x = self.tdnn1_batchnorm(x)
173+
174+
# tdnnf requires input of shape [N, C, T]
175+
for i in range(len(self.tdnnfs)):
176+
x = self.tdnnfs[i](x)
154177

155178
# at this point, x is [N, C, T]
156179

157-
# we have two branches from this point on
180+
x = self.prefinal_l(x)
181+
182+
# at this point, x is [N, C, T]
158183

159-
# first, for the chain branch
160-
x_chain = self.prefinal_chain_tdnn(x)
161-
x_chain = F.relu(x_chain)
162-
x_chain = self.prefinal_chain_batch_norm(x_chain)
163-
x_chain = x_chain.permute(0, 2, 1)
164-
# at this point, x_chain is [N, T, C]
165-
nnet_output = self.output_fc(x_chain)
184+
# for the output node
185+
nnet_output = self.prefinal_chain(x)
166186

167-
# now for the xent branch
168-
x_xent = self.prefinal_xent_tdnn(x)
169-
x_xent = F.relu(x_xent)
170-
x_xent = self.prefinal_xent_batch_norm(x_xent)
171-
x_xent = x_xent.permute(0, 2, 1)
187+
# at this point, nnet_output is [N, C, T]
188+
nnet_output = nnet_output.permute(0, 2, 1)
189+
# at this point, nnet_output is [N, T, C]
190+
nnet_output = self.output_affine(nnet_output)
191+
192+
# for the xent node
193+
xent_output = self.prefinal_xent(x)
194+
195+
# at this point, xent_output is [N, C, T]
196+
xent_output = xent_output.permute(0, 2, 1)
197+
# at this point, xent_output is [N, T, C]
198+
xent_output = self.output_xent_affine(xent_output)
172199

173-
# at this point x_xent is [N, T, C]
174-
xent_output = self.output_xent_fc(x_xent)
175200
xent_output = F.log_softmax(xent_output, dim=-1)
176201

177202
return nnet_output, xent_output
203+
204+
def constrain_orthonormal(self):
205+
for i in range(len(self.tdnnfs)):
206+
self.tdnnfs[i].constrain_orthonormal()
207+
208+
self.prefinal_l.constrain_orthonormal()
209+
self.prefinal_chain.constrain_orthonormal()
210+
self.prefinal_xent.constrain_orthonormal()
211+
212+
213+
if __name__ == '__main__':
214+
feat_dim = 43
215+
output_dim = 4344
216+
model = ChainModel(feat_dim=feat_dim, output_dim=output_dim)
217+
N = 1
218+
T = 150 + 27 + 27
219+
C = feat_dim * 3
220+
x = torch.arange(N * T * C).reshape(N, T, C).float()
221+
nnet_output, xent_output = model(x)
222+
print(x.shape, nnet_output.shape, xent_output.shape)
223+
model.constrain_orthonormal()

egs/aishell/s10/chain/options.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,19 @@ def _check_args(args):
129129
assert args.feat_dim > 0
130130
assert args.output_dim > 0
131131
assert args.hidden_dim > 0
132+
assert args.bottleneck_dim > 0
132133

133-
assert args.kernel_size_list is not None
134-
assert len(args.kernel_size_list) > 0
134+
assert args.time_stride_list is not None
135+
assert len(args.time_stride_list) > 0
135136

136-
assert args.stride_list is not None
137-
assert len(args.stride_list) > 0
137+
assert args.conv_stride_list is not None
138+
assert len(args.conv_stride_list) > 0
138139

139-
args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
140+
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]
140141

141-
args.stride_list = [int(k) for k in args.stride_list.split(', ')]
142+
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]
142143

143-
assert len(args.kernel_size_list) == len(args.stride_list)
144+
assert len(args.time_stride_list) == len(args.conv_stride_list)
144145

145146
assert args.log_level in ['debug', 'info', 'warning']
146147

@@ -195,15 +196,21 @@ def get_args():
195196
required=True,
196197
type=int)
197198

198-
parser.add_argument('--kernel-size-list',
199-
dest='kernel_size_list',
200-
help='kernel size list',
199+
parser.add_argument('--bottleneck-dim',
200+
dest='bottleneck_dim',
201+
help='nn bottleneck dimension',
202+
required=True,
203+
type=int)
204+
205+
parser.add_argument('--time-stride-list',
206+
dest='time_stride_list',
207+
help='time stride list',
201208
required=True,
202209
type=str)
203210

204-
parser.add_argument('--stride-list',
205-
dest='stride_list',
206-
help='stride list',
211+
parser.add_argument('--conv-stride-list',
212+
dest='conv_stride_list',
213+
help='conv stride list',
207214
required=True,
208215
type=str)
209216

0 commit comments

Comments
 (0)