Skip to content

Commit 3d04127

Browse files
jiaxiyanshijin-aws
authored andcommitted
prov/efa: Implement FI_CONTEXT2 in EFA Direct
Store the completion flags and peer address in FI_CONTEXT2 and retrieve later when writing cq. Signed-off-by: Jessie Yang <[email protected]>
1 parent 56254ad commit 3d04127

File tree

5 files changed

+102
-17
lines changed

5 files changed

+102
-17
lines changed

prov/efa/src/efa.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,41 @@ struct efa_fabric {
107107
#endif
108108
};
109109

110+
struct efa_context {
111+
uint64_t completion_flags;
112+
fi_addr_t addr;
113+
};
114+
115+
#if defined(static_assert)
116+
static_assert(sizeof(struct efa_context) <= sizeof(struct fi_context2),
117+
"efa_context must not be larger than fi_context2");
118+
#endif
119+
120+
/**
121+
* Prepare and return a pointer to an EFA context structure.
122+
*
123+
* @param context Pointer to the msg context.
124+
* @param addr Peer address associated with the operation.
125+
* @param flags Operation flags (e.g., FI_COMPLETION).
126+
* @param completion_flags Completion flags reported in the cq entry.
127+
* @return A pointer to an initialized EFA context structure,
128+
* or NULL if context is invalid or FI_COMPLETION is not set.
129+
*/
130+
static inline struct efa_context *efa_fill_context(const void *context,
131+
fi_addr_t addr,
132+
uint64_t flags,
133+
uint64_t completion_flags)
134+
{
135+
if (!context || !(flags & FI_COMPLETION))
136+
return NULL;
137+
138+
struct efa_context *efa_context = (struct efa_context *) context;
139+
efa_context->completion_flags = completion_flags;
140+
efa_context->addr = addr;
141+
142+
return efa_context;
143+
}
144+
110145
static inline
111146
int efa_str_to_ep_addr(const char *node, const char *service, struct efa_ep_addr *addr)
112147
{

prov/efa/src/efa_cq.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ static void efa_cq_construct_cq_entry(struct ibv_cq_ex *ibv_cqx,
3535
struct fi_cq_tagged_entry *entry)
3636
{
3737
entry->op_context = (void *)ibv_cqx->wr_id;
38-
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
38+
if (ibv_cqx->wr_id)
39+
entry->flags = ((struct efa_context *) ibv_cqx->wr_id)->completion_flags;
40+
else
41+
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
3942
entry->len = ibv_wc_read_byte_len(ibv_cqx);
4043
entry->buf = NULL;
4144
entry->data = 0;
@@ -80,8 +83,7 @@ static void efa_cq_handle_error(struct efa_base_ep *base_ep,
8083
err_entry.prov_errno = prov_errno;
8184

8285
if (is_tx)
83-
// TODO: get correct peer addr for TX operation
84-
addr = FI_ADDR_NOTAVAIL;
86+
addr = ibv_cq_ex->wr_id ? ((struct efa_context *)ibv_cq_ex->wr_id)->addr : FI_ADDR_NOTAVAIL;
8587
else
8688
addr = efa_av_reverse_lookup(base_ep->av,
8789
ibv_wc_read_slid(ibv_cq_ex),

prov/efa/src/efa_msg.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi
101101
wr = &base_ep->efa_recv_wr_vec[wr_index].wr;
102102
wr->num_sge = msg->iov_count;
103103
wr->sg_list = base_ep->efa_recv_wr_vec[wr_index].sge;
104-
wr->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
104+
wr->wr_id = (uintptr_t) efa_fill_context(msg->context, msg->addr, flags,
105+
FI_RECV | FI_MSG);
105106

106107
for (i = 0; i < msg->iov_count; i++) {
107108
addr = (uintptr_t)msg->msg_iov[i].iov_base;
@@ -224,7 +225,8 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi
224225
base_ep->is_wr_started = true;
225226
}
226227

227-
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
228+
qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
229+
msg->context, msg->addr, flags, FI_SEND | FI_MSG);
228230

229231
if (flags & FI_REMOTE_CQ_DATA) {
230232
ibv_wr_send_imm(qp->ibv_qp_ex, msg->data);

prov/efa/src/efa_rma.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep,
9090
ibv_wr_start(qp->ibv_qp_ex);
9191
base_ep->is_wr_started = true;
9292
}
93-
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
93+
94+
qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
95+
msg->context, msg->addr, flags, FI_RMA | FI_READ);
9496

9597
/* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */
9698
ibv_wr_rdma_read(qp->ibv_qp_ex, msg->rma_iov[0].key, msg->rma_iov[0].addr);
@@ -225,7 +227,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,
225227
ibv_wr_start(qp->ibv_qp_ex);
226228
base_ep->is_wr_started = true;
227229
}
228-
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
230+
231+
qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
232+
msg->context, msg->addr, flags, FI_RMA | FI_WRITE);
229233

230234
if (flags & FI_REMOTE_CQ_DATA) {
231235
ibv_wr_rdma_write_imm(qp->ibv_qp_ex, msg->rma_iov[0].key,

prov/efa/test/efa_unit_test_cq.c

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,8 @@ void test_ibv_cq_ex_read_ignore_removed_peer()
811811
#endif
812812

813813
static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
814-
int ibv_wc_opcode, int status, int vendor_error)
814+
int ibv_wc_opcode, int status, int vendor_error,
815+
struct efa_context *ctx)
815816
{
816817
int ret;
817818
size_t raw_addr_len = sizeof(struct efa_ep_addr);
@@ -845,16 +846,19 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
845846
if (ibv_wc_opcode == IBV_WC_RECV) {
846847
ibv_cqx = container_of(base_ep->util_ep.rx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex;
847848
ibv_cqx->start_poll = &efa_mock_ibv_start_poll_return_mock;
848-
ibv_cqx->wr_id = (uintptr_t)12345;
849+
ctx->completion_flags = FI_RECV | FI_MSG;
849850
will_return(efa_mock_ibv_start_poll_return_mock, 0);
850851
ibv_cqx->status = status;
851852
} else {
852853
ibv_cqx = container_of(base_ep->util_ep.tx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex;
853854
/* this mock will set ibv_cq_ex->wr_id to the wr_id of the head of global send_wr,
854855
* and set ibv_cq_ex->status to mock value */
855856
ibv_cqx->start_poll = &efa_mock_ibv_start_poll_use_saved_send_wr_with_mock_status;
857+
ctx->completion_flags = FI_SEND | FI_MSG;
856858
will_return(efa_mock_ibv_start_poll_use_saved_send_wr_with_mock_status, status);
857859
}
860+
ctx->addr = *addr;
861+
ibv_cqx->wr_id = (uintptr_t) ctx;
858862

859863
ibv_cqx->next_poll = &efa_mock_ibv_next_poll_return_mock;
860864
ibv_cqx->end_poll = &efa_mock_ibv_end_poll_check_mock;
@@ -892,19 +896,29 @@ void test_efa_cq_read_send_success(struct efa_resource **state)
892896
{
893897
struct efa_resource *resource = *state;
894898
struct efa_unit_test_buff send_buff;
899+
struct efa_base_ep *base_ep;
900+
struct efa_context *efa_context;
901+
struct fi_context2 ctx;
895902
struct fi_cq_data_entry cq_entry;
896903
fi_addr_t addr;
897904
int ret;
898905

899-
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0);
906+
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0,
907+
(struct efa_context *) &ctx);
900908
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);
901909

902910
assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
903911
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
904-
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
912+
fi_mr_desc(send_buff.mr), addr, &ctx);
905913
assert_int_equal(ret, 0);
906914
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);
907915

916+
base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
917+
efa_context = (struct efa_context *) base_ep->qp->ibv_qp_ex->wr_id;
918+
assert_true(efa_context->completion_flags & FI_SEND);
919+
assert_true(efa_context->completion_flags & FI_MSG);
920+
assert_true(efa_context->addr == addr);
921+
908922
ret = fi_cq_read(resource->cq, &cq_entry, 1);
909923
/* fi_cq_read() called efa_mock_ibv_start_poll_use_saved_send_wr(), which pulled one send_wr from g_ibv_submitted_wr_idv=_vec */
910924
assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
@@ -921,17 +935,27 @@ void test_efa_cq_read_recv_success(struct efa_resource **state)
921935
{
922936
struct efa_resource *resource = *state;
923937
struct efa_unit_test_buff recv_buff;
938+
struct efa_base_ep *base_ep;
939+
struct efa_context *efa_context;
924940
struct fi_cq_data_entry cq_entry;
941+
struct fi_context2 ctx;
925942
fi_addr_t addr;
926943
int ret;
927944

928-
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0);
945+
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0,
946+
(struct efa_context *) &ctx);
929947
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);
930948

931949
ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
932-
fi_mr_desc(recv_buff.mr), addr, NULL);
950+
fi_mr_desc(recv_buff.mr), addr, &ctx);
933951
assert_int_equal(ret, 0);
934952

953+
base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
954+
efa_context = (struct efa_context *) base_ep->efa_recv_wr_vec[base_ep->recv_wr_index].wr.wr_id;
955+
assert_true(efa_context->completion_flags & FI_RECV);
956+
assert_true(efa_context->completion_flags & FI_MSG);
957+
assert_true(efa_context->addr == addr);
958+
935959
ret = fi_cq_read(resource->cq, &cq_entry, 1);
936960
assert_int_equal(ret, 1);
937961

@@ -971,20 +995,29 @@ void test_efa_cq_read_send_failure(struct efa_resource **state)
971995
{
972996
struct efa_resource *resource = *state;
973997
struct efa_unit_test_buff send_buff;
998+
struct efa_base_ep *base_ep;
999+
struct efa_context *efa_context;
9741000
struct fi_cq_data_entry cq_entry;
1001+
struct fi_context2 ctx;
9751002
fi_addr_t addr;
9761003
int ret;
9771004

9781005
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_GENERAL_ERR,
979-
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
1006+
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx);
9801007
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);
9811008

9821009
assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
9831010
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
984-
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
1011+
fi_mr_desc(send_buff.mr), addr, &ctx);
9851012
assert_int_equal(ret, 0);
9861013
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);
9871014

1015+
base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
1016+
efa_context = (struct efa_context *) base_ep->qp->ibv_qp_ex->wr_id;
1017+
assert_true(efa_context->completion_flags & FI_SEND);
1018+
assert_true(efa_context->completion_flags & FI_MSG);
1019+
assert_true(efa_context->addr == addr);
1020+
9881021
ret = fi_cq_read(resource->cq, &cq_entry, 1);
9891022
/* fi_cq_read() called efa_mock_ibv_start_poll_use_saved_send_wr(), which pulled one send_wr from g_ibv_submitted_wr_idv=_vec */
9901023
assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
@@ -1008,18 +1041,27 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state)
10081041
{
10091042
struct efa_resource *resource = *state;
10101043
struct efa_unit_test_buff recv_buff;
1044+
struct efa_base_ep *base_ep;
1045+
struct efa_context *efa_context;
10111046
struct fi_cq_data_entry cq_entry;
1047+
struct fi_context2 ctx;
10121048
fi_addr_t addr;
10131049
int ret;
10141050

10151051
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_GENERAL_ERR,
1016-
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
1052+
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx);
10171053
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);
10181054

10191055
ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
1020-
fi_mr_desc(recv_buff.mr), addr, NULL);
1056+
fi_mr_desc(recv_buff.mr), addr, &ctx);
10211057
assert_int_equal(ret, 0);
10221058

1059+
base_ep = container_of(resource->ep, struct efa_base_ep, util_ep.ep_fid);
1060+
efa_context = (struct efa_context *) base_ep->efa_recv_wr_vec[base_ep->recv_wr_index].wr.wr_id;
1061+
assert_true(efa_context->completion_flags & FI_RECV);
1062+
assert_true(efa_context->completion_flags & FI_MSG);
1063+
assert_true(efa_context->addr == addr);
1064+
10231065
ret = fi_cq_read(resource->cq, &cq_entry, 1);
10241066
assert_int_equal(ret, -FI_EAVAIL);
10251067

0 commit comments

Comments
 (0)