Skip to content

Commit a43d3a3

Browse files
andreas-abelcopybara-github
authored andcommitted
Add an option to set the random number generator seed to a specific value.
PiperOrigin-RevId: 605068471 Change-Id: I5ecf6ffcf0179db7ed71d5b80e5d0e9993267795
1 parent 76a403c commit a43d3a3

File tree

4 files changed

+50
-9
lines changed

4 files changed

+50
-9
lines changed

fleetbench/common/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ cc_library(
1111
"@bazel_tools//tools/cpp/runfiles",
1212
"@com_google_absl//absl/flags:flag",
1313
"@com_google_absl//absl/log",
14+
"@com_google_absl//absl/log:check",
1415
"@com_google_absl//absl/strings",
1516
"@com_google_benchmark//:benchmark",
1617
],

fleetbench/common/common.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "tools/cpp/runfiles/runfiles.h"
2828
#include "absl/flags/flag.h"
29+
#include "absl/log/check.h"
2930
#include "absl/log/log.h"
3031
#include "absl/strings/match.h"
3132
#include "absl/strings/numbers.h"
@@ -37,7 +38,11 @@
3738
using bazel::tools::cpp::runfiles::Runfiles;
3839

3940
ABSL_FLAG(bool, fixed_seed, true,
40-
"Use a fixed seed for random number generation.");
41+
"Use a fixed seed for random number generation. The seed can be "
42+
"specified with the --seed flag; the default is 0. If "
43+
"--fixed_seed=false, random seeds are used.");
44+
ABSL_FLAG(std::optional<int>, seed, {},
45+
"Seed for random number generation. Requires --fixed_seed=true.");
4146

4247
ABSL_FLAG(std::optional<int>, L1_data_size, {},
4348
"Size of the L1 data cache in bytes. Determined automatically if the "
@@ -54,13 +59,21 @@ namespace fleetbench {
5459
using CacheInfo = benchmark::CPUInfo::CacheInfo;
5560

5661
Random& Random::instance() {
57-
static auto* instance = new Random(absl::GetFlag(FLAGS_fixed_seed));
62+
static auto* instance = [] {
63+
if (absl::GetFlag(FLAGS_seed).has_value()) {
64+
CHECK(absl::GetFlag(FLAGS_fixed_seed))
65+
<< "--seed requires --fixed_seed=true";
66+
}
67+
return new Random(absl::GetFlag(FLAGS_fixed_seed),
68+
absl::GetFlag(FLAGS_seed).value_or(0));
69+
}();
5870
return *instance;
5971
}
6072

61-
Random::Random(bool fixed_seed) : fixed_seed_(fixed_seed) {
73+
Random::Random(bool fixed_seed, int seed)
74+
: fixed_seed_(fixed_seed), seed_(seed) {
6275
if (fixed_seed_) {
63-
rng_.seed(0);
76+
rng_.seed(seed_);
6477
} else {
6578
std::random_device rd;
6679
rng_.seed(rd());
@@ -69,7 +82,7 @@ Random::Random(bool fixed_seed) : fixed_seed_(fixed_seed) {
6982

7083
void Random::Reset() {
7184
if (fixed_seed_) {
72-
rng_.seed(0);
85+
rng_.seed(seed_);
7386
}
7487
}
7588

fleetbench/common/common.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,35 @@
1515
#define THIRD_PARTY_FLEETBENCH_COMMON_COMMON_H_
1616

1717
#include <filesystem> // NOLINT
18+
#include <optional>
1819
#include <random>
1920
#include <string>
20-
#include <utility>
2121
#include <vector>
2222

23+
#include "absl/flags/declare.h"
2324
#include "absl/strings/string_view.h"
2425

26+
// Exposed for testing only.
27+
ABSL_DECLARE_FLAG(bool, fixed_seed);
28+
ABSL_DECLARE_FLAG(std::optional<int>, seed);
29+
2530
namespace fleetbench {
2631

2732
// Wrapper around a random number generator.
28-
// If --fixed_seed is true (the default), the seed will be 0 for better
29-
// reproducibility.
33+
// If --fixed_seed is true (the default), the seed will be set to the provided
34+
// value for better reproducibility. Otherwise, the seed is set to random
35+
// values.
3036
class Random {
3137
public:
3238
static Random& instance();
3339
std::default_random_engine& rng() { return rng_; }
3440
void Reset();
3541

3642
private:
37-
explicit Random(bool fixed_seed);
43+
explicit Random(bool fixed_seed, int seed);
3844

3945
const bool fixed_seed_;
46+
const int seed_;
4047
std::default_random_engine rng_;
4148
};
4249

fleetbench/common/common_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,24 @@ TEST(ReadCsvTest, Mixed) {
7676
EXPECT_EQ(result[1][1], "2");
7777
}
7878

79+
TEST(RandomTest, Seed) {
80+
absl::SetFlag(&FLAGS_fixed_seed, true);
81+
absl::SetFlag(&FLAGS_seed, 1);
82+
83+
int random1 = GetRNG()();
84+
Random::instance().Reset();
85+
int random2 = GetRNG()();
86+
EXPECT_EQ(random1, random2);
87+
}
88+
89+
TEST(DeathTest, SeedFlags) {
90+
ASSERT_DEATH(
91+
{
92+
absl::SetFlag(&FLAGS_fixed_seed, false);
93+
absl::SetFlag(&FLAGS_seed, 1);
94+
GetRNG()();
95+
},
96+
"--seed requires --fixed_seed=true");
97+
}
98+
7999
} // namespace fleetbench

0 commit comments

Comments
 (0)