Skip to content

refactoring async_mutex #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/components/async_mutex/async_mutex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@

namespace NComponents {

AsyncMutex::AsyncMutex() : event() {}
AsyncMutex::AsyncMutex() : mutex_impl() {}

Event& AsyncMutex::lock() {
return event;
AsyncMutexCoroImpl& AsyncMutex::lock() {
return mutex_impl;
}

AsyncMutex::~AsyncMutex() {
assert(event.waiters.empty());
// assert(mutex_impl.waiters.empty());
}

void AsyncMutex::unlock() {
event.Unlock();
mutex_impl.Unlock();
}

} // namespace NComponents
6 changes: 3 additions & 3 deletions src/components/async_mutex/async_mutex.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <atomic>
#include <coroutine>

#include <components/async_mutex/event.h>
#include <components/async_mutex/async_mutex_coro_impl.h>
#include <components/async_mutex/resumable_no_own.h>

namespace NComponents {
Expand All @@ -17,11 +17,11 @@ class AsyncMutex final {
AsyncMutex();
~AsyncMutex();

Event& lock();
AsyncMutexCoroImpl& lock();
void unlock();

private:
Event event;
AsyncMutexCoroImpl mutex_impl;
};

} // namespace NComponents
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
//
// Created by konstantin on 19.07.24.
//
#include "event.h"

#include <iostream>
#include <syncstream>

#include "async_mutex_coro_impl.h"

namespace NComponents {

bool Event::TryLock() {
bool AsyncMutexCoroImpl::TryLock() {
if (lock_flag) {
std::osyncstream(std::cout)
<< "[Event::TryLock][thread_id=" << std::this_thread::get_id()
<< "[AsyncMutexCoroImpl::TryLock][thread_id=" << std::this_thread::get_id()
<< "] lock_flag was locked. Need park." << std::endl;
return false;
}

std::osyncstream(std::cout)
<< "[Event::TryLock][thread_id=" << std::this_thread::get_id()
<< "[AsyncMutexCoroImpl::TryLock][thread_id=" << std::this_thread::get_id()
<< "] lock lock_flag." << std::endl;
lock_flag = true;
return true;
}

void Event::Unlock() {
void AsyncMutexCoroImpl::Unlock() {
std::unique_lock lock(spinlock);

std::osyncstream(std::cout)
<< "[Event::Unlock][thread_id=" << std::this_thread::get_id()
<< "[AsyncMutexCoroImpl::Unlock][thread_id=" << std::this_thread::get_id()
<< "] call" << std::endl;

if (!waiters.empty()) {
std::cout << "[Event::Unlock] Waiters size=" << waiters.size()
std::cout << "[AsyncMutexCoroImpl::Unlock] Waiters size=" << waiters.size()
<< ", wake up first" << std::endl;
MutexAwaiter waiter = std::move(waiters.front());
waiters.pop_front();
Expand All @@ -41,26 +41,25 @@ void Event::Unlock() {
return;
}

std::cout << "[Event::Unlock] Waiters empty. Set lock_flag to false"
std::cout << "[AsyncMutexCoroImpl::Unlock] Waiters empty. Set lock_flag to false"
<< std::endl;
lock_flag = false;
}

MutexAwaiter Event::operator co_await() {
// std::unique_lock lock(spinlock);
return MutexAwaiter{*this, spinlock};
}

void Event::ParkAwaiter(MutexAwaiter* awaiter) {
void AsyncMutexCoroImpl::ParkAwaiter(MutexAwaiter* awaiter) {
assert(awaiter);
// assert(awaiter->HasLock());

std::osyncstream(std::cout)
<< "[Event::ParkAwaiter][thead_id=" << std::this_thread::get_id()
<< "[AsyncMutexCoroImpl::ParkAwaiter][thead_id=" << std::this_thread::get_id()
<< "] add waiter, new size=" << waiters.size() + 1 << std::endl;

waiters.emplace_back(std::move(*awaiter));
waiters.back().ReleaseLock();
}

MutexAwaiter AsyncMutexCoroImpl::operator co_await() {
MutexAwaiter::LockGuard guard(spinlock);
return MutexAwaiter{*this, std::move(guard)};
}

} // namespace NComponents
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

namespace NComponents {

struct Event final {
class AsyncMutexCoroImpl final {
mutable NSync::SpinLock spinlock;
bool lock_flag{};
std::list<MutexAwaiter> waiters;

Event() = default;
public:

AsyncMutexCoroImpl() = default;

bool TryLock();

Expand All @@ -33,9 +35,9 @@ struct Event final {
return lock_flag;
}

MutexAwaiter operator co_await();

void ParkAwaiter(MutexAwaiter* awaiter);

MutexAwaiter operator co_await();
};

} // namespace NComponents
26 changes: 12 additions & 14 deletions src/components/async_mutex/mutex_awaiter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@
#include <syncstream>
#include <thread>

#include <components/async_mutex/event.h>
#include <components/async_mutex/async_mutex_coro_impl.h>

namespace NComponents {

MutexAwaiter::MutexAwaiter(Event& event, NSync::SpinLock& guard)
: event(event), guard(guard) {

guard.lock();
MutexAwaiter::MutexAwaiter(AsyncMutexCoroImpl& event, LockGuard&& guard)
: event(event), guard(std::move(guard)) {
std::osyncstream(std::cout) << *this << " create with guard." << std::endl;
}

MutexAwaiter::MutexAwaiter(MutexAwaiter&& o) noexcept
: event(o.event), guard(o.guard), coro(o.coro) {
: event(o.event), coro(o.coro) {
o.coro = nullptr;
if (auto* p = o.guard.release(); p) {
guard = MutexAwaiter::LockGuard(*p, std::adopt_lock);
}
}

MutexAwaiter::~MutexAwaiter() {
Expand All @@ -32,12 +33,11 @@ MutexAwaiter::~MutexAwaiter() {
void MutexAwaiter::ReleaseLock() const {
std::osyncstream(std::cout)
<< *this << "[await_suspend] release guard." << std::endl;
guard.unlock();
}

//bool MutexAwaiter::HasLock() const {
// return guard.owns_lock();
//}
if (auto* p = guard.release(); p) {
p->unlock();
}
}

bool MutexAwaiter::await_ready() const {
const bool lock_own = event.TryLock();
Expand All @@ -59,9 +59,7 @@ void MutexAwaiter::await_suspend(std::coroutine_handle<> coro_) noexcept {

void MutexAwaiter::await_resume() const noexcept {
std::osyncstream(std::cout)
<< *this
<< "[await_resume] call and just resume, status flag=" << event.IsSet()
<< std::endl;
<< *this << "[await_resume] call and just resume." << std::endl;
}

std::ostream& operator<<(std::ostream& stream, const MutexAwaiter& /*w*/) {
Expand Down
17 changes: 9 additions & 8 deletions src/components/async_mutex/mutex_awaiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,28 @@

namespace NComponents {

struct Event;
using LockGuard = std::unique_lock<NSync::SpinLock>;
class AsyncMutexCoroImpl;

class MutexAwaiter final {
Event& event;
NSync::SpinLock& guard;
std::coroutine_handle<> coro{};

public:
MutexAwaiter(Event& event, NSync::SpinLock& guard);
using LockGuard = std::unique_lock<NSync::SpinLock>;

MutexAwaiter(AsyncMutexCoroImpl& event, LockGuard&& guard);
MutexAwaiter(MutexAwaiter&& o) noexcept;

~MutexAwaiter();

void Resume() const { coro.resume(); }
void ReleaseLock() const;
// bool HasLock() const;

bool await_ready() const;
void await_suspend(std::coroutine_handle<>) noexcept;
void await_resume() const noexcept;

private:
AsyncMutexCoroImpl& event;
mutable LockGuard guard;
std::coroutine_handle<> coro{};
};

std::ostream& operator<<(std::ostream& stream, const MutexAwaiter& w);
Expand Down
10 changes: 6 additions & 4 deletions src/components/async_mutex/resumable_no_own.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace NComponents {

class AsyncMutexCoroImpl;

struct ResumableNoOwn {
struct promise_type {
std::suspend_never initial_suspend() const noexcept { return {}; }
Expand All @@ -23,10 +25,10 @@ struct ResumableNoOwn {
};

ResumableNoOwn(std::coroutine_handle<promise_type> /*handle*/) {
// std::osyncstream(std::cout)
// << "[ResumableNoOwn][this=" << this
// << "][thread_id=" << std::this_thread::get_id() << "] create"
// << std::endl;
std::osyncstream(std::cout)
<< "[ResumableNoOwn][this=" << this
<< "][thread_id=" << std::this_thread::get_id() << "] create"
<< std::endl;
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/components/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ components_sources = files(
'lock_free/hazard/hazard_manager.cpp',
'lock_free/hazard/thread_state.cpp',
'async_mutex/async_mutex.cpp',
'async_mutex/event.cpp',
'async_mutex/async_mutex_coro_impl.cpp',
'async_mutex/mutex_awaiter.cpp',
'async_mutex/resumable_no_own.cpp',
)
Expand Down
33 changes: 17 additions & 16 deletions test/components/test_async_mutex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include "gtest/gtest.h"

#include <condition_variable>
#include <thread>
#include <latch>
#include <thread>

#include <components/async_mutex/async_mutex.h>

Expand All @@ -28,17 +28,21 @@ struct TestSyncIncrement final {
size_t number{};

std::latch latch;
const size_t count_iterations;

explicit TestSyncIncrement(size_t n) : latch(n) {}
explicit TestSyncIncrement(size_t n, size_t count_iterations)
: latch(n), count_iterations(count_iterations) {}

NComponents::ResumableNoOwn run() {
{
std::unique_lock lock(cv_wait);
while (!wait_flag) cv.wait(lock);
}
co_await mutex.lock();
number += 1;
mutex.unlock();
for (size_t i = 0; i < count_iterations; i++) {
co_await mutex.lock();
number += 1;
mutex.unlock();
}

latch.count_down();
}
Expand All @@ -52,9 +56,7 @@ struct TestSyncIncrement final {
mutex.unlock();
}

void Wait() {
latch.wait();
}
void Wait() { latch.wait(); }

private:
std::mutex cv_wait;
Expand All @@ -73,13 +75,11 @@ struct TestWait final {

std::this_thread::sleep_for(5s);

if (status)
throw std::logic_error("no wait");
if (status) throw std::logic_error("no wait");

mutex.unlock();

if (!status)
throw std::logic_error("status is false");
if (!status) throw std::logic_error("status is false");
}

private:
Expand All @@ -99,20 +99,21 @@ TEST(TestAsyncMutex, JustWorking) {
}

TEST(TestAsyncMutex, SyncIncrementInThreads) {
constexpr size_t MaxCount = 32;
constexpr size_t MaxCountThreads = 32;
constexpr size_t CountIterations = 100;

TestSyncIncrement worker(MaxCount);
TestSyncIncrement worker(MaxCountThreads, CountIterations);
{
std::vector<std::jthread> workers;
for (size_t i = 0; i < MaxCount; i++) {
for (size_t i = 0; i < MaxCountThreads; i++) {
workers.emplace_back(&TestSyncIncrement::run, &worker);
}

worker.StartAll();
worker.Wait();
}

ASSERT_EQ(worker.number, MaxCount);
ASSERT_EQ(worker.number, MaxCountThreads * CountIterations);
}

TEST(TestAsyncMutex, ThreadWait) {
Expand Down
Loading