forked from deepseek-ai/DeepEP
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathruntime.cu
More file actions
126 lines (99 loc) · 3.43 KB
/
runtime.cu
File metadata and controls
126 lines (99 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <cstring>
#include <vector>
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#ifndef DISABLE_NVSHMEM
#include "ibgda_device.cuh"
#include "nvshmem.h"
#endif
namespace deep_ep {
namespace intranode {
template <int kNumRanks>
__global__ void barrier(int** barrier_signal_ptrs, int rank) {
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
}
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \
break
SETUP_LAUNCH_CONFIG(1, 32, stream);
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
#undef BARRIER_LAUNCH_CASE
}
} // namespace intranode
namespace internode {
#ifndef DISABLE_NVSHMEM
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;
#ifndef GLOBALS_H
#define GLOBALS_H
class GlobalState {
public:
static GlobalState& instance() {
static GlobalState inst;
return inst;
}
int64_t counter;
private:
GlobalState() : counter(0) {}
GlobalState(const GlobalState&) = delete;
GlobalState& operator=(const GlobalState&) = delete;
};
#endif
std::vector<uint8_t> get_unique_id() {
nvshmemx_uniqueid_t unique_id;
nvshmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
return result;
}
int init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
nvshmemx_uniqueid_t root_unique_id;
nvshmemx_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
// Initialize before nvshmem_team_split_strided
nvshmem_barrier_all();
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD,
rank % NUM_MAX_NVL_PEERS,
NUM_MAX_NVL_PEERS,
num_ranks / NUM_MAX_NVL_PEERS,
&cpu_rdma_team_config,
0,
&cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}
// nvshmem_barrier_all();
return nvshmem_my_pe();
}
void* alloc(size_t size, size_t alignment) {
return nvshmem_align(alignment, size);
}
void free(void* ptr) {
nvshmem_free(ptr);
}
void barrier() {
nvshmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
nvshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
}
GlobalState::instance().counter++;
if (GlobalState::instance().counter > 1) {
nvshmem_finalize();
}
// nvshmem_finalize();
}
#endif
} // namespace internode
} // namespace deep_ep