33
44#pragma once
55
6- #include < array>
7- #include < cstddef>
8- #include < cstdint>
96#include < cstdio>
107#include < cstdlib>
11- #include < fstream>
12- #include < string>
13- #include < string_view>
14- #include < vector>
158
169#include < fmt/core.h>
1710#include < mpi.h>
1811#include < nccl.h>
1912#include < cuda_runtime.h>
2013
21- namespace {
22-
23- enum struct LogLevel {
24- FATAL,
25- ERROR,
26- WARN,
27- INFO,
28- TRACE,
29- };
30-
31- using namespace std ::string_view_literals;
32- constexpr std::array level_txt{" FATAL" sv, " ERROR" sv, " WARNING" sv, " INFO" sv, " TRACE" sv};
33-
34- #define KC_LOG (lvl, ...) \
35- fmt::println (" [{}] {}:{}: {}" , level_txt[static_cast <int >(lvl)], __FILE__, __LINE__, fmt::format(__VA_ARGS__))
36-
37- #define KC_FATAL (...) (KC_LOG(LogLevel::FATAL, __VA_ARGS__), std::exit(EXIT_FAILURE))
38-
39- #define KC_ERROR (...) KC_LOG(LogLevel::ERROR, __VA_ARGS__)
14+ #include " logging.hpp"
4015
41- #define KC_WARN (...) KC_LOG(LogLevel::WARN, __VA_ARGS__)
42-
43- #define KC_INFO (...) KC_LOG(LogLevel::INFO, __VA_ARGS__)
44-
45- #define KC_TRACE (...) KC_LOG(LogLevel::TRACE, __VA_ARGS__)
46-
47- #define KC_CHECK (expr, ...) ((expr) ? void (0 ) : KC_FATAL(__VA_ARGS__))
48-
49- #define KC_MPI_CHECK (expr ) \
50- ([&]() { \
51- int kc_res_ = (expr); \
52- return kc_res_ == MPI_SUCCESS ? void (0 ) : KC_FATAL (" MPI check failed: `" #expr " `: {}" , kc_res_); \
53- }())
54-
55- #define KC_NCCL_CHECK (expr ) \
56- ([&]() { \
57- ncclResult_t kc_res_ = (expr); \
58- return kc_res_ == ncclSuccess ? void (0 ) \
59- : KC_FATAL (" NCCL check failed: `" #expr " `: {}" , ncclGetErrorString (kc_res_)); \
60- }())
61-
62- #define KC_CUDA_CHECK (expr ) \
63- ([&]() { \
64- cudaError_t kc_res_ = (expr); \
65- return kc_res_ == cudaSuccess ? void (0 ) \
66- : KC_FATAL (" CUDA check failed: `" #expr " `: {}" , cudaGetErrorString (kc_res_)); \
67- }())
16+ namespace {
6817
6918[[nodiscard]] auto get_local_rank (MPI_Comm comm, int my_rank) -> int {
7019 MPI_Comm node_comm;
@@ -94,16 +43,15 @@ class Ctx {
9443 int n_ranks, my_rank;
9544 MPI_Comm_size (mpi_comm, &n_ranks);
9645 MPI_Comm_rank (mpi_comm, &my_rank);
97- KC_INFO (" P%d/%d - MPI initialized" , n_ranks, my_rank);
9846 int local_rank = get_local_rank (mpi_comm, my_rank);
9947
10048 int n_gpus;
10149 KC_CUDA_CHECK (cudaGetDeviceCount (&n_gpus));
102- KC_INFO (" P%d found %d CUDA devices" , my_rank, n_gpus);
50+ KC_INFO (" P{} found {} CUDA devices" , my_rank, n_gpus);
10351
104- KC_CHECK (local_rank <= n_gpus, " P%d needs GPU %d but only %d devices available" , my_rank, local_rank, n_gpus);
52+ KC_CHECK (local_rank <= n_gpus, " P{} needs device #{} but only {} devices available" , my_rank, local_rank, n_gpus);
10553 KC_CUDA_CHECK (cudaSetDevice (local_rank));
106- KC_INFO (" P%d assigned to CUDA device %d " , my_rank, local_rank);
54+ KC_INFO (" P{} assigned to CUDA device #{} " , my_rank, local_rank);
10755
10856 // Get NCCL unique ID at rank 0 and broadcast it to all others
10957 ncclUniqueId nccl_id;
@@ -125,6 +73,7 @@ class Ctx {
12573 }
12674
12775 ~Ctx () { KC_NCCL_CHECK (ncclCommDestroy (comm_)); }
76+ // Forbid copies and moves
12877 Ctx (const Ctx &) = delete ;
12978 auto operator =(const Ctx &) -> Ctx & = delete ;
13079 Ctx (Ctx &&) = delete ;
0 commit comments