-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathlayer_norm_indrnn.py
More file actions
219 lines (179 loc) · 7.88 KB
/
Copy pathlayer_norm_indrnn.py
File metadata and controls
219 lines (179 loc) · 7.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layer Normalized Independently Recurrent Neural Network"""
import pkg_resources
import tensorflow as tf
from tensorflow.compat import v1
from tensorflow.compat.v1.nn import rnn_cell
from .base_rnn import BaseRNN
from .weight_config import WeightConfig
__all__ = [
'LayerNormIndRNN'
]
LIB = tf.load_op_library(pkg_resources.resource_filename(__name__, 'libhaste_tf.so'))
@tf.RegisterGradient("HasteLayerNormIndrnn")
def layer_norm_indrnn_gradient(op, *grads):
training = op.get_attr('training')
if not training:
raise ValueError(('LayerNormIndRNN can only compute gradients if `training=True` was specified '
'during the forward pass.\nFailed op: {}').format(op.name))
# Extract inputs and outputs from the op.
x = op.inputs[0]
W = op.inputs[1]
u = op.inputs[2]
b = op.inputs[3]
gamma = op.inputs[4]
zoneout_mask = op.inputs[5]
h = op.outputs[0]
cache = op.outputs[1]
# Pre-transpose matrices for better performance.
x = tf.transpose(x, [2, 0, 1])
W = tf.transpose(W, [1, 0])
grads = LIB.haste_layer_norm_indrnn_grad(x, W, u, b, gamma, zoneout_mask, h, cache, grads[0])
return [*grads, None]
def _get_initializer(initializer):
if not isinstance(initializer, dict):
return initializer
if 'uniform' in initializer:
value = initializer['uniform']
return v1.initializers.random_uniform(-value, value)
if 'normal' in initializer:
value = initializer['normal']
return v1.initializers.truncated_normal(stddev=value)
raise ValueError(f'Unknown initializer {initializer}')
class LayerNormIndRNNLayer(tf.Module):
def __init__(self,
num_units,
kernel_initializer=None,
recurrent_initializer=None,
bias_initializer=None,
kernel_transform=None,
recurrent_transform=None,
bias_transform=None,
zoneout=0.0,
dtype=None,
name=None):
super().__init__(name)
self.realname = name
self.num_units = num_units
identity = lambda x: x
self.kernel_config = WeightConfig(v1.initializers.glorot_uniform(), None, identity)
self.recurrent_config = WeightConfig(v1.initializers.random_uniform(-0.5, 0.5), None, identity)
self.bias_config = WeightConfig(v1.initializers.zeros(), None, identity)
self.kernel_config.override(_get_initializer(kernel_initializer), None, kernel_transform)
self.recurrent_config.override(_get_initializer(recurrent_initializer), None, recurrent_transform)
self.bias_config.override(_get_initializer(bias_initializer), None, bias_transform)
self.zoneout = zoneout
self.dtype = dtype or tf.float32
self.kernel = None
self.recurrent_scale = None
self.bias = None
self.gamma = None
self.recurrent_bias = None
self.built = False
def build(self, shape):
if self.built:
return
num_units = self.num_units
input_size = int(shape[-1])
kernel_shape = tf.TensorShape([input_size, num_units])
recurrent_shape = tf.TensorShape([num_units])
bias_shape = tf.TensorShape([num_units])
kernel_weights = self.kernel_config.initializer(kernel_shape, dtype=self.dtype)
recurrent_weights = self.recurrent_config.initializer(recurrent_shape, dtype=self.dtype)
biases = self.bias_config.initializer(bias_shape)
with self.name_scope, v1.variable_scope(self.realname, 'indrnn_cell'):
self.kernel = v1.get_variable('kernel', initializer=kernel_weights)
self.recurrent_scale = v1.get_variable('recurrent_scale', initializer=recurrent_weights)
self.bias = v1.get_variable('bias', initializer=biases)
self.gamma = v1.get_variable('gamma', shape=[2, self.num_units], initializer=v1.initializers.ones())
self.built = True
def get_weights(self):
return {
'kernel': self.kernel_config.transform(self.kernel),
'recurrent_scale': self.recurrent_config.transform(self.recurrent_scale),
'bias': self.bias_config.transform(self.bias),
'gamma': self.gamma,
}
def __call__(self, inputs, sequence_length, training):
self.build(inputs.shape)
shape = tf.shape(inputs)
time_steps = shape[0]
batch_size = shape[1]
# Use an empty zoneout mask if no zoneout is going to be applied.
# Sadly, we can't pass `None` to the op but at least we won't be wasting
# memory or bandwidth on this tensor.
zoneout_mask = tf.zeros([0, 0, 0], dtype=self.dtype)
if self.zoneout:
zoneout_mask = 1.0 - self.zoneout
zoneout_mask += tf.random.uniform([time_steps, batch_size, self.num_units], dtype=self.dtype)
zoneout_mask = tf.floor(zoneout_mask)
weights = self.get_weights()
result, _ = LIB.haste_layer_norm_indrnn(
inputs,
weights['kernel'],
weights['recurrent_scale'],
weights['bias'],
weights['gamma'],
zoneout_mask,
training=training,
zoneout_prob=self.zoneout)
if sequence_length is not None:
# 0-indexed tensors, so length-1.
indices = sequence_length
indices = tf.stack([indices, tf.range(batch_size, dtype=sequence_length.dtype)], axis=-1)
state = tf.gather_nd(result, indices)
else:
state = result[-1]
return result[1:], state
class LayerNormIndRNN(BaseRNN):
"""
Layer Normalized Independently Recurrent Neural Network layer.
This IndRNN layer applies layer normalization to the input activations of a
standard IndRNN. The implementation is fused and GPU-accelerated.
This layer has built-in support for Zoneout regularization.
"""
def __init__(self, num_units, direction='unidirectional', **kwargs):
"""
Initialize the parameters of the IndRNN layer.
Arguments:
num_units: int, the number of units in the IndRNN cell.
direction: string, 'unidirectional' or 'bidirectional'.
**kwargs: Dict, keyword arguments (see below).
Keyword Arguments:
kernel_initializer: (optional) the initializer to use for the input
matrix weights. Defaults to `glorot_uniform`.
recurrent_initializer: (optional) the initializer to use for the
recurrent scale weights. Defaults to uniform random in [-0.5, 0.5].
Note that this initialization scheme is different than in the original
authors' implementation. See https://github.com/lmnt-com/haste/issues/7
for details.
bias_initializer: (optional) the initializer to use for the bias vector.
Defaults to `zeros`.
kernel_transform: (optional) a function with signature
`(kernel: Tensor) -> Tensor` that transforms the kernel before it is
used. Defaults to the identity function.
recurrent_transform: (optional) a function with signature
`(recurrent_scale: Tensor) -> Tensor` that transforms the recurrent
scale vector before it is used. Defaults to the identity function.
bias_transform: (optional) a function with signature
`(bias: Tensor) -> Tensor` that transforms the bias before it is used.
Defaults to the identity function.
zoneout: (optional) float, sets the zoneout rate for Zoneout
regularization. Defaults to 0.
dtype: (optional) the data type for this layer. Defaults to `tf.float32`.
name: (optional) string, the name for this layer.
"""
super().__init__(LayerNormIndRNNLayer, num_units, direction, 'indrnn_cell', **kwargs)