Skip to content

Commit e6e3e4c

Browse files
authored
refactoring async_mutex (#55)
* refactoring async_mutex
1 parent 1b60fde commit e6e3e4c

9 files changed

+75
-72
lines changed

src/components/async_mutex/async_mutex.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212

1313
namespace NComponents {
1414

15-
AsyncMutex::AsyncMutex() : event() {}
15+
AsyncMutex::AsyncMutex() : mutex_impl() {}
1616

17-
Event& AsyncMutex::lock() {
18-
return event;
17+
AsyncMutexCoroImpl& AsyncMutex::lock() {
18+
return mutex_impl;
1919
}
2020

2121
AsyncMutex::~AsyncMutex() {
22-
assert(event.waiters.empty());
22+
// assert(mutex_impl.waiters.empty());
2323
}
2424

2525
void AsyncMutex::unlock() {
26-
event.Unlock();
26+
mutex_impl.Unlock();
2727
}
2828

2929
} // namespace NComponents

src/components/async_mutex/async_mutex.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <atomic>
88
#include <coroutine>
99

10-
#include <components/async_mutex/event.h>
10+
#include <components/async_mutex/async_mutex_coro_impl.h>
1111
#include <components/async_mutex/resumable_no_own.h>
1212

1313
namespace NComponents {
@@ -17,11 +17,11 @@ class AsyncMutex final {
1717
AsyncMutex();
1818
~AsyncMutex();
1919

20-
Event& lock();
20+
AsyncMutexCoroImpl& lock();
2121
void unlock();
2222

2323
private:
24-
Event event;
24+
AsyncMutexCoroImpl mutex_impl;
2525
};
2626

2727
} // namespace NComponents
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,37 @@
11
//
22
// Created by konstantin on 19.07.24.
33
//
4-
#include "event.h"
5-
64
#include <iostream>
75
#include <syncstream>
86

7+
#include "async_mutex_coro_impl.h"
8+
99
namespace NComponents {
1010

11-
bool Event::TryLock() {
11+
bool AsyncMutexCoroImpl::TryLock() {
1212
if (lock_flag) {
1313
std::osyncstream(std::cout)
14-
<< "[Event::TryLock][thread_id=" << std::this_thread::get_id()
14+
<< "[AsyncMutexCoroImpl::TryLock][thread_id=" << std::this_thread::get_id()
1515
<< "] lock_flag was locked. Need park." << std::endl;
1616
return false;
1717
}
1818

1919
std::osyncstream(std::cout)
20-
<< "[Event::TryLock][thread_id=" << std::this_thread::get_id()
20+
<< "[AsyncMutexCoroImpl::TryLock][thread_id=" << std::this_thread::get_id()
2121
<< "] lock lock_flag." << std::endl;
2222
lock_flag = true;
2323
return true;
2424
}
2525

26-
void Event::Unlock() {
26+
void AsyncMutexCoroImpl::Unlock() {
2727
std::unique_lock lock(spinlock);
2828

2929
std::osyncstream(std::cout)
30-
<< "[Event::Unlock][thread_id=" << std::this_thread::get_id()
30+
<< "[AsyncMutexCoroImpl::Unlock][thread_id=" << std::this_thread::get_id()
3131
<< "] call" << std::endl;
3232

3333
if (!waiters.empty()) {
34-
std::cout << "[Event::Unlock] Waiters size=" << waiters.size()
34+
std::cout << "[AsyncMutexCoroImpl::Unlock] Waiters size=" << waiters.size()
3535
<< ", wake up first" << std::endl;
3636
MutexAwaiter waiter = std::move(waiters.front());
3737
waiters.pop_front();
@@ -41,26 +41,25 @@ void Event::Unlock() {
4141
return;
4242
}
4343

44-
std::cout << "[Event::Unlock] Waiters empty. Set lock_flag to false"
44+
std::cout << "[AsyncMutexCoroImpl::Unlock] Waiters empty. Set lock_flag to false"
4545
<< std::endl;
4646
lock_flag = false;
4747
}
4848

49-
MutexAwaiter Event::operator co_await() {
50-
// std::unique_lock lock(spinlock);
51-
return MutexAwaiter{*this, spinlock};
52-
}
53-
54-
void Event::ParkAwaiter(MutexAwaiter* awaiter) {
49+
void AsyncMutexCoroImpl::ParkAwaiter(MutexAwaiter* awaiter) {
5550
assert(awaiter);
56-
// assert(awaiter->HasLock());
5751

5852
std::osyncstream(std::cout)
59-
<< "[Event::ParkAwaiter][thead_id=" << std::this_thread::get_id()
53+
<< "[AsyncMutexCoroImpl::ParkAwaiter][thead_id=" << std::this_thread::get_id()
6054
<< "] add waiter, new size=" << waiters.size() + 1 << std::endl;
6155

6256
waiters.emplace_back(std::move(*awaiter));
6357
waiters.back().ReleaseLock();
6458
}
6559

60+
MutexAwaiter AsyncMutexCoroImpl::operator co_await() {
61+
MutexAwaiter::LockGuard guard(spinlock);
62+
return MutexAwaiter{*this, std::move(guard)};
63+
}
64+
6665
} // namespace NComponents

src/components/async_mutex/event.h renamed to src/components/async_mutex/async_mutex_coro_impl.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
namespace NComponents {
1919

20-
struct Event final {
20+
class AsyncMutexCoroImpl final {
2121
mutable NSync::SpinLock spinlock;
2222
bool lock_flag{};
2323
std::list<MutexAwaiter> waiters;
2424

25-
Event() = default;
25+
public:
26+
27+
AsyncMutexCoroImpl() = default;
2628

2729
bool TryLock();
2830

@@ -33,9 +35,9 @@ struct Event final {
3335
return lock_flag;
3436
}
3537

36-
MutexAwaiter operator co_await();
37-
3838
void ParkAwaiter(MutexAwaiter* awaiter);
39+
40+
MutexAwaiter operator co_await();
3941
};
4042

4143
} // namespace NComponents

src/components/async_mutex/mutex_awaiter.cpp

+12-14
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,21 @@
99
#include <syncstream>
1010
#include <thread>
1111

12-
#include <components/async_mutex/event.h>
12+
#include <components/async_mutex/async_mutex_coro_impl.h>
1313

1414
namespace NComponents {
1515

16-
MutexAwaiter::MutexAwaiter(Event& event, NSync::SpinLock& guard)
17-
: event(event), guard(guard) {
18-
19-
guard.lock();
16+
MutexAwaiter::MutexAwaiter(AsyncMutexCoroImpl& event, LockGuard&& guard)
17+
: event(event), guard(std::move(guard)) {
2018
std::osyncstream(std::cout) << *this << " create with guard." << std::endl;
2119
}
2220

2321
MutexAwaiter::MutexAwaiter(MutexAwaiter&& o) noexcept
24-
: event(o.event), guard(o.guard), coro(o.coro) {
22+
: event(o.event), coro(o.coro) {
2523
o.coro = nullptr;
24+
if (auto* p = o.guard.release(); p) {
25+
guard = MutexAwaiter::LockGuard(*p, std::adopt_lock);
26+
}
2627
}
2728

2829
MutexAwaiter::~MutexAwaiter() {
@@ -32,12 +33,11 @@ MutexAwaiter::~MutexAwaiter() {
3233
void MutexAwaiter::ReleaseLock() const {
3334
std::osyncstream(std::cout)
3435
<< *this << "[await_suspend] release guard." << std::endl;
35-
guard.unlock();
36-
}
3736

38-
//bool MutexAwaiter::HasLock() const {
39-
// return guard.owns_lock();
40-
//}
37+
if (auto* p = guard.release(); p) {
38+
p->unlock();
39+
}
40+
}
4141

4242
bool MutexAwaiter::await_ready() const {
4343
const bool lock_own = event.TryLock();
@@ -59,9 +59,7 @@ void MutexAwaiter::await_suspend(std::coroutine_handle<> coro_) noexcept {
5959

6060
void MutexAwaiter::await_resume() const noexcept {
6161
std::osyncstream(std::cout)
62-
<< *this
63-
<< "[await_resume] call and just resume, status flag=" << event.IsSet()
64-
<< std::endl;
62+
<< *this << "[await_resume] call and just resume." << std::endl;
6563
}
6664

6765
std::ostream& operator<<(std::ostream& stream, const MutexAwaiter& /*w*/) {

src/components/async_mutex/mutex_awaiter.h

+9-8
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,28 @@
1212

1313
namespace NComponents {
1414

15-
struct Event;
16-
using LockGuard = std::unique_lock<NSync::SpinLock>;
15+
class AsyncMutexCoroImpl;
1716

1817
class MutexAwaiter final {
19-
Event& event;
20-
NSync::SpinLock& guard;
21-
std::coroutine_handle<> coro{};
22-
2318
public:
24-
MutexAwaiter(Event& event, NSync::SpinLock& guard);
19+
using LockGuard = std::unique_lock<NSync::SpinLock>;
20+
21+
MutexAwaiter(AsyncMutexCoroImpl& event, LockGuard&& guard);
2522
MutexAwaiter(MutexAwaiter&& o) noexcept;
2623

2724
~MutexAwaiter();
2825

2926
void Resume() const { coro.resume(); }
3027
void ReleaseLock() const;
31-
// bool HasLock() const;
3228

3329
bool await_ready() const;
3430
void await_suspend(std::coroutine_handle<>) noexcept;
3531
void await_resume() const noexcept;
32+
33+
private:
34+
AsyncMutexCoroImpl& event;
35+
mutable LockGuard guard;
36+
std::coroutine_handle<> coro{};
3637
};
3738

3839
std::ostream& operator<<(std::ostream& stream, const MutexAwaiter& w);

src/components/async_mutex/resumable_no_own.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
namespace NComponents {
1414

15+
class AsyncMutexCoroImpl;
16+
1517
struct ResumableNoOwn {
1618
struct promise_type {
1719
std::suspend_never initial_suspend() const noexcept { return {}; }
@@ -23,10 +25,10 @@ struct ResumableNoOwn {
2325
};
2426

2527
ResumableNoOwn(std::coroutine_handle<promise_type> /*handle*/) {
26-
// std::osyncstream(std::cout)
27-
// << "[ResumableNoOwn][this=" << this
28-
// << "][thread_id=" << std::this_thread::get_id() << "] create"
29-
// << std::endl;
28+
std::osyncstream(std::cout)
29+
<< "[ResumableNoOwn][this=" << this
30+
<< "][thread_id=" << std::this_thread::get_id() << "] create"
31+
<< std::endl;
3032
}
3133
};
3234

src/components/meson.build

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ components_sources = files(
99
'lock_free/hazard/hazard_manager.cpp',
1010
'lock_free/hazard/thread_state.cpp',
1111
'async_mutex/async_mutex.cpp',
12-
'async_mutex/event.cpp',
12+
'async_mutex/async_mutex_coro_impl.cpp',
1313
'async_mutex/mutex_awaiter.cpp',
1414
'async_mutex/resumable_no_own.cpp',
1515
)

test/components/test_async_mutex.cpp

+17-16
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include "gtest/gtest.h"
66

77
#include <condition_variable>
8-
#include <thread>
98
#include <latch>
9+
#include <thread>
1010

1111
#include <components/async_mutex/async_mutex.h>
1212

@@ -28,17 +28,21 @@ struct TestSyncIncrement final {
2828
size_t number{};
2929

3030
std::latch latch;
31+
const size_t count_iterations;
3132

32-
explicit TestSyncIncrement(size_t n) : latch(n) {}
33+
explicit TestSyncIncrement(size_t n, size_t count_iterations)
34+
: latch(n), count_iterations(count_iterations) {}
3335

3436
NComponents::ResumableNoOwn run() {
3537
{
3638
std::unique_lock lock(cv_wait);
3739
while (!wait_flag) cv.wait(lock);
3840
}
39-
co_await mutex.lock();
40-
number += 1;
41-
mutex.unlock();
41+
for (size_t i = 0; i < count_iterations; i++) {
42+
co_await mutex.lock();
43+
number += 1;
44+
mutex.unlock();
45+
}
4246

4347
latch.count_down();
4448
}
@@ -52,9 +56,7 @@ struct TestSyncIncrement final {
5256
mutex.unlock();
5357
}
5458

55-
void Wait() {
56-
latch.wait();
57-
}
59+
void Wait() { latch.wait(); }
5860

5961
private:
6062
std::mutex cv_wait;
@@ -73,13 +75,11 @@ struct TestWait final {
7375

7476
std::this_thread::sleep_for(5s);
7577

76-
if (status)
77-
throw std::logic_error("no wait");
78+
if (status) throw std::logic_error("no wait");
7879

7980
mutex.unlock();
8081

81-
if (!status)
82-
throw std::logic_error("status is false");
82+
if (!status) throw std::logic_error("status is false");
8383
}
8484

8585
private:
@@ -99,20 +99,21 @@ TEST(TestAsyncMutex, JustWorking) {
9999
}
100100

101101
TEST(TestAsyncMutex, SyncIncrementInThreads) {
102-
constexpr size_t MaxCount = 32;
102+
constexpr size_t MaxCountThreads = 32;
103+
constexpr size_t CountIterations = 100;
103104

104-
TestSyncIncrement worker(MaxCount);
105+
TestSyncIncrement worker(MaxCountThreads, CountIterations);
105106
{
106107
std::vector<std::jthread> workers;
107-
for (size_t i = 0; i < MaxCount; i++) {
108+
for (size_t i = 0; i < MaxCountThreads; i++) {
108109
workers.emplace_back(&TestSyncIncrement::run, &worker);
109110
}
110111

111112
worker.StartAll();
112113
worker.Wait();
113114
}
114115

115-
ASSERT_EQ(worker.number, MaxCount);
116+
ASSERT_EQ(worker.number, MaxCountThreads * CountIterations);
116117
}
117118

118119
TEST(TestAsyncMutex, ThreadWait) {

0 commit comments

Comments
 (0)