2929#include " lbann/callbacks/mlperf_logging.hpp"
3030#include " lbann/metrics/metric.hpp"
3131#include " lbann/weights/weights.hpp"
32+ #include " lbann/trainers/trainer.hpp"
3233
3334#include < callbacks.pb.h>
3435
4243namespace lbann {
4344namespace callback {
4445
45- // FIXME Does this need an anon namespace since it's only in the cpp file?
46+ namespace {
47+ void print_value (std::ostringstream& os, int value)
48+ {
49+ os << value;
50+ }
4651void print_value (std::ostringstream& os, double value)
4752{
4853 os << value;
@@ -67,8 +72,9 @@ template <typename T>
6772void print_value (std::ostringstream& os, T value)
6873{
6974 // FIXME: Should I push the value anyway?
70- os << " UNKNOWN_DATA_TYPE" ;
75+ os << " \" UNKNOWN_DATA_TYPE\" " ;
7176}
77+ }// namespace
7278
7379template <typename T>
7480void mlperf_logging::print (std::ostringstream& os, mlperf_logging::event_type et,
@@ -114,15 +120,14 @@ size_t mlperf_logging::get_ms_since_epoch()
114120 system_clock::now ().time_since_epoch ()).count ();
115121}
116122
117- // FIXME(KLG): There is no on_setup_begin. Can I steal this as a callback hook?
118123void mlperf_logging::setup (model *m)
119124{
120125 std::ostringstream os;
121126
122- // FIXME: What is this?
123- std::string value = " null" ;
124- print (os, mlperf_logging::event_type::TIME_POINT, " cache_clear" , value,
125- __FILE__, __LINE__);
127+ // Not a good/portable way to do this in C++
128+ // std::string value = "null";
129+ // print(os, mlperf_logging::event_type::TIME_POINT, "cache_clear", value,
130+ // __FILE__, __LINE__);
126131
127132 print (os, mlperf_logging::event_type::TIME_POINT, " submission_benchmark" ,
128133 m_sub_benchmark, __FILE__, __LINE__);
@@ -139,89 +144,89 @@ void mlperf_logging::setup(model *m)
139144 print (os, mlperf_logging::event_type::TIME_POINT, " submission_platform" ,
140145 m_sub_platform, __FILE__, __LINE__);
141146
142- value = " null" ;
143- print (os, mlperf_logging::event_type::TIME_POINT , " init_start" , value ,
147+ // value = "null";
148+ print (os, mlperf_logging::event_type::INT_START , " init_start" , " null " ,
144149 __FILE__, __LINE__);
145150}
146151void mlperf_logging::on_setup_end (model *m)
147152{
148153 std::ostringstream os;
149154 lbann_comm *comm = m->get_comm ();
155+ auto const & trainer = get_const_trainer ();
150156
151- // FIXME: num_trainers or world size?
152157 print (os, mlperf_logging::event_type::TIME_POINT, " number_of_ranks" ,
153- static_cast <double >(comm->get_num_trainers ()), __FILE__, __LINE__);
158+ static_cast <int >(comm->get_procs_in_world ()), __FILE__, __LINE__);
154159
155160 // FIXME
156161 auto nodes = -1 ;
157162 print (os, mlperf_logging::event_type::TIME_POINT, " number_of_nodes" ,
158- static_cast <double >(nodes), __FILE__, __LINE__);
163+ static_cast <int >(nodes), __FILE__, __LINE__);
159164
160165 // FIXME
161166 auto accelerators = -1 ;
162167 print (os, mlperf_logging::event_type::TIME_POINT, " accelerators_per_node" ,
163- static_cast <double >(accelerators), __FILE__, __LINE__);
168+ static_cast <int >(accelerators), __FILE__, __LINE__);
164169
165- // FIXME: From trainer.hpp?
166- auto seed = -1 ;
170+ auto const seed = trainer.get_random_seed ();
167171 print (os, mlperf_logging::event_type::TIME_POINT, " seed" ,
168- static_cast < double >( seed) , __FILE__, __LINE__);
172+ seed, __FILE__, __LINE__);
169173
170- // FIXME: Add get_minibatch_size to model or metrics?
171- auto batch_size = -1 ;
174+ auto const & dc = trainer.get_data_coordinator ();
175+ auto const batch_size = dc.get_global_mini_batch_size (
176+ execution_mode::training);
172177 print (os, mlperf_logging::event_type::TIME_POINT, " global_batch_size" ,
173- static_cast < double >( batch_size) , __FILE__, __LINE__);
178+ batch_size, __FILE__, __LINE__);
174179
175- metric_statistics metrics;
176- auto samples = metrics.get_num_samples ();
180+ auto samples = dc.get_total_num_samples (execution_mode::training);
177181 print (os, mlperf_logging::event_type::TIME_POINT, " train_samples" ,
178- static_cast < double >( samples) , __FILE__, __LINE__);
182+ samples, __FILE__, __LINE__);
179183
180- // FIXME
181- auto eval_samples = - 1 ;
184+ // FIXME: Should this be execution_mode::validation? Tom thinks no
185+ auto eval_samples = dc. get_total_num_samples (execution_mode::testing) ;
182186 print (os, mlperf_logging::event_type::TIME_POINT, " eval_samples" ,
183- static_cast <double >(eval_samples), __FILE__, __LINE__);
184-
185- // FIXME: I couldn't get this to work
186- // auto* optimizer = m->get_weights().get_optimizer();
187- std::string opt = " opt_name" ;
188- print (os, mlperf_logging::event_type::TIME_POINT, " opt_name" ,
189- opt, __FILE__, __LINE__);
190-
191- // FIXME
192- auto opt_learning_rate = -1 ;
193- print (os, mlperf_logging::event_type::TIME_POINT, " opt_base_learning_rate" ,
194- static_cast <double >(opt_learning_rate), __FILE__, __LINE__);
195-
196- // FIXME
197- auto opt_warmup_steps = -1 ;
198- print (os, mlperf_logging::event_type::TIME_POINT,
199- " opt_learning_rate_warmup_steps" ,
200- static_cast <double >(opt_warmup_steps),
201- __FILE__, __LINE__);
202-
203- // FIXME
204- auto opt_warmup_factor = -1 ;
205- print (os, mlperf_logging::event_type::TIME_POINT,
206- " opt_learning_rate_warmup_factor" ,
207- static_cast <double >(opt_warmup_factor),
208- __FILE__, __LINE__);
209-
210- // FIXME
211- auto opt_decay_bound_steps = -1 ;
212- print (os, mlperf_logging::event_type::TIME_POINT,
213- " opt_learning_rate_decay_boundary_steps" ,
214- static_cast <double >(opt_decay_bound_steps),
215- __FILE__, __LINE__);
216-
217- // FIXME
218- auto opt_decay_factor = -1 ;
219- print (os, mlperf_logging::event_type::TIME_POINT,
220- " opt_learning_rate_decay_factor" ,
221- static_cast <double >(opt_decay_factor),
222- __FILE__, __LINE__);
223-
224- print (os, mlperf_logging::event_type::TIME_POINT, " init_stop" , " null" ,
187+ eval_samples, __FILE__, __LINE__);
188+
189+ auto const weights = m->get_weights ();
190+ for (auto const w : weights)
191+ if ( w->get_optimizer () != nullptr ){
192+ std::string opt = w->get_optimizer ()->get_type ();
193+ print (os, mlperf_logging::event_type::TIME_POINT, " opt_name" ,
194+ opt, __FILE__, __LINE__);
195+
196+ auto opt_learning_rate = w->get_optimizer ()->get_learning_rate ();
197+ print (os, mlperf_logging::event_type::TIME_POINT,
198+ " opt_base_learning_rate" , static_cast <double >(opt_learning_rate),
199+ __FILE__, __LINE__);
200+ break ;
201+ }
202+
203+ // LBANN does not perform warmup steps.
204+ // auto opt_warmup_steps = -1;
205+ // print(os, mlperf_logging::event_type::TIME_POINT,
206+ // "opt_learning_rate_warmup_steps",
207+ // static_cast<size_t>(opt_warmup_steps),
208+ // __FILE__, __LINE__);
209+
210+ // auto opt_warmup_factor = -1;
211+ // print(os, mlperf_logging::event_type::TIME_POINT,
212+ // "opt_learning_rate_warmup_factor",
213+ // static_cast<double>(opt_warmup_factor),
214+ // __FILE__, __LINE__);
215+
216+ // FIXME (Tom's problem)
217+ // auto opt_decay_bound_steps = -1;
218+ // print(os, mlperf_logging::event_type::TIME_POINT,
219+ // "opt_learning_rate_decay_boundary_steps",
220+ // static_cast<size_t>(opt_decay_bound_steps),
221+ // __FILE__, __LINE__);
222+
223+ // auto opt_decay_factor = -1;
224+ // print(os, mlperf_logging::event_type::TIME_POINT,
225+ // "opt_learning_rate_decay_factor",
226+ // static_cast<double>(opt_decay_factor),
227+ // __FILE__, __LINE__);
228+
229+ print (os, mlperf_logging::event_type::INT_END, " init_stop" , " null" ,
225230 __FILE__, __LINE__);
226231}
227232
@@ -241,7 +246,7 @@ void mlperf_logging::on_epoch_end(model *m)
241246 const auto & epoch = static_cast <const SGDExecutionContext&>(
242247 m->get_execution_context ()).get_epoch ();
243248
244- print (os, mlperf_logging::event_type::INT_START , " epoch_stop" , " null" ,
249+ print (os, mlperf_logging::event_type::INT_END , " epoch_stop" , " null" ,
245250 __FILE__, __LINE__, epoch);
246251}
247252
@@ -251,7 +256,6 @@ void mlperf_logging::on_train_begin(model *m)
251256 const auto & epoch = static_cast <const SGDExecutionContext&>(
252257 m->get_execution_context ()).get_epoch ();
253258
254- // FIXME: run_start? Same time stamp as epoch 1 in results
255259 print (os, mlperf_logging::event_type::INT_START, " run_start" , " null" ,
256260 __FILE__, __LINE__, epoch);
257261}
@@ -262,8 +266,7 @@ void mlperf_logging::on_train_end(model *m)
262266 const auto & epoch = static_cast <const SGDExecutionContext&>(
263267 m->get_execution_context ()).get_epoch ();
264268
265- // FIXME: run_stop? End of training?
266- print (os, mlperf_logging::event_type::INT_START, " run_stop" , " null" ,
269+ print (os, mlperf_logging::event_type::INT_END, " run_stop" , " null" ,
267270 __FILE__, __LINE__, epoch);
268271}
269272
@@ -283,10 +286,10 @@ void mlperf_logging::on_batch_evaluate_end(model *m)
283286 const auto & epoch = static_cast <const SGDExecutionContext&>(
284287 m->get_execution_context ()).get_epoch ();
285288
286- print (os, mlperf_logging::event_type::INT_START , " eval_stop" , " null" ,
289+ print (os, mlperf_logging::event_type::INT_END , " eval_stop" , " null" ,
287290 __FILE__, __LINE__, epoch);
288291
289- // FIXME
292+ // FIXME (Tom's problem)
290293 auto eval_error = -1 ;
291294 print (os, mlperf_logging::event_type::TIME_POINT, " eval_error" ,
292295 static_cast <double >(eval_error), __FILE__,
@@ -304,8 +307,7 @@ build_mlperf_logging_callback_from_pbuf(
304307 params.sub_org (),
305308 params.sub_division (),
306309 params.sub_status (),
307- params.sub_platform (),
308- params.output_filename ());
310+ params.sub_platform ());
309311}
310312} // namespace callback
311313} // namespace lbann
0 commit comments