Skip to content

Commit f5875be

Browse files
authored
show L2 norm of parameters during training; add example script (#3925)
1 parent be0842f commit f5875be

File tree

9 files changed

+448
-235
lines changed

9 files changed

+448
-235
lines changed

egs/aishell/s10/chain/egs_dataset.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Apache 2.0
55

66
import glob
7+
import os
78

89
import numpy as np
910
import torch
@@ -17,9 +18,12 @@
1718
from common import splice_feats
1819

1920

20-
def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):
21+
def get_egs_dataloader(egs_dir_or_scp,
22+
egs_left_context,
23+
egs_right_context,
24+
shuffle=True):
2125

22-
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
26+
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir_or_scp)
2327
frame_subsampling_factor = 3
2428

2529
# we have merged egs offline, so batch size is 1
@@ -32,6 +36,7 @@ def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):
3236

3337
dataloader = DataLoader(dataset,
3438
batch_size=batch_size,
39+
shuffle=shuffle,
3540
num_workers=0,
3641
collate_fn=collate_fn)
3742
return dataloader
@@ -50,11 +55,16 @@ def read_nnet_chain_example(rxfilename):
5055

5156
class NnetChainExampleDataset(Dataset):
5257

53-
def __init__(self, egs_dir):
58+
def __init__(self, egs_dir_or_scp):
5459
'''
55-
We assume that there exist many cegs.*.scp files inside egs_dir
60+
If egs_dir_or_scp is a directory, we assume that there exist many cegs.*.scp files
61+
inside egs_dir_or_scp.
5662
'''
57-
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir))
63+
if os.path.isdir(egs_dir_or_scp):
64+
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir_or_scp))
65+
else:
66+
self.scps = [egs_dir_or_scp]
67+
5868
assert len(self.scps) > 0
5969
self.items = list()
6070
for scp in self.scps:
@@ -171,7 +181,7 @@ def __call__(self, batch):
171181

172182
def _test_nnet_chain_example_dataset():
173183
egs_dir = 'exp/chain/merged_egs'
174-
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
184+
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir)
175185
egs_left_context = 29
176186
egs_right_context = 29
177187
frame_subsampling_factor = 3

egs/aishell/s10/chain/inference.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ def main():
3030
else:
3131
device = torch.device('cuda', args.device_id)
3232

33-
model = get_chain_model(feat_dim=args.feat_dim,
34-
output_dim=args.output_dim,
35-
lda_mat_filename=args.lda_mat_filename,
36-
hidden_dim=args.hidden_dim,
37-
bottleneck_dim=args.bottleneck_dim,
38-
time_stride_list=args.time_stride_list,
39-
conv_stride_list=args.conv_stride_list)
33+
model = get_chain_model(
34+
feat_dim=args.feat_dim,
35+
output_dim=args.output_dim,
36+
lda_mat_filename=args.lda_mat_filename,
37+
hidden_dim=args.hidden_dim,
38+
bottleneck_dim=args.bottleneck_dim,
39+
prefinal_bottleneck_dim=args.prefinal_bottleneck_dim,
40+
kernel_size_list=args.kernel_size_list,
41+
subsampling_factor_list=args.subsampling_factor_list)
4042

4143
load_checkpoint(args.checkpoint, model)
4244

egs/aishell/s10/chain/model.py

+33-21
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ def get_chain_model(feat_dim,
1919
output_dim,
2020
hidden_dim,
2121
bottleneck_dim,
22-
time_stride_list,
23-
conv_stride_list,
22+
prefinal_bottleneck_dim,
23+
kernel_size_list,
24+
subsampling_factor_list,
2425
lda_mat_filename=None):
2526
model = ChainModel(feat_dim=feat_dim,
2627
output_dim=output_dim,
2728
lda_mat_filename=lda_mat_filename,
2829
hidden_dim=hidden_dim,
29-
time_stride_list=time_stride_list,
30-
conv_stride_list=conv_stride_list)
30+
bottleneck_dim=bottleneck_dim,
31+
prefinal_bottleneck_dim=prefinal_bottleneck_dim,
32+
kernel_size_list=kernel_size_list,
33+
subsampling_factor_list=subsampling_factor_list)
3134
return model
3235

3336

@@ -72,55 +75,58 @@ def __init__(self,
7275
lda_mat_filename=None,
7376
hidden_dim=1024,
7477
bottleneck_dim=128,
75-
time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1],
76-
conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
78+
prefinal_bottleneck_dim=256,
79+
kernel_size_list=[2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2],
80+
subsampling_factor_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
7781
frame_subsampling_factor=3):
7882
super().__init__()
7983

8084
# at present, we support only frame_subsampling_factor to be 3
8185
assert frame_subsampling_factor == 3
8286

83-
assert len(time_stride_list) == len(conv_stride_list)
84-
num_layers = len(time_stride_list)
87+
assert len(kernel_size_list) == len(subsampling_factor_list)
88+
num_layers = len(kernel_size_list)
8589

8690
# tdnn1_affine requires [N, T, C]
8791
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
8892
out_features=hidden_dim)
8993

9094
# tdnn1_batchnorm requires [N, C, T]
91-
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim)
95+
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim,
96+
affine=False)
9297

9398
tdnnfs = []
9499
for i in range(num_layers):
95-
time_stride = time_stride_list[i]
96-
conv_stride = conv_stride_list[i]
100+
kernel_size = kernel_size_list[i]
101+
subsampling_factor = subsampling_factor_list[i]
97102
layer = FactorizedTDNN(dim=hidden_dim,
98103
bottleneck_dim=bottleneck_dim,
99-
time_stride=time_stride,
100-
conv_stride=conv_stride)
104+
kernel_size=kernel_size,
105+
subsampling_factor=subsampling_factor)
101106
tdnnfs.append(layer)
102107

103108
# tdnnfs requires [N, C, T]
104109
self.tdnnfs = nn.ModuleList(tdnnfs)
105110

106111
# prefinal_l affine requires [N, C, T]
107-
self.prefinal_l = OrthonormalLinear(dim=hidden_dim,
108-
bottleneck_dim=bottleneck_dim * 2,
109-
time_stride=0)
112+
self.prefinal_l = OrthonormalLinear(
113+
dim=hidden_dim,
114+
bottleneck_dim=prefinal_bottleneck_dim,
115+
kernel_size=1)
110116

111117
# prefinal_chain requires [N, C, T]
112118
self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim,
113-
small_dim=bottleneck_dim * 2)
119+
small_dim=prefinal_bottleneck_dim)
114120

115121
# output_affine requires [N, T, C]
116-
self.output_affine = nn.Linear(in_features=bottleneck_dim * 2,
122+
self.output_affine = nn.Linear(in_features=prefinal_bottleneck_dim,
117123
out_features=output_dim)
118124

119125
# prefinal_xent requires [N, C, T]
120126
self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim,
121-
small_dim=bottleneck_dim * 2)
127+
small_dim=prefinal_bottleneck_dim)
122128

123-
self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2,
129+
self.output_xent_affine = nn.Linear(in_features=prefinal_bottleneck_dim,
124130
out_features=output_dim)
125131

126132
if lda_mat_filename:
@@ -130,7 +136,8 @@ def __init__(self,
130136
self.has_LDA = True
131137
else:
132138
logging.info('replace LDA with BatchNorm')
133-
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3)
139+
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
140+
affine=False)
134141
self.has_LDA = False
135142

136143
def forward(self, x):
@@ -211,9 +218,14 @@ def constrain_orthonormal(self):
211218

212219

213220
if __name__ == '__main__':
221+
logging.basicConfig(
222+
level=logging.DEBUG,
223+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
224+
)
214225
feat_dim = 43
215226
output_dim = 4344
216227
model = ChainModel(feat_dim=feat_dim, output_dim=output_dim)
228+
logging.info(model)
217229
N = 1
218230
T = 150 + 27 + 27
219231
C = feat_dim * 3

egs/aishell/s10/chain/options.py

+41-14
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def _set_training_args(parser):
5454
help='cegs dir containing comibined cegs.*.scp',
5555
type=str)
5656

57+
parser.add_argument('--train.valid-cegs-scp',
58+
dest='valid_cegs_scp',
59+
help='validation cegs scp',
60+
type=str)
61+
5762
parser.add_argument('--train.den-fst',
5863
dest='den_fst_filename',
5964
help='denominator fst filename',
@@ -84,9 +89,20 @@ def _set_training_args(parser):
8489
help='l2 regularize',
8590
type=float)
8691

92+
parser.add_argument('--train.xent-regularize',
93+
dest='xent_regularize',
94+
help='xent regularize',
95+
type=float)
96+
97+
parser.add_argument('--train.leaky-hmm-coefficient',
98+
dest='leaky_hmm_coefficient',
99+
help='leaky hmm coefficient',
100+
type=float)
101+
87102

88103
def _check_training_args(args):
89104
assert os.path.isdir(args.cegs_dir)
105+
assert os.path.isfile(args.valid_cegs_scp)
90106

91107
assert os.path.isfile(args.den_fst_filename)
92108

@@ -95,7 +111,9 @@ def _check_training_args(args):
95111

96112
assert args.num_epochs > 0
97113
assert args.learning_rate > 0
98-
assert args.l2_regularize > 0
114+
assert args.l2_regularize >= 0
115+
assert args.xent_regularize >= 0
116+
assert args.leaky_hmm_coefficient >= 0
99117

100118
if args.checkpoint:
101119
assert os.path.exists(args.checkpoint)
@@ -130,18 +148,21 @@ def _check_args(args):
130148
assert args.output_dim > 0
131149
assert args.hidden_dim > 0
132150
assert args.bottleneck_dim > 0
151+
assert args.prefinal_bottleneck_dim > 0
133152

134-
assert args.time_stride_list is not None
135-
assert len(args.time_stride_list) > 0
153+
assert args.kernel_size_list is not None
154+
assert len(args.kernel_size_list) > 0
136155

137-
assert args.conv_stride_list is not None
138-
assert len(args.conv_stride_list) > 0
156+
assert args.subsampling_factor_list is not None
157+
assert len(args.subsampling_factor_list) > 0
139158

140-
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]
159+
args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
141160

142-
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]
161+
args.subsampling_factor_list = [
162+
int(k) for k in args.subsampling_factor_list.split(', ')
163+
]
143164

144-
assert len(args.time_stride_list) == len(args.conv_stride_list)
165+
assert len(args.kernel_size_list) == len(args.subsampling_factor_list)
145166

146167
assert args.log_level in ['debug', 'info', 'warning']
147168

@@ -202,15 +223,21 @@ def get_args():
202223
required=True,
203224
type=int)
204225

205-
parser.add_argument('--time-stride-list',
206-
dest='time_stride_list',
207-
help='time stride list',
226+
parser.add_argument('--prefinal-bottleneck-dim',
227+
dest='prefinal_bottleneck_dim',
228+
help='nn prefinal bottleneck dimension',
229+
required=True,
230+
type=int)
231+
232+
parser.add_argument('--kernel-size-list',
233+
dest='kernel_size_list',
234+
help='kernel_size_list',
208235
required=True,
209236
type=str)
210237

211-
parser.add_argument('--conv-stride-list',
212-
dest='conv_stride_list',
213-
help='conv stride list',
238+
parser.add_argument('--subsampling-factor-list',
239+
dest='subsampling_factor_list',
240+
help='subsampling_factor_list',
214241
required=True,
215242
type=str)
216243

0 commit comments

Comments
 (0)