-
Notifications
You must be signed in to change notification settings - Fork 128
Expand file tree
/
Copy pathtl_ucp_ep.c
More file actions
169 lines (154 loc) · 5.36 KB
/
tl_ucp_ep.c
File metadata and controls
169 lines (154 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/**
* Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
#include "tl_ucp.h"
#include "tl_ucp_ep.h"
// NOLINTNEXTLINE
static void ucc_tl_ucp_err_handler(void *arg, ucp_ep_h ep, ucs_status_t status)
{
/* In case we don't have OOB barrier, errors are expected.
* This cb will suppress UCX from raising errors*/
;
}
static inline ucc_status_t ucc_tl_ucp_connect_ep(
ucc_tl_ucp_context_t *ctx, int is_service, ucp_ep_h *ep, void *ucp_address)
{
ucp_worker_h worker = (is_service) ? ctx->service_worker.ucp_worker
: ctx->worker.ucp_worker;
ucp_ep_params_t ep_params;
ucs_status_t status;
if (*ep) {
/* Already connected */
return UCC_OK;
}
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ep_params.address = (ucp_address_t *)ucp_address;
if (!UCC_TL_CTX_HAS_OOB(ctx)) {
ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER;
ep_params.err_handler.cb = ucc_tl_ucp_err_handler;
ep_params.err_handler.arg = NULL;
ep_params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER;
}
status = ucp_ep_create(worker, &ep_params, ep);
if (ucc_unlikely(UCS_OK != status)) {
tl_error(
ctx->super.super.lib,
"ucp returned connect error: %s",
ucs_status_string(status));
return ucs_status_to_ucc_status(status);
}
return UCC_OK;
}
ucc_status_t ucc_tl_ucp_connect_team_ep(
ucc_tl_ucp_team_t *team, ucc_rank_t core_rank, ucp_ep_h *ep)
{
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
int use_service_worker = USE_SERVICE_WORKER(team);
void *addr;
addr = ucc_get_team_ep_addr(
UCC_TL_CORE_CTX(team),
UCC_TL_CORE_TEAM(team),
core_rank,
ucc_tl_ucp.super.super.id);
addr = use_service_worker ? TL_UCP_EP_ADDR_WORKER_SERVICE(addr)
: TL_UCP_EP_ADDR_WORKER(addr);
return ucc_tl_ucp_connect_ep(ctx, use_service_worker, ep, addr);
}
/* Finds next non-NULL ep in the storage and returns that handle
for closure. In case of "hash" storage it pops the item,
in case of "array" sets it to NULL */
static inline ucp_ep_h get_next_ep_to_close(
ucc_tl_ucp_worker_t *worker, ucc_tl_ucp_context_t *ctx, int *i)
{
ucp_ep_h ep = NULL;
ucc_rank_t size;
if (worker->eps) {
size = (ucc_rank_t)ctx->super.super.ucc_context->params.oob.n_oob_eps;
while (NULL == ep && (*i) < size) {
ep = worker->eps[*i];
worker->eps[*i] = NULL;
(*i)++;
}
} else {
ep = tl_ucp_hash_pop(worker->ep_hash);
}
return ep;
}
void ucc_tl_ucp_close_eps(
ucc_tl_ucp_worker_t *worker, ucc_tl_ucp_context_t *ctx)
{
int i = 0;
int n_reqs = 0;
int n_inflight;
int j;
ucp_ep_h ep;
ucs_status_t status;
ucs_status_ptr_t close_req;
ucp_request_param_t param;
size_t max_eps;
ucs_status_ptr_t *reqs;
max_eps = worker->eps
? (size_t)ctx->super.super.ucc_context->params.oob.n_oob_eps
: kh_size(worker->ep_hash);
if (max_eps == 0) {
return;
}
reqs = (ucs_status_ptr_t *)ucc_calloc(max_eps, sizeof(*reqs), "close_reqs");
if (!reqs) {
tl_error(
ctx->super.super.lib, "failed to allocate close requests array");
return;
}
/* Use graceful flush with OOB, force close otherwise */
param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
param.flags = UCC_TL_CTX_HAS_OOB(ctx) ? 0 : UCP_EP_CLOSE_FLAG_FORCE;
ep = get_next_ep_to_close(worker, ctx, &i);
while (ep) {
close_req = ucp_ep_close_nbx(ep, ¶m);
if (UCS_PTR_IS_PTR(close_req)) {
reqs[n_reqs++] = close_req;
} else {
status = UCS_PTR_STATUS(close_req);
ucc_assert(status <= UCS_OK);
if (status != UCS_OK) {
tl_error(
ctx->super.super.lib,
"error during ucp ep close, ep %p, status %s",
ep,
ucs_status_string(status));
}
}
ep = get_next_ep_to_close(worker, ctx, &i);
}
n_inflight = n_reqs;
while (n_inflight > 0) {
ucp_worker_progress(ctx->worker.ucp_worker);
if (ctx->cfg.service_worker != 0) {
ucp_worker_progress(ctx->service_worker.ucp_worker);
}
n_inflight = 0;
for (j = 0; j < n_reqs; j++) {
if (!reqs[j]) {
continue;
}
status = ucp_request_check_status(reqs[j]);
if (status != UCS_INPROGRESS) {
ucc_assert(status <= UCS_OK);
if (status != UCS_OK) {
tl_error(
ctx->super.super.lib,
"error during ucp ep close, status %s",
ucs_status_string(status));
}
ucp_request_free(reqs[j]);
reqs[j] = NULL;
} else {
n_inflight++;
}
}
}
ucc_free(reqs);
}