Skip to content

Commit e444fc7

Browse files
committed
rand-fst: Avoid UB in random number generation
std::uniform_int_distribution(a, b) defines a closed interval and requires a <= b. Ensure it is never constructed with a > b. https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.html Change return type of RandFst from void to absl::Status. This should eventually be moved to absl random functions, but that will most likely require the same checks. It's difficult to make sense of `Uniform(0, -1)`. Tested: scripts/test_bazel.sh scripts/test_cmake.sh PiperOrigin-RevId: 882610853
1 parent cd5a5ce commit e444fc7

File tree

3 files changed

+62
-41
lines changed

3 files changed

+62
-41
lines changed

openfst/test/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ cc_library(
191191
"//openfst/lib:weight",
192192
"@com_google_absl//absl/flags:flag",
193193
"@com_google_absl//absl/log",
194+
"@com_google_absl//absl/log:check",
194195
"@com_google_googletest//:gtest",
195196
],
196197
)
@@ -331,6 +332,7 @@ cc_library(
331332
deps = [
332333
"//openfst/lib",
333334
"@com_google_absl//absl/log",
335+
"@com_google_absl//absl/status",
334336
],
335337
)
336338

@@ -2181,8 +2183,10 @@ cc_test(
21812183
"//openfst/lib:weight",
21822184
"@com_google_absl//absl/flags:flag",
21832185
"@com_google_absl//absl/log",
2186+
"@com_google_absl//absl/log:check",
21842187
"@com_google_absl//absl/log:flags",
21852188
"@com_google_absl//absl/memory",
2189+
"@com_google_absl//absl/status",
21862190
"@com_google_googletest//:gtest",
21872191
],
21882192
)

openfst/test/algo_test.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "gtest/gtest.h"
3333
#include "absl/flags/declare.h"
3434
#include "absl/flags/flag.h"
35+
#include "absl/log/check.h"
3536
#include "absl/log/log.h"
3637
#include "openfst/lib/arc-map.h"
3738
#include "openfst/lib/arc.h"
@@ -1383,10 +1384,9 @@ class AlgoTester {
13831384
new UnweightedTester<Arc>(zero_fst_, one_fst_, univ_fst_, seed));
13841385
}
13851386

1386-
void MakeRandFst(MutableFst<Arc>* fst) {
1387-
RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
1388-
kNumRandomLabels, kAcyclicProb, generate_,
1389-
rand_(), fst);
1387+
absl::Status MakeRandFst(MutableFst<Arc>* fst) {
1388+
return RandFst(kNumRandomStates, kNumRandomArcs, kNumRandomLabels,
1389+
kAcyclicProb, generate_, rand_(), fst);
13901390
}
13911391

13921392
void Test() {
@@ -1397,9 +1397,9 @@ class AlgoTester {
13971397
VectorFst<Arc> T1;
13981398
VectorFst<Arc> T2;
13991399
VectorFst<Arc> T3;
1400-
MakeRandFst(&T1);
1401-
MakeRandFst(&T2);
1402-
MakeRandFst(&T3);
1400+
CHECK_OK(MakeRandFst(&T1));
1401+
CHECK_OK(MakeRandFst(&T2));
1402+
CHECK_OK(MakeRandFst(&T3));
14031403
weighted_tester_->Test(T1, T2, T3);
14041404

14051405
VectorFst<Arc> A1(T1);

openfst/test/rand-fst.h

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <random>
2121

2222
#include "absl/log/log.h"
23+
#include "absl/status/status.h"
2324
#include "openfst/lib/mutable-fst.h"
2425
#include "openfst/lib/properties.h"
2526
#include "openfst/lib/verify.h"
@@ -28,9 +29,22 @@ namespace fst {
2829

2930
// Generates a random FST.
3031
template <class Arc, class Generate>
31-
void RandFst(const int num_random_states, const int num_random_arcs,
32-
const int num_random_labels, const float acyclic_prob,
33-
Generate generate, uint64_t seed, MutableFst<Arc>* fst) {
32+
absl::Status RandFst(const int num_random_states, const int num_random_arcs,
33+
const int num_random_labels, const float acyclic_prob,
34+
Generate generate, uint64_t seed, MutableFst<Arc>* fst) {
35+
if (num_random_states < 0) {
36+
return absl::InvalidArgumentError("num_random_states must be non-negative");
37+
}
38+
if (num_random_arcs < 0) {
39+
return absl::InvalidArgumentError("num_random_arcs must be non-negative");
40+
}
41+
if (num_random_arcs > 0 && num_random_labels <= 0) {
42+
return absl::InvalidArgumentError("num_random_labels must be positive");
43+
}
44+
if (acyclic_prob < 0.0 || acyclic_prob > 1.0) {
45+
return absl::InvalidArgumentError("acyclic_prob must be in [0, 1]");
46+
}
47+
3448
using Label = typename Arc::Label;
3549
using StateId = typename Arc::StateId;
3650

@@ -43,49 +57,51 @@ void RandFst(const int num_random_states, const int num_random_arcs,
4357
NUM_DIRECTIONS = 3
4458
};
4559

60+
fst->DeleteStates();
61+
if (num_random_states == 0) return absl::OkStatus();
62+
4663
std::mt19937_64 rand(seed);
4764
const StateId ns =
4865
std::uniform_int_distribution<>(0, num_random_states - 1)(rand);
49-
std::uniform_int_distribution<size_t> arc_dist(0, num_random_arcs - 1);
50-
std::uniform_int_distribution<Label> label_dist(0, num_random_labels - 1);
66+
if (ns == 0) return absl::OkStatus();
67+
fst->AddStates(ns);
68+
5169
std::uniform_int_distribution<StateId> ns_dist(0, ns - 1);
70+
const StateId start = ns_dist(rand);
71+
fst->SetStart(start);
5272

5373
ArcDirection arc_direction = ANY_DIRECTION;
5474
if (!std::bernoulli_distribution(acyclic_prob)(rand)) {
5575
arc_direction = std::bernoulli_distribution(.5)(rand) ? FORWARD_DIRECTION
5676
: REVERSE_DIRECTION;
5777
}
5878

59-
fst->DeleteStates();
60-
61-
if (ns == 0) return;
62-
fst->AddStates(ns);
63-
64-
const StateId start = ns_dist(rand);
65-
fst->SetStart(start);
66-
67-
const size_t na = arc_dist(rand);
68-
for (size_t n = 0; n < na; ++n) {
69-
StateId s = ns_dist(rand);
70-
Arc arc;
71-
arc.ilabel = label_dist(rand);
72-
arc.olabel = label_dist(rand);
73-
arc.weight = generate();
74-
arc.nextstate = ns_dist(rand);
75-
if ((arc_direction == FORWARD_DIRECTION ||
76-
arc_direction == REVERSE_DIRECTION) &&
77-
s == arc.nextstate) {
78-
continue; // Skips self-loops.
79+
if (num_random_arcs > 0) {
80+
std::uniform_int_distribution<size_t> arc_dist(0, num_random_arcs - 1);
81+
std::uniform_int_distribution<Label> label_dist(0, num_random_labels - 1);
82+
const size_t na = arc_dist(rand);
83+
for (size_t n = 0; n < na; ++n) {
84+
StateId s = ns_dist(rand);
85+
Arc arc;
86+
arc.ilabel = label_dist(rand);
87+
arc.olabel = label_dist(rand);
88+
arc.weight = generate();
89+
arc.nextstate = ns_dist(rand);
90+
if ((arc_direction == FORWARD_DIRECTION ||
91+
arc_direction == REVERSE_DIRECTION) &&
92+
s == arc.nextstate) {
93+
continue; // Skips self-loops.
94+
}
95+
96+
if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate) ||
97+
(arc_direction == REVERSE_DIRECTION && s < arc.nextstate)) {
98+
StateId t = s; // reverses arcs
99+
s = arc.nextstate;
100+
arc.nextstate = t;
101+
}
102+
103+
fst->AddArc(s, arc);
79104
}
80-
81-
if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate) ||
82-
(arc_direction == REVERSE_DIRECTION && s < arc.nextstate)) {
83-
StateId t = s; // reverses arcs
84-
s = arc.nextstate;
85-
arc.nextstate = t;
86-
}
87-
88-
fst->AddArc(s, arc);
89105
}
90106

91107
const StateId nf = std::uniform_int_distribution<>(0, ns)(rand);
@@ -107,6 +123,7 @@ void RandFst(const int num_random_states, const int num_random_arcs,
107123
}
108124
mask &= ~kTrinaryProperties;
109125
fst->SetProperties(props & ~mask, mask);
126+
return absl::OkStatus();
110127
}
111128

112129
} // namespace fst

0 commit comments

Comments
 (0)