Skip to content

Commit af6b19a

Browse files
authored
Merge pull request #1241 from google/internal-merge
Merges internal changes to OSS.
2 parents 0d6ab09 + de32a1e commit af6b19a

8 files changed

Lines changed: 263 additions & 190 deletions

src/bpe_model.cc

Lines changed: 146 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "bpe_model.h"
1616

17+
#include <cstdint>
1718
#include <functional>
1819
#include <memory>
1920
#include <queue>
@@ -22,8 +23,12 @@
2223
#include <vector>
2324

2425
#include "freelist.h"
26+
#include "model_interface.h"
27+
#include "sentencepiece_model.pb.h"
28+
#include "third_party/absl/base/attributes.h"
2529
#include "third_party/absl/container/flat_hash_map.h"
2630
#include "third_party/absl/random/random.h"
31+
#include "third_party/absl/strings/string_view.h"
2732
#include "util.h"
2833

2934
namespace sentencepiece {
@@ -43,17 +48,52 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
4348
}
4449

4550
struct SymbolPair {
46-
int left; // left index of this pair
47-
int right; // right index of this pair
48-
float score; // score of this pair. large is better.
49-
size_t size; // length of this piece
51+
union {
52+
float score; // score of this pair. large is better.
53+
int32_t int_score;
54+
};
55+
uint32_t left; // left index of this pair
56+
int right; // right index of this pair
57+
unsigned int size; // length of this piece
5058
};
5159

5260
class SymbolPairComparator {
5361
public:
54-
const bool operator()(SymbolPair *h1, SymbolPair *h2) {
55-
return (h1->score < h2->score ||
56-
(h1->score == h2->score && h1->left > h2->left));
62+
ABSL_ATTRIBUTE_ALWAYS_INLINE inline bool operator()(const SymbolPair &h1,
63+
const SymbolPair &h2) {
64+
const int32_t i1 = h1.int_score;
65+
const int32_t i2 = h2.int_score;
66+
67+
// Fast path for the common case where both scores are negative because
68+
// they are log-probabilities.
69+
// Note: we use the fact that IEEE 754 floating point format enables
70+
// to compare the integer representation of negative floats which is
71+
// cheaper than using float comparison. And it works the same way for
72+
// little endian and big endian machines because the IEEE 754 format is
73+
// aligned with the endianness.
74+
// `(i1 & i2) < 0` is an efficient way to check `i1 < 0 && i2 < 0`.
75+
if ((i1 & i2) < 0) {
76+
// For negative floats, their integer representation order is the
77+
// reverse of the float order. That is, for two negative floats f1, f2,
78+
// f1 < f2 iff i1 > i2.
79+
return (i1 > i2) || (i1 == i2 && h1.left > h2.left);
80+
}
81+
82+
// Slow path for uncommon cases (mixed signs or both positive).
83+
// Note: the comparison between NaN and +0 and +1 can be different than
84+
// if we used float numbers but it should not influence the result.
85+
bool score_less;
86+
// If signs are different ((i1 ^ i2) < 0), the negative score is smaller.
87+
if ((i1 ^ i2) < 0) {
88+
score_less = i1 < 0;
89+
} else {
90+
// If signs are the same (and not both negative), they must both be
91+
// non-negative. For non-negative floats, integer order is the same as
92+
// float order.
93+
score_less = i1 < i2;
94+
}
95+
96+
return score_less || (i1 == i2 && h1.left > h2.left);
5797
}
5898
};
5999

@@ -64,50 +104,12 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
64104
absl::string_view piece;
65105
};
66106

67-
using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
68-
SymbolPairComparator>;
69-
Agenda agenda;
70107
std::vector<Symbol> symbols;
71108
symbols.reserve(normalized.size());
72109

73-
// Reverse merge rules.
74-
// key: merged symbol, value: pair of original symbols.
75-
absl::flat_hash_map<absl::string_view,
76-
std::pair<absl::string_view, absl::string_view>>
77-
rev_merge;
78-
79-
// Pre-allocates SymbolPair for efficiency.
80-
constexpr size_t kPreallocateSymbolPairSize = 256;
81-
model::FreeList<SymbolPair> symbol_pair_allocator(kPreallocateSymbolPairSize);
82-
83-
// Lookup new symbol pair at [left, right] and inserts it to agenda.
84-
auto MaybeAddNewSymbolPair = [this, &symbol_pair_allocator, &symbols, &agenda,
85-
&rev_merge](int left, int right) {
86-
if (left == -1 || right == -1 || symbols[left].freeze ||
87-
symbols[right].freeze)
88-
return;
89-
const absl::string_view piece(
90-
symbols[left].piece.data(),
91-
symbols[left].piece.size() + symbols[right].piece.size());
92-
const auto it = pieces_.find(piece);
93-
if (it == pieces_.end()) {
94-
return;
95-
}
96-
auto *h = symbol_pair_allocator.Allocate();
97-
h->left = left;
98-
h->right = right;
99-
h->score = GetScore(it->second);
100-
h->size = piece.size();
101-
agenda.push(h);
102-
103-
// Makes `rev_merge` for resegmentation.
104-
if (IsUnusedInlined(it->second)) {
105-
rev_merge[piece] =
106-
std::make_pair(symbols[left].piece, symbols[right].piece);
107-
}
108-
};
109-
110-
// Splits the input into character sequence
110+
// Splits the input into Symbols doing longest prefix match of the input
111+
// from pieces(type:UNUSED) in the vocabulary.
112+
// Does character splitting as a fallback of longest prefix match.
111113
int index = 0;
112114
while (!normalized.empty()) {
113115
Symbol s;
@@ -124,53 +126,118 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
124126
return {};
125127
}
126128

129+
std::vector<SymbolPair> agenda_vec;
130+
agenda_vec.reserve(symbols.size());
131+
132+
// Reverse merge rules.
133+
// key: merged symbol, value: pair of original symbols.
134+
absl::flat_hash_map<absl::string_view,
135+
std::pair<absl::string_view, absl::string_view>>
136+
rev_merge;
137+
127138
// Lookup all bigrams.
128-
for (size_t i = 1; i < symbols.size(); ++i) {
129-
MaybeAddNewSymbolPair(i - 1, i);
139+
if (symbols.size() > 1) {
140+
int left = 0;
141+
int right = 1;
142+
Symbol *symbol_left = &symbols[left];
143+
Symbol *symbol_right = &symbols[right];
144+
for (; right < symbols.size();
145+
left = right, symbol_left = symbol_right, ++right, ++symbol_right) {
146+
if (symbol_left->freeze || symbol_right->freeze) continue;
147+
const absl::string_view piece(
148+
symbol_left->piece.data(),
149+
symbol_left->piece.size() + symbol_right->piece.size());
150+
const auto it = pieces_.find(piece);
151+
if (it == pieces_.end()) continue;
152+
SymbolPair &h = agenda_vec.emplace_back();
153+
h.left = left;
154+
h.right = right;
155+
h.score = GetScore(it->second);
156+
h.size = piece.size();
157+
158+
// Makes `rev_merge` for resegmentation.
159+
if (IsUnusedInlined(it->second))
160+
rev_merge[piece] =
161+
std::make_pair(symbol_left->piece, symbol_right->piece);
162+
}
130163
}
131164

132-
// BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
133-
absl::BitGen *rand_gen = nullptr;
134-
auto skip_merge = [&]() {
135-
if (alpha <= 0.0) return false;
136-
if (alpha >= 1.0) return true;
137-
if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator();
138-
std::uniform_real_distribution<> gen(0.0, 1.0);
139-
return gen(*rand_gen) < alpha;
165+
using Agenda = std::priority_queue<SymbolPair, std::vector<SymbolPair>,
166+
SymbolPairComparator>;
167+
Agenda agenda(SymbolPairComparator(), std::move(agenda_vec));
168+
// Lookup new symbol pair at [left, right] and inserts it to agenda.
169+
auto MaybeAddNewSymbolPair = [this, &symbols, &agenda, &rev_merge](
170+
int left, int right) {
171+
if (left == -1 || right == -1) return;
172+
const Symbol &left_symbol = symbols[left];
173+
const Symbol &right_symbol = symbols[right];
174+
if (left_symbol.freeze || right_symbol.freeze) return;
175+
const absl::string_view piece(
176+
left_symbol.piece.data(),
177+
left_symbol.piece.size() + right_symbol.piece.size());
178+
const auto it = pieces_.find(piece);
179+
if (it == pieces_.end()) {
180+
return;
181+
}
182+
const int id = it->second;
183+
SymbolPair h;
184+
h.left = left;
185+
h.right = right;
186+
h.score = GetScore(id);
187+
h.size = piece.size();
188+
agenda.push(h);
189+
190+
// Makes `rev_merge` for resegmentation.
191+
if (IsUnusedInlined(id))
192+
rev_merge[piece] = std::make_pair(left_symbol.piece, right_symbol.piece);
140193
};
141194

195+
absl::BitGen *rand_gen = nullptr;
142196
// Main loop.
143197
while (!agenda.empty()) {
144-
SymbolPair *top = agenda.top();
145-
agenda.pop();
146-
147-
// `top` is no longer available.
148-
if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
149-
symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
150-
top->size) {
198+
// Pop the top pair if it is stale.
199+
const SymbolPair &top_ref = agenda.top();
200+
if (symbols[top_ref.left].piece.empty() ||
201+
symbols[top_ref.right].piece.empty() ||
202+
(symbols[top_ref.left].piece.size() +
203+
symbols[top_ref.right].piece.size() !=
204+
top_ref.size)) {
205+
agenda.pop();
151206
continue;
152207
}
153208

154-
// Note that orignal BPE-dropout paper assumes that all merged symbols are
155-
// pre computed, but here we randomly skip merge opration inside this loop.
156-
// This implemenation is theoretically equivalent to the original one.
157-
if (skip_merge()) continue;
209+
SymbolPair top = agenda.top();
210+
agenda.pop();
211+
212+
Symbol &left_symbol = symbols[top.left];
213+
Symbol &right_symbol = symbols[top.right];
214+
215+
// Note that original BPE-dropout paper assumes that all merged symbols are
216+
// pre computed, but here we randomly skip merge operation inside this loop.
217+
// This implementation is theoretically equivalent to the original one.
218+
// BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
219+
if (alpha > 0.0) {
220+
if (alpha >= 1.0) continue;
221+
if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator();
222+
std::uniform_real_distribution<> gen(0.0, 1.0);
223+
if (gen(*rand_gen) < alpha) continue;
224+
}
158225

159226
// Replaces symbols with `top` rule.
160-
symbols[top->left].piece = absl::string_view(
161-
symbols[top->left].piece.data(),
162-
symbols[top->left].piece.size() + symbols[top->right].piece.size());
227+
left_symbol.piece =
228+
absl::string_view(left_symbol.piece.data(),
229+
left_symbol.piece.size() + right_symbol.piece.size());
163230

164231
// Updates prev/next pointers.
165-
symbols[top->left].next = symbols[top->right].next;
166-
if (symbols[top->right].next >= 0) {
167-
symbols[symbols[top->right].next].prev = top->left;
232+
left_symbol.next = right_symbol.next;
233+
if (right_symbol.next >= 0) {
234+
symbols[right_symbol.next].prev = top.left;
168235
}
169-
symbols[top->right].piece = absl::string_view("");
236+
right_symbol.piece = absl::string_view("");
170237

171238
// Adds new symbol pairs which are newly added after symbol replacement.
172-
MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
173-
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
239+
MaybeAddNewSymbolPair(left_symbol.prev, top.left);
240+
MaybeAddNewSymbolPair(top.left, left_symbol.next);
174241
}
175242

176243
std::function<void(absl::string_view, EncodeResult *)> resegment;
@@ -194,6 +261,7 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
194261
};
195262

196263
EncodeResult output;
264+
output.reserve(symbols.size());
197265
for (int index = 0; index != -1; index = symbols[index].next) {
198266
if (index >= 0 && index < static_cast<int>(symbols.size())) {
199267
resegment(symbols[index].piece, &output);

0 commit comments

Comments
 (0)