forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnnet-utils.h
More file actions
247 lines (193 loc) · 10.1 KB
/
nnet-utils.h
File metadata and controls
247 lines (193 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
// nnet3/nnet-utils.h
// Copyright 2015 Johns Hopkins University (author: Daniel Povey)
// 2016 Daniel Galvez
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET3_NNET_UTILS_H_
#define KALDI_NNET3_NNET_UTILS_H_
#include "base/kaldi-common.h"
#include "util/kaldi-io.h"
#include "matrix/matrix-lib.h"
#include "nnet3/nnet-common.h"
#include "nnet3/nnet-component-itf.h"
#include "nnet3/nnet-descriptor.h"
#include "nnet-computation-graph.h"
namespace kaldi {
namespace nnet3 {
/// \file nnet3/nnet-utils.h
/// This file contains some miscellaneous functions dealing with class Nnet.
/// Given an nnet and a computation request, this function works out which
/// requested outputs in the computation request are computable; it outputs this
/// information as a vector "is_computable" indexed by the same indexes as
/// request.outputs.
/// It does this by executing some of the early stages of compilation.
void EvaluateComputationRequest(
const Nnet &nnet,
const ComputationRequest &request,
std::vector<std::vector<bool> > *is_computable);
/// returns the number of output nodes of this nnet.
int32 NumOutputNodes(const Nnet &nnet);
/// returns the number of input nodes of this nnet.
int32 NumInputNodes(const Nnet &nnet);
/// Calls SetZero (with the given is_gradient parameter) on all updatable
/// components of the nnet.
void SetZero(bool is_gradient,
Nnet *nnet);
/// Calls PerturbParams (with the given stddev) on all updatable components of
/// the nnet.
void PerturbParams(BaseFloat stddev,
Nnet *nnet);
/// Returns dot product between two networks of the same structure (calls the
/// DotProduct functions of the Updatable components and sums up the return
/// values).
BaseFloat DotProduct(const Nnet &nnet1,
const Nnet &nnet2);
/// Returns dot products between two networks of the same structure (calls the
/// DotProduct functions of the Updatable components and fill in the output
/// vector).
void ComponentDotProducts(const Nnet &nnet1,
const Nnet &nnet2,
VectorBase<BaseFloat> *dot_prod);
/// This function is for printing, to a string, a vector with one element per
/// updatable component of the nnet (e.g. the output of ComponentDotProducts),
/// in a human readable way, as [ component-name1:number1
/// component-name2:number2 ... ].
std::string PrintVectorPerUpdatableComponent(const Nnet &nnet,
const VectorBase<BaseFloat> &vec);
/// This function returns true if the nnet has the following properties:
/// It has an called "output" (other outputs are allowed but may be
/// ignored).
/// It has an input called "input", and possibly an extra input called
/// "ivector", but no other inputs.
/// There are probably some other properties that we really ought to
/// be checking, and we may add more later on.
bool IsSimpleNnet(const Nnet &nnet);
/// Zeroes the component stats in all nonlinear components in the nnet.
void ZeroComponentStats(Nnet *nnet);
/// ComputeNnetContext computes the left-context and right-context of a nnet.
/// The nnet must satisfy IsSimpleNnet(nnet).
///
/// It does this by constructing a ComputationRequest with a certain number of inputs
/// available, outputs can be computed.. It does the same after shifting the time
/// index of the output to all values 0, 1, ... n-1, where n is the output
/// of nnet.Modulus(). Then it returns the largest left context and the largest
/// right context that it infers from any of these computation requests.
void ComputeSimpleNnetContext(const Nnet &nnet,
int32 *left_context,
int32 *right_context);
/// Sets the underlying learning rate for all the components in the nnet to this
/// value. this will get multiplied by the individual learning-rate-factors to
/// produce the actual learning rates.
void SetLearningRate(BaseFloat learning_rate,
Nnet *nnet);
/// Scales the actual learning rate for all the components in the nnet
/// by this factor
void ScaleLearningRate(BaseFloat learning_rate_scale,
Nnet *nnet);
/// Sets the actual learning rates for all the updatable components in the
/// neural net to the values in 'learning_rates' vector
/// (one for each updatable component).
void SetLearningRates(const Vector<BaseFloat> &learning_rates,
Nnet *nnet);
/// Get the learning rates for all the updatable components in the neural net
/// (the output must have dim equal to the number of updatable components).
void GetLearningRates(const Nnet &nnet,
Vector<BaseFloat> *learning_rates);
/// Scales the nnet parameters and stats by this scale.
void ScaleNnet(BaseFloat scale, Nnet *nnet);
/// Scales the parameters of each of the updatable components.
/// Here, scales is a vector of size equal to the number of updatable
/// components
void ScaleNnetComponents(const Vector<BaseFloat> &scales,
Nnet *nnet);
/// Does *dest += alpha * src (affects nnet parameters and
/// stored stats).
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest);
/// Does *dest += alpha * src for updatable components (affect nnet parameters),
/// and *dest += scale * src for other components (affect stored stats).
/// Here, alphas is a vector of size equal to the number of updatable components
void AddNnetComponents(const Nnet &src, const Vector<BaseFloat> &alphas,
BaseFloat scale, Nnet *dest);
/// Returns the total of the number of parameters in the updatable components of
/// the nnet.
int32 NumParameters(const Nnet &src);
/// Copies the nnet parameters to *params, whose dimension must
/// be equal to NumParameters(src).
void VectorizeNnet(const Nnet &src,
VectorBase<BaseFloat> *params);
/// Copies the parameters from params to *dest. the dimension of params must
/// be equal to NumParameters(*dest).
void UnVectorizeNnet(const VectorBase<BaseFloat> ¶ms,
Nnet *dest);
/// Returns the number of updatable components in the nnet.
int32 NumUpdatableComponents(const Nnet &dest);
/// Convert all components of type RepeatedAffineComponent or
/// NaturalGradientRepeatedAffineComponent to BlockAffineComponent in nnet.
void ConvertRepeatedToBlockAffine(Nnet *nnet);
/// This function returns various info about the neural net.
/// If the nnet satisfied IsSimpleNnet(nnet), the info includes "left-context=5\nright-context=3\n...". The info includes
/// the output of nnet.Info().
/// This is modeled after the info that AmNnetSimple returns in its
/// Info() function (we need this in the CTC code).
std::string NnetInfo(const Nnet &nnet);
/// This function sets the dropout proportion in all dropout component to
/// dropout_proportion value.
void SetDropoutProportion(BaseFloat dropout_proportion, bool dropout_per_frame, Nnet *nnet);
/// This function finds a list of components that are never used, and outputs
/// the integer comopnent indexes (you can use these to index
/// nnet.GetComponentNames() to get their names).
void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components);
/// This function finds a list of nodes that are never used to compute any
/// output, and outputs the integer node indexes (you can use these to index
/// nnet.GetNodeNames() to get their names).
void FindOrphanNodes(const Nnet &nnet, std::vector<int32> *nodes);
/**
ReadEditConfig() reads a file with a similar-looking format to the config file
read by Nnet::ReadConfig(), but this consists of a sequence of operations to
perform on an existing network, mostly modifying components. It's one
"directive" (i.e. command) per line.
The following describes the allowed commands. Note: all patterns are like
UNIX globbing patterns where the only metacharacter is '*', representing zero
or more characters.
\verbatim
convert-to-fixed-affine [name=<name-pattern>]
Converts the given affine components to FixedAffineComponent which is not updatable.
remove-orphan-nodes [remove-orphan-inputs=(true|false)]
Removes orphan nodes (that are never used to compute anything). Note:
remove-orphan-inputs defaults to false.
remove-orphan-components
Removes orphan components (those that are never used by any node).
remove-orphans [remove-orphan-inputs=(true|false)]
The same as calling remove-orphan-nodes and then remove-orphan-components.
set-learning-rate [name=<name-pattern>] learning-rate=<learning-rate>
Sets the learning rate for any updatable nodes matching the name pattern.
rename-node old-name=<old-name> new-name=<new-name>
Renames a node; this is a surface renaming that does not affect the structure
(for structural changes, use the regular config file format, not the
edits-config). This is mostly useful for outputs, e.g. when doing
multilingual experiments.
remove-output-nodes name=<name-pattern>
Removes a subset of output nodes, those matching the pattern. You cannot
remove internal nodes directly; instead you should use the command
'remove-orphans'.
set-dropout-proportion [name=<name-pattern>] proportion=<dropout-proportion> dropout-per-frame=<dropout-per-frame>
Sets the dropout rates for any components of type DropoutComponent whose
names match the given <name-pattern> (e.g. lstm*). <name-pattern> defaults to "*".
\endverbatim
*/
void ReadEditConfig(std::istream &config_file, Nnet *nnet);
} // namespace nnet3
} // namespace kaldi
#endif