Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 45 additions & 35 deletions deepmark_chainer/net/deepspeech2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import six

from chainer.functions.activation import clipped_relu
from chainer.functions.array import concat
from chainer.functions.array import reshape
from chainer.functions.array import split_axis
from chainer import link
Expand Down Expand Up @@ -55,9 +57,11 @@ def __call__(self, x, train=True):
return self.h


class BRNN(link.Chain):
class BRNNReLU(link.Chain):

def __init__(self, input_dim, output_dim, rnn_unit):
assert output_dim % 2 == 0
output_dim //= 2 # output dim for each direction is halved
if rnn_unit == 'LSTM':
forward = lstm.LSTM(input_dim, output_dim)
reverse = lstm.LSTM(input_dim, output_dim)
Expand All @@ -69,7 +73,7 @@ def __init__(self, input_dim, output_dim, rnn_unit):
reverse = StatefulLinearRNN(input_dim, output_dim)
else:
raise ValueError('Invalid rnn_unit:{}'.format(rnn_unit))
super(BRNN, self).__init__(forward=forward, reverse=reverse)
super(BRNNReLU, self).__init__(forward=forward, reverse=reverse)

def reset_state(self):
self.forward.reset_state()
Expand All @@ -81,20 +85,23 @@ def __call__(self, xs, train=True):
x_reverse = [self.reverse(xs[n], train) for n
in six.moves.range(N - 1, -1, -1)]
x_reverse.reverse()
return [x_f + x_r for x_f, x_r in zip(x_forward, x_reverse)]
xs = [concat.concat((x_f, x_r)) for x_f, x_r
in six.moves.zip(x_forward, x_reverse)]
return [clipped_relu.clipped_relu(x) for x in xs]


class ConvBN(link.Chain):
class ConvBNReLU(link.Chain):

def __init__(self, *args, **kwargs):
conv = C.Convolution2D(*args, **kwargs)
out_channel = len(conv.W.data)
batch_norm = B.BatchNormalization(out_channel)
super(ConvBN, self).__init__(conv=conv, batch_norm=batch_norm)
super(ConvBNReLU, self).__init__(conv=conv, batch_norm=batch_norm)

def __call__(self, x, train=True):
x = self.conv(x)
return self.batch_norm(x, test=not train)
x = self.batch_norm(x, test=not train)
return clipped_relu.clipped_relu(x)


class LinearBN(link.Chain):
Expand All @@ -110,39 +117,44 @@ def __call__(self, x, train=True):
return self.batch_norm(x, test=not train)


class Sequential(link.ChainList):

def __call__(self, x, *args, **kwargs):
for l in self:
x = l(x, *args, **kwargs)
return x


class DeepSpeech2(link.Chain):

def __init__(self, channel_dim=32, hidden_dim=1760, out_dim=29, rnn_unit='Linear', use_cudnn=True):
c1 = ConvBN(1, channel_dim, (5, 20), 2, use_cudnn=use_cudnn)
c2 = ConvBN(channel_dim, channel_dim, (5, 10), (1, 2), use_cudnn=use_cudnn)
convolution = Sequential(c1, c2)

brnn1 = BRNN(31 * channel_dim, hidden_dim, rnn_unit=rnn_unit)
brnn2 = BRNN(hidden_dim, hidden_dim, rnn_unit=rnn_unit)
brnn3 = BRNN(hidden_dim, hidden_dim, rnn_unit=rnn_unit)
brnn4 = BRNN(hidden_dim, hidden_dim, rnn_unit=rnn_unit)
brnn5 = BRNN(hidden_dim, hidden_dim, rnn_unit=rnn_unit)
brnn6 = BRNN(hidden_dim, hidden_dim, rnn_unit=rnn_unit)
brnn7 = BRNN(hidden_dim, hidden_dim, rnn_unit=rnn_unit)
recurrent = Sequential(brnn1, brnn2, brnn3, brnn4,
brnn5, brnn6, brnn7)

fc1 = LinearBN(hidden_dim, hidden_dim)
def __init__(self, channel_dim=32, hidden_dim=1760, out_dim=29,
rnn_unit='Linear', use_cudnn=True):
c1 = ConvBNReLU(1, channel_dim, (5, 20), 2, use_cudnn=use_cudnn)
c2 = ConvBNReLU(channel_dim, channel_dim, (5, 10), (1, 2), use_cudnn=use_cudnn)
convolution = link.ChainList(c1, c2)

brnn1 = BRNNReLU(31 * channel_dim, hidden_dim * 2, rnn_unit=rnn_unit)
brnn2 = BRNNReLU(hidden_dim * 2, hidden_dim * 2, rnn_unit=rnn_unit)
brnn3 = BRNNReLU(hidden_dim * 2, hidden_dim * 2, rnn_unit=rnn_unit)
brnn4 = BRNNReLU(hidden_dim * 2, hidden_dim * 2, rnn_unit=rnn_unit)
brnn5 = BRNNReLU(hidden_dim * 2, hidden_dim * 2, rnn_unit=rnn_unit)
brnn6 = BRNNReLU(hidden_dim * 2, hidden_dim * 2, rnn_unit=rnn_unit)
brnn7 = BRNNReLU(hidden_dim * 2, hidden_dim * 2, rnn_unit=rnn_unit)
recurrent = link.ChainList(brnn1, brnn2, brnn3, brnn4,
brnn5, brnn6, brnn7)

fc1 = LinearBN(hidden_dim * 2, hidden_dim)
fc2 = L.Linear(hidden_dim, out_dim)
linear = link.ChainList(fc1, fc2)
super(DeepSpeech2, self).__init__(convolution=convolution,
recurrent=recurrent,
linear=linear)

def _linear(self, xs, train=True):

def _convolution(self, x, train):
for c in self.convolution:
x = c(x, train)
return x

def _recurrent(self, xs, train):
for r in self.recurrent:
r.reset_state()
xs = r(xs, train)
return xs

def _linear(self, xs, train):
ret = []
for x in xs:
x = self.linear[0](x, train)
Expand All @@ -152,12 +164,10 @@ def _linear(self, xs, train=True):

def __call__(self, x, train=True):
x = reshape.reshape(x, (len(x.data), 1) + x.data.shape[1:])
x = self.convolution(x, train)
x = self._convolution(x, train)
xs = split_axis.split_axis(x, x.data.shape[2], 2)
for x in xs:
x.data = self.xp.ascontiguousarray(x.data)
for r in self.recurrent:
r.reset_state()
xs = self.recurrent(xs, train)
xs = self._recurrent(xs, train)
xs = self._linear(xs, train)
return xs