@@ -26,6 +26,7 @@ struct Line {
2626 const DictEntry* entry;
2727 size_t end_pos;
2828 double weight;
29+ size_t text_hash; // for dedup
2930
3031 static const Line kEmpty ;
3132
@@ -68,7 +69,7 @@ struct Line {
6869 }
6970};
7071
71- const Line Line::kEmpty {nullptr , nullptr , 0 , 0.0 };
72+ const Line Line::kEmpty {nullptr , nullptr , 0 , 0.0 , 0 };
7273
7374inline static Grammar* create_grammar (Config* config) {
7475 if (auto * grammar = Grammar::Require (" grammar" )) {
@@ -251,4 +252,96 @@ an<Sentence> Poet::MakeSentence(const WordGraph& graph,
251252 graph, total_length, preceding_text);
252253}
253254
255+ // Make `max_sentences` sentences using beam search and dp on word graph.
256+ //
257+ // There is no strategy because it unconditionally use grammar.
258+ deque<an<Sentence>> Poet::MakeSentences (const WordGraph& graph,
259+ size_t total_length,
260+ const string& preceding_text,
261+ size_t max_sentences) {
262+ size_t beam_width =
263+ max_sentences * 3 ; // allow more possibilities during search
264+ using State = std::list<Line>;
265+ map<int , State> states;
266+ states[0 ].push_back (Line::kEmpty );
267+ for (const auto & sv : graph) {
268+ size_t start_pos = sv.first ;
269+ if (states.find (start_pos) == states.end ())
270+ continue ;
271+
272+ const auto & source_state = states[start_pos];
273+ for (const auto & ev : sv.second ) {
274+ size_t end_pos = ev.first ;
275+ if (start_pos == 0 && end_pos == total_length)
276+ continue ;
277+ const DictEntryList& entries = ev.second ;
278+ bool is_rear = end_pos == total_length;
279+ auto & target_state = states[end_pos];
280+
281+ for (const auto & source_line : source_state) {
282+ for (const auto & entry : entries) {
283+ const string& context =
284+ source_line.empty () ? preceding_text : source_line.context ();
285+ double weight = source_line.weight +
286+ Grammar::Evaluate (context, entry->text , entry->weight ,
287+ is_rear, grammar_.get ());
288+ size_t new_hash = source_line.text_hash ;
289+ for (char c : entry->text ) {
290+ new_hash = new_hash * 31 + c;
291+ }
292+ Line new_line{&source_line, entry.get (), end_pos, weight, new_hash};
293+
294+ // dedup by text hash
295+ auto dup = std::find_if (
296+ target_state.begin (), target_state.end (),
297+ [&](const Line& l) { return l.text_hash == new_line.text_hash ; });
298+ if (dup != target_state.end ()) {
299+ if (new_line.weight > dup->weight ) {
300+ target_state.erase (dup);
301+ } else {
302+ continue ;
303+ }
304+ }
305+
306+ // insert in descending order of weight
307+ auto it = std::find_if (
308+ target_state.begin (), target_state.end (),
309+ [&](const Line& l) { return l.weight < new_line.weight ; });
310+ target_state.insert (it, new_line);
311+ if (target_state.size () > beam_width)
312+ target_state.pop_back ();
313+ }
314+ }
315+ }
316+ }
317+
318+ auto found = states.find (total_length);
319+ if (found == states.end () || found->second .empty ())
320+ return {};
321+
322+ deque<an<Sentence>> results;
323+ size_t i = 0 ;
324+ double last_weight = found->second .front ().weight ;
325+ for (const auto & candidate : found->second ) {
326+ i++;
327+ if (i > max_sentences)
328+ break ;
329+ double cur_weight = candidate.weight ;
330+ // idea: if the current sentence is, on average, not too rare when
331+ // compared to last sentence, we should consider it too
332+ if (fabs (cur_weight - last_weight) / fabs (last_weight) > 0.05 ) {
333+ break ;
334+ }
335+ last_weight = cur_weight;
336+ auto sentence = New<Sentence>(language_);
337+ for (const auto * c : candidate.components ()) {
338+ if (!c->entry )
339+ continue ;
340+ sentence->Extend (*c->entry , c->end_pos , c->weight );
341+ }
342+ results.emplace_back (sentence);
343+ }
344+ return results;
345+ }
346+
254347} // namespace rime
0 commit comments