Skip to content

Commit d563648

Browse files
author
Bret Ambrose
committed
Refactor result to be a Crt::Variant; don't send terminate stream when response is a terminate stream
1 parent 8ccccb2 commit d563648

7 files changed

Lines changed: 597 additions & 568 deletions

File tree

eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h

Lines changed: 5 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <aws/crt/JsonObject.h>
1313
#include <aws/crt/StlAllocator.h>
1414
#include <aws/crt/Types.h>
15+
#include <aws/crt/Variant.h>
1516
#include <aws/crt/io/SocketOptions.h>
1617
#include <aws/crt/io/TlsOptions.h>
1718

@@ -373,69 +374,10 @@ namespace Aws
373374
RPC_ERROR
374375
};
375376

376-
/**
377-
* A wrapper for operation result.
378-
*/
379-
class AWS_EVENTSTREAMRPC_API TaggedResult
380-
{
381-
public:
382-
TaggedResult() noexcept;
383-
explicit TaggedResult(Crt::ScopedResource<AbstractShapeBase> response) noexcept;
384-
explicit TaggedResult(Crt::ScopedResource<OperationError> error) noexcept;
385-
explicit TaggedResult(RpcError rpcError) noexcept;
386-
TaggedResult(TaggedResult &&rhs) noexcept;
387-
TaggedResult &operator=(TaggedResult &&rhs) noexcept;
388-
~TaggedResult() noexcept;
389-
/**
390-
* @return true if the response is associated with an expected response;
391-
* false if the response is associated with an error.
392-
*/
393-
explicit operator bool() const noexcept;
377+
using EventstreamResultVariantType =
378+
Crt::Variant<Crt::ScopedResource<AbstractShapeBase>, Crt::ScopedResource<OperationError>, RpcError>;
394379

395-
/**
396-
* Get operation result.
397-
* @return A pointer to the resulting object in case of success, nullptr otherwise.
398-
*/
399-
AbstractShapeBase *GetOperationResponse() const noexcept;
400-
401-
/**
402-
* Get error for a failed operation.
403-
* @return A pointer to the error object in case of failure, nullptr otherwise.
404-
*/
405-
OperationError *GetOperationError() const noexcept;
406-
407-
/**
408-
* Get RPC-level error.
409-
* @return A pointer to the error object in case of RPC-level failure, nullptr otherwise.
410-
*/
411-
RpcError GetRpcError() const noexcept;
412-
413-
/**
414-
* Get the type of the result with which the operation has completed.
415-
* @return Result type.
416-
*/
417-
ResultType GetResultType() const noexcept { return m_responseType; }
418-
419-
private:
420-
union AWS_EVENTSTREAMRPC_API OperationResult
421-
{
422-
explicit OperationResult(Crt::ScopedResource<AbstractShapeBase> &&response) noexcept
423-
: m_response(std::move(response))
424-
{
425-
}
426-
explicit OperationResult(Crt::ScopedResource<OperationError> &&error) noexcept
427-
: m_error(std::move(error))
428-
{
429-
}
430-
OperationResult() noexcept : m_response(nullptr) {}
431-
~OperationResult() noexcept {}
432-
Crt::ScopedResource<AbstractShapeBase> m_response;
433-
Crt::ScopedResource<OperationError> m_error;
434-
};
435-
ResultType m_responseType;
436-
OperationResult m_operationResult;
437-
RpcError m_rpcError;
438-
};
380+
AWS_EVENTSTREAMRPC_API ResultType ResultVariantToResultType(const EventstreamResultVariantType &resultVariant);
439381

440382
using ExpectedResponseFactory = std::function<
441383
Crt::ScopedResource<AbstractShapeBase>(const Crt::StringView &payload, Crt::Allocator *allocator)>;
@@ -591,7 +533,7 @@ namespace Aws
591533
std::future<RpcError> Activate(
592534
const AbstractShapeBase *shape,
593535
OnMessageFlushCallback &&onMessageFlushCallback,
594-
std::function<void(TaggedResult &&)> &&onResultCallback) noexcept;
536+
std::function<void(EventstreamResultVariantType &&)> &&onResultCallback) noexcept;
595537

596538
/**
597539
* Sends a message on the stream

eventstream_rpc/source/EventStreamClient.cpp

Lines changed: 43 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,100 +1298,6 @@ namespace Aws
12981298
}
12991299
}
13001300

1301-
TaggedResult::TaggedResult(Crt::ScopedResource<AbstractShapeBase> operationResponse) noexcept
1302-
: m_responseType(OPERATION_RESPONSE), m_rpcError({})
1303-
{
1304-
m_operationResult.m_response = std::move(operationResponse);
1305-
}
1306-
1307-
TaggedResult::~TaggedResult() noexcept
1308-
{
1309-
if (m_responseType == OPERATION_RESPONSE)
1310-
{
1311-
m_operationResult.m_response.~unique_ptr();
1312-
}
1313-
else if (m_responseType == OPERATION_ERROR)
1314-
{
1315-
m_operationResult.m_error.~unique_ptr();
1316-
}
1317-
}
1318-
1319-
TaggedResult::TaggedResult(Crt::ScopedResource<OperationError> operationError) noexcept
1320-
: m_responseType(OPERATION_ERROR), m_rpcError({EVENT_STREAM_RPC_UNINITIALIZED, 0})
1321-
{
1322-
m_operationResult.m_error = std::move(operationError);
1323-
}
1324-
1325-
TaggedResult &TaggedResult::operator=(TaggedResult &&rhs) noexcept
1326-
{
1327-
m_responseType = rhs.m_responseType;
1328-
if (m_responseType == OPERATION_RESPONSE)
1329-
{
1330-
m_operationResult.m_response = std::move(rhs.m_operationResult.m_response);
1331-
}
1332-
else if (m_responseType == OPERATION_ERROR)
1333-
{
1334-
m_operationResult.m_error = std::move(rhs.m_operationResult.m_error);
1335-
}
1336-
m_rpcError = rhs.m_rpcError;
1337-
rhs.m_rpcError = {EVENT_STREAM_RPC_UNINITIALIZED, 0};
1338-
1339-
return *this;
1340-
}
1341-
1342-
TaggedResult::TaggedResult(RpcError rpcError) noexcept
1343-
: m_responseType(RPC_ERROR), m_operationResult(), m_rpcError(rpcError)
1344-
{
1345-
}
1346-
1347-
TaggedResult::TaggedResult() noexcept
1348-
: m_responseType(RPC_ERROR), m_operationResult(), m_rpcError({EVENT_STREAM_RPC_UNINITIALIZED, 0})
1349-
{
1350-
}
1351-
1352-
TaggedResult::TaggedResult(TaggedResult &&rhs) noexcept
1353-
{
1354-
m_responseType = rhs.m_responseType;
1355-
if (m_responseType == OPERATION_RESPONSE)
1356-
{
1357-
m_operationResult.m_response = std::move(rhs.m_operationResult.m_response);
1358-
}
1359-
else if (m_responseType == OPERATION_ERROR)
1360-
{
1361-
m_operationResult.m_error = std::move(rhs.m_operationResult.m_error);
1362-
}
1363-
m_rpcError = rhs.m_rpcError;
1364-
rhs.m_rpcError = {EVENT_STREAM_RPC_UNINITIALIZED, 0};
1365-
}
1366-
1367-
TaggedResult::operator bool() const noexcept
1368-
{
1369-
return m_responseType == OPERATION_RESPONSE;
1370-
}
1371-
1372-
AbstractShapeBase *TaggedResult::GetOperationResponse() const noexcept
1373-
{
1374-
return (m_responseType == OPERATION_RESPONSE) ? m_operationResult.m_response.get() : nullptr;
1375-
}
1376-
1377-
OperationError *TaggedResult::GetOperationError() const noexcept
1378-
{
1379-
return (m_responseType == OPERATION_ERROR) ? m_operationResult.m_error.get() : nullptr;
1380-
}
1381-
1382-
RpcError TaggedResult::GetRpcError() const noexcept
1383-
{
1384-
if (m_responseType == RPC_ERROR)
1385-
{
1386-
return m_rpcError;
1387-
}
1388-
else
1389-
{
1390-
/* Assume success since an application response or error was received. */
1391-
return {EVENT_STREAM_RPC_SUCCESS, 0};
1392-
}
1393-
}
1394-
13951301
bool StreamResponseHandler::OnStreamError(Crt::ScopedResource<OperationError> operationError, RpcError rpcError)
13961302
{
13971303
(void)operationError;
@@ -1488,17 +1394,15 @@ namespace Aws
14881394
ContinuationSharedState();
14891395

14901396
ContinuationStateType m_currentState;
1491-
ContinuationStateType m_desiredState;
14921397
struct aws_event_stream_rpc_client_continuation_token *m_continuation;
14931398
std::shared_ptr<OnMessageFlushCallbackContainer> m_activationCallbackContainer;
1494-
std::function<void(TaggedResult &&)> m_activationResponseCallback;
1399+
std::function<void(EventstreamResultVariantType &&)> m_activationResponseCallback;
14951400
std::shared_ptr<OnMessageFlushCallbackContainer> m_closeCallbackContainer;
14961401
};
14971402

14981403
ContinuationSharedState::ContinuationSharedState()
1499-
: m_currentState(ContinuationStateType::None), m_desiredState(ContinuationStateType::None),
1500-
m_continuation(nullptr), m_activationCallbackContainer(nullptr), m_activationResponseCallback(),
1501-
m_closeCallbackContainer(nullptr)
1404+
: m_currentState(ContinuationStateType::None), m_continuation(nullptr),
1405+
m_activationCallbackContainer(nullptr), m_activationResponseCallback(), m_closeCallbackContainer(nullptr)
15021406
{
15031407
}
15041408

@@ -1520,7 +1424,7 @@ namespace Aws
15201424
const Crt::Optional<Crt::ByteBuf> &payload,
15211425
MessageType messageType,
15221426
uint32_t messageFlags,
1523-
std::function<void(TaggedResult &&)> &&onResultCallback,
1427+
std::function<void(EventstreamResultVariantType &&)> &&onResultCallback,
15241428
OnMessageFlushCallback &&onMessageFlushCallback) noexcept;
15251429

15261430
std::future<RpcError> SendStreamMessage(
@@ -1636,7 +1540,6 @@ namespace Aws
16361540
{
16371541
m_continuationValid = false;
16381542
m_sharedState.m_currentState = ContinuationStateType::Closed;
1639-
m_sharedState.m_desiredState = ContinuationStateType::Closed;
16401543
}
16411544
}
16421545

@@ -1707,7 +1610,7 @@ namespace Aws
17071610
struct aws_event_stream_rpc_client_continuation_token *releaseContinuation = nullptr;
17081611
std::shared_ptr<OnMessageFlushCallbackContainer> closeCallbackContainer = nullptr;
17091612
std::shared_ptr<OnMessageFlushCallbackContainer> activationCallbackContainer = nullptr;
1710-
std::function<void(TaggedResult &&)> activationResponseCallback = nullptr;
1613+
std::function<void(EventstreamResultVariantType &&)> activationResponseCallback = nullptr;
17111614

17121615
// This block prevents streaming event callbacks from triggering after scope exit
17131616
{
@@ -1739,7 +1642,6 @@ namespace Aws
17391642
releaseContinuation = m_sharedState.m_continuation;
17401643
m_sharedState.m_continuation = nullptr;
17411644
m_sharedState.m_currentState = ContinuationStateType::Closed;
1742-
m_sharedState.m_desiredState = ContinuationStateType::Closed;
17431645
}
17441646

17451647
activationCallbackContainer = m_sharedState.m_activationCallbackContainer;
@@ -1757,7 +1659,7 @@ namespace Aws
17571659
if (activationResponseCallback)
17581660
{
17591661
activationResponseCallback(
1760-
TaggedResult(RpcError{EVENT_STREAM_RPC_CONTINUATION_CLOSED, AWS_ERROR_SUCCESS}));
1662+
EventstreamResultVariantType(RpcError{EVENT_STREAM_RPC_CONTINUATION_CLOSED, AWS_ERROR_SUCCESS}));
17611663
}
17621664

17631665
// Short-circuit and simulate both activate and close callbacks as necessary.
@@ -1795,7 +1697,7 @@ namespace Aws
17951697
const Crt::Optional<Crt::ByteBuf> &payload,
17961698
MessageType messageType,
17971699
uint32_t messageFlags,
1798-
std::function<void(TaggedResult &&)> &&onResultCallback,
1700+
std::function<void(EventstreamResultVariantType &&)> &&onResultCallback,
17991701
OnMessageFlushCallback &&onMessageFlushCallback) noexcept
18001702
{
18011703
AWS_FATAL_ASSERT(static_cast<bool>(onResultCallback));
@@ -1818,7 +1720,6 @@ namespace Aws
18181720
activateContinuation = m_sharedState.m_continuation;
18191721
aws_event_stream_rpc_client_continuation_acquire(activateContinuation);
18201722
m_sharedState.m_currentState = ContinuationStateType::PendingActivate;
1821-
m_sharedState.m_desiredState = ContinuationStateType::Activated;
18221723
m_sharedState.m_activationCallbackContainer = activationContainerWrapper->GetContainer();
18231724
m_sharedState.m_activationResponseCallback = std::move(onResultCallback);
18241725
}
@@ -1877,7 +1778,6 @@ namespace Aws
18771778
// ah shucks, we failed, rollback our optimistic shared state update
18781779
std::lock_guard<std::mutex> lock(m_sharedStateLock);
18791780
m_sharedState.m_currentState = ContinuationStateType::None;
1880-
m_sharedState.m_desiredState = ContinuationStateType::None;
18811781
m_sharedState.m_activationCallbackContainer = nullptr;
18821782
m_sharedState.m_activationResponseCallback = nullptr;
18831783

@@ -2080,13 +1980,12 @@ namespace Aws
20801980
struct aws_event_stream_rpc_client_continuation_token *releaseContinuation = nullptr;
20811981
std::shared_ptr<OnMessageFlushCallbackContainer> closeCallbackContainer = nullptr;
20821982
std::shared_ptr<OnMessageFlushCallbackContainer> activationCallbackContainer = nullptr;
2083-
std::function<void(TaggedResult &&)> activationResponseCallback;
1983+
std::function<void(EventstreamResultVariantType &&)> activationResponseCallback;
20841984

20851985
{
20861986
std::lock_guard<std::mutex> lock(m_sharedStateLock);
20871987

20881988
m_sharedState.m_currentState = ContinuationStateType::Closed;
2089-
m_sharedState.m_desiredState = ContinuationStateType::Closed;
20901989
releaseContinuation = m_sharedState.m_continuation;
20911990
m_sharedState.m_continuation = nullptr;
20921991

@@ -2103,7 +2002,7 @@ namespace Aws
21032002
if (activationResponseCallback)
21042003
{
21052004
activationResponseCallback(
2106-
TaggedResult(RpcError{EVENT_STREAM_RPC_CONTINUATION_CLOSED, AWS_ERROR_SUCCESS}));
2005+
EventstreamResultVariantType(RpcError{EVENT_STREAM_RPC_CONTINUATION_CLOSED, AWS_ERROR_SUCCESS}));
21072006
}
21082007

21092008
OnMessageFlushCallbackContainer::Complete(
@@ -2264,7 +2163,7 @@ namespace Aws
22642163
MessageResult result;
22652164
bool isResponse = false;
22662165
bool shouldClose = false;
2267-
std::function<void(TaggedResult &&)> activationResultCallback = nullptr;
2166+
std::function<void(EventstreamResultVariantType &&)> activationResultCallback = nullptr;
22682167

22692168
{
22702169
std::lock_guard<std::mutex> lock(m_sharedStateLock);
@@ -2280,7 +2179,18 @@ namespace Aws
22802179
if (result.m_statusCode == EVENT_STREAM_RPC_SUCCESS &&
22812180
result.m_message.value().m_route == EventStreamMessageRoutingType::Response)
22822181
{
2283-
m_sharedState.m_currentState = ContinuationStateType::Activated;
2182+
if ((messageArgs->message_flags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM) == 0)
2183+
{
2184+
m_sharedState.m_currentState = ContinuationStateType::Activated;
2185+
}
2186+
else
2187+
{
2188+
/*
2189+
* The underlying implementation is going to close underneath us. No need to send
2190+
* an empty terminate stream message, which older server versions can fail on.
2191+
*/
2192+
m_sharedState.m_currentState = ContinuationStateType::PendingClose;
2193+
}
22842194
}
22852195
else
22862196
{
@@ -2296,7 +2206,8 @@ namespace Aws
22962206
const auto &message = result.m_message.value();
22972207
if (message.m_route == EventStreamMessageRoutingType::Response)
22982208
{
2299-
activationResultCallback(TaggedResult(std::move(result.m_message.value().m_shape)));
2209+
activationResultCallback(
2210+
EventstreamResultVariantType(std::move(result.m_message.value().m_shape)));
23002211
}
23012212
else
23022213
{
@@ -2306,12 +2217,12 @@ namespace Aws
23062217
static_cast<OperationError *>(result.m_message.value().m_shape.release()),
23072218
[allocator](OperationError *shape) { Crt::Delete(shape, allocator); });
23082219

2309-
activationResultCallback(TaggedResult(std::move(errorResponse)));
2220+
activationResultCallback(EventstreamResultVariantType(std::move(errorResponse)));
23102221
}
23112222
}
23122223
else
23132224
{
2314-
activationResultCallback(TaggedResult(RpcError{result.m_statusCode, 0}));
2225+
activationResultCallback(EventstreamResultVariantType(RpcError{result.m_statusCode, 0}));
23152226
}
23162227
}
23172228

@@ -2439,7 +2350,7 @@ namespace Aws
24392350
std::future<RpcError> ClientOperation::Activate(
24402351
const AbstractShapeBase *shape,
24412352
OnMessageFlushCallback &&onMessageFlushCallback,
2442-
std::function<void(TaggedResult &&)> &&onResultCallback) noexcept
2353+
std::function<void(EventstreamResultVariantType &&)> &&onResultCallback) noexcept
24432354
{
24442355
Crt::List<EventStreamHeader> headers;
24452356
headers.emplace_back(
@@ -2478,5 +2389,21 @@ namespace Aws
24782389
0,
24792390
std::move(onMessageFlushCallback));
24802391
}
2392+
2393+
ResultType ResultVariantToResultType(const EventstreamResultVariantType &resultVariant)
2394+
{
2395+
if (resultVariant.holds_alternative<Crt::ScopedResource<AbstractShapeBase>>())
2396+
{
2397+
return OPERATION_RESPONSE;
2398+
}
2399+
else if (resultVariant.holds_alternative<Crt::ScopedResource<OperationError>>())
2400+
{
2401+
return OPERATION_ERROR;
2402+
}
2403+
else
2404+
{
2405+
return RPC_ERROR;
2406+
}
2407+
}
24812408
} /* namespace Eventstreamrpc */
24822409
} // namespace Aws

0 commit comments

Comments
 (0)