Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions fdbcli/ExcludeCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Future<bool> excludeServersAndLocalities(Reference<IDatabase> db,
}
}

Future<std::vector<std::string>> getExcludedServers(Reference<IDatabase> db) {
AsyncResult<std::vector<std::string>> getExcludedServers(Reference<IDatabase> db) {
Reference<ITransaction> tr = db->createTransaction();
while (true) {
Error err;
Expand All @@ -110,7 +110,7 @@ Future<std::vector<std::string>> getExcludedServers(Reference<IDatabase> db) {
}

// Get the list of excluded localities by reading the keys.
Future<std::vector<std::string>> getExcludedLocalities(Reference<IDatabase> db) {
AsyncResult<std::vector<std::string>> getExcludedLocalities(Reference<IDatabase> db) {
Reference<ITransaction> tr = db->createTransaction();
while (true) {
Error err;
Expand All @@ -133,7 +133,7 @@ Future<std::vector<std::string>> getExcludedLocalities(Reference<IDatabase> db)
}
}

Future<std::vector<std::string>> getFailedServers(Reference<IDatabase> db) {
AsyncResult<std::vector<std::string>> getFailedServers(Reference<IDatabase> db) {
Reference<ITransaction> tr = db->createTransaction();
while (true) {
Error err;
Expand All @@ -157,7 +157,7 @@ Future<std::vector<std::string>> getFailedServers(Reference<IDatabase> db) {
}

// Get the list of failed localities by reading the keys.
Future<std::vector<std::string>> getFailedLocalities(Reference<IDatabase> db) {
AsyncResult<std::vector<std::string>> getFailedLocalities(Reference<IDatabase> db) {
Reference<ITransaction> tr = db->createTransaction();
while (true) {
Error err;
Expand Down
76 changes: 76 additions & 0 deletions fdbrpc/CoroTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,29 @@ struct Tracker {
}
};

AsyncResult<Tracker> immediateAsyncResultTracker() {
co_return Tracker{};
}

AsyncResult<Tracker> delayedAsyncResultTracker(Future<Void> signal) {
co_await signal;
co_return Tracker{};
}

AsyncResult<int> immediateAsyncResultInt(int value) {
co_return value;
}

AsyncResult<int> delayedAsyncResultInt(Future<Void> signal, int value) {
co_await signal;
co_return value;
}

AsyncResult<int> failingAsyncResultInt(Future<Void> signal) {
co_await signal;
throw io_error();
}

} // namespace

TEST_CASE("/flow/coro/PromiseStream/move") {
Expand Down Expand Up @@ -1202,6 +1225,23 @@ TEST_CASE("/flow/coro/PromiseStream/move2") {
ASSERT(movedTracker.copied == 0);
}

TEST_CASE("/flow/coro/AsyncResult/move") {
{
Tracker tracker = co_await immediateAsyncResultTracker();
ASSERT(!tracker.moved);
ASSERT(tracker.copied == 0);
}

{
Promise<Void> signal;
AsyncResult<Tracker> result = delayedAsyncResultTracker(signal.getFuture());
signal.send(Void());
Tracker tracker = co_await result;
ASSERT(!tracker.moved);
ASSERT(tracker.copied == 0);
}
}

namespace {

constexpr double mutexTestDelay = 0.00001;
Expand Down Expand Up @@ -2081,3 +2121,39 @@ TEST_CASE("/flow/coro/raceStreamSuccess") {
ASSERT_EQ(std::get<0>(result), 13);
co_return;
}

TEST_CASE("/flow/coro/raceAsyncResultReady") {
Future<std::variant<int, std::string>> raced = race(immediateAsyncResultInt(17), Future<std::string>("later"));
ASSERT(raced.isReady());
auto result = raced.get();
ASSERT_EQ(result.index(), 0);
ASSERT_EQ(std::get<0>(result), 17);
return Void();
}

TEST_CASE("/flow/coro/raceAsyncResultSuccess") {
Promise<Void> signal;
Promise<std::string> stringPromise;
Future<std::variant<int, std::string>> raced =
race(delayedAsyncResultInt(signal.getFuture(), 19), stringPromise.getFuture());
signal.send(Void());
auto result = co_await raced;
ASSERT_EQ(result.index(), 0);
ASSERT_EQ(std::get<0>(result), 19);
co_return;
}

TEST_CASE("/flow/coro/raceAsyncResultError") {
Promise<Void> signal;
Promise<std::string> stringPromise;
Future<std::variant<int, std::string>> raced =
race(failingAsyncResultInt(signal.getFuture()), stringPromise.getFuture());
signal.send(Void());
try {
co_await raced;
ASSERT(false);
} catch (Error const& e) {
ASSERT_EQ(e.code(), error_code_io_error);
}
co_return;
}
68 changes: 56 additions & 12 deletions flow/include/flow/CoroUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,48 @@

namespace coro {

template <class Parent, int Idx, class ValueType>
struct ActorAsyncResultCallback : AsyncResultCallback<ValueType> {
AsyncResultState<ValueType>* state = nullptr;

void bind(AsyncResult<ValueType>& result) { state = result.state; }

void remove() {
if (state) {
state->clearCallback(this);
state = nullptr;
}
}

void fire(ValueType const& value) override {
#ifdef ENABLE_SAMPLING
LineageScope _(static_cast<Parent*>(this)->lineageAddr());
#endif
static_cast<Parent*>(this)->a_callback_fire(this, value);
}

void fire(ValueType&& value) override {
#ifdef ENABLE_SAMPLING
LineageScope _(static_cast<Parent*>(this)->lineageAddr());
#endif
static_cast<Parent*>(this)->a_callback_fire(this, std::move(value));
}

void error(Error e) override {
#ifdef ENABLE_SAMPLING
LineageScope _(static_cast<Parent*>(this)->lineageAddr());
#endif
static_cast<Parent*>(this)->a_callback_error(this, e);
}
};

template <class Parent, int Idx, class F>
using ConditionalActorCallback = std::conditional_t<GetFutureTypeV<F> == FutureType::Future,
ActorCallback<Parent, Idx, FutureReturnTypeT<F>>,
ActorSingleCallback<Parent, Idx, FutureReturnTypeT<F>>>;
using ConditionalActorCallback =
std::conditional_t<GetFutureTypeV<F> == FutureType::Future,
ActorCallback<Parent, Idx, FutureReturnTypeT<F>>,
std::conditional_t<GetFutureTypeV<F> == FutureType::FutureStream,
ActorSingleCallback<Parent, Idx, FutureReturnTypeT<F>>,
ActorAsyncResultCallback<Parent, Idx, FutureReturnTypeT<F>>>>;

template <class Parent, int Idx, class... Args>
struct ChooseImplCallback;
Expand Down Expand Up @@ -216,25 +254,27 @@ template <class... Futures>
using RaceResult = std::variant<FutureReturnTypeT<std::decay_t<Futures>>...>;

template <std::size_t Idx, class Result, class F>
Future<Result> raceReadyResult(F const& future) {
Future<Result> raceReadyResult(F&& future) {
if (future.isError()) {
return future.getError();
}
if constexpr (GetFutureTypeV<F> == FutureType::Future) {
if constexpr (GetFutureTypeV<std::remove_cvref_t<F>> == FutureType::Future) {
return Result(std::in_place_index<Idx>, future.get());
} else {
} else if constexpr (GetFutureTypeV<std::remove_cvref_t<F>> == FutureType::FutureStream) {
auto fs = future;
return Result(std::in_place_index<Idx>, fs.pop());
} else {
return Result(std::in_place_index<Idx>, std::forward<F>(future).get());
}
}

template <std::size_t Idx, class Result, class First, class... Rest>
Future<Result> raceReady(First const& first, Rest const&... rest) {
Future<Result> raceReady(First&& first, Rest&&... rest) {
if (first.isReady()) {
return raceReadyResult<Idx, Result>(first);
return raceReadyResult<Idx, Result>(std::forward<First>(first));
}
if constexpr (sizeof...(Rest) > 0) {
return raceReady<Idx + 1, Result>(rest...);
return raceReady<Idx + 1, Result>(std::forward<Rest>(rest)...);
}
return Future<Result>();
}
Expand All @@ -256,16 +296,20 @@ struct RaceImplCallback<Parent, Idx, F, Futures...>
if constexpr (futureType == FutureType::Future) {
StrictFuture<ValueType> sf = std::get<Idx>(getParent()->futures);
sf.addCallbackAndClear(static_cast<ThisCallback*>(this));
} else {
} else if constexpr (futureType == FutureType::FutureStream) {
auto sf = std::get<Idx>(getParent()->futures);
sf.addCallbackAndClear(static_cast<ThisCallback*>(this));
} else {
ThisCallback::bind(std::get<Idx>(getParent()->futures));
std::move(std::get<Idx>(getParent()->futures)).addCallbackAndClear(static_cast<ThisCallback*>(this));
}
if constexpr (sizeof...(Futures) > 0) {
RaceImplCallback<Parent, Idx + 1, Futures...>::registerCallbacks();
}
}

void a_callback_fire(ThisCallback*, ValueType const& value) { getParent()->template finish<Idx>(value); }
void a_callback_fire(ThisCallback*, ValueType&& value) { getParent()->template finish<Idx>(std::move(value)); }

void a_callback_error(ThisCallback*, Error e) { getParent()->fail(e); }

Expand Down Expand Up @@ -299,10 +343,10 @@ struct RaceImplActor final : Actor<Result>,
}

template <std::size_t Idx, class T>
void finish(T const& value) {
void finish(T&& value) {
this->actor_wait_state = ACTOR_WAIT_STATE_NOT_WAITING;
RaceImplCallback<RaceImplActor<Result, Futures...>, 0, Futures...>::removeCallbacks();
this->SAV<Result>::sendAndDelPromiseRef(Result(std::in_place_index<Idx>, value));
this->SAV<Result>::sendAndDelPromiseRef(Result(std::in_place_index<Idx>, std::forward<T>(value)));
}

void fail(Error e) {
Expand Down
Loading