1414
1515#include " bpe_model.h"
1616
17+ #include < cstdint>
1718#include < functional>
1819#include < memory>
1920#include < queue>
2223#include < vector>
2324
2425#include " freelist.h"
26+ #include " model_interface.h"
27+ #include " sentencepiece_model.pb.h"
28+ #include " third_party/absl/base/attributes.h"
2529#include " third_party/absl/container/flat_hash_map.h"
2630#include " third_party/absl/random/random.h"
31+ #include " third_party/absl/strings/string_view.h"
2732#include " util.h"
2833
2934namespace sentencepiece {
@@ -43,17 +48,52 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
4348 }
4449
4550 struct SymbolPair {
46- int left; // left index of this pair
47- int right; // right index of this pair
48- float score; // score of this pair. large is better.
49- size_t size; // length of this piece
51+ union {
52+ float score; // score of this pair. large is better.
53+ int32_t int_score;
54+ };
55+ uint32_t left; // left index of this pair
56+ int right; // right index of this pair
57+ unsigned int size; // length of this piece
5058 };
5159
5260 class SymbolPairComparator {
5361 public:
54- const bool operator ()(SymbolPair *h1, SymbolPair *h2) {
55- return (h1->score < h2->score ||
56- (h1->score == h2->score && h1->left > h2->left ));
62+ ABSL_ATTRIBUTE_ALWAYS_INLINE inline bool operator ()(const SymbolPair &h1,
63+ const SymbolPair &h2) {
64+ const int32_t i1 = h1.int_score ;
65+ const int32_t i2 = h2.int_score ;
66+
67+ // Fast path for the common case where both scores are negative because
68+ // they are log-probabilities.
69+ // Note: we use the fact that IEEE 754 floating point format enables
70+ // to compare the integer representation of negative floats which is
71+ // cheaper than using float comparison. And it works the same way for
72+ // little endian and big endian machines because the IEEE 754 format is
73+ // aligned with the endianness.
74+ // `(i1 & i2) < 0` is an efficient way to check `i1 < 0 && i2 < 0`.
75+ if ((i1 & i2) < 0 ) {
76+ // For negative floats, their integer representation order is the
77+ // reverse of the float order. That is, for two negative floats f1, f2,
78+ // f1 < f2 iff i1 > i2.
79+ return (i1 > i2) || (i1 == i2 && h1.left > h2.left );
80+ }
81+
82+ // Slow path for uncommon cases (mixed signs or both positive).
83+ // Note: the comparison between NaN and +0 and +1 can be different than
84+ // if we used float numbers but it should not influence the result.
85+ bool score_less;
86+ // If signs are different ((i1 ^ i2) < 0), the negative score is smaller.
87+ if ((i1 ^ i2) < 0 ) {
88+ score_less = i1 < 0 ;
89+ } else {
90+ // If signs are the same (and not both negative), they must both be
91+ // non-negative. For non-negative floats, integer order is the same as
92+ // float order.
93+ score_less = i1 < i2;
94+ }
95+
96+ return score_less || (i1 == i2 && h1.left > h2.left );
5797 }
5898 };
5999
@@ -64,50 +104,12 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
64104 absl::string_view piece;
65105 };
66106
67- using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
68- SymbolPairComparator>;
69- Agenda agenda;
70107 std::vector<Symbol> symbols;
71108 symbols.reserve (normalized.size ());
72109
73- // Reverse merge rules.
74- // key: merged symbol, value: pair of original symbols.
75- absl::flat_hash_map<absl::string_view,
76- std::pair<absl::string_view, absl::string_view>>
77- rev_merge;
78-
79- // Pre-allocates SymbolPair for efficiency.
80- constexpr size_t kPreallocateSymbolPairSize = 256 ;
81- model::FreeList<SymbolPair> symbol_pair_allocator (kPreallocateSymbolPairSize );
82-
83- // Lookup new symbol pair at [left, right] and inserts it to agenda.
84- auto MaybeAddNewSymbolPair = [this , &symbol_pair_allocator, &symbols, &agenda,
85- &rev_merge](int left, int right) {
86- if (left == -1 || right == -1 || symbols[left].freeze ||
87- symbols[right].freeze )
88- return ;
89- const absl::string_view piece (
90- symbols[left].piece .data (),
91- symbols[left].piece .size () + symbols[right].piece .size ());
92- const auto it = pieces_.find (piece);
93- if (it == pieces_.end ()) {
94- return ;
95- }
96- auto *h = symbol_pair_allocator.Allocate ();
97- h->left = left;
98- h->right = right;
99- h->score = GetScore (it->second );
100- h->size = piece.size ();
101- agenda.push (h);
102-
103- // Makes `rev_merge` for resegmentation.
104- if (IsUnusedInlined (it->second )) {
105- rev_merge[piece] =
106- std::make_pair (symbols[left].piece , symbols[right].piece );
107- }
108- };
109-
110- // Splits the input into character sequence
110+ // Splits the input into Symbols doing longest prefix match of the input
111+ // from pieces(type:UNUSED) in the vocabulary.
112+ // Does character splitting as a fallback of longest prefix match.
111113 int index = 0 ;
112114 while (!normalized.empty ()) {
113115 Symbol s;
@@ -124,53 +126,118 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
124126 return {};
125127 }
126128
129+ std::vector<SymbolPair> agenda_vec;
130+ agenda_vec.reserve (symbols.size ());
131+
132+ // Reverse merge rules.
133+ // key: merged symbol, value: pair of original symbols.
134+ absl::flat_hash_map<absl::string_view,
135+ std::pair<absl::string_view, absl::string_view>>
136+ rev_merge;
137+
127138 // Lookup all bigrams.
128- for (size_t i = 1 ; i < symbols.size (); ++i) {
129- MaybeAddNewSymbolPair (i - 1 , i);
139+ if (symbols.size () > 1 ) {
140+ int left = 0 ;
141+ int right = 1 ;
142+ Symbol *symbol_left = &symbols[left];
143+ Symbol *symbol_right = &symbols[right];
144+ for (; right < symbols.size ();
145+ left = right, symbol_left = symbol_right, ++right, ++symbol_right) {
146+ if (symbol_left->freeze || symbol_right->freeze ) continue ;
147+ const absl::string_view piece (
148+ symbol_left->piece .data (),
149+ symbol_left->piece .size () + symbol_right->piece .size ());
150+ const auto it = pieces_.find (piece);
151+ if (it == pieces_.end ()) continue ;
152+ SymbolPair &h = agenda_vec.emplace_back ();
153+ h.left = left;
154+ h.right = right;
155+ h.score = GetScore (it->second );
156+ h.size = piece.size ();
157+
158+ // Makes `rev_merge` for resegmentation.
159+ if (IsUnusedInlined (it->second ))
160+ rev_merge[piece] =
161+ std::make_pair (symbol_left->piece , symbol_right->piece );
162+ }
130163 }
131164
132- // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
133- absl::BitGen *rand_gen = nullptr ;
134- auto skip_merge = [&]() {
135- if (alpha <= 0.0 ) return false ;
136- if (alpha >= 1.0 ) return true ;
137- if (rand_gen == nullptr ) rand_gen = random::GetRandomGenerator ();
138- std::uniform_real_distribution<> gen (0.0 , 1.0 );
139- return gen (*rand_gen) < alpha;
165+ using Agenda = std::priority_queue<SymbolPair, std::vector<SymbolPair>,
166+ SymbolPairComparator>;
167+ Agenda agenda (SymbolPairComparator (), std::move (agenda_vec));
168+ // Lookup new symbol pair at [left, right] and inserts it to agenda.
169+ auto MaybeAddNewSymbolPair = [this , &symbols, &agenda, &rev_merge](
170+ int left, int right) {
171+ if (left == -1 || right == -1 ) return ;
172+ const Symbol &left_symbol = symbols[left];
173+ const Symbol &right_symbol = symbols[right];
174+ if (left_symbol.freeze || right_symbol.freeze ) return ;
175+ const absl::string_view piece (
176+ left_symbol.piece .data (),
177+ left_symbol.piece .size () + right_symbol.piece .size ());
178+ const auto it = pieces_.find (piece);
179+ if (it == pieces_.end ()) {
180+ return ;
181+ }
182+ const int id = it->second ;
183+ SymbolPair h;
184+ h.left = left;
185+ h.right = right;
186+ h.score = GetScore (id);
187+ h.size = piece.size ();
188+ agenda.push (h);
189+
190+ // Makes `rev_merge` for resegmentation.
191+ if (IsUnusedInlined (id))
192+ rev_merge[piece] = std::make_pair (left_symbol.piece , right_symbol.piece );
140193 };
141194
195+ absl::BitGen *rand_gen = nullptr ;
142196 // Main loop.
143197 while (!agenda.empty ()) {
144- SymbolPair *top = agenda.top ();
145- agenda.pop ();
146-
147- // `top` is no longer available.
148- if (symbols[top->left ].piece .empty () || symbols[top->right ].piece .empty () ||
149- symbols[top->left ].piece .size () + symbols[top->right ].piece .size () !=
150- top->size ) {
198+ // Pop the top pair if it is stale.
199+ const SymbolPair &top_ref = agenda.top ();
200+ if (symbols[top_ref.left ].piece .empty () ||
201+ symbols[top_ref.right ].piece .empty () ||
202+ (symbols[top_ref.left ].piece .size () +
203+ symbols[top_ref.right ].piece .size () !=
204+ top_ref.size )) {
205+ agenda.pop ();
151206 continue ;
152207 }
153208
154- // Note that orignal BPE-dropout paper assumes that all merged symbols are
155- // pre computed, but here we randomly skip merge opration inside this loop.
156- // This implemenation is theoretically equivalent to the original one.
157- if (skip_merge ()) continue ;
209+ SymbolPair top = agenda.top ();
210+ agenda.pop ();
211+
212+ Symbol &left_symbol = symbols[top.left ];
213+ Symbol &right_symbol = symbols[top.right ];
214+
215+ // Note that original BPE-dropout paper assumes that all merged symbols are
216+ // pre computed, but here we randomly skip merge operation inside this loop.
217+ // This implementation is theoretically equivalent to the original one.
218+ // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
219+ if (alpha > 0.0 ) {
220+ if (alpha >= 1.0 ) continue ;
221+ if (rand_gen == nullptr ) rand_gen = random::GetRandomGenerator ();
222+ std::uniform_real_distribution<> gen (0.0 , 1.0 );
223+ if (gen (*rand_gen) < alpha) continue ;
224+ }
158225
159226 // Replaces symbols with `top` rule.
160- symbols[top-> left ] .piece = absl::string_view (
161- symbols[top-> left ] .piece .data (),
162- symbols[top-> left ] .piece .size () + symbols[top-> right ] .piece .size ());
227+ left_symbol .piece =
228+ absl::string_view (left_symbol .piece .data (),
229+ left_symbol .piece .size () + right_symbol .piece .size ());
163230
164231 // Updates prev/next pointers.
165- symbols[top-> left ] .next = symbols[top-> right ] .next ;
166- if (symbols[top-> right ] .next >= 0 ) {
167- symbols[symbols[top-> right ] .next ].prev = top-> left ;
232+ left_symbol .next = right_symbol .next ;
233+ if (right_symbol .next >= 0 ) {
234+ symbols[right_symbol .next ].prev = top. left ;
168235 }
169- symbols[top-> right ] .piece = absl::string_view (" " );
236+ right_symbol .piece = absl::string_view (" " );
170237
171238 // Adds new symbol pairs which are newly added after symbol replacement.
172- MaybeAddNewSymbolPair (symbols[top-> left ] .prev , top-> left );
173- MaybeAddNewSymbolPair (top-> left , symbols[top-> left ] .next );
239+ MaybeAddNewSymbolPair (left_symbol .prev , top. left );
240+ MaybeAddNewSymbolPair (top. left , left_symbol .next );
174241 }
175242
176243 std::function<void (absl::string_view, EncodeResult *)> resegment;
@@ -194,6 +261,7 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
194261 };
195262
196263 EncodeResult output;
264+ output.reserve (symbols.size ());
197265 for (int index = 0 ; index != -1 ; index = symbols[index].next ) {
198266 if (index >= 0 && index < static_cast <int >(symbols.size ())) {
199267 resegment (symbols[index].piece , &output);
0 commit comments