Skip to content

Commit 3993c81

Browse files
committed
uses abseil random library
1 parent aacaa0a commit 3993c81

11 files changed

Lines changed: 42 additions & 224 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,5 @@ _sentencepiece.*.so
7575
third_party/abseil-cpp
7676

7777
python/sentencepiece
78+
79+
third_party/absl

data/gen_spec_parser.pl

Lines changed: 0 additions & 175 deletions
This file was deleted.

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ list(APPEND SPM_LIBS absl::flags_parse)
1818
list(APPEND SPM_LIBS absl::log)
1919
list(APPEND SPM_LIBS absl::log_initialize)
2020
list(APPEND SPM_LIBS absl::check)
21+
list(APPEND SPM_LIBS absl::random_random)
2122

2223
if (SPM_PROTOBUF_PROVIDER STREQUAL "internal")
2324
set(SPM_PROTO_HDRS builtin_pb/sentencepiece.pb.h)

src/bpe_model.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "freelist.h"
2525
#include "third_party/absl/container/flat_hash_map.h"
26+
#include "third_party/absl/random/random.h"
2627
#include "util.h"
2728

2829
namespace sentencepiece {
@@ -129,7 +130,7 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
129130
}
130131

131132
// BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
132-
std::mt19937 *rand_gen = nullptr;
133+
absl::BitGen *rand_gen = nullptr;
133134
auto skip_merge = [&]() {
134135
if (alpha <= 0.0) return false;
135136
if (alpha >= 1.0) return true;

src/sentencepiece_trainer.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,7 @@ util::Status SentencePieceNormalizer::Load(absl::string_view filename) {
316316
util::Status SentencePieceNormalizer::LoadFromSerializedProto(
317317
absl::string_view serialized) {
318318
auto model_proto = std::make_unique<ModelProto>();
319-
RET_CHECK(
320-
model_proto->ParseFromArray(serialized.data(), serialized.size()));
319+
RET_CHECK(model_proto->ParseFromArray(serialized.data(), serialized.size()));
321320
return Load(std::move(model_proto));
322321
}
323322

src/trainer_interface.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "sentencepiece_processor.h"
3131
#include "sentencepiece_trainer.h"
3232
#include "third_party/absl/container/flat_hash_map.h"
33+
#include "third_party/absl/random/random.h"
3334
#include "third_party/absl/strings/numbers.h"
3435
#include "third_party/absl/strings/str_cat.h"
3536
#include "third_party/absl/strings/str_format.h"
@@ -77,7 +78,7 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
7778
#undef CHECK_RANGE
7879

7980
RET_CHECK(trainer_spec.input_sentence_size() <= 0 ||
80-
trainer_spec.input_sentence_size() > 100);
81+
trainer_spec.input_sentence_size() > 100);
8182

8283
RET_CHECK(!trainer_spec.unk_piece().empty());
8384
RET_CHECK(!trainer_spec.bos_piece().empty());
@@ -87,7 +88,7 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
8788
if (SentencePieceTrainer::GetPretokenizerForTraining() ||
8889
!trainer_spec.pretokenization_delimiter().empty()) {
8990
RET_CHECK(trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
90-
trainer_spec.model_type() == TrainerSpec::BPE)
91+
trainer_spec.model_type() == TrainerSpec::BPE)
9192
<< "PretokenizerForTraining is only supported in UNIGRAM or BPE mode.";
9293
}
9394

@@ -307,7 +308,7 @@ bool TrainerInterface::IsValidSentencePiece(
307308
}
308309

309310
template <typename T>
310-
void AddDPNoise(const TrainerSpec &trainer_spec, std::mt19937 *generator,
311+
void AddDPNoise(const TrainerSpec &trainer_spec, absl::BitGen *generator,
311312
T *to_update) {
312313
if (trainer_spec.differential_privacy_noise_level() > 0) {
313314
std::normal_distribution<float> dist(
@@ -327,13 +328,12 @@ util::Status TrainerInterface::LoadSentences() {
327328
RET_CHECK(sentences_.empty());
328329
RET_CHECK(required_chars_.empty());
329330
RET_CHECK(trainer_spec_.input_format().empty() ||
330-
trainer_spec_.input_format() == "text" ||
331-
trainer_spec_.input_format() == "tsv")
331+
trainer_spec_.input_format() == "text" ||
332+
trainer_spec_.input_format() == "tsv")
332333
<< "Supported formats are 'text' and 'tsv'.";
333334

334-
RET_CHECK(
335-
(sentence_iterator_ != nullptr && trainer_spec_.input().empty()) ||
336-
(sentence_iterator_ == nullptr && !trainer_spec_.input().empty()))
335+
RET_CHECK((sentence_iterator_ != nullptr && trainer_spec_.input().empty()) ||
336+
(sentence_iterator_ == nullptr && !trainer_spec_.input().empty()))
337337
<< "SentenceIterator and trainer_spec.input() must be exclusive.";
338338

339339
RET_CHECK(
@@ -487,7 +487,7 @@ util::Status TrainerInterface::LoadSentences() {
487487
auto *generator = random::GetRandomGenerator();
488488
for (size_t i = n; i < sentences_.size(); i += num_workers) {
489489
AddDPNoise<int64_t>(trainer_spec_, generator,
490-
&(sentences_[i].second));
490+
&(sentences_[i].second));
491491
}
492492
});
493493
}
@@ -581,9 +581,8 @@ util::Status TrainerInterface::LoadSentences() {
581581

582582
if (trainer_spec_.model_type() != TrainerSpec::WORD &&
583583
trainer_spec_.model_type() != TrainerSpec::CHAR) {
584-
RET_CHECK_LE(
585-
static_cast<int>(required_chars_.size() + meta_pieces_.size()),
586-
trainer_spec_.vocab_size())
584+
RET_CHECK_LE(static_cast<int>(required_chars_.size() + meta_pieces_.size()),
585+
trainer_spec_.vocab_size())
587586
<< "Vocabulary size is smaller than required_chars. "
588587
<< trainer_spec_.vocab_size() << " vs "
589588
<< required_chars_.size() + meta_pieces_.size() << ". "
@@ -619,7 +618,7 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
619618

620619
model_proto->Clear();
621620

622-
#define CHECK_PIECE(piece) \
621+
#define CHECK_PIECE(piece) \
623622
RET_CHECK(string_util::IsStructurallyValid(piece)); \
624623
RET_CHECK(!piece.empty()); \
625624
RET_CHECK(dup.insert(piece).second) << piece << " is already defined";
@@ -656,17 +655,15 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
656655
if (!trainer_spec_.hard_vocab_limit() ||
657656
trainer_spec_.model_type() == TrainerSpec::CHAR) {
658657
RET_CHECK_GE(trainer_spec_.vocab_size(), model_proto->pieces_size());
659-
RET_CHECK_GE(trainer_spec_.vocab_size(),
660-
static_cast<int32_t>(dup.size()));
658+
RET_CHECK_GE(trainer_spec_.vocab_size(), static_cast<int32_t>(dup.size()));
661659
model_proto->mutable_trainer_spec()->set_vocab_size(
662660
model_proto->pieces_size());
663661
} else {
664662
RET_CHECK_EQ(trainer_spec_.vocab_size(), model_proto->pieces_size())
665663
<< absl::StrFormat(
666664
"Vocabulary size too high (%d). Please set it to a value <= %d.",
667665
trainer_spec_.vocab_size(), model_proto->pieces_size());
668-
RET_CHECK_EQ(trainer_spec_.vocab_size(),
669-
static_cast<int32_t>(dup.size()));
666+
RET_CHECK_EQ(trainer_spec_.vocab_size(), static_cast<int32_t>(dup.size()));
670667
}
671668

672669
// Saves self-testing data.

src/unicode_script.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.!
1414

15-
#include <unordered_map>
16-
1715
#include "third_party/absl/container/flat_hash_map.h"
1816
#include "unicode_script.h"
1917
#include "unicode_script_map.h"

src/unigram_model.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <vector>
2626

2727
#include "third_party/absl/container/flat_hash_map.h"
28+
#include "third_party/absl/random/random.h"
2829
#include "third_party/absl/strings/str_split.h"
2930
#include "third_party/absl/strings/string_view.h"
3031
#include "util.h"

src/unigram_model_trainer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class BoundedPriorityQueue {
123123
}
124124

125125
size_t max_capacity_;
126-
std::unordered_map<std::string, int64_t> data_;
126+
absl::flat_hash_map<std::string, int64_t> data_;
127127
};
128128
} // namespace
129129

src/util.cc

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,18 @@
1818
#include <cstddef>
1919
#include <memory>
2020

21+
#include "third_party/absl/random/random.h"
22+
2123
namespace sentencepiece {
2224

2325
namespace {
24-
constexpr uint32_t kDefaultSeed = static_cast<uint32_t>(-1);
26+
static constexpr uint32_t kDefaultSeed = static_cast<uint32_t>(-1);
2527
static std::atomic<uint32_t> g_seed = kDefaultSeed;
2628
} // namespace
2729

28-
void SetRandomGeneratorSeed(uint32_t seed) {
29-
if (seed != kDefaultSeed) g_seed.store(seed);
30-
}
30+
void SetRandomGeneratorSeed(uint32_t seed) { g_seed.store(seed); }
3131

32-
uint32_t GetRandomGeneratorSeed() {
33-
try {
34-
return g_seed == kDefaultSeed ? std::random_device{}() : g_seed.load();
35-
} catch (...) {
36-
return g_seed.load();
37-
}
38-
}
32+
uint32_t GetRandomGeneratorSeed() { return g_seed.load(); }
3933

4034
namespace {
4135
std::shared_ptr<const std::string> *GetSharedDataDir() {
@@ -174,18 +168,14 @@ std::string UnicodeTextToUTF8(const UnicodeText &utext) {
174168
} // namespace string_util
175169

176170
namespace random {
177-
std::mt19937 *GetRandomGenerator() {
171+
absl::BitGen *GetRandomGenerator() {
178172
// Thread-locals occupy stack space in every thread ever created by the
179173
// program, even if that thread never uses the thread-local variable.
180-
//
181-
// https://maskray.me/blog/2021-02-14-all-about-thread-local-storage
182-
//
183-
// sizeof(std::mt19937) is several kilobytes, so it is safer to put that on
184-
// the heap, leaving only a pointer to it in thread-local storage. This must
185-
// be a unique_ptr, not a raw pointer, so that the generator is not leaked on
186-
// thread exit.
187174
thread_local static auto mt =
188-
std::make_unique<std::mt19937>(GetRandomGeneratorSeed());
175+
GetRandomGeneratorSeed() == kDefaultSeed
176+
? std::make_unique<absl::BitGen>()
177+
: std::make_unique<absl::BitGen>(
178+
std::seed_seq{GetRandomGeneratorSeed()});
189179
return mt.get();
190180
}
191181
} // namespace random

0 commit comments

Comments
 (0)