Skip to content

Commit 0fb3b96

Browse files
committed
feat(script_translator): allow multiple sentence candidates
1 parent 77a8feb commit 0fb3b96

5 files changed

Lines changed: 84 additions & 23 deletions

File tree

src/rime/gear/poet.cc

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ an<Sentence> Poet::MakeSentence(const WordGraph& graph,
258258
deque<an<Sentence>> Poet::MakeSentences(const WordGraph& graph,
259259
size_t total_length,
260260
const string& preceding_text,
261-
size_t max_sentences) {
261+
size_t max_sentences,
262+
double cutoff_threshold) {
262263
size_t beam_width =
263264
max_sentences * 3; // allow more possibilities during search
264265
using State = std::list<Line>;
@@ -320,17 +321,24 @@ deque<an<Sentence>> Poet::MakeSentences(const WordGraph& graph,
320321
return {};
321322

322323
deque<an<Sentence>> results;
323-
size_t i = 0;
324-
double last_weight = found->second.front().weight;
325-
for (const auto& candidate : found->second) {
326-
i++;
327-
if (i > max_sentences)
328-
break;
324+
double last_weight;
325+
double acceleration = 1.0 - 1.0 / (double)max_sentences;
326+
auto iter = found->second.begin();
327+
for (size_t i = 0; iter != found->second.end() && i < max_sentences;
328+
++i, ++iter) {
329+
const auto& candidate = *iter;
329330
double cur_weight = candidate.weight;
330-
// idea: if the current sentence is, on average, not too rare when
331-
// compared to last sentence, we should consider it too
332-
if (fabs(cur_weight - last_weight) / fabs(last_weight) > 0.05) {
333-
break;
331+
if (i > 0) {
332+
// idea: if the current sentence is, on average, not too rare when
333+
// compared to last sentence, we should consider it too
334+
if (fabs(cur_weight - last_weight) / fabs(last_weight) >
335+
cutoff_threshold) {
336+
break;
337+
}
338+
// but don't deviate too far from the first weight by accelerating
339+
// the cutoff threshold. cutoff_threshold becomes
340+
// ~0.36*cutoff_threshold after N candidates are added.
341+
cutoff_threshold *= acceleration;
334342
}
335343
last_weight = cur_weight;
336344
auto sentence = New<Sentence>(language_);

src/rime/gear/poet.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class Poet {
4242
deque<an<Sentence>> MakeSentences(const WordGraph& graph,
4343
size_t total_length,
4444
const string& preceding_text,
45-
size_t count);
45+
size_t count,
46+
double cutoff_threshold);
4647

4748
template <class TranslatorT>
4849
an<Translation> ContextualWeighted(an<Translation> translation,

src/rime/gear/script_translator.cc

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,18 @@ class ScriptTranslation : public Translation {
115115
Poet* poet,
116116
const string& input,
117117
size_t start,
118-
size_t end_of_input)
118+
size_t end_of_input,
119+
int max_sentences,
120+
double sentence_cutoff_threshold)
119121
: translator_(translator),
120122
poet_(poet),
121123
start_(start),
122124
end_of_input_(end_of_input),
123125
syllabifier_(
124126
New<ScriptSyllabifier>(translator, corrector, input, start)),
125-
enable_correction_(corrector) {
127+
enable_correction_(corrector),
128+
max_sentences_(max_sentences),
129+
sentence_cutoff_threshold_(sentence_cutoff_threshold) {
126130
set_exhausted(true);
127131
}
128132
bool Evaluate(Dictionary* dict, UserDictionary* user_dict);
@@ -136,7 +140,11 @@ class ScriptTranslation : public Translation {
136140
template <class QueryResult>
137141
void EnrollEntries(map<int, DictEntryList>& entries_by_end_pos,
138142
const an<QueryResult>& query_result);
143+
WordGraph PrepareForMakingSentence(Dictionary* dict,
144+
UserDictionary* user_dict);
139145
an<Sentence> MakeSentence(Dictionary* dict, UserDictionary* user_dict);
146+
deque<an<Sentence>> MakeSentences(Dictionary* dict,
147+
UserDictionary* user_dict);
140148

141149
ScriptTranslator* translator_;
142150
Poet* poet_;
@@ -146,7 +154,7 @@ class ScriptTranslation : public Translation {
146154

147155
an<DictEntryCollector> phrase_;
148156
an<UserDictEntryCollector> user_phrase_;
149-
an<Sentence> sentence_;
157+
deque<an<Sentence>> sentences_;
150158

151159
an<Phrase> candidate_ = nullptr;
152160
size_t candidate_index_ = 0;
@@ -163,6 +171,8 @@ class ScriptTranslation : public Translation {
163171

164172
const size_t max_corrections_ = 4;
165173
size_t correction_count_ = 0;
174+
int max_sentences_ = 1;
175+
double sentence_cutoff_threshold_ = 0.1;
166176

167177
bool enable_correction_;
168178
};
@@ -210,8 +220,9 @@ an<Translation> ScriptTranslator::Query(const string& input,
210220

211221
size_t end_of_input = engine_->context()->input().length();
212222
// the translator should survive translations it creates
213-
auto result = New<ScriptTranslation>(this, corrector_.get(), poet_.get(),
214-
input, segment.start, end_of_input);
223+
auto result = New<ScriptTranslation>(
224+
this, corrector_.get(), poet_.get(), input, segment.start, end_of_input,
225+
max_sentences_, sentence_cutoff_threshold_);
215226
if (!result || !result->Evaluate(
216227
dict_.get(), enable_user_dict ? user_dict_.get() : NULL)) {
217228
return nullptr;
@@ -484,7 +495,15 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {
484495
// make sentences when there is no exact-matching phrase candidate
485496
if (has_at_least_two_syllables && !has_reliable_phrase &&
486497
!has_reliable_user_phrase) {
487-
sentence_ = MakeSentence(dict, user_dict);
498+
if (max_sentences_ > 1)
499+
sentences_ = MakeSentences(dict, user_dict);
500+
else if (max_sentences_) {
501+
auto sentence = MakeSentence(dict, user_dict);
502+
if (sentence)
503+
sentences_ = {sentence};
504+
else
505+
sentences_.clear();
506+
}
488507
}
489508

490509
return !CheckEmpty();
@@ -501,7 +520,8 @@ bool ScriptTranslation::Next() {
501520
case kUninitialized:
502521
break;
503522
case kSentence:
504-
sentence_.reset();
523+
if (!sentences_.empty())
524+
sentences_.pop_front();
505525
break;
506526
case kUserPhrase: {
507527
UserDictEntryIterator& uter(user_phrase_iter_->second);
@@ -575,9 +595,9 @@ bool ScriptTranslation::PrepareCandidate() {
575595
candidate_ = nullptr;
576596
return false;
577597
}
578-
if (sentence_) {
598+
if (!sentences_.empty()) {
579599
candidate_source_ = kSentence;
580-
candidate_ = sentence_;
600+
candidate_ = sentences_[0];
581601
return true;
582602
}
583603
const size_t full_code_length = end_of_input_ - start_;
@@ -675,8 +695,9 @@ void ScriptTranslation::EnrollEntries(
675695
}
676696
}
677697

678-
an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
679-
UserDictionary* user_dict) {
698+
WordGraph ScriptTranslation::PrepareForMakingSentence(
699+
Dictionary* dict,
700+
UserDictionary* user_dict) {
680701
const int kMaxSyllablesForUserPhraseQuery = 5;
681702
const auto& syllable_graph = syllabifier_->syllable_graph();
682703
WordGraph graph;
@@ -691,6 +712,29 @@ an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
691712
EnrollEntries(same_start_pos, dict->Lookup(syllable_graph, x.first,
692713
&translator_->blacklist()));
693714
}
715+
return graph;
716+
}
717+
718+
deque<an<Sentence>> ScriptTranslation::MakeSentences(
719+
Dictionary* dict,
720+
UserDictionary* user_dict) {
721+
const auto& syllable_graph = syllabifier_->syllable_graph();
722+
WordGraph graph = PrepareForMakingSentence(dict, user_dict);
723+
auto sentences =
724+
poet_->MakeSentences(graph, syllable_graph.interpreted_length,
725+
translator_->GetPrecedingText(start_),
726+
max_sentences_, sentence_cutoff_threshold_);
727+
for (auto& sentence : sentences) {
728+
sentence->Offset(start_);
729+
sentence->set_syllabifier(syllabifier_);
730+
}
731+
return sentences;
732+
}
733+
734+
an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
735+
UserDictionary* user_dict) {
736+
const auto& syllable_graph = syllabifier_->syllable_graph();
737+
WordGraph graph = PrepareForMakingSentence(dict, user_dict);
694738
if (auto sentence =
695739
poet_->MakeSentence(graph, syllable_graph.interpreted_length,
696740
translator_->GetPrecedingText(start_))) {

src/rime/gear/translator_commons.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,18 @@ TranslatorOptions::TranslatorOptions(const Ticket& ticket) {
125125
config->GetBool(ticket.name_space + "/strict_spelling", &strict_spelling_);
126126
config->GetDouble(ticket.name_space + "/initial_quality",
127127
&initial_quality_);
128+
config->GetInt(ticket.name_space + "/max_sentences", &max_sentences_);
129+
max_sentences_ = std::min(std::max(1, max_sentences_), 100);
130+
config->GetDouble(ticket.name_space + "/sentence_cutoff_threshold",
131+
&sentence_cutoff_threshold_);
132+
128133
preedit_formatter_.Load(
129134
config->GetList(ticket.name_space + "/preedit_format"));
130135
comment_formatter_.Load(
131136
config->GetList(ticket.name_space + "/comment_format"));
132137
user_dict_disabling_patterns_.Load(
133138
config->GetList(ticket.name_space + "/disable_user_dict_for_patterns"));
139+
134140
string tag;
135141
if (config->GetString(ticket.name_space + "/tag", &tag)) {
136142
// replace the first tag, and understand /tags as extra tags

src/rime/gear/translator_commons.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ class TranslatorOptions {
176176
bool enable_completion_ = true;
177177
bool strict_spelling_ = false;
178178
double initial_quality_ = 0.;
179+
int max_sentences_ = 1;
180+
double sentence_cutoff_threshold_ = 0.1;
179181
Projection preedit_formatter_;
180182
Projection comment_formatter_;
181183
Patterns user_dict_disabling_patterns_;

0 commit comments

Comments
 (0)