Skip to content

Commit 66664b1

Browse files
committed
Added user input args to mlperf callback, moved mlperf data class out of separate class
1 parent 5495bb6 commit 66664b1

File tree

3 files changed

+84
-78
lines changed

3 files changed

+84
-78
lines changed

include/lbann/callbacks/mlperf_logging.hpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,28 @@ class mlperf_logging : public callback_base {
5353
/** @brief mlperf_logging Constructor.
5454
* @param output_filename Output filename (default = results.txt)
5555
*/
56-
mlperf_logging(std::string output_filename)
56+
mlperf_logging(std::string output_filename, std::string sub_benchmark,
57+
std::string sub_org, std::string sub_division,
58+
std::string sub_status, std::string sub_platform)
5759
: callback_base(/*batch_interval=*/1),
5860
m_output_filename{output_filename.size() ?
5961
std::move(output_filename) :
60-
std::string("results.txt")}
62+
std::string("results.txt")},
63+
m_sub_benchmark{sub_benchmark.size() ?
64+
std::move(sub_benchmark) :
65+
std::string("UNKNOWN_SUBMISSION_BENCHMARK")},
66+
m_sub_org{sub_org.size() ?
67+
std::move(sub_org) :
68+
std::string("LBANN")},
69+
m_sub_division{sub_division.size() ?
70+
std::move(sub_division) :
71+
std::string("UNKNOWN_SUBMISSION_DIVISION")},
72+
m_sub_status{sub_status.size() ?
73+
std::move(sub_status) :
74+
std::string("UNKNOWN_SUBMISSION_STATUS")},
75+
m_sub_platform{sub_platform.size() ?
76+
std::move(sub_platform) :
77+
std::string("UNKNOWN_SUBMISSION_PLATFORM")}
6178
{}
6279

6380
/** @brief Copy interface */
@@ -69,7 +86,7 @@ class mlperf_logging : public callback_base {
6986
std::string name() const override { return "mlperf_logging"; }
7087

7188
/** @brief Push mlperf formatted log string to stream object.
72-
* @param ostream os Stores log strings.
89+
* @param ostringstream os Stores log strings.
7390
* @param event_type et Type of mlperf style event.
7491
* @param string key Mlperf log key.
7592
* @param T value Mlperf log value.
@@ -78,7 +95,7 @@ class mlperf_logging : public callback_base {
7895
* @param double epoch Current epoch number.
7996
*/
8097
template <typename T>
81-
void print(std::ostream& os, mlperf_logging::event_type et, std::string key,
98+
void print(std::ostringstream& os, mlperf_logging::event_type et, std::string key,
8299
T value, char const* file, size_t line, double epoch = -1) const;
83100

84101
void setup(model *m) override;
@@ -93,22 +110,15 @@ class mlperf_logging : public callback_base {
93110
private:
94111

95112
/** @brief Populate log with mlperf event type.
96-
* @param ostream os Stores log string.
113+
* @param ostringstream os Stores log string.
97114
* @param event_type et Type of mlperf style event.
98115
*/
99-
void print_event_type(std::ostream& os, mlperf_logging::event_type et) const;
116+
void print_event_type(std::ostringstream& os, mlperf_logging::event_type et) const;
100117

101118
/** @brief Populate log with value.
102-
* @param ostream os Stores log string.
119+
* @param ostringstream os Stores log string.
103120
* @param event_type et Mlperf log value.
104121
*/
105-
void print_value(std::ostream& os, double value) const;
106-
void print_value(std::ostream& os, long value) const;
107-
void print_value(std::ostream& os, size_t value) const;
108-
void print_value(std::ostream& os, std::string value) const;
109-
//FIXME: Always picks this function first
110-
//template <typename T>
111-
//void print_value(std::ostream& os, T value) const;
112122

113123
static size_t get_ms_since_epoch();
114124

@@ -117,10 +127,14 @@ class mlperf_logging : public callback_base {
117127
//FIXME: get logger to output file
118128
/* @brief name of output file. Default = results.txt */
119129
std::string m_output_filename;
120-
121-
//FIXME: Add custom logging tag
122130
/* @brief DiHydrogen logger */
123-
h2::Logger m_logger;
131+
h2::Logger m_logger{":::MLLOG", m_output_filename};
132+
std::string m_sub_benchmark;
133+
std::string m_sub_org;
134+
std::string m_sub_division;
135+
std::string m_sub_status;
136+
std::string m_sub_platform;
137+
124138

125139
}; // class mlperf_logging
126140

src/callbacks/mlperf_logging.cpp

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,36 @@
4242
namespace lbann {
4343
namespace callback {
4444

45+
// FIXME Does this need an anon namespace since it's only in the cpp file?
46+
void print_value(std::ostringstream& os, double value)
47+
{
48+
os << value;
49+
}
50+
void print_value(std::ostringstream& os, long value)
51+
{
52+
os << value;
53+
}
54+
void print_value(std::ostringstream& os, size_t value)
55+
{
56+
os << value;
57+
}
58+
void print_value(std::ostringstream& os, std::string const& value)
59+
{
60+
os << "\"" << value << "\"";
61+
}
62+
void print_value(std::ostringstream& os, char const* value)
63+
{
64+
os << "\"" << value << "\"";
65+
}
4566
template <typename T>
46-
void mlperf_logging::print(std::ostream& os, mlperf_logging::event_type et,
67+
void print_value(std::ostringstream& os, T value)
68+
{
69+
//FIXME: Should I push the value anyway?
70+
os << "UNKNOWN_DATA_TYPE";
71+
}
72+
73+
template <typename T>
74+
void mlperf_logging::print(std::ostringstream& os, mlperf_logging::event_type et,
4775
std::string key, T value, char const* file,
4876
size_t line, double epoch) const
4977
{
@@ -54,19 +82,22 @@ void mlperf_logging::print(std::ostream& os, mlperf_logging::event_type et,
5482
print_event_type(os, et);
5583

5684
os << "\", "
57-
<< "\"key\": " << key << "\", "
85+
<< "\"key\": \"" << key << "\", "
5886
<< "\"value\": ";
5987
print_value(os, value);
6088
os << ", "
6189
<< "\"metadata\": {\"file\": \"" << file << "\", "
6290
<< "\"lineno\": " << line;
6391
if(epoch < 0)
64-
os << "}}\n";
92+
os << "}}";
6593
else
66-
os << ", " << "\"epoch_num\": " << epoch << "}}\n";
94+
os << ", " << "\"epoch_num\": " << epoch << "}}";
95+
96+
H2_INFO(os.str());
97+
os.flush();
6798
}
6899

69-
void mlperf_logging::print_event_type(std::ostream& os, mlperf_logging::event_type et) const
100+
void mlperf_logging::print_event_type(std::ostringstream& os, mlperf_logging::event_type et) const
70101
{
71102
switch (et) {
72103
case mlperf_logging::event_type::TIME_POINT: os << "POINT_IN_TIME"; break;
@@ -76,30 +107,6 @@ void mlperf_logging::print_event_type(std::ostream& os, mlperf_logging::event_ty
76107
}
77108
}
78109

79-
void mlperf_logging::print_value(std::ostream& os, double value) const
80-
{
81-
os << value;
82-
}
83-
void mlperf_logging::print_value(std::ostream& os, long value) const
84-
{
85-
os << value;
86-
}
87-
void mlperf_logging::print_value(std::ostream& os, size_t value) const
88-
{
89-
os << value;
90-
}
91-
void mlperf_logging::print_value(std::ostream& os, std::string value) const
92-
{
93-
os << value;
94-
}
95-
/*template <typename T>
96-
void mlperf_logging::print_value(std::ostream& os, T value) const
97-
{
98-
//FIXME: Should I push the value anyway?
99-
os << "UNKNOWN_DATA_TYPE";
100-
}
101-
*/
102-
103110
size_t mlperf_logging::get_ms_since_epoch()
104111
{
105112
using namespace std::chrono;
@@ -117,35 +124,24 @@ void mlperf_logging::setup(model *m)
117124
print(os, mlperf_logging::event_type::TIME_POINT, "cache_clear", value,
118125
__FILE__, __LINE__);
119126

120-
//FIXME: Make these user input vars
121-
value = "oc20";
122127
print(os, mlperf_logging::event_type::TIME_POINT, "submission_benchmark",
123-
value, __FILE__, __LINE__);
128+
m_sub_benchmark, __FILE__, __LINE__);
124129

125-
value = "LBANN";
126130
print(os, mlperf_logging::event_type::TIME_POINT, "submission_org",
127-
value, __FILE__, __LINE__);
131+
m_sub_org, __FILE__, __LINE__);
128132

129-
//FIXME: value = closed?
130-
value = "closed";
131133
print(os, mlperf_logging::event_type::TIME_POINT, "submission_division",
132-
value, __FILE__, __LINE__);
134+
m_sub_division, __FILE__, __LINE__);
133135

134-
//FIXME: value = onprem?
135-
value = "onprem";
136136
print(os, mlperf_logging::event_type::TIME_POINT, "submission_status",
137-
value, __FILE__, __LINE__);
137+
m_sub_status, __FILE__, __LINE__);
138138

139-
//FIXME: value = SUBMISSION_PLATFORM_PLACEHOLDER?
140-
value = "?";
141139
print(os, mlperf_logging::event_type::TIME_POINT, "submission_platform",
142-
value, __FILE__, __LINE__);
140+
m_sub_platform, __FILE__, __LINE__);
143141

144142
value = "null";
145143
print(os, mlperf_logging::event_type::TIME_POINT, "init_start", value,
146144
__FILE__, __LINE__);
147-
148-
H2_INFO(os.str());
149145
}
150146
void mlperf_logging::on_setup_end(model *m)
151147
{
@@ -227,8 +223,6 @@ void mlperf_logging::on_setup_end(model *m)
227223

228224
print(os, mlperf_logging::event_type::TIME_POINT, "init_stop", "null",
229225
__FILE__, __LINE__);
230-
231-
H2_INFO(os.str());
232226
}
233227

234228
void mlperf_logging::on_epoch_begin(model *m)
@@ -239,8 +233,6 @@ void mlperf_logging::on_epoch_begin(model *m)
239233

240234
print(os, mlperf_logging::event_type::INT_START, "epoch_start", "null",
241235
__FILE__, __LINE__, epoch);
242-
243-
H2_INFO(os.str());
244236
}
245237

246238
void mlperf_logging::on_epoch_end(model *m)
@@ -251,8 +243,6 @@ void mlperf_logging::on_epoch_end(model *m)
251243

252244
print(os, mlperf_logging::event_type::INT_START, "epoch_stop", "null",
253245
__FILE__, __LINE__, epoch);
254-
255-
H2_INFO(os.str());
256246
}
257247

258248
void mlperf_logging::on_train_begin(model *m)
@@ -264,8 +254,6 @@ void mlperf_logging::on_train_begin(model *m)
264254
//FIXME: run_start? Same time stamp as epoch 1 in results
265255
print(os, mlperf_logging::event_type::INT_START, "run_start", "null",
266256
__FILE__, __LINE__, epoch);
267-
268-
H2_INFO(os.str());
269257
}
270258

271259
void mlperf_logging::on_train_end(model *m)
@@ -277,8 +265,6 @@ void mlperf_logging::on_train_end(model *m)
277265
//FIXME: run_stop? End of training?
278266
print(os, mlperf_logging::event_type::INT_START, "run_stop", "null",
279267
__FILE__, __LINE__, epoch);
280-
281-
H2_INFO(os.str());
282268
}
283269

284270
void mlperf_logging::on_batch_evaluate_begin(model *m)
@@ -289,8 +275,6 @@ void mlperf_logging::on_batch_evaluate_begin(model *m)
289275

290276
print(os, mlperf_logging::event_type::INT_START, "eval_start", "null",
291277
__FILE__, __LINE__, epoch);
292-
293-
H2_INFO(os.str());
294278
}
295279

296280
void mlperf_logging::on_batch_evaluate_end(model *m)
@@ -307,8 +291,6 @@ void mlperf_logging::on_batch_evaluate_end(model *m)
307291
print(os, mlperf_logging::event_type::TIME_POINT, "eval_error",
308292
static_cast<double>(eval_error), __FILE__,
309293
__LINE__, epoch);
310-
311-
H2_INFO(os.str());
312294
}
313295

314296
std::unique_ptr<callback_base>
@@ -318,7 +300,12 @@ build_mlperf_logging_callback_from_pbuf(
318300
{
319301
const auto& params =
320302
dynamic_cast<const lbann_data::Callback::CallbackMlperfLogging&>(proto_msg);
321-
return std::make_unique<mlperf_logging>(params.output_filename());
303+
return std::make_unique<mlperf_logging>(params.sub_benchmark(),
304+
params.sub_org(),
305+
params.sub_division(),
306+
params.sub_status(),
307+
params.sub_platform(),
308+
params.output_filename());
322309
}
323310
} // namespace callback
324311
} // namespace lbann

src/proto/callbacks.proto

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,11 @@ message Callback {
428428

429429
/** @brief Prints mlperf compliant benchmark logs */
430430
message CallbackMlperfLogging {
431-
string output_filename = 1;
431+
string output_filename = 1; // Output filename
432+
string sub_benchmark = 2; // FIXME(KLG): document these
433+
string sub_org = 3;
434+
string sub_division = 4;
435+
string sub_status = 5;
436+
string sub_platform = 6;
432437
}
433438
}

0 commit comments

Comments
 (0)