@@ -101,36 +101,17 @@ bool save_rng_to_checkpoint(persist& p, lbann_comm* comm, bool is_distributed)
101101 save_rng_state (rng_name, get_data_seq_generator ());
102102
103103 rng_name = dirname + " /EL_generator" ;
104- std::ofstream rng_EL (rng_name);
105- if (!rng_EL) {
106- LBANN_ERROR (" Failed to open " , rng_name);
107- }
108- rng_EL << El::Generator ();
109- rng_EL.close ();
104+ save_rng_state (rng_name, El::Generator ());
110105 }
111106
112107 for (int i = 0 ; i < get_num_io_generators (); i++) {
113- rng_name = dirname + " /rng_io_generator_" + rank_in_trainer + " _t" +
114- std::to_string (i);
115- std::ofstream rng_io (rng_name);
116- if (!rng_io) {
117- LBANN_ERROR (" Failed to open " , rng_name);
118- }
119- rng_name = dirname + " /rng_fast_io_generator_" + rank_in_trainer + " _t" +
120- std::to_string (i);
121- std::ofstream rng_fast_io (rng_name);
122- if (!rng_fast_io) {
123- LBANN_ERROR (" Failed to open " , rng_name);
124- }
125-
126108 locked_io_rng_ref io_rng = set_io_generators_local_index (i);
127- // save_rng_state(rng_name);
128- // save_rng_state(rng_name);
129- rng_io << get_io_generator ();
130- rng_fast_io << get_fast_io_generator ();
131-
132- rng_io.close ();
133- rng_fast_io.close ();
109+ save_rng_state (dirname + " /rng_io_generator_" + rank_in_trainer + " _t" +
110+ std::to_string (i),
111+ get_io_generator ());
112+ save_rng_state (dirname + " /rng_fast_io_generator_" + rank_in_trainer +
113+ " _t" + std::to_string (i),
114+ get_fast_io_generator ());
134115 }
135116
136117 rng_name = dirname + " /rng_generator_" + rank_in_trainer;
@@ -143,7 +124,6 @@ bool save_rng_to_checkpoint(persist& p, lbann_comm* comm, bool is_distributed)
143124 save_rng_state (rng_name, get_ltfb_generator ());
144125
145126#if not defined(LBANN_DETERMINISTIC) && defined(_OPENMP)
146- // #ifdef _OPENMP
147127#pragma omp parallel private(rng_name)
148128 {
149129 rng_name = dirname + " /rng_OMP_generator_" + rank_in_trainer + " _" +
@@ -186,11 +166,7 @@ bool load_rng_from_checkpoint(persist& p, const lbann_comm* comm)
186166 load_rng_state (rng_name, get_data_seq_generator ());
187167
188168 rng_name = dirname + " /EL_generator" ;
189- std::ifstream rng_EL (rng_name);
190- if (!rng_EL) {
191- LBANN_ERROR (" Failed to open " , rng_name);
192- }
193- rng_EL >> El::Generator ();
169+ load_rng_state (rng_name, El::Generator ());
194170
195171 std::string rank_in_trainer;
196172 if (comm == nullptr ) {
@@ -201,22 +177,13 @@ bool load_rng_from_checkpoint(persist& p, const lbann_comm* comm)
201177 }
202178
203179 for (int i = 0 ; i < get_num_io_generators (); i++) {
204- rng_name = dirname + " /rng_io_generator_" + rank_in_trainer + " _t" +
205- std::to_string (i);
206- std::ifstream rng_io (rng_name);
207- if (!rng_io) {
208- LBANN_ERROR (" Failed to open " , rng_name);
209- }
210- rng_name = dirname + " /rng_fast_io_generator_" + rank_in_trainer + " _t" +
211- std::to_string (i);
212- std::ifstream rng_fast_io (rng_name);
213- if (!rng_fast_io) {
214- LBANN_ERROR (" Failed to open " , rng_name);
215- }
216-
217180 locked_io_rng_ref io_rng = set_io_generators_local_index (i);
218- rng_io >> get_io_generator ();
219- rng_fast_io >> get_fast_io_generator ();
181+ load_rng_state (dirname + " /rng_io_generator_" + rank_in_trainer + " _t" +
182+ std::to_string (i),
183+ get_io_generator ());
184+ load_rng_state (dirname + " /rng_fast_io_generator_" + rank_in_trainer +
185+ " _t" + std::to_string (i),
186+ get_fast_io_generator ());
220187 }
221188
222189 rng_name = dirname + " /rng_generator_" + rank_in_trainer;
0 commit comments