diff --git a/include/trompeloeil/coro.hpp b/include/trompeloeil/coro.hpp index 3d6f50b..1e8f51f 100644 --- a/include/trompeloeil/coro.hpp +++ b/include/trompeloeil/coro.hpp @@ -14,6 +14,10 @@ #ifndef TROMPELOEIL_CORO_HPP #define TROMPELOEIL_CORO_HPP +#if __cplusplus < 202002L +# error "C++20 is required" +#endif + #if defined(__cpp_impl_coroutine) # define TROMPELOEIL_COROUTINES_SUPPORTED 1 #else @@ -21,6 +25,7 @@ #endif #ifdef TROMPELOEIL_COROUTINES_SUPPORTED +#include #ifndef TROMPELOEIL_MOCK_HPP_ #include "mock.hpp" @@ -45,14 +50,22 @@ namespace trompeloeil { static auto func() { - if constexpr (requires {std::declval().operator co_await();}) + if constexpr (requires(T coro){ coro.operator co_await(); }) { return type_wrapper().operator co_await().await_resume())>{}; } else + if constexpr (requires(T coro){ coro.await_resume(); }) { return type_wrapper().await_resume())>{}; } + else + { + static_assert( + std::ranges::input_range, + "non-awaitable coroutine shall be a range"); + return type_wrapper>{}; + } } using type = typename decltype(func())::type; }; diff --git a/test/test_co_mock.cpp b/test/test_co_mock.cpp index ba7aed2..ed67e39 100644 --- a/test/test_co_mock.cpp +++ b/test/test_co_mock.cpp @@ -25,6 +25,11 @@ #include "test_reporter.hpp" #include +#include +#ifdef __cpp_lib_generator +#include +#endif + using trompeloeil::_; namespace { @@ -35,6 +40,10 @@ namespace { MAKE_MOCK0 (voidret, coro::task()); MAKE_MOCK1 (unique, coro::task(iptr)); MAKE_MOCK0 (gen, coro::generator()); + +#ifdef __cpp_lib_generator + MAKE_MOCK0 (stdgen, std::generator()); +#endif // __cpp_lib_generator }; } @@ -198,4 +207,46 @@ TEST_CASE_METHOD( REQUIRE(v == 3); REQUIRE(reports.empty()); } + +#ifdef __cpp_lib_generator +TEST_CASE_METHOD( + Fixture, + "CO_YIELD with std::generator", + "[coro]") +{ + co_mock m; + REQUIRE_CALL(m, stdgen()) + .CO_YIELD(5) + .CO_YIELD(8) + .CO_YIELD(3) + .CO_YIELD(0) + .CO_RETURN(); + + auto gen = m.stdgen(); + + SECTION("as iterator") + { + auto it = std::ranges::begin(gen); + REQUIRE(*it == 5); + ++it; + REQUIRE(it != std::ranges::end(gen)); + REQUIRE(*it == 8); + ++it; + REQUIRE(it != std::ranges::end(gen)); + REQUIRE(*it == 3); + ++it; + REQUIRE(it != std::ranges::end(gen)); + REQUIRE(*it == 0); + ++it; + REQUIRE(it == std::ranges::end(gen)); + } + + SECTION("as range") + { + REQUIRE(std::ranges::equal(gen, std::array{5, 8, 3, 0})); + } + + REQUIRE(reports.empty()); +} +#endif // __cpp_lib_generator #endif