Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow methods returning IIterable<T> to be used as synchronous generators #361

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
80 changes: 53 additions & 27 deletions include/wil/cppwinrt_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,16 +321,14 @@ namespace wil
namespace wil::details
{
template<typename TResult>
struct iterable_promise : winrt::implements<
iterable_promise<TResult>,
winrt::Windows::Foundation::Collections::IIterable<TResult>,
struct iterator_promise : winrt::implements<
iterator_promise<TResult>,
winrt::Windows::Foundation::Collections::IIterator<TResult>
>
{
private:
enum class IterationStatus
{
Initial,
Producing,
Value,
Done
Expand All @@ -344,18 +342,18 @@ namespace wil::details
if (remaining == 0)
{
std::atomic_thread_fence(std::memory_order_acquire);
coroutine_handle<iterable_promise>::from_promise(*this).destroy();
coroutine_handle<iterator_promise>::from_promise(*this).destroy();
}

return remaining;
}

winrt::Windows::Foundation::Collections::IIterable<TResult> get_return_object() noexcept
winrt::Windows::Foundation::Collections::IIterator<TResult> get_return_object() noexcept
{
return { winrt::get_abi(static_cast<winrt::Windows::Foundation::Collections::IIterable<TResult> const&>(*this)), winrt::take_ownership_from_abi };
return { winrt::get_abi(static_cast<winrt::Windows::Foundation::Collections::IIterator<TResult> const&>(*this)), winrt::take_ownership_from_abi };
}

suspend_always initial_suspend() const noexcept
suspend_never initial_suspend() const noexcept
{
return {};
}
Expand Down Expand Up @@ -421,7 +419,7 @@ namespace wil::details

uint32_t produce_values(winrt::array_view<TResult> const& view)
{
if (m_status != IterationStatus::Initial && m_status != IterationStatus::Value)
if (m_status != IterationStatus::Value)
{
return 0;
}
Expand All @@ -430,23 +428,11 @@ namespace wil::details
m_current = m_values.begin();
m_status = IterationStatus::Producing;

coroutine_handle<iterable_promise>::from_promise(*this).resume();
coroutine_handle<iterator_promise>::from_promise(*this).resume();

return static_cast<uint32_t>(m_current - m_values.begin());
}

winrt::Windows::Foundation::Collections::IIterator<TResult> First()
{
if (m_status != IterationStatus::Initial)
{
throw winrt::hresult_changed_state();
}

MoveNext();

return *this;
}

bool HasCurrent() const noexcept
{
return m_status == IterationStatus::Value;
Expand Down Expand Up @@ -487,11 +473,51 @@ namespace wil::details
}

private:
IterationStatus m_status{ IterationStatus::Initial };
winrt::array_view<TResult> m_values;
TResult* m_current{ nullptr };
IterationStatus m_status{ IterationStatus::Producing };
winrt::array_view<TResult> m_values{ &m_last_value, 1 };
TResult* m_current{ &m_last_value };
TResult m_last_value{ empty<TResult>() };
};

template<typename TResult, typename Func, typename... Args>
struct iterable_iterator_helper : winrt::implements<
iterable_iterator_helper<TResult, Func, Args...>,
winrt::Windows::Foundation::Collections::IIterable<TResult>
>
{
iterable_iterator_helper(Func&& func, Args&&... args) :
m_func{ std::forward<Func>(func) },
m_args{ std::forward<Args>(args)... }
{
}

auto First()
{
return std::apply(m_func, m_args);
}

private:
Func m_func;
std::tuple<Args...> m_args;
};

template<typename>
struct iterator_result;

template<typename TResult>
struct iterator_result<winrt::Windows::Foundation::Collections::IIterator<TResult>>
{
using type = TResult;
};
}

namespace wil
{
template<typename Func, typename... Args, typename TResult = typename details::iterator_result<std::invoke_result_t<Func, Args...>>::type>
winrt::Windows::Foundation::Collections::IIterable<TResult> make_iterable_from_iterator(Func&& func, Args&&... args)
{
return winrt::make<details::iterable_iterator_helper<TResult, Func, Args...>>(std::forward<Func>(func), std::forward<Args>(args)...);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding a way to keep a strong or weak reference to a class here, as well as a pointer. Maybe something similar to C++/WinRT's delegate implementation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Func accepts those arguments, it's already possible (e.g. calling a member function with a strong reference that is stored in the IIterable<T> helper with wil::make_iterable_from_iterator(&Foo::Bar, get_strong()). That doesn't work with weak references, but what would the semantics be if the weak pointer's object has been destroyed? Should First return an empty iterator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A member function pointer won't take a strong reference as argument directly via std::invoke/std::apply.

Good point reguarding weak references, lets just ignore that for now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah looks like std::invoke will try (*arg).(*f)(), I did not know that. Nothing to do here, then.

}
}

#ifdef __cpp_lib_coroutine
Expand All @@ -501,9 +527,9 @@ namespace std::experimental
#endif
{
template<typename T, typename... Args>
struct coroutine_traits<winrt::Windows::Foundation::Collections::IIterable<T>, Args...>
struct coroutine_traits<winrt::Windows::Foundation::Collections::IIterator<T>, Args...>
{
using promise_type = wil::details::iterable_promise<T>;
using promise_type = wil::details::iterator_promise<T>;
};
}
#endif
Expand Down
72 changes: 39 additions & 33 deletions tests/CppWinRTTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ TEST_CASE("CppWinRTTests::ResumeForegroundTests", "[cppwinrt]")

namespace test
{
winrt::Windows::Foundation::Collections::IIterable<winrt::hstring> hello_world_generator()
winrt::Windows::Foundation::Collections::IIterator<winrt::hstring> hello_world_generator()
{
co_yield L"Hello";
co_yield L"World!";
Expand Down Expand Up @@ -694,8 +694,7 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")
{
SECTION("Hello World")
{
auto generator = test::hello_world_generator();
auto iterator = generator.First();
auto iterator = test::hello_world_generator();
REQUIRE(iterator.HasCurrent());
REQUIRE(iterator.Current() == L"Hello");

Expand All @@ -709,15 +708,14 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")

SECTION("Value types")
{
auto generator = []() -> winrt::Windows::Foundation::Collections::IIterable<int>
auto iterator = []() -> winrt::Windows::Foundation::Collections::IIterator<int>
{
for (int i = 0; i < 10; ++i)
{
co_yield i;
}
}();

auto iterator = generator.First();
REQUIRE(iterator.HasCurrent());
REQUIRE(iterator.MoveNext());
REQUIRE(iterator.Current() == 1);
Expand All @@ -732,17 +730,10 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")
REQUIRE(!iterator.HasCurrent());
}

SECTION("First can only be called once")
{
auto generator = test::hello_world_generator();
auto iterator = generator.First();
REQUIRE_THROWS_AS(generator.First(), winrt::hresult_changed_state);
}

SECTION("GetMany")
{
{
auto iterator = test::hello_world_generator().First();
auto iterator = test::hello_world_generator();

std::array<winrt::hstring, 2> values;
REQUIRE(iterator.GetMany(values) == 2);
Expand All @@ -752,7 +743,7 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")
}

{
auto iterator = test::hello_world_generator().First();
auto iterator = test::hello_world_generator();
std::array<winrt::hstring, 1> values;
REQUIRE(iterator.GetMany(values) == 1);
REQUIRE(values[0] == L"Hello");
Expand All @@ -767,7 +758,7 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")

SECTION("Coroutine destruction")
{
auto set_true_on_destruction_generator = [](test::set_true_on_destruction) -> winrt::Windows::Foundation::Collections::IIterable<winrt::hstring>
auto set_true_on_destruction_generator = [](test::set_true_on_destruction) -> winrt::Windows::Foundation::Collections::IIterator<winrt::hstring>
{
co_yield L"Hello";
co_yield L"World!";
Expand All @@ -779,26 +770,11 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")
}

REQUIRE(destroyed);

destroyed = false;
{
winrt::Windows::Foundation::Collections::IIterator<winrt::hstring> iterator;
{
auto generator = set_true_on_destruction_generator(destroyed);
iterator = generator.First();
}

REQUIRE(!destroyed);
REQUIRE(iterator.HasCurrent());
REQUIRE(iterator.Current() == L"Hello");
}

REQUIRE(destroyed);
}

SECTION("Coroutine destruction with exception")
{
auto set_true_on_destruction_generator = [](test::set_true_on_destruction) -> winrt::Windows::Foundation::Collections::IIterable<winrt::hstring>
auto set_true_on_destruction_generator = [](test::set_true_on_destruction) -> winrt::Windows::Foundation::Collections::IIterator<winrt::hstring>
{
co_yield L"Hello";

Expand All @@ -808,14 +784,44 @@ TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]")
};

bool destroyed = false;
auto iterator = set_true_on_destruction_generator(destroyed).First();
auto iterator = set_true_on_destruction_generator(destroyed);
REQUIRE_THROWS_AS(iterator.MoveNext(), winrt::hresult_invalid_argument);
}

SECTION("make_iterable_from_iterator")
{
auto generator = wil::make_iterable_from_iterator(&test::hello_world_generator);
auto iterator = generator.First();

REQUIRE(iterator.Current() == L"Hello");
REQUIRE(iterator.MoveNext());
REQUIRE(iterator.Current() == L"World!");
REQUIRE(!iterator.MoveNext());

auto iterator2 = generator.First();
REQUIRE(iterator2.Current() == L"Hello");
REQUIRE(iterator2.MoveNext());
REQUIRE(iterator2.Current() == L"World!");
REQUIRE(!iterator2.MoveNext());
}

SECTION("make_iterable_from_iterator with arguments")
{
auto ptr = std::make_unique<int>(3);
auto const_ref_generator = wil::make_iterable_from_iterator([](const std::unique_ptr<int> &ptr) -> winrt::Windows::Foundation::Collections::IIterator<int>
{
co_yield *ptr;
}, ptr);

REQUIRE(const_ref_generator.First().Current() == 3);
*ptr = 4;
REQUIRE(const_ref_generator.First().Current() == 4);
}

SECTION("Range-based for loop")
{
std::wstring result;
for (const auto &i : test::hello_world_generator())
for (const auto &i : wil::make_iterable_from_iterator(&test::hello_world_generator))
{
result += i;
}
Expand Down