Skip to content

Commit 4fcfb19

Browse files
committed
add training script.
1 parent 0148732 commit 4fcfb19

File tree

8 files changed

+717
-2
lines changed

8 files changed

+717
-2
lines changed

egs/aishell/s10b/ctc/common.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
# Apache 2.0
5+
6+
from datetime import datetime
7+
import logging
8+
9+
import numpy as np
10+
11+
import torch
12+
13+
import kaldi
14+
15+
16+
def setup_logger(log_filename, log_level='info'):
17+
now = datetime.now()
18+
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
19+
log_filename = '{}-{}'.format(log_filename, date_time)
20+
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
21+
if log_level == 'debug':
22+
level = logging.DEBUG
23+
elif log_level == 'info':
24+
level = logging.INFO
25+
elif log_level == 'warning':
26+
level = logging.WARNING
27+
logging.basicConfig(filename=log_filename,
28+
format=formatter,
29+
level=level,
30+
filemode='w')
31+
console = logging.StreamHandler()
32+
console.setLevel(level)
33+
console.setFormatter(logging.Formatter(formatter))
34+
logging.getLogger('').addHandler(console)

egs/aishell/s10b/ctc/dataset.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
# Apache 2.0
5+
6+
import os
7+
import logging
8+
9+
import torch
10+
from torch.nn.utils.rnn import pad_sequence
11+
from torch.utils.data import DataLoader
12+
from torch.utils.data import Dataset
13+
14+
import kaldi
15+
16+
17+
def get_ctc_dataloader(feats_scp,
18+
labels_scp=None,
19+
batch_size=1,
20+
shuffle=False,
21+
num_workers=0):
22+
23+
dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)
24+
25+
collate_fn = CtcDatasetCollateFunc()
26+
27+
dataloader = DataLoader(dataset,
28+
batch_size=batch_size,
29+
shuffle=shuffle,
30+
num_workers=num_workers,
31+
collate_fn=collate_fn)
32+
33+
return dataloader
34+
35+
36+
class CtcDataset(Dataset):
37+
38+
def __init__(self, feats_scp, labels_scp=None):
39+
'''
40+
Args:
41+
feats_scp: filename for feats.scp
42+
labels_scp: if provided, it is the filename of labels.scp
43+
'''
44+
assert os.path.isfile(feats_scp)
45+
if labels_scp:
46+
assert os.path.isfile(labels_scp)
47+
logging.info('labels scp: {}'.format(labels_scp))
48+
else:
49+
logging.warn('No labels scp is given.')
50+
51+
# items is a dict of [uttid, feat_rxfilename, None]
52+
# or [uttid, feat_rxfilename, label_rxfilename] if labels_scp is not None
53+
items = dict()
54+
55+
with open(feats_scp, 'r') as f:
56+
for line in f:
57+
# every line has the following format:
58+
# uttid feat_rxfilename
59+
uttid_rxfilename = line.split()
60+
assert len(uttid_rxfilename) == 2
61+
62+
uttid, rxfilename = uttid_rxfilename
63+
64+
assert uttid not in items
65+
66+
items[uttid] = [uttid, rxfilename, None]
67+
68+
if labels_scp:
69+
expected_count = len(items)
70+
n = 0
71+
with open(labels_scp, 'r') as f:
72+
for line in f:
73+
# every line has the following format:
74+
# uttid rxfilename
75+
uttid_rxfilename = line.split()
76+
77+
assert len(uttid_rxfilename) == 2
78+
79+
uttid, rxfilename = uttid_rxfilename
80+
81+
assert uttid in items
82+
83+
items[uttid][-1] = rxfilename
84+
85+
n += 1
86+
87+
# every utterance should have a label if
88+
# labels_scp is given
89+
assert n == expected_count
90+
91+
self.items = list(items.values())
92+
self.num_items = len(self.items)
93+
self.feats_scp = feats_scp
94+
self.labels_scp = labels_scp
95+
96+
def __len__(self):
97+
return self.num_items
98+
99+
def __getitem__(self, i):
100+
'''
101+
Returns:
102+
a list [key, feat_rxfilename, label_rxfilename]
103+
Note that label_rxfilename may be None.
104+
'''
105+
return self.items[i]
106+
107+
def __str__(self):
108+
s = 'feats scp: {}\n'.format(self.feats_scp)
109+
110+
if self.labels_scp:
111+
s += 'labels scp: {}\n'.format(self.labels_scp)
112+
113+
s += 'num utterances: {}\n'.format(self.num_items)
114+
115+
return s
116+
117+
118+
class CtcDatasetCollateFunc:
119+
120+
def __call__(self, batch):
121+
'''
122+
Args:
123+
batch: a list of [uttid, feat_rxfilename, label_rxfilename].
124+
Note that label_rxfilename may be None.
125+
126+
Returns:
127+
uttid_list: a list of utterance id
128+
129+
feat: a 3-D float tensor of shape [batch_size, seq_len, feat_dim]
130+
131+
feat_len_list: number of frames of each utterance before padding
132+
133+
label_list: a list of labels of each utterance; It may be None.
134+
135+
label_len_list: label length of each utterance; It is None if label_list is None.
136+
'''
137+
uttid_list = [] # utterance id of each utterance
138+
feat_len_list = [] # number of frames of each utterance
139+
label_list = [] # label of each utterance
140+
label_len_list = [] # label length of each utterance
141+
142+
feat_list = []
143+
144+
for b in batch:
145+
uttid, feat_rxfilename, label_rxfilename = b
146+
147+
uttid_list.append(uttid)
148+
149+
feat = kaldi.read_mat(feat_rxfilename).numpy()
150+
feat = torch.from_numpy(feat).float()
151+
feat_list.append(feat)
152+
153+
feat_len_list.append(feat.size(0))
154+
155+
if label_rxfilename:
156+
label = kaldi.read_vec_int(label_rxfilename)
157+
label_list.append(label)
158+
label_len_list.append(len(label))
159+
160+
feat = pad_sequence(feat_list, batch_first=True)
161+
162+
if not label_list:
163+
label_list = None
164+
label_len_list = None
165+
166+
return uttid_list, feat, feat_len_list, label_list, label_len_list
167+
168+
169+
def _test_dataset():
170+
feats_scp = 'data/train_sp/feats.scp'
171+
labels_scp = 'data/train_sp/labels.scp'
172+
173+
dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)
174+
175+
print(dataset)
176+
177+
178+
def _test_dataloader():
179+
feats_scp = 'data/test/feats.scp'
180+
labels_scp = 'data/test/labels.scp'
181+
182+
dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)
183+
184+
dataloader = DataLoader(dataset,
185+
batch_size=2,
186+
num_workers=10,
187+
shuffle=True,
188+
collate_fn=CtcDatasetCollateFunc())
189+
i = 0
190+
for batch in dataloader:
191+
uttid_list, feat, feat_len_list, label_list, label_len_list = batch
192+
print(uttid_list, feat.shape, feat_len_list, label_len_list)
193+
i += 1
194+
if i > 10:
195+
break
196+
197+
198+
if __name__ == '__main__':
199+
# _test_dataset()
200+
_test_dataloader()

egs/aishell/s10b/ctc/model.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
# Apache 2.0
5+
6+
import logging
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
from torch.nn.utils.rnn import pack_padded_sequence
12+
from torch.nn.utils.rnn import pad_packed_sequence
13+
14+
15+
def get_ctc_model(input_dim,
16+
output_dim,
17+
num_layers=4,
18+
hidden_dim=512,
19+
proj_dim=256):
20+
model = CtcModel(input_dim=input_dim,
21+
output_dim=output_dim,
22+
num_layers=num_layers,
23+
hidden_dim=hidden_dim,
24+
proj_dim=proj_dim)
25+
26+
return model
27+
28+
29+
class CtcModel(nn.Module):
30+
31+
def __init__(self, input_dim, output_dim, num_layers, hidden_dim, proj_dim):
32+
'''
33+
Args:
34+
input_dim: input dimension of the network
35+
36+
output_dim: output dimension of the network
37+
38+
num_layers: number of LSTM layers of the network
39+
40+
hidden_dim: the dimension of the hidden state of LSTM layers
41+
42+
proj_dim: dimension of the affine layer after every LSTM layer
43+
'''
44+
super().__init__()
45+
46+
lstm_layer_list = []
47+
proj_layer_list = []
48+
49+
# batchnorm requires input of shape [N, C, L] == [batch_size, dim, seq_len]
50+
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
51+
affine=False)
52+
53+
for i in range(num_layers):
54+
if i == 0:
55+
lstm_input_dim = input_dim
56+
else:
57+
lstm_input_dim = proj_dim
58+
59+
lstm_layer = nn.LSTM(input_size=lstm_input_dim,
60+
hidden_size=hidden_dim,
61+
num_layers=1,
62+
batch_first=True)
63+
64+
proj_layer = nn.Linear(in_features=hidden_dim,
65+
out_features=proj_dim)
66+
67+
lstm_layer_list.append(lstm_layer)
68+
proj_layer_list.append(proj_layer)
69+
70+
self.lstm_layer_list = nn.ModuleList(lstm_layer_list)
71+
self.proj_layer_list = nn.ModuleList(proj_layer_list)
72+
73+
self.num_layers = num_layers
74+
75+
self.prefinal_affine = nn.Linear(in_features=proj_dim,
76+
out_features=output_dim)
77+
78+
def forward(self, feat, feat_len_list):
79+
'''
80+
Args:
81+
feat: a 3-D tensor of shape [batch_size, seq_len, feat_dim]
82+
feat_len_list: feat length of each utterance before padding
83+
84+
Returns:
85+
a 3-D tensor of shape [batch_size, seq_len, output_dim]
86+
representing log prob, i.e., the output of log_softmax.
87+
'''
88+
x = feat
89+
90+
# at his point, x is of shape [batch_size, seq_len, feat_dim]
91+
x = x.permute(0, 2, 1)
92+
93+
# at his point, x is of shape [batch_size, feat_dim, seq_len] == [N, C, L]
94+
x = self.input_batch_norm(x)
95+
96+
x = x.permute(0, 2, 1)
97+
98+
# at his point, x is of shape [batch_size, seq_len, feat_dim] == [N, L, C]
99+
100+
for i in range(self.num_layers):
101+
x = pack_padded_sequence(input=x,
102+
lengths=feat_len_list,
103+
batch_first=True,
104+
enforce_sorted=False)
105+
106+
# TODO(fangjun): save intermediate LSTM state to support streaming inference
107+
x, _ = self.lstm_layer_list[i](x)
108+
109+
x, _ = pad_packed_sequence(x, batch_first=True)
110+
111+
x = self.proj_layer_list[i](x)
112+
113+
x = torch.tanh(x)
114+
115+
x = self.prefinal_affine(x)
116+
117+
x = F.log_softmax(x, dim=-1)
118+
119+
return x
120+
121+
122+
def _test_ctc_model():
123+
input_dim = 5
124+
output_dim = 20
125+
model = CtcModel(input_dim=input_dim,
126+
output_dim=output_dim,
127+
num_layers=2,
128+
hidden_dim=3,
129+
proj_dim=4)
130+
131+
feat1 = torch.randn((6, input_dim))
132+
feat2 = torch.randn((8, input_dim))
133+
134+
from torch.nn.utils.rnn import pad_sequence
135+
feat = pad_sequence([feat1, feat2], batch_first=True)
136+
assert feat.shape == torch.Size([2, 8, input_dim])
137+
138+
feat_len_list = [6, 8]
139+
x = model(feat, feat_len_list)
140+
141+
assert x.shape == torch.Size([2, 8, output_dim])
142+
143+
144+
if __name__ == '__main__':
145+
_test_ctc_model()

0 commit comments

Comments
 (0)