Skip to content

Commit b8ebd4c

Browse files
committed
Address review comments
1 parent 9da46f3 commit b8ebd4c

File tree

2 files changed

+6
-22
lines changed

2 files changed

+6
-22
lines changed

src/callbacks/gpu_memory_usage.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "lbann/models/model.hpp"
3131
#include "lbann/utils/gpu/helpers.hpp"
3232
#include "lbann/utils/serialize.hpp"
33+
#include <h2/gpu/memory_utils.hpp>
3334
#include <iomanip>
3435
#include <sstream>
3536

@@ -79,13 +80,7 @@ void gpu_memory_usage::write_specific_proto(lbann_data::Callback& proto) const
7980
void gpu_memory_usage::on_epoch_begin(model* m)
8081
{
8182
#ifdef LBANN_HAS_GPU
82-
size_t available;
83-
size_t total;
84-
#ifdef LBANN_HAS_CUDA
85-
FORCE_CHECK_CUDA(cudaMemGetInfo(&available, &total));
86-
#elif defined(LBANN_HAS_ROCM)
87-
FORCE_CHECK_ROCM(hipMemGetInfo(&available, &total));
88-
#endif
83+
auto const [available, total] = h2::gpu::mem_info();
8984
size_t used = total - available;
9085
auto comm = m->get_comm();
9186
if (comm->am_trainer_master()) {

src/callbacks/memory_profiler.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "h2/patterns/multimethods/SwitchDispatcher.hpp"
4040

4141
#include <algorithm>
42+
#include <h2/gpu/memory_utils.hpp>
4243
#include <string>
4344

4445
namespace lbann {
@@ -169,17 +170,12 @@ size_t get_activation_and_error_signal_size(Layer const& x, std::ostream& os)
169170
/**
170171
* @brief Returns the currently used memory, or 0 if LBANN was not compiled with
171172
* GPU support.
173+
* TODO(later): Gather across all ranks?
172174
*/
173175
size_t get_used_gpu_memory()
174176
{
175177
#ifdef LBANN_HAS_GPU
176-
size_t available;
177-
size_t total;
178-
#ifdef LBANN_HAS_CUDA
179-
FORCE_CHECK_CUDA(cudaMemGetInfo(&available, &total));
180-
#elif defined(LBANN_HAS_ROCM)
181-
FORCE_CHECK_ROCM(hipMemGetInfo(&available, &total));
182-
#endif
178+
auto const [available, total] = h2::gpu::mem_info();
183179
// TODO(later): Might be nicer to return a struct with gathered information
184180
// (min, max, median across ranks)
185181
return total - available;
@@ -195,14 +191,7 @@ size_t get_used_gpu_memory()
195191
static inline size_t get_total_gpu_memory()
196192
{
197193
#ifdef LBANN_HAS_GPU
198-
size_t available;
199-
size_t total;
200-
#ifdef LBANN_HAS_CUDA
201-
FORCE_CHECK_CUDA(cudaMemGetInfo(&available, &total));
202-
#elif defined(LBANN_HAS_ROCM)
203-
FORCE_CHECK_ROCM(hipMemGetInfo(&available, &total));
204-
#endif
205-
return total;
194+
return h2::gpu::mem_info().total;
206195
#else
207196
return 0;
208197
#endif

0 commit comments

Comments
 (0)