Skip to content

Commit 6e32bac

Browse files
committed
Additional cleanup of RNG checkpoint code.
1 parent f4f5ae8 commit 6e32bac

File tree

1 file changed

+14
-47
lines changed

1 file changed

+14
-47
lines changed

src/utils/random.cpp

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -101,36 +101,17 @@ bool save_rng_to_checkpoint(persist& p, lbann_comm* comm, bool is_distributed)
101101
save_rng_state(rng_name, get_data_seq_generator());
102102

103103
rng_name = dirname + "/EL_generator";
104-
std::ofstream rng_EL(rng_name);
105-
if (!rng_EL) {
106-
LBANN_ERROR("Failed to open ", rng_name);
107-
}
108-
rng_EL << El::Generator();
109-
rng_EL.close();
104+
save_rng_state(rng_name, El::Generator());
110105
}
111106

112107
for (int i = 0; i < get_num_io_generators(); i++) {
113-
rng_name = dirname + "/rng_io_generator_" + rank_in_trainer + "_t" +
114-
std::to_string(i);
115-
std::ofstream rng_io(rng_name);
116-
if (!rng_io) {
117-
LBANN_ERROR("Failed to open ", rng_name);
118-
}
119-
rng_name = dirname + "/rng_fast_io_generator_" + rank_in_trainer + "_t" +
120-
std::to_string(i);
121-
std::ofstream rng_fast_io(rng_name);
122-
if (!rng_fast_io) {
123-
LBANN_ERROR("Failed to open ", rng_name);
124-
}
125-
126108
locked_io_rng_ref io_rng = set_io_generators_local_index(i);
127-
// save_rng_state(rng_name);
128-
// save_rng_state(rng_name);
129-
rng_io << get_io_generator();
130-
rng_fast_io << get_fast_io_generator();
131-
132-
rng_io.close();
133-
rng_fast_io.close();
109+
save_rng_state(dirname + "/rng_io_generator_" + rank_in_trainer + "_t" +
110+
std::to_string(i),
111+
get_io_generator());
112+
save_rng_state(dirname + "/rng_fast_io_generator_" + rank_in_trainer +
113+
"_t" + std::to_string(i),
114+
get_fast_io_generator());
134115
}
135116

136117
rng_name = dirname + "/rng_generator_" + rank_in_trainer;
@@ -143,7 +124,6 @@ bool save_rng_to_checkpoint(persist& p, lbann_comm* comm, bool is_distributed)
143124
save_rng_state(rng_name, get_ltfb_generator());
144125

145126
#if not defined(LBANN_DETERMINISTIC) && defined(_OPENMP)
146-
// #ifdef _OPENMP
147127
#pragma omp parallel private(rng_name)
148128
{
149129
rng_name = dirname + "/rng_OMP_generator_" + rank_in_trainer + "_" +
@@ -186,11 +166,7 @@ bool load_rng_from_checkpoint(persist& p, const lbann_comm* comm)
186166
load_rng_state(rng_name, get_data_seq_generator());
187167

188168
rng_name = dirname + "/EL_generator";
189-
std::ifstream rng_EL(rng_name);
190-
if (!rng_EL) {
191-
LBANN_ERROR("Failed to open ", rng_name);
192-
}
193-
rng_EL >> El::Generator();
169+
load_rng_state(rng_name, El::Generator());
194170

195171
std::string rank_in_trainer;
196172
if (comm == nullptr) {
@@ -201,22 +177,13 @@ bool load_rng_from_checkpoint(persist& p, const lbann_comm* comm)
201177
}
202178

203179
for (int i = 0; i < get_num_io_generators(); i++) {
204-
rng_name = dirname + "/rng_io_generator_" + rank_in_trainer + "_t" +
205-
std::to_string(i);
206-
std::ifstream rng_io(rng_name);
207-
if (!rng_io) {
208-
LBANN_ERROR("Failed to open ", rng_name);
209-
}
210-
rng_name = dirname + "/rng_fast_io_generator_" + rank_in_trainer + "_t" +
211-
std::to_string(i);
212-
std::ifstream rng_fast_io(rng_name);
213-
if (!rng_fast_io) {
214-
LBANN_ERROR("Failed to open ", rng_name);
215-
}
216-
217180
locked_io_rng_ref io_rng = set_io_generators_local_index(i);
218-
rng_io >> get_io_generator();
219-
rng_fast_io >> get_fast_io_generator();
181+
load_rng_state(dirname + "/rng_io_generator_" + rank_in_trainer + "_t" +
182+
std::to_string(i),
183+
get_io_generator());
184+
load_rng_state(dirname + "/rng_fast_io_generator_" + rank_in_trainer +
185+
"_t" + std::to_string(i),
186+
get_fast_io_generator());
220187
}
221188

222189
rng_name = dirname + "/rng_generator_" + rank_in_trainer;

0 commit comments

Comments
 (0)