2020#include < random>
2121
2222#include " absl/log/log.h"
23+ #include " absl/status/status.h"
2324#include " openfst/lib/mutable-fst.h"
2425#include " openfst/lib/properties.h"
2526#include " openfst/lib/verify.h"
@@ -28,9 +29,22 @@ namespace fst {
2829
2930// Generates a random FST.
3031template <class Arc , class Generate >
31- void RandFst (const int num_random_states, const int num_random_arcs,
32- const int num_random_labels, const float acyclic_prob,
33- Generate generate, uint64_t seed, MutableFst<Arc>* fst) {
32+ absl::Status RandFst (const int num_random_states, const int num_random_arcs,
33+ const int num_random_labels, const float acyclic_prob,
34+ Generate generate, uint64_t seed, MutableFst<Arc>* fst) {
35+ if (num_random_states < 0 ) {
36+ return absl::InvalidArgumentError (" num_random_states must be non-negative" );
37+ }
38+ if (num_random_arcs < 0 ) {
39+ return absl::InvalidArgumentError (" num_random_arcs must be non-negative" );
40+ }
41+ if (num_random_arcs > 0 && num_random_labels <= 0 ) {
42+ return absl::InvalidArgumentError (" num_random_labels must be positive" );
43+ }
44+ if (acyclic_prob < 0.0 || acyclic_prob > 1.0 ) {
45+ return absl::InvalidArgumentError (" acyclic_prob must be in [0, 1]" );
46+ }
47+
3448 using Label = typename Arc::Label;
3549 using StateId = typename Arc::StateId;
3650
@@ -43,49 +57,51 @@ void RandFst(const int num_random_states, const int num_random_arcs,
4357 NUM_DIRECTIONS = 3
4458 };
4559
60+ fst->DeleteStates ();
61+ if (num_random_states == 0 ) return absl::OkStatus ();
62+
4663 std::mt19937_64 rand (seed);
4764 const StateId ns =
4865 std::uniform_int_distribution<>(0 , num_random_states - 1 )(rand);
49- std::uniform_int_distribution<size_t > arc_dist (0 , num_random_arcs - 1 );
50- std::uniform_int_distribution<Label> label_dist (0 , num_random_labels - 1 );
66+ if (ns == 0 ) return absl::OkStatus ();
67+ fst->AddStates (ns);
68+
5169 std::uniform_int_distribution<StateId> ns_dist (0 , ns - 1 );
70+ const StateId start = ns_dist (rand);
71+ fst->SetStart (start);
5272
5373 ArcDirection arc_direction = ANY_DIRECTION;
5474 if (!std::bernoulli_distribution (acyclic_prob)(rand)) {
5575 arc_direction = std::bernoulli_distribution (.5 )(rand) ? FORWARD_DIRECTION
5676 : REVERSE_DIRECTION;
5777 }
5878
59- fst->DeleteStates ();
60-
61- if (ns == 0 ) return ;
62- fst->AddStates (ns);
63-
64- const StateId start = ns_dist (rand);
65- fst->SetStart (start);
66-
67- const size_t na = arc_dist (rand);
68- for (size_t n = 0 ; n < na; ++n) {
69- StateId s = ns_dist (rand);
70- Arc arc;
71- arc.ilabel = label_dist (rand);
72- arc.olabel = label_dist (rand);
73- arc.weight = generate ();
74- arc.nextstate = ns_dist (rand);
75- if ((arc_direction == FORWARD_DIRECTION ||
76- arc_direction == REVERSE_DIRECTION) &&
77- s == arc.nextstate ) {
78- continue ; // Skips self-loops.
79+ if (num_random_arcs > 0 ) {
80+ std::uniform_int_distribution<size_t > arc_dist (0 , num_random_arcs - 1 );
81+ std::uniform_int_distribution<Label> label_dist (0 , num_random_labels - 1 );
82+ const size_t na = arc_dist (rand);
83+ for (size_t n = 0 ; n < na; ++n) {
84+ StateId s = ns_dist (rand);
85+ Arc arc;
86+ arc.ilabel = label_dist (rand);
87+ arc.olabel = label_dist (rand);
88+ arc.weight = generate ();
89+ arc.nextstate = ns_dist (rand);
90+ if ((arc_direction == FORWARD_DIRECTION ||
91+ arc_direction == REVERSE_DIRECTION) &&
92+ s == arc.nextstate ) {
93+ continue ; // Skips self-loops.
94+ }
95+
96+ if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate ) ||
97+ (arc_direction == REVERSE_DIRECTION && s < arc.nextstate )) {
98+ StateId t = s; // reverses arcs
99+ s = arc.nextstate ;
100+ arc.nextstate = t;
101+ }
102+
103+ fst->AddArc (s, arc);
79104 }
80-
81- if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate ) ||
82- (arc_direction == REVERSE_DIRECTION && s < arc.nextstate )) {
83- StateId t = s; // reverses arcs
84- s = arc.nextstate ;
85- arc.nextstate = t;
86- }
87-
88- fst->AddArc (s, arc);
89105 }
90106
91107 const StateId nf = std::uniform_int_distribution<>(0 , ns)(rand);
@@ -107,6 +123,7 @@ void RandFst(const int num_random_states, const int num_random_arcs,
107123 }
108124 mask &= ~kTrinaryProperties ;
109125 fst->SetProperties (props & ~mask, mask);
126+ return absl::OkStatus ();
110127}
111128
112129} // namespace fst
0 commit comments