Skip to content

Commit d211d33

Browse files
authored
[scripts] expose egs as Dataloader (#3999)
* expose egs as Dataloader * update dropout fraction computation to get zero at the end of training * remove useless code * change feature list to feature * decrease frequency of doing orthonormal constrain
1 parent 051d6a2 commit d211d33

File tree

6 files changed

+404
-102
lines changed

6 files changed

+404
-102
lines changed
+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2020 Xiaomi Corporation, Beijing, China (author: Haowen Qiu)
4+
# Apache 2.0
5+
6+
from multiprocessing import Process
7+
import datetime
8+
import glob
9+
import os
10+
11+
import numpy as np
12+
import torch
13+
import torch.distributed as dist
14+
15+
from torch.utils.data import Dataset
16+
17+
from kaldi import SequentialNnetChainExampleReader
18+
import kaldi
19+
import kaldi_pybind.nnet3 as nnet3
20+
21+
from common import splice_feats
22+
23+
def get_egs_dataloader(egs_dir_or_scp,
24+
egs_left_context,
25+
egs_right_context,
26+
frame_subsampling_factor=3,
27+
world_size=None,
28+
local_rank=None):
29+
'''
30+
world_size and local_rank is for DistributedDataParallel training.
31+
'''
32+
dataset = NnetChainExampleScpDataset(egs_dir_or_scp)
33+
34+
collate_fn = NnetChainExampleCollateFunc(
35+
egs_left_context=egs_left_context,
36+
egs_right_context=egs_right_context,
37+
frame_subsampling_factor=frame_subsampling_factor)
38+
39+
if local_rank is not None:
40+
sampler = torch.utils.data.distributed.DistributedSampler(
41+
dataset, num_replicas=world_size, rank=local_rank, shuffle=True)
42+
else:
43+
#sampler = torch.utils.data.SequentialSampler(dataset)
44+
sampler = torch.utils.data.RandomSampler(dataset)
45+
46+
dataloader = NnetChainExampleDataLoader(dataset,
47+
sampler=sampler,
48+
collate_fn=collate_fn)
49+
return dataloader
50+
51+
52+
class NnetChainExampleScpDataset(Dataset):
53+
54+
def __init__(self, egs_dir_or_scp):
55+
'''
56+
If egs_dir_or_scp is a directory, we assume that there exist many cegs.*.scp files
57+
inside it.
58+
'''
59+
if os.path.isdir(egs_dir_or_scp):
60+
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir_or_scp))
61+
else:
62+
self.scps = [egs_dir_or_scp]
63+
64+
assert len(self.scps) > 0
65+
66+
def __len__(self):
67+
return len(self.scps)
68+
69+
def __getitem__(self, i):
70+
return self.scps[i]
71+
72+
def __str__(self):
73+
s = 'num egs scp files: {}\n'.format(len(self.scps))
74+
return s
75+
76+
77+
class NnetChainExampleDataLoader(object):
78+
'''
79+
Nnet chain example data loader, provides an iterable over the given scp files.
80+
81+
Arguments:
82+
dataset (Dataset): scp files from which to load the egs.
83+
sampler (Sampler): defines the strategy to draw samples
84+
from the dataset.
85+
collate_fn (callable): creates a batch from mergerd eg.
86+
87+
'''
88+
89+
def __init__(self, dataset, sampler, collate_fn):
90+
91+
self.dataset = dataset
92+
self.sampler = sampler
93+
self.collate_fn = collate_fn
94+
95+
def __len__(self):
96+
return len(self.sampler)
97+
98+
def __iter__(self):
99+
# iterates over one scp file in a `pseudo_epoch`
100+
for pseudo_epoch, sample_idx in enumerate(self.sampler):
101+
# one sample is one scp file
102+
egs_rspecifier = 'scp:' + self.dataset[sample_idx]
103+
with SequentialNnetChainExampleReader(egs_rspecifier) as example_reader:
104+
for key, eg in example_reader:
105+
batch = self.collate_fn(eg)
106+
yield pseudo_epoch, batch
107+
108+
109+
class NnetChainExampleCollateFunc:
110+
111+
def __init__(self, egs_left_context, egs_right_context,
112+
frame_subsampling_factor=3):
113+
114+
'''
115+
egs_left_context is from egs/info/left_context
116+
egs_right_context is from egs/info/right_context
117+
'''
118+
assert egs_left_context >= 0
119+
assert egs_left_context >= 0
120+
121+
# currently support either no subsampling or
122+
# subsampling factor to be 3
123+
assert frame_subsampling_factor in [1, 3]
124+
125+
self.egs_left_context = egs_left_context
126+
self.egs_right_context = egs_right_context
127+
self.frame_subsampling_factor = frame_subsampling_factor
128+
129+
def __call__(self, eg):
130+
'''
131+
eg is a batch as it has been merged
132+
'''
133+
assert eg.inputs[0].name == 'input'
134+
assert len(eg.outputs) == 1
135+
assert eg.outputs[0].name == 'output'
136+
137+
138+
supervision = eg.outputs[0].supervision
139+
140+
batch_size = supervision.num_sequences
141+
frames_per_sequence = (supervision.frames_per_sequence *
142+
self.frame_subsampling_factor) + \
143+
self.egs_left_context + self.egs_right_context
144+
145+
146+
_feats = kaldi.FloatMatrix()
147+
eg.inputs[0].features.GetMatrix(_feats)
148+
feats = _feats.numpy()
149+
150+
if len(eg.inputs) > 1:
151+
_ivectors = kaldi.FloatMatrix()
152+
eg.inputs[1].features.GetMatrix(_ivectors)
153+
ivectors = _ivectors.numpy()
154+
155+
assert feats.shape[0] == batch_size * frames_per_sequence
156+
157+
feat_list = []
158+
for i in range(batch_size):
159+
start_index = i * frames_per_sequence
160+
if self.frame_subsampling_factor == 3:
161+
shift = np.random.choice([-1, 0, 1], 1)[0]
162+
start_index += shift
163+
164+
end_index = start_index + frames_per_sequence
165+
start_index += 2 # remove the leftmost frame added for frame shift
166+
end_index -= 2 # remove the rightmost frame added for frame shift
167+
feat = feats[start_index:end_index:, :]
168+
if len(eg.inputs) > 1:
169+
repeat_ivector = torch.from_numpy(
170+
ivectors[i]).repeat(feat.shape[0], 1)
171+
feat = torch.cat(
172+
(torch.from_numpy(feat), repeat_ivector), dim=1).numpy()
173+
feat_list.append(feat)
174+
175+
batched_feat = np.stack(feat_list, axis=0)
176+
assert batched_feat.shape[0] == batch_size
177+
178+
assert batched_feat.shape[1] == frames_per_sequence - 4
179+
if len(eg.inputs) > 1:
180+
assert batched_feat.shape[2] == feats.shape[-1] + ivectors.shape[-1]
181+
else:
182+
assert batched_feat.shape[2] == feats.shape[-1]
183+
184+
torch_feat = torch.from_numpy(batched_feat).float()
185+
186+
return torch_feat, supervision
187+
188+
189+
def _test_nnet_chain_example_dataloader():
190+
scp_dir = 'exp/chain_pybind/tdnn_sp/egs_chain2'
191+
_test_dataloader_iter(scp_dir)
192+
193+
def _test_dataloader_iter(scp_dir_or_file):
194+
egs_left_context = 29
195+
egs_right_context = 29
196+
frame_subsampling_factor = 3
197+
198+
dataloader = get_egs_dataloader(
199+
scp_dir_or_file,
200+
egs_left_context,
201+
egs_right_context,
202+
frame_subsampling_factor)
203+
204+
for i in range(2):
205+
batch_idx = 0
206+
for pseudo_epoch, batch in dataloader:
207+
print('{}: epoch {}, pseudo_epoch {}, batch_idx {}'.format(
208+
datetime.datetime.now(), i, pseudo_epoch, batch_idx))
209+
batch_idx = batch_idx + 1
210+
feature, supervision = batch
211+
assert feature.shape == (128, 204, 120) \
212+
or feature.shape == (128, 144, 120) \
213+
or feature.shape == (128, 165, 120)
214+
assert supervision.weight == 1
215+
assert supervision.num_sequences == 128 # minibach size is 128
216+
217+
218+
if __name__ == '__main__':
219+
_test_nnet_chain_example_dataloader()

egs/aishell/s10/chain/model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def get_chain_model(feat_dim,
7070

7171

7272
def constrain_orthonormal_hook(model, unused_x):
73-
if model.training == False:
73+
if not model.training:
74+
return
75+
76+
model.ortho_constrain_count = (model.ortho_constrain_count + 1) % 2
77+
if model.ortho_constrain_count != 0:
7478
return
7579

7680
with torch.no_grad():
@@ -100,6 +104,8 @@ def __init__(self,
100104

101105
assert len(kernel_size_list) == len(subsampling_factor_list)
102106
num_layers = len(kernel_size_list)
107+
108+
self.ortho_constrain_count = 0
103109

104110
input_dim = feat_dim * 3 + ivector_dim
105111

0 commit comments

Comments
 (0)