forked from k2-fsa/icefall
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdprnn.py
More file actions
305 lines (259 loc) · 9.39 KB
/
dprnn.py
File metadata and controls
305 lines (259 loc) · 9.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import random
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM
from torch.autograd import Variable
EPS = torch.finfo(torch.get_default_dtype()).eps
def _pad_segment(input, segment_size):
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342
# input is the features: (B, N, T)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
if rest > 0:
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def split_feature(input, segment_size):
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358
# split the feature into chunks of segment size
# input is the features: (B, N, T)
input, rest = _pad_segment(input, segment_size)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
segments1 = (
input[:, :, :-segment_stride]
.contiguous()
.view(batch_size, dim, -1, segment_size)
)
segments2 = (
input[:, :, segment_stride:]
.contiguous()
.view(batch_size, dim, -1, segment_size)
)
segments = (
torch.cat([segments1, segments2], 3)
.view(batch_size, dim, -1, segment_size)
.transpose(2, 3)
)
return segments.contiguous(), rest
def merge_feature(input, rest):
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385
# merge the splitted features into full utterance
# input is the features: (B, N, L, K)
batch_size, dim, segment_size, _ = input.shape
segment_stride = segment_size // 2
input = (
input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2)
) # B, N, K, L
input1 = (
input[:, :, :, :segment_size]
.contiguous()
.view(batch_size, dim, -1)[:, :, segment_stride:]
)
input2 = (
input[:, :, :, segment_size:]
.contiguous()
.view(batch_size, dim, -1)[:, :, :-segment_stride]
)
output = input1 + input2
if rest > 0:
output = output[:, :, :-rest]
return output.contiguous() # B, N, T
class RNNEncoderLayer(nn.Module):
"""
RNNEncoderLayer is made up of lstm and feedforward networks.
Args:
input_size:
The number of expected features in the input (required).
hidden_size:
The hidden dimension of rnn layer.
dropout:
The dropout value (default=0.1).
layer_dropout:
The dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
input_size: int,
hidden_size: int,
dropout: float = 0.1,
bidirectional: bool = False,
) -> None:
super(RNNEncoderLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
assert hidden_size >= input_size, (hidden_size, input_size)
self.lstm = ScaledLSTM(
input_size=input_size,
hidden_size=hidden_size // 2 if bidirectional else hidden_size,
proj_size=0,
num_layers=1,
dropout=0.0,
batch_first=True,
bidirectional=bidirectional,
)
self.norm_final = BasicNorm(input_size)
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
self.balancer = ActivationBalancer(
num_channels=input_size,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, input_size);
states[1] is the cell states of all layers,
with shape of (1, N, hidden_size).
"""
src_orig = src
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
alpha = warmup if self.training else 1.0
# lstm module
src_lstm, new_states = self.lstm(src, states)
src = self.dropout(src_lstm) + src
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src
# dual-path RNN
class DPRNN(nn.Module):
"""Deep dual-path RNN.
Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py
args:
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
output_size: int, dimension of the output size.
dropout: float, dropout ratio. Default is 0.
num_blocks: int, number of stacked RNN layers. Default is 1.
"""
def __init__(
self,
feature_dim,
input_size,
hidden_size,
output_size,
dropout=0.1,
num_blocks=1,
segment_size=50,
chunk_width_randomization=False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.segment_size = segment_size
self.chunk_width_randomization = chunk_width_randomization
self.input_embed = nn.Sequential(
ScaledLinear(feature_dim, input_size),
BasicNorm(input_size),
ActivationBalancer(
num_channels=input_size,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
),
)
# dual-path RNN
self.row_rnn = nn.ModuleList([])
self.col_rnn = nn.ModuleList([])
for _ in range(num_blocks):
# intra-RNN is non-causal
self.row_rnn.append(
RNNEncoderLayer(
input_size, hidden_size, dropout=dropout, bidirectional=True
)
)
self.col_rnn.append(
RNNEncoderLayer(
input_size, hidden_size, dropout=dropout, bidirectional=False
)
)
# output layer
self.out_embed = nn.Sequential(
ScaledLinear(input_size, output_size),
BasicNorm(output_size),
ActivationBalancer(
num_channels=output_size,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
),
)
def forward(self, input):
# input shape: B, T, F
input = self.input_embed(input)
B, T, D = input.shape
if self.chunk_width_randomization and self.training:
segment_size = random.randint(self.segment_size // 2, self.segment_size)
else:
segment_size = self.segment_size
input, rest = split_feature(input.transpose(1, 2), segment_size)
# input shape: batch, N, dim1, dim2
# apply RNN on dim1 first and then dim2
# output shape: B, output_size, dim1, dim2
# input = input.to(device)
batch_size, _, dim1, dim2 = input.shape
output = input
for i in range(len(self.row_rnn)):
row_input = (
output.permute(0, 3, 2, 1)
.contiguous()
.view(batch_size * dim2, dim1, -1)
) # B*dim2, dim1, N
output = self.row_rnn[i](row_input) # B*dim2, dim1, H
output = (
output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
) # B, N, dim1, dim2
col_input = (
output.permute(0, 2, 3, 1)
.contiguous()
.view(batch_size * dim1, dim2, -1)
) # B*dim1, dim2, N
output = self.col_rnn[i](col_input) # B*dim1, dim2, H
output = (
output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
) # B, N, dim1, dim2
output = merge_feature(output, rest)
output = output.transpose(1, 2)
output = self.out_embed(output)
# Apply ReLU to the output
output = torch.relu(output)
return output
if __name__ == "__main__":
model = DPRNN(
80,
256,
256,
160,
dropout=0.1,
num_blocks=4,
segment_size=32,
chunk_width_randomization=True,
)
input = torch.randn(2, 1002, 80)
print(sum(p.numel() for p in model.parameters()))
print(model(input).shape)