Skip to content

Commit f9ae938

Browse files
committed
updating the HCLG boosting code, debugging
1 parent d87ec23 commit f9ae938

File tree

2 files changed

+95
-81
lines changed

2 files changed

+95
-81
lines changed

egs/wsj/s5/steps/nnet3/decode_compose.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ min_active=200
2222
ivector_scale=1.0
2323
lattice_beam=8.0 # Beam we use in lattice generation.
2424
iter=final
25-
num_threads=1 # if >1, will use gmm-latgen-faster-parallel
26-
use_gpu=false # If true, will use a GPU, with nnet3-latgen-faster-batch.
25+
#num_threads=1 # if >1, will use gmm-latgen-faster-parallel
26+
#use_gpu=false # If true, will use a GPU, with nnet3-latgen-faster-batch.
2727
# In that case it is recommended to set num-threads to a large
2828
# number, e.g. 20 if you have that many free CPU slots on a GPU
2929
# node, and to use a small number of jobs.

src/nnet3bin/nnet3-latgen-faster-compose.cc

+93-79
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "base/timer.h"
3232

3333
#include <fst/compose.h>
34+
#include <fst/rmepsilon.h>
3435
#include <memory>
3536

3637

@@ -154,106 +155,119 @@ int main(int argc, char *argv[]) {
154155

155156
RandomAccessTableReader<fst::VectorFstHolder> boosting_fst_reader(boosting_fst_rspecifier);
156157

157-
// HCLG FST is just one FST, not a table of FSTs.
158-
auto hclg_fst = std::unique_ptr<VectorFst<StdArc>>(fst::ReadFstKaldi(hclg_fst_rxfilename));
158+
// 'hclg_fst' is a single FST.
159+
VectorFst<StdArc> hclg_fst;
160+
{
161+
auto hclg_fst_tmp = std::unique_ptr<Fst<StdArc>>(fst::ReadFstKaldiGeneric(hclg_fst_rxfilename));
162+
hclg_fst = VectorFst<StdArc>(*hclg_fst_tmp); // Fst -> VectorFst, as it has to be MutableFst...
163+
// 'hclg_fst_tmp' is deleted by 'going out of scope' ...
164+
}
159165

160166
// make sure hclg is sorted on olabel
161-
if (hclg_fst->Properties(fst::kOLabelSorted, true) == 0) {
167+
if (hclg_fst.Properties(fst::kOLabelSorted, true) == 0) {
162168
fst::OLabelCompare<StdArc> olabel_comp;
163-
fst::ArcSort(hclg_fst.get(), olabel_comp);
169+
fst::ArcSort(&hclg_fst, olabel_comp);
164170
}
165171

166172
timer.Reset();
167173

168-
{
169-
170-
for (; !feature_reader.Done(); feature_reader.Next()) {
171-
std::string utt = feature_reader.Key();
172-
const Matrix<BaseFloat> &features (feature_reader.Value());
173-
if (features.NumRows() == 0) {
174-
KALDI_WARN << "Zero-length utterance: " << utt;
174+
//// MAIN LOOP ////
175+
for (; !feature_reader.Done(); feature_reader.Next()) {
176+
std::string utt = feature_reader.Key();
177+
const Matrix<BaseFloat> &features (feature_reader.Value());
178+
if (features.NumRows() == 0) {
179+
KALDI_WARN << "Zero-length utterance: " << utt;
180+
num_fail++;
181+
continue;
182+
}
183+
const Matrix<BaseFloat> *online_ivectors = NULL;
184+
const Vector<BaseFloat> *ivector = NULL;
185+
if (!ivector_rspecifier.empty()) {
186+
if (!ivector_reader.HasKey(utt)) {
187+
KALDI_WARN << "No iVector available for utterance " << utt;
175188
num_fail++;
176189
continue;
190+
} else {
191+
ivector = &ivector_reader.Value(utt);
177192
}
178-
const Matrix<BaseFloat> *online_ivectors = NULL;
179-
const Vector<BaseFloat> *ivector = NULL;
180-
if (!ivector_rspecifier.empty()) {
181-
if (!ivector_reader.HasKey(utt)) {
182-
KALDI_WARN << "No iVector available for utterance " << utt;
183-
num_fail++;
184-
continue;
185-
} else {
186-
ivector = &ivector_reader.Value(utt);
187-
}
188-
}
189-
if (!online_ivector_rspecifier.empty()) {
190-
if (!online_ivector_reader.HasKey(utt)) {
191-
KALDI_WARN << "No online iVector available for utterance " << utt;
192-
num_fail++;
193-
continue;
194-
} else {
195-
online_ivectors = &online_ivector_reader.Value(utt);
196-
}
197-
}
198-
199-
// get the boosting graph,
200-
VectorFst<StdArc> boosting_fst;
201-
if (!boosting_fst_reader.HasKey(utt)) {
202-
KALDI_WARN << "No boosting fst for utterance " << utt;
193+
}
194+
if (!online_ivector_rspecifier.empty()) {
195+
if (!online_ivector_reader.HasKey(utt)) {
196+
KALDI_WARN << "No online iVector available for utterance " << utt;
203197
num_fail++;
204198
continue;
205199
} else {
206-
boosting_fst = boosting_fst_reader.Value(utt); // copy,
200+
online_ivectors = &online_ivector_reader.Value(utt);
207201
}
202+
}
208203

209-
timer_compose.Reset();
210-
211-
// make sure boosting graph is sorted on ilabel,
212-
if (boosting_fst.Properties(fst::kILabelSorted, true) == 0) {
213-
fst::ILabelCompare<StdArc> ilabel_comp;
214-
fst::ArcSort(&boosting_fst, ilabel_comp);
215-
}
204+
// get the boosting graph,
205+
VectorFst<StdArc> boosting_fst;
206+
if (!boosting_fst_reader.HasKey(utt)) {
207+
KALDI_WARN << "No boosting fst for utterance " << utt;
208+
num_fail++;
209+
continue;
210+
} else {
211+
boosting_fst = boosting_fst_reader.Value(utt); // copy,
212+
}
216213

217-
// TODO: should we call rmepsilon on boosting_fst ?
214+
timer_compose.Reset();
218215

219-
// run composition (measure time),
220-
VectorFst<StdArc> decode_fst;
221-
fst::Compose(*hclg_fst, boosting_fst, &decode_fst);
216+
// RmEpsilon saved 30% of composition runtime...
217+
// - Note: we are loading 2-state graphs with eps back-link to the initial state.
218+
if (boosting_fst.Properties(fst::kIEpsilons, true) != 0) {
219+
fst::RmEpsilon(&boosting_fst);
220+
}
222221

223-
// TODO: should we sort the 'decode_fst' by isymbols ?
224-
// (we don't do it, as it would take time.
225-
// not sure it decoding would be faster if
226-
// decode_fst was sorted by isymbols)
222+
// make sure boosting graph is sorted on ilabel,
223+
if (boosting_fst.Properties(fst::kILabelSorted, true) == 0) {
224+
fst::ILabelCompare<StdArc> ilabel_comp;
225+
fst::ArcSort(&boosting_fst, ilabel_comp);
226+
}
227227

228-
// Check that composed graph is non-empty,
229-
if (decode_fst.Start() == fst::kNoStateId) {
230-
KALDI_WARN << "Empty 'decode_fst' HCLG for utterance "
231-
<< utt << " (bad boosting graph?)";
232-
num_fail++;
233-
continue;
234-
}
228+
// run composition,
229+
VectorFst<StdArc> decode_fst;
230+
fst::Compose(hclg_fst, boosting_fst, &decode_fst);
235231

236-
elapsed_compose += timer_compose.Elapsed();
237-
238-
DecodableAmNnetSimple nnet_decodable(
239-
decodable_opts, trans_model, am_nnet,
240-
features, ivector, online_ivectors,
241-
online_ivector_period, &compiler);
242-
243-
LatticeFasterDecoder decoder(decode_fst, config);
244-
245-
double like;
246-
if (DecodeUtteranceLatticeFaster(
247-
decoder, nnet_decodable, trans_model, word_syms.get(), utt,
248-
decodable_opts.acoustic_scale, determinize, allow_partial,
249-
&alignment_writer, &words_writer, &compact_lattice_writer,
250-
&lattice_writer,
251-
&like)) {
252-
tot_like += like;
253-
frame_count += nnet_decodable.NumFramesReady();
254-
num_success++;
255-
} else num_fail++;
232+
// check that composed graph is non-empty,
233+
if (decode_fst.Start() == fst::kNoStateId) {
234+
KALDI_WARN << "Empty 'decode_fst' HCLG for utterance "
235+
<< utt << " (bad boosting graph?)";
236+
num_fail++;
237+
continue;
256238
}
239+
240+
elapsed_compose += timer_compose.Elapsed();
241+
242+
DecodableAmNnetSimple nnet_decodable(
243+
decodable_opts, trans_model, am_nnet,
244+
features, ivector, online_ivectors,
245+
online_ivector_period, &compiler);
246+
247+
// Note: decode_fst is VectorFst, not ConstFst.
248+
//
249+
// OpenFst docs say that more specific iterators
250+
// are faster than generic iterators. And in HCLG
251+
// is usually loaded for decoding as ConstFst.
252+
//
253+
// auto decode_fst_ = ConstFst<StdArc>(decode_fst);
254+
//
255+
// In this way, I tried to cast VectorFst to ConstFst,
256+
// but this made the decoding 20% slower.
257+
//
258+
LatticeFasterDecoder decoder(decode_fst, config);
259+
260+
double like;
261+
if (DecodeUtteranceLatticeFaster(
262+
decoder, nnet_decodable, trans_model, word_syms.get(), utt,
263+
decodable_opts.acoustic_scale, determinize, allow_partial,
264+
&alignment_writer, &words_writer, &compact_lattice_writer,
265+
&lattice_writer,
266+
&like)) {
267+
tot_like += like;
268+
frame_count += nnet_decodable.NumFramesReady();
269+
num_success++;
270+
} else num_fail++;
257271
}
258272
}
259273

0 commit comments

Comments
 (0)