Skip to content

Commit ecb0ce7

Browse files
committed
nnet3-latgen-faster-compose, online composition of HCLG graph with boosting graph
1 parent f2b6724 commit ecb0ce7

File tree

4 files changed

+283
-2
lines changed

4 files changed

+283
-2
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ core
3737
.[#]*
3838
*~
3939

40+
# vim autosave and backup files.
41+
*.sw?
42+
4043
# [ecg]tag files.
4144
TAGS
4245
tags

src/latbin/lattice-compose-fsts.cc

+4
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ int main(int argc, char *argv[]) {
8888
fst::ILabelCompare<StdArc> ilabel_comp;
8989
ArcSort(fst2, ilabel_comp);
9090
}
91+
/* // THIS MAKES ALL STATES FINAL STATES! WHY?
9192
if (phi_label > 0)
9293
PropagateFinal(phi_label, fst2);
94+
*/
9395

9496
fst::CacheOptions cache_opts(true, num_states_cache);
9597
fst::MapFstOptions mapfst_opts(cache_opts);
@@ -144,8 +146,10 @@ int main(int argc, char *argv[]) {
144146
fst::ILabelCompare<StdArc> ilabel_comp;
145147
fst::ArcSort(&fst2, ilabel_comp);
146148
}
149+
/* // THIS MAKES ALL STATES FINAL STATES! WHY?
147150
if (phi_label > 0)
148151
PropagateFinal(phi_label, &fst2);
152+
*/
149153

150154
// mapped_fst2 is fst2 interpreted using the LatticeWeight semiring,
151155
// with all the cost on the first member of the pair (since we're

src/nnet3bin/Makefile

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \
2222
nnet3-xvector-compute-batched \
2323
nnet3-latgen-grammar nnet3-compute-batch nnet3-latgen-faster-batch \
2424
nnet3-latgen-faster-lookahead cuda-gpu-available cuda-compiled \
25-
nnet3-latgen-faster-looped-parallel
25+
nnet3-latgen-faster-looped-parallel \
26+
nnet3-latgen-faster-compose
2627

2728
OBJFILES =
2829

@@ -37,7 +38,7 @@ ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \
3738
../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \
3839
../transform/kaldi-transform.a ../ivector/kaldi-ivector.a ../gmm/kaldi-gmm.a \
3940
../tree/kaldi-tree.a ../util/kaldi-util.a \
40-
../matrix/kaldi-matrix.a ../base/kaldi-base.a
41+
../matrix/kaldi-matrix.a ../base/kaldi-base.a
4142

4243

4344
include ../makefiles/default_rules.mk
+273
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
// nnet3bin/nnet3-latgen-faster-compose.cc
2+
3+
// Copyright 2020 Brno University of Technology (author: Karel Vesely)
4+
// 2012-2015 Johns Hopkins University (author: Daniel Povey)
5+
// 2014 Guoguo Chen
6+
7+
// See ../../COPYING for clarification regarding multiple authors
8+
//
9+
// Licensed under the Apache License, Version 2.0 (the "License");
10+
// you may not use this file except in compliance with the License.
11+
// You may obtain a copy of the License at
12+
//
13+
// http://www.apache.org/licenses/LICENSE-2.0
14+
//
15+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18+
// MERCHANTABLITY OR NON-INFRINGEMENT.
19+
// See the Apache 2 License for the specific language governing permissions and
20+
// limitations under the License.
21+
22+
23+
#include "base/kaldi-common.h"
24+
#include "util/common-utils.h"
25+
#include "tree/context-dep.h"
26+
#include "hmm/transition-model.h"
27+
#include "fstext/fstext-lib.h"
28+
#include "decoder/decoder-wrappers.h"
29+
#include "nnet3/nnet-am-decodable-simple.h"
30+
#include "nnet3/nnet-utils.h"
31+
#include "base/timer.h"
32+
33+
#include <fst/compose.h>
34+
#include <memory>
35+
36+
37+
int main(int argc, char *argv[]) {
38+
// note: making this program work with GPUs is as simple as initializing the
39+
// device, but it probably won't make a huge difference in speed for typical
40+
// setups. You should use nnet3-latgen-faster-batch if you want to use a GPU.
41+
try {
42+
using namespace kaldi;
43+
using namespace kaldi::nnet3;
44+
typedef kaldi::int32 int32;
45+
using fst::SymbolTable;
46+
using fst::Fst;
47+
using fst::VectorFst;
48+
using fst::StdArc;
49+
50+
const char *usage =
51+
"Generate lattices using nnet3 neural net model, with on-the-fly composition HCLG o B.\n"
52+
"B is utterance-specific boosting graph, typically a single-state FST with\n"
53+
"all words from words.txt on self loop arcs (then composition is not prohibitevly slow).\n"
54+
"Some word-arcs will have score discounts as costs, to boost them in HMM beam-search.\n"
55+
"Or, by not including words in B, we can remove them from HCLG network.\n"
56+
"Usage: nnet3-latgen-faster-compose [options] <nnet-in> <fst-in> <boost-fsts-rspecifier> <features-rspecifier>"
57+
" <lattice-wspecifier> [ <words-wspecifier> [<alignments-wspecifier>] ]\n"
58+
"See also: nnet3-latgen-faster-parallel, nnet3-latgen-faster-batch\n";
59+
60+
ParseOptions po(usage);
61+
62+
Timer timer, timer_compose;
63+
double elapsed_compose = 0.0;
64+
65+
bool allow_partial = false;
66+
LatticeFasterDecoderConfig config;
67+
NnetSimpleComputationOptions decodable_opts;
68+
69+
std::string word_syms_filename;
70+
std::string ivector_rspecifier,
71+
online_ivector_rspecifier,
72+
utt2spk_rspecifier;
73+
int32 online_ivector_period = 0;
74+
config.Register(&po);
75+
decodable_opts.Register(&po);
76+
po.Register("word-symbol-table", &word_syms_filename,
77+
"Symbol table for words [for debug output]");
78+
po.Register("allow-partial", &allow_partial,
79+
"If true, produce output even if end state was not reached.");
80+
po.Register("ivectors", &ivector_rspecifier, "Rspecifier for "
81+
"iVectors as vectors (i.e. not estimated online); per utterance "
82+
"by default, or per speaker if you provide the --utt2spk option.");
83+
po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for "
84+
"utt2spk option used to get ivectors per speaker");
85+
po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for "
86+
"iVectors estimated online, as matrices. If you supply this,"
87+
" you must set the --online-ivector-period option.");
88+
po.Register("online-ivector-period", &online_ivector_period, "Number of frames "
89+
"between iVectors in matrices supplied to the --online-ivectors "
90+
"option");
91+
92+
po.Read(argc, argv);
93+
94+
if (po.NumArgs() < 4 || po.NumArgs() > 6) {
95+
po.PrintUsage();
96+
exit(1);
97+
}
98+
99+
std::string model_in_filename = po.GetArg(1),
100+
hclg_fst_rxfilename = po.GetArg(2),
101+
boosting_fst_rspecifier = po.GetArg(3),
102+
feature_rspecifier = po.GetArg(4),
103+
lattice_wspecifier = po.GetArg(5),
104+
words_wspecifier = po.GetOptArg(6),
105+
alignment_wspecifier = po.GetOptArg(7);
106+
107+
TransitionModel trans_model;
108+
AmNnetSimple am_nnet;
109+
{
110+
bool binary;
111+
Input ki(model_in_filename, &binary);
112+
trans_model.Read(ki.Stream(), binary);
113+
am_nnet.Read(ki.Stream(), binary);
114+
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
115+
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
116+
CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));
117+
}
118+
119+
bool determinize = config.determinize_lattice;
120+
CompactLatticeWriter compact_lattice_writer;
121+
LatticeWriter lattice_writer;
122+
if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
123+
: lattice_writer.Open(lattice_wspecifier)))
124+
KALDI_ERR << "Could not open table for writing lattices: "
125+
<< lattice_wspecifier;
126+
127+
RandomAccessBaseFloatMatrixReader online_ivector_reader(
128+
online_ivector_rspecifier);
129+
RandomAccessBaseFloatVectorReaderMapped ivector_reader(
130+
ivector_rspecifier, utt2spk_rspecifier);
131+
132+
Int32VectorWriter words_writer(words_wspecifier);
133+
Int32VectorWriter alignment_writer(alignment_wspecifier);
134+
135+
std::unique_ptr<fst::SymbolTable> word_syms = nullptr;
136+
if (word_syms_filename != "") {
137+
word_syms.reset(fst::SymbolTable::ReadText(word_syms_filename));
138+
if (!word_syms)
139+
KALDI_ERR << "Could not read symbol table from file "
140+
<< word_syms_filename;
141+
}
142+
143+
double tot_like = 0.0;
144+
kaldi::int64 frame_count = 0;
145+
int num_success = 0, num_fail = 0;
146+
// this compiler object allows caching of computations across
147+
// different utterances.
148+
CachingOptimizingCompiler compiler(am_nnet.GetNnet(),
149+
decodable_opts.optimize_config);
150+
151+
KALDI_ASSERT(ClassifyRspecifier(hclg_fst_rxfilename, NULL, NULL) == kNoRspecifier);
152+
{
153+
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
154+
155+
RandomAccessTableReader<fst::VectorFstHolder> boosting_fst_reader(boosting_fst_rspecifier);
156+
157+
// HCLG FST is just one FST, not a table of FSTs.
158+
auto hclg_fst = std::unique_ptr<VectorFst<StdArc>>(fst::ReadFstKaldi(hclg_fst_rxfilename));
159+
160+
// make sure hclg is sorted on olabel
161+
if (hclg_fst->Properties(fst::kOLabelSorted, true) == 0) {
162+
fst::OLabelCompare<StdArc> olabel_comp;
163+
fst::ArcSort(hclg_fst.get(), olabel_comp);
164+
}
165+
166+
timer.Reset();
167+
168+
{
169+
170+
for (; !feature_reader.Done(); feature_reader.Next()) {
171+
std::string utt = feature_reader.Key();
172+
const Matrix<BaseFloat> &features (feature_reader.Value());
173+
if (features.NumRows() == 0) {
174+
KALDI_WARN << "Zero-length utterance: " << utt;
175+
num_fail++;
176+
continue;
177+
}
178+
const Matrix<BaseFloat> *online_ivectors = NULL;
179+
const Vector<BaseFloat> *ivector = NULL;
180+
if (!ivector_rspecifier.empty()) {
181+
if (!ivector_reader.HasKey(utt)) {
182+
KALDI_WARN << "No iVector available for utterance " << utt;
183+
num_fail++;
184+
continue;
185+
} else {
186+
ivector = &ivector_reader.Value(utt);
187+
}
188+
}
189+
if (!online_ivector_rspecifier.empty()) {
190+
if (!online_ivector_reader.HasKey(utt)) {
191+
KALDI_WARN << "No online iVector available for utterance " << utt;
192+
num_fail++;
193+
continue;
194+
} else {
195+
online_ivectors = &online_ivector_reader.Value(utt);
196+
}
197+
}
198+
199+
// get the boosting graph,
200+
VectorFst<StdArc> boosting_fst;
201+
if (!boosting_fst_reader.HasKey(utt)) {
202+
KALDI_WARN << "No boosting fst for utterance " << utt;
203+
num_fail++;
204+
continue;
205+
} else {
206+
boosting_fst = boosting_fst_reader.Value(utt); // copy,
207+
}
208+
209+
timer_compose.Reset();
210+
211+
// make sure boosting graph is sorted on ilabel,
212+
if (boosting_fst.Properties(fst::kILabelSorted, true) == 0) {
213+
fst::ILabelCompare<StdArc> ilabel_comp;
214+
fst::ArcSort(&boosting_fst, ilabel_comp);
215+
}
216+
217+
// TODO: should we call rmepsilon on boosting_fst ?
218+
219+
// run composition (measure time),
220+
VectorFst<StdArc> decode_fst;
221+
fst::Compose(*hclg_fst, boosting_fst, &decode_fst);
222+
223+
// TODO: should we sort the 'decode_fst' by isymbols ?
224+
// (we don't do it, as it would take time.
225+
// not sure it decoding would be faster if
226+
// decode_fst was sorted by isymbols)
227+
228+
elapsed_compose += timer_compose.Elapsed();
229+
230+
DecodableAmNnetSimple nnet_decodable(
231+
decodable_opts, trans_model, am_nnet,
232+
features, ivector, online_ivectors,
233+
online_ivector_period, &compiler);
234+
235+
LatticeFasterDecoder decoder(decode_fst, config);
236+
237+
double like;
238+
if (DecodeUtteranceLatticeFaster(
239+
decoder, nnet_decodable, trans_model, word_syms.get(), utt,
240+
decodable_opts.acoustic_scale, determinize, allow_partial,
241+
&alignment_writer, &words_writer, &compact_lattice_writer,
242+
&lattice_writer,
243+
&like)) {
244+
tot_like += like;
245+
frame_count += nnet_decodable.NumFramesReady();
246+
num_success++;
247+
} else num_fail++;
248+
}
249+
}
250+
}
251+
252+
kaldi::int64 input_frame_count =
253+
frame_count * decodable_opts.frame_subsampling_factor;
254+
255+
double elapsed = timer.Elapsed();
256+
KALDI_LOG << "Time taken "<< elapsed
257+
<< "s: real-time factor assuming 100 frames/sec is "
258+
<< (elapsed * 100.0 / input_frame_count);
259+
KALDI_LOG << "Composition time "<< elapsed_compose
260+
<< "s (" << (elapsed_compose * 100.0 / elapsed) << "%)";
261+
KALDI_LOG << "Done " << num_success << " utterances, failed for "
262+
<< num_fail;
263+
KALDI_LOG << "Overall log-likelihood per frame is "
264+
<< (tot_like / frame_count) << " over "
265+
<< frame_count << " frames.";
266+
267+
if (num_success != 0) return 0;
268+
else return 1;
269+
} catch(const std::exception &e) {
270+
std::cerr << e.what();
271+
return -1;
272+
}
273+
}

0 commit comments

Comments
 (0)