Skip to content

Commit 77a8feb

Browse files
committed
feat(poet): make multiple sentences
1 parent 422ad2d commit 77a8feb

3 files changed

Lines changed: 100 additions & 1 deletion

File tree

src/rime/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <utility>
2020
#include <vector>
21+
#include <deque>
2122
#define BOOST_BIND_NO_PLACEHOLDERS
2223
#include <boost/signals2/connection.hpp>
2324
#include <boost/signals2/signal.hpp>
@@ -37,6 +38,7 @@
3738

3839
namespace rime {
3940

41+
using std::deque;
4042
using std::function;
4143
using std::list;
4244
using std::make_pair;

src/rime/gear/poet.cc

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct Line {
2626
const DictEntry* entry;
2727
size_t end_pos;
2828
double weight;
29+
size_t text_hash; // for dedup
2930

3031
static const Line kEmpty;
3132

@@ -68,7 +69,7 @@ struct Line {
6869
}
6970
};
7071

71-
const Line Line::kEmpty{nullptr, nullptr, 0, 0.0};
72+
const Line Line::kEmpty{nullptr, nullptr, 0, 0.0, 0};
7273

7374
inline static Grammar* create_grammar(Config* config) {
7475
if (auto* grammar = Grammar::Require("grammar")) {
@@ -251,4 +252,96 @@ an<Sentence> Poet::MakeSentence(const WordGraph& graph,
251252
graph, total_length, preceding_text);
252253
}
253254

255+
// Make `max_sentences` sentences using beam search and dp on word graph.
256+
//
257+
// There is no strategy because it unconditionally use grammar.
258+
deque<an<Sentence>> Poet::MakeSentences(const WordGraph& graph,
259+
size_t total_length,
260+
const string& preceding_text,
261+
size_t max_sentences) {
262+
size_t beam_width =
263+
max_sentences * 3; // allow more possibilities during search
264+
using State = std::list<Line>;
265+
map<int, State> states;
266+
states[0].push_back(Line::kEmpty);
267+
for (const auto& sv : graph) {
268+
size_t start_pos = sv.first;
269+
if (states.find(start_pos) == states.end())
270+
continue;
271+
272+
const auto& source_state = states[start_pos];
273+
for (const auto& ev : sv.second) {
274+
size_t end_pos = ev.first;
275+
if (start_pos == 0 && end_pos == total_length)
276+
continue;
277+
const DictEntryList& entries = ev.second;
278+
bool is_rear = end_pos == total_length;
279+
auto& target_state = states[end_pos];
280+
281+
for (const auto& source_line : source_state) {
282+
for (const auto& entry : entries) {
283+
const string& context =
284+
source_line.empty() ? preceding_text : source_line.context();
285+
double weight = source_line.weight +
286+
Grammar::Evaluate(context, entry->text, entry->weight,
287+
is_rear, grammar_.get());
288+
size_t new_hash = source_line.text_hash;
289+
for (char c : entry->text) {
290+
new_hash = new_hash * 31 + c;
291+
}
292+
Line new_line{&source_line, entry.get(), end_pos, weight, new_hash};
293+
294+
// dedup by text hash
295+
auto dup = std::find_if(
296+
target_state.begin(), target_state.end(),
297+
[&](const Line& l) { return l.text_hash == new_line.text_hash; });
298+
if (dup != target_state.end()) {
299+
if (new_line.weight > dup->weight) {
300+
target_state.erase(dup);
301+
} else {
302+
continue;
303+
}
304+
}
305+
306+
// insert in descending order of weight
307+
auto it = std::find_if(
308+
target_state.begin(), target_state.end(),
309+
[&](const Line& l) { return l.weight < new_line.weight; });
310+
target_state.insert(it, new_line);
311+
if (target_state.size() > beam_width)
312+
target_state.pop_back();
313+
}
314+
}
315+
}
316+
}
317+
318+
auto found = states.find(total_length);
319+
if (found == states.end() || found->second.empty())
320+
return {};
321+
322+
deque<an<Sentence>> results;
323+
size_t i = 0;
324+
double last_weight = found->second.front().weight;
325+
for (const auto& candidate : found->second) {
326+
i++;
327+
if (i > max_sentences)
328+
break;
329+
double cur_weight = candidate.weight;
330+
// idea: if the current sentence is, on average, not too rare when
331+
// compared to last sentence, we should consider it too
332+
if (fabs(cur_weight - last_weight) / fabs(last_weight) > 0.05) {
333+
break;
334+
}
335+
last_weight = cur_weight;
336+
auto sentence = New<Sentence>(language_);
337+
for (const auto* c : candidate.components()) {
338+
if (!c->entry)
339+
continue;
340+
sentence->Extend(*c->entry, c->end_pos, c->weight);
341+
}
342+
results.emplace_back(sentence);
343+
}
344+
return results;
345+
}
346+
254347
} // namespace rime

src/rime/gear/poet.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class Poet {
3939
an<Sentence> MakeSentence(const WordGraph& graph,
4040
size_t total_length,
4141
const string& preceding_text);
42+
deque<an<Sentence>> MakeSentences(const WordGraph& graph,
43+
size_t total_length,
44+
const string& preceding_text,
45+
size_t count);
4246

4347
template <class TranslatorT>
4448
an<Translation> ContextualWeighted(an<Translation> translation,

0 commit comments

Comments
 (0)