Skip to content

Commit c615db2

Browse files
committed
added function to get num nodes
1 parent c3ce714 commit c615db2

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

src/callbacks/mlperf_logging.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,29 @@ void print_value(std::ostringstream& os, T value)
7474
//FIXME: Should I push the value anyway?
7575
os << "\"UNKNOWN_DATA_TYPE\"";
7676
}
77+
78+
//FIXME: Tom's problem
79+
int get_real_num_accelerators()
80+
{
81+
return 0;
82+
}
83+
84+
int get_num_nodes()
85+
{
86+
if (std::getenv("SLURM_NNODES"))
87+
return atoi(std::getenv("SLURM_NNODES"));
88+
else if (std::getenv("FLUX_JOB_NNODES"))
89+
return atoi(std::getenv("FLUX_JOB_NNODES"));
90+
else return -1;
91+
//FIXME: count number of unique hostnames in universe?
92+
}
7793
}// namespace
7894

7995
template <typename T>
80-
void mlperf_logging::print(std::ostringstream& os, mlperf_logging::event_type et,
81-
std::string key, T value, char const* file,
82-
size_t line, double epoch) const
96+
void mlperf_logging::print(std::ostringstream& os,
97+
mlperf_logging::event_type et, std::string key,
98+
T value, char const* file, size_t line,
99+
double epoch) const
83100
{
84101
os << "{"
85102
<< "\"namespace\": \"\", "
@@ -144,7 +161,6 @@ void mlperf_logging::setup(model *m)
144161
print(os, mlperf_logging::event_type::TIME_POINT, "submission_platform",
145162
m_sub_platform, __FILE__, __LINE__);
146163

147-
//value = "null";
148164
print(os, mlperf_logging::event_type::INT_START, "init_start", "null",
149165
__FILE__, __LINE__);
150166
}
@@ -157,13 +173,10 @@ void mlperf_logging::on_setup_end(model *m)
157173
print(os, mlperf_logging::event_type::TIME_POINT, "number_of_ranks",
158174
static_cast<int>(comm->get_procs_in_world()), __FILE__, __LINE__);
159175

160-
//FIXME
161-
auto nodes = -1;
162176
print(os, mlperf_logging::event_type::TIME_POINT, "number_of_nodes",
163-
static_cast<int>(nodes), __FILE__, __LINE__);
177+
static_cast<int>(get_num_nodes()), __FILE__, __LINE__);
164178

165-
//FIXME
166-
auto accelerators = -1;
179+
auto accelerators = get_real_num_accelerators();
167180
print(os, mlperf_logging::event_type::TIME_POINT, "accelerators_per_node",
168181
static_cast<int>(accelerators), __FILE__, __LINE__);
169182

@@ -308,6 +321,7 @@ build_mlperf_logging_callback_from_pbuf(
308321
params.sub_division(),
309322
params.sub_status(),
310323
params.sub_platform());
324+
//params.num_nodes());
311325
}
312326
} // namespace callback
313327
} // namespace lbann

0 commit comments

Comments
 (0)