Skip to content

Commit d166674

Browse files
committed
Cuda13 linux nvshmem (PaddlePaddle#75557)
* nvshmem cuda13 * cuda13 * templete bypass
1 parent a9af5ed commit d166674

3 files changed

Lines changed: 344 additions & 1 deletion

File tree

.github/workflows/CheckPRTemplate.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@ jobs:
1616
- name: Clone paddle
1717
uses: actions/checkout@v4
1818

19+
- name: Check bypass
20+
id: check-bypass
21+
uses: ./.github/actions/check-bypass
22+
with:
23+
github-token: ${{ secrets.GITHUB_TOKEN }}
24+
workflow-name: template
25+
1926
- name: Check PR Template
27+
if: steps.check-bypass.outputs.can-skip != 'true'
2028
env:
2129
AGILE_PULL_ID: ${{ github.event.pull_request.number }}
2230
AGILE_COMPILE_BRANCH: ${{ github.base_ref }}

cmake/external/nvshmem.cmake

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ else()
5353
extern_nvshmem)
5454
endif()
5555

56-
set(NVSHMEM_PATCH_PATH ${PADDLE_SOURCE_DIR}/patches/nvshmem/nvshmem.patch)
56+
if(CUDA_VERSION VERSION_GREATER_EQUAL 13)
57+
set(NVSHMEM_PATCH_PATH
58+
${PADDLE_SOURCE_DIR}/patches/nvshmem/nvshmem_cuda13.patch)
59+
else()
60+
set(NVSHMEM_PATCH_PATH ${PADDLE_SOURCE_DIR}/patches/nvshmem/nvshmem.patch)
61+
endif()
5762
set(NVSHMEM_PATCH_COMMAND
5863
git init && git config --global --add safe.directory ${NVSHMEM_SOURCE_DIR}
5964
&& git config user.name "PaddlePaddle" && git config user.email
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
2+
index cba899b..88f291d 100644
3+
--- a/src/CMakeLists.txt
4+
+++ b/src/CMakeLists.txt
5+
@@ -213,8 +213,8 @@ set_target_properties(nvshmem nvshmem_host
6+
PROPERTIES POSITION_INDEPENDENT_CODE ON
7+
CXX_STANDARD_REQUIRED ON
8+
CUDA_STANDARD_REQUIRED ON
9+
- CXX_STANDARD 11
10+
- CUDA_STANDARD 11
11+
+ CXX_STANDARD 17
12+
+ CUDA_STANDARD 17
13+
CUDA_SEPARABLE_COMPILATION ON
14+
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib"
15+
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib"
16+
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
17+
index 8b8a263..080a8fe 100644
18+
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
19+
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
20+
@@ -46,6 +46,8 @@
21+
qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
22+
qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
23+
qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
24+
+ qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
25+
+ qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
26+
qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
27+
qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
28+
} while (0);
29+
@@ -168,14 +170,18 @@ typedef struct {
30+
uint64_t get_head; // last wqe idx + 1 with a "fetch" operation (g, get, amo_fetch)
31+
uint64_t get_tail; // last wqe idx + 1 polled with cst; get_tail > get_head is possible
32+
} tx_wq;
33+
+ struct {
34+
+ uint64_t resv_head; // last reserved wqe idx + 1
35+
+ uint64_t cons_idx; // polled wqe idx + 1 (consumer index + 1)
36+
+ } rx_wq;
37+
struct {
38+
uint64_t head;
39+
uint64_t tail;
40+
} ibuf;
41+
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
42+
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
43+
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96,
44+
- "ibgda_device_qp_management_v1 must be 96 bytes.");
45+
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112,
46+
+ "ibgda_device_qp_management_v1 must be 112 bytes.");
47+
48+
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
49+
50+
@@ -199,9 +205,19 @@ typedef struct nvshmemi_ibgda_device_qp {
51+
// May point to mvars.prod_idx or internal prod_idx
52+
uint64_t *prod_idx;
53+
} tx_wq;
54+
+ struct {
55+
+ uint16_t nwqes;
56+
+ uint64_t tail;
57+
+ void *wqe;
58+
+ __be32 *dbrec;
59+
+ void *bf;
60+
+ nvshmemi_ibgda_device_cq_t *cq;
61+
+ // May point to mvars.prod_idx or internal prod_idx
62+
+ uint64_t *prod_idx;
63+
+ } rx_wq;
64+
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
65+
} nvshmemi_ibgda_device_qp_v1;
66+
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, "ibgda_device_qp_v1 must be 184 bytes.");
67+
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, "ibgda_device_qp_v1 must be 256 bytes.");
68+
69+
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
70+
71+
diff --git a/src/modules/transport/common/transport_ib_common.cpp b/src/modules/transport/common/transport_ib_common.cpp
72+
index c89f408..f99018a 100644
73+
--- a/src/modules/transport/common/transport_ib_common.cpp
74+
+++ b/src/modules/transport/common/transport_ib_common.cpp
75+
@@ -26,6 +26,9 @@ int nvshmemt_ib_common_nv_peer_mem_available() {
76+
if (access("/sys/kernel/mm/memory_peers/nvidia-peermem/version", F_OK) == 0) {
77+
return NVSHMEMX_SUCCESS;
78+
}
79+
+ if (access("/sys/module/nvidia_peermem/version", F_OK) == 0) {
80+
+ return NVSHMEMX_SUCCESS;
81+
+ }
82+
83+
return NVSHMEMX_ERROR_INTERNAL;
84+
}
85+
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
86+
index ef325cd..bc339c5 100644
87+
--- a/src/modules/transport/ibgda/ibgda.cpp
88+
+++ b/src/modules/transport/ibgda/ibgda.cpp
89+
@@ -198,6 +198,7 @@ struct ibgda_ep {
90+
off_t dbr_offset;
91+
92+
struct ibgda_cq *send_cq;
93+
+ struct ibgda_cq *recv_cq;
94+
struct ibv_ah *ah;
95+
96+
uint32_t user_index;
97+
@@ -1066,7 +1067,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) {
98+
ibgda_host_mem_free(mobject);
99+
}
100+
101+
-static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) {
102+
+static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) {
103+
int status = 0;
104+
105+
struct ibgda_cq *gcq = NULL;
106+
@@ -1117,7 +1118,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device)
107+
cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context);
108+
DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);
109+
DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B);
110+
- DEVX_SET(cqc, cq_context, cc, 0x1); // Use collapsed CQ
111+
+ DEVX_SET(cqc, cq_context, cc, cc); // Use collapsed CQ
112+
DEVX_SET(cqc, cq_context, oi, 0x1); // Allow overrun
113+
DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id);
114+
DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe));
115+
@@ -1538,7 +1539,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
116+
117+
struct ibv_context *context = device->context;
118+
119+
- unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes;
120+
+ // Each RC qp has one send CQ and one recv CQ.
121+
+ unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2;
122+
123+
assert(ibgda_qp_depth > 0);
124+
size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
125+
@@ -1701,7 +1703,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
126+
}
127+
128+
// Allocate and map WQ buffer for all QPs.
129+
- wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB; // num_wqebb is always a power of 2
130+
+ // Todo: reduce the size of wq buffer.
131+
+ wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2; // num_wqebb is always a power of 2
132+
wq_buf_size = wq_buf_size_per_qp * num_eps;
133+
status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE);
134+
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "cannot allocate wq buf.\n");
135+
@@ -1882,8 +1885,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
136+
int cqe_version = 0;
137+
138+
struct ibgda_cq *send_cq = NULL;
139+
+ struct ibgda_cq *recv_cq = NULL;
140+
141+
size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
142+
+ size_t num_recv_wqe = ibgda_qp_depth;
143+
+ size_t recv_wqe_size = 16;
144+
145+
int status = 0;
146+
147+
@@ -1911,6 +1917,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
148+
status = ibgda_create_cq(&send_cq, device);
149+
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
150+
151+
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
152+
+ status = ibgda_create_cq(&recv_cq, device);
153+
+ NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
154+
+ }
155+
+
156+
ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep));
157+
NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
158+
"Unable to allocate mem for ep.\n");
159+
@@ -1939,12 +1950,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
160+
DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED);
161+
DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn);
162+
DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id); // BF register
163+
- DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue
164+
- DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
165+
DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn);
166+
- DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn);
167+
+ DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn);
168+
DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb));
169+
- DEVX_SET(qpc, qp_context, log_rq_size, 0);
170+
DEVX_SET(qpc, qp_context, cs_req, 0); // Disable CS Request
171+
DEVX_SET(qpc, qp_context, cs_res, 0); // Disable CS Response
172+
DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE); // Enable dbr_umem_id
173+
@@ -1953,6 +1961,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
174+
DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id); // DBR buffer
175+
DEVX_SET(qpc, qp_context, user_index, qp_idx);
176+
DEVX_SET(qpc, qp_context, page_offset, 0);
177+
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){
178+
+ DEVX_SET(qpc, qp_context, rq_type, 0); // Regular recv queue
179+
+ DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe
180+
+ DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B
181+
+ } else {
182+
+ DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue, DC must use this.
183+
+ DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
184+
+ DEVX_SET(qpc, qp_context, log_rq_size, 0);
185+
+ }
186+
187+
ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));
188+
NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out,
189+
@@ -1962,9 +1979,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
190+
ep->portid = portid;
191+
192+
ep->sq_cnt = num_wqebb;
193+
- ep->sq_buf_offset = 0;
194+
+ ep->sq_buf_offset = num_recv_wqe * recv_wqe_size;
195+
196+
- ep->rq_cnt = 0;
197+
+ ep->rq_cnt = num_recv_wqe;
198+
ep->rq_buf_offset = 0;
199+
200+
ep->wq_mobject = device->qp_shared_object.wq_mobject;
201+
@@ -1978,6 +1995,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
202+
ep->uar_mobject = uar_mobject;
203+
204+
ep->send_cq = send_cq;
205+
+ ep->recv_cq = recv_cq;
206+
207+
ep->qp_type = qp_type;
208+
209+
@@ -1989,6 +2007,7 @@ out:
210+
if (status) {
211+
if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject);
212+
if (send_cq) ibgda_destroy_cq(send_cq);
213+
+ if (recv_cq) ibgda_destroy_cq(recv_cq);
214+
if (ep) free(ep);
215+
}
216+
217+
@@ -2287,6 +2306,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) {
218+
ibgda_destroy_cq(ep->send_cq);
219+
}
220+
221+
+ if (ep->recv_cq) {
222+
+ ibgda_destroy_cq(ep->recv_cq);
223+
+ }
224+
+
225+
if (ep->ah) {
226+
ftable.destroy_ah(ep->ah);
227+
}
228+
@@ -2318,7 +2341,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
229+
dev_qp->qpn = ep->qpn;
230+
231+
assert(ep->wq_mobject->has_gpu_mapping);
232+
- dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset);
233+
+ dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset);
234+
235+
if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) {
236+
assert(ep->dbr_mobject->has_gpu_mapping);
237+
@@ -2330,6 +2353,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
238+
}
239+
240+
dev_qp->tx_wq.nwqes = ep->sq_cnt;
241+
+ if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
242+
+ dev_qp->rx_wq.nwqes = ep->rq_cnt;
243+
+ dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset);
244+
+ dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset);
245+
+ dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr;
246+
+ }
247+
248+
ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr;
249+
ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps);
250+
@@ -2379,6 +2408,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
251+
nvshmemi_ibgda_device_cq_t *cq_d = NULL;
252+
nvshmemi_ibgda_device_cq_t *cq_h = NULL;
253+
254+
+ nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL;
255+
+ nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL;
256+
+
257+
uint8_t *qp_group_switches_d = NULL;
258+
259+
const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars);
260+
@@ -2386,6 +2418,8 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
261+
const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx);
262+
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
263+
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
264+
+ const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
265+
+ const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx);
266+
267+
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
268+
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
269+
@@ -2421,7 +2455,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
270+
num_dct_handles += device->dct.num_eps * n_pes;
271+
num_dci_handles += device->dci.num_eps;
272+
num_rc_handles += device->rc.num_eps_per_pe * n_pes;
273+
- num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1));
274+
+ num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2);
275+
num_shared_dci_handles += device->dci.num_shared_eps;
276+
}
277+
assert(num_dci_handles - num_shared_dci_handles >= 0);
278+
@@ -2456,6 +2490,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
279+
for (int i = 0; i < num_cq_handles; i++) {
280+
nvshmemi_init_ibgda_device_cq(cq_h[i]);
281+
}
282+
+
283+
+ recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h));
284+
+ NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "recv_cq calloc err.");
285+
+ nvshmemi_init_ibgda_device_cq(recv_cq_h[0]);
286+
/* allocate host memory for dct, rc, cq, dci end */
287+
288+
/* allocate device memory for dct, rc, cq, dci start */
289+
@@ -2559,6 +2597,15 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
290+
}
291+
292+
++cq_idx;
293+
+
294+
+ rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx];
295+
+
296+
+ ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
297+
+ cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
298+
+ cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset);
299+
+ cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
300+
+ cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
301+
+ ++cq_idx;
302+
}
303+
}
304+
}
305+
@@ -2936,17 +2983,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id
306+
INFO(ibgda_state->log_level, "Creating %d RC QPs", device->rc.num_eps_per_pe);
307+
for (int i = 0; i < num_rc_eps; ++i) {
308+
// Do not create loopback to self
309+
- if (i / device->rc.num_eps_per_pe == mype) {
310+
+ int dst_pe = (i + 1 + mype) % n_pes;
311+
+ int offset = i / n_pes;
312+
+ int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset;
313+
+ if (dst_pe == mype) {
314+
continue;
315+
}
316+
- status = ibgda_create_qp(&device->rc.eps[i], device, portid, i,
317+
+ status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i,
318+
NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC);
319+
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
320+
- "ibgda_create_dci failed on RC #%d.", i);
321+
+ "ibgda_create_dci failed on RC #%d.", mapped_i);
322+
323+
- status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device);
324+
+ status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device);
325+
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
326+
- "ibgda_get_rc_handle failed on RC #%d.", i);
327+
+ "ibgda_get_rc_handle failed on RC #%d.", mapped_i);
328+
}
329+
330+
if (num_rc_eps) {

0 commit comments

Comments
 (0)