Skip to content

Commit f2b6724

Browse files
committed
Tool for lattice rescoring by composing with per-utterance FSTs.
1 parent fe230a0 commit f2b6724

File tree

2 files changed

+193
-2
lines changed

2 files changed

+193
-2
lines changed

src/latbin/Makefile

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \
2626
lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons \
2727
lattice-arc-post lattice-determinize-non-compact lattice-lmrescore-kaldi-rnnlm \
2828
lattice-lmrescore-pruned lattice-lmrescore-kaldi-rnnlm-pruned lattice-reverse \
29-
lattice-expand lattice-path-cover lattice-add-nnlmscore
29+
lattice-expand lattice-path-cover lattice-add-nnlmscore \
30+
lattice-compose-fsts
3031

3132
OBJFILES =
3233

@@ -36,6 +37,6 @@ TESTFILES =
3637
ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../nnet3/kaldi-nnet3.a \
3738
../cudamatrix/kaldi-cudamatrix.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \
3839
../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \
39-
../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a
40+
../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a
4041

4142
include ../makefiles/default_rules.mk

src/latbin/lattice-compose-fsts.cc

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// latbin/lattice-compose-fsts.cc
2+
3+
// Copyright 2020 Brno University of Technology; Microsoft Corporation
4+
5+
// See ../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
21+
#include "base/kaldi-common.h"
22+
#include "util/common-utils.h"
23+
#include "fstext/fstext-lib.h"
24+
#include "lat/kaldi-lattice.h"
25+
#include "lat/lattice-functions.h"
26+
27+
int main(int argc, char *argv[]) {
28+
try {
29+
using namespace kaldi;
30+
typedef kaldi::int32 int32;
31+
typedef kaldi::int64 int64;
32+
using fst::SymbolTable;
33+
using fst::VectorFst;
34+
using fst::StdArc;
35+
36+
const char *usage =
37+
"Composes lattices (in transducer form, as type Lattice) with word-network FSTs.\n"
38+
"Either with a single FST from rxfilename or with per-utterance FSTs from rspecifier.\n"
39+
"The FST weights are interpreted as \"graph weights\" when converted into the Lattice format.\n"
40+
"\n"
41+
"Usage: lattice-compose-fsts [options] lattice-rspecifier1 "
42+
"(fst-rspecifier2|fst-rxfilename2) lattice-wspecifier\n"
43+
" e.g.: lattice-compose-fsts ark:1.lats ark:2.fsts ark:composed.lats\n"
44+
" or: lattice-compose-fsts ark:1.lats G.fst ark:composed.lats\n";
45+
46+
ParseOptions po(usage);
47+
48+
bool write_compact = true;
49+
int32 num_states_cache = 50000;
50+
int32 phi_label = fst::kNoLabel; // == -1
51+
po.Register("write-compact", &write_compact, "If true, write in normal (compact) form.");
52+
po.Register("phi-label", &phi_label, "If >0, the label on backoff arcs of the LM");
53+
po.Register("num-states-cache", &num_states_cache,
54+
"Number of states we cache when mapping LM FST to lattice type. "
55+
"More -> more memory but faster.");
56+
po.Read(argc, argv);
57+
58+
if (po.NumArgs() != 3) {
59+
po.PrintUsage();
60+
exit(1);
61+
}
62+
63+
KALDI_ASSERT(phi_label > 0 || phi_label == fst::kNoLabel); // e.g. 0 not allowed.
64+
65+
std::string lats_rspecifier1 = po.GetArg(1),
66+
arg2 = po.GetArg(2),
67+
lats_wspecifier = po.GetArg(3);
68+
int32 n_done = 0, n_fail = 0;
69+
70+
SequentialLatticeReader lattice_reader1(lats_rspecifier1);
71+
72+
CompactLatticeWriter compact_lattice_writer;
73+
LatticeWriter lattice_writer;
74+
75+
if (write_compact)
76+
compact_lattice_writer.Open(lats_wspecifier);
77+
else
78+
lattice_writer.Open(lats_wspecifier);
79+
80+
if (ClassifyRspecifier(arg2, NULL, NULL) == kNoRspecifier) {
81+
std::string fst_rxfilename = arg2;
82+
VectorFst<StdArc>* fst2 = fst::ReadFstKaldi(fst_rxfilename);
83+
// mapped_fst2 is fst2 interpreted using the LatticeWeight semiring,
84+
// with all the cost on the first member of the pair (since we're
85+
// assuming it's a graph weight).
86+
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
87+
// Make sure fst2 is sorted on ilabel.
88+
fst::ILabelCompare<StdArc> ilabel_comp;
89+
ArcSort(fst2, ilabel_comp);
90+
}
91+
if (phi_label > 0)
92+
PropagateFinal(phi_label, fst2);
93+
94+
fst::CacheOptions cache_opts(true, num_states_cache);
95+
fst::MapFstOptions mapfst_opts(cache_opts);
96+
fst::StdToLatticeMapper<BaseFloat> mapper;
97+
fst::MapFst<StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> >
98+
mapped_fst2(*fst2, mapper, mapfst_opts);
99+
100+
for (; !lattice_reader1.Done(); lattice_reader1.Next()) {
101+
std::string key = lattice_reader1.Key();
102+
KALDI_VLOG(1) << "Processing lattice for key " << key;
103+
Lattice lat1 = lattice_reader1.Value();
104+
ArcSort(&lat1, fst::OLabelCompare<LatticeArc>());
105+
Lattice composed_lat;
106+
if (phi_label > 0) PhiCompose(lat1, mapped_fst2, phi_label, &composed_lat);
107+
else Compose(lat1, mapped_fst2, &composed_lat);
108+
if (composed_lat.Start() == fst::kNoStateId) {
109+
KALDI_WARN << "Empty lattice for utterance " << key << " (incompatible LM?)";
110+
n_fail++;
111+
} else {
112+
if (write_compact) {
113+
CompactLattice clat;
114+
ConvertLattice(composed_lat, &clat);
115+
compact_lattice_writer.Write(key, clat);
116+
} else {
117+
lattice_writer.Write(key, composed_lat);
118+
}
119+
n_done++;
120+
}
121+
}
122+
delete fst2;
123+
} else {
124+
// composing with each utterance with different fst,
125+
std::string fst_rspecifier2 = arg2;
126+
RandomAccessTableReader<fst::VectorFstHolder> fst_reader2(fst_rspecifier2);
127+
128+
for (; !lattice_reader1.Done(); lattice_reader1.Next()) {
129+
std::string key = lattice_reader1.Key();
130+
KALDI_VLOG(1) << "Processing lattice for key " << key;
131+
Lattice lat1 = lattice_reader1.Value();
132+
lattice_reader1.FreeCurrent();
133+
134+
if (!fst_reader2.HasKey(key)) {
135+
KALDI_WARN << "Not producing output for utterance " << key
136+
<< " because not present in second table.";
137+
n_fail++;
138+
continue;
139+
}
140+
141+
VectorFst<StdArc> fst2 = fst_reader2.Value(key);
142+
if (fst2.Properties(fst::kILabelSorted, true) == 0) {
143+
// Make sure fst2 is sorted on ilabel.
144+
fst::ILabelCompare<StdArc> ilabel_comp;
145+
fst::ArcSort(&fst2, ilabel_comp);
146+
}
147+
if (phi_label > 0)
148+
PropagateFinal(phi_label, &fst2);
149+
150+
// mapped_fst2 is fst2 interpreted using the LatticeWeight semiring,
151+
// with all the cost on the first member of the pair (since we're
152+
// assuming it's a graph weight).
153+
fst::CacheOptions cache_opts(true, num_states_cache);
154+
fst::MapFstOptions mapfst_opts(cache_opts);
155+
fst::StdToLatticeMapper<BaseFloat> mapper;
156+
fst::MapFst<StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> >
157+
mapped_fst2(fst2, mapper, mapfst_opts);
158+
159+
// sort lat1 on olabel.
160+
ArcSort(&lat1, fst::OLabelCompare<LatticeArc>());
161+
162+
Lattice composed_lat;
163+
if (phi_label > 0) PhiCompose(lat1, mapped_fst2, phi_label, &composed_lat);
164+
else Compose(lat1, mapped_fst2, &composed_lat);
165+
166+
if (composed_lat.Start() == fst::kNoStateId) {
167+
KALDI_WARN << "Empty lattice for utterance " << key << " (incompatible LM?)";
168+
n_fail++;
169+
} else {
170+
if (write_compact) {
171+
CompactLattice clat;
172+
ConvertLattice(composed_lat, &clat);
173+
compact_lattice_writer.Write(key, clat);
174+
} else {
175+
lattice_writer.Write(key, composed_lat);
176+
}
177+
n_done++;
178+
}
179+
}
180+
}
181+
182+
KALDI_LOG << "Done " << n_done << " lattices; failed for "
183+
<< n_fail;
184+
185+
return (n_done != 0 ? 0 : 1);
186+
} catch(const std::exception &e) {
187+
std::cerr << e.what();
188+
return -1;
189+
}
190+
}

0 commit comments

Comments
 (0)