Skip to content

Commit 1585bab

Browse files
committed
Wait for drain before writing
1 parent 92caf68 commit 1585bab

8 files changed

+33
-11
lines changed

Diff for: src/galaxy/api/jsonrpc.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(self, reader, writer, encoder=json.JSONEncoder()):
8888
self._methods = {}
8989
self._notifications = {}
9090
self._task_manager = TaskManager("jsonrpc server")
91+
self._write_lock = asyncio.Lock()
9192

9293
def register_method(self, name, callback, immediate, sensitive_params=False):
9394
"""
@@ -223,12 +224,16 @@ def _parse_request(data):
223224
raise InvalidRequest()
224225

225226
def _send(self, data):
227+
async def send_task(data_):
228+
async with self._write_lock:
229+
self._writer.write(data_)
230+
await self._writer.drain()
231+
226232
try:
227233
line = self._encoder.encode(data)
228234
logging.debug("Sending data: %s", line)
229235
data = (line + "\n").encode("utf-8")
230-
self._writer.write(data)
231-
self._task_manager.create_task(self._writer.drain(), "drain")
236+
self._task_manager.create_task(send_task(data), "send")
232237
except TypeError as error:
233238
logging.error(str(error))
234239

@@ -263,6 +268,7 @@ def __init__(self, writer, encoder=json.JSONEncoder()):
263268
self._encoder = encoder
264269
self._methods = {}
265270
self._task_manager = TaskManager("notification client")
271+
self._write_lock = asyncio.Lock()
266272

267273
def notify(self, method, params, sensitive_params=False):
268274
"""
@@ -286,12 +292,16 @@ async def close(self):
286292
await self._task_manager.wait()
287293

288294
def _send(self, data):
295+
async def send_task(data_):
296+
async with self._write_lock:
297+
self._writer.write(data_)
298+
await self._writer.drain()
299+
289300
try:
290301
line = self._encoder.encode(data)
291302
data = (line + "\n").encode("utf-8")
292303
logging.debug("Sending %d byte of data", len(data))
293-
self._writer.write(data)
294-
self._task_manager.create_task(self._writer.drain(), "drain")
304+
self._task_manager.create_task(send_task(data), "send")
295305
except TypeError as error:
296306
logging.error("Failed to parse outgoing message: %s", str(error))
297307

Diff for: tests/test_achievements.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from galaxy.api.types import Achievement
77
from galaxy.api.errors import BackendError
8-
from galaxy.unittest.mock import async_return_value
8+
from galaxy.unittest.mock import async_return_value, skip_loop
99

1010
from tests import create_message, get_messages
1111

@@ -201,6 +201,7 @@ async def test_import_in_progress(plugin, read, write):
201201
async def test_unlock_achievement(plugin, write):
202202
achievement = Achievement(achievement_id="lvl20", unlock_time=1548422395)
203203
plugin.unlock_achievement("14", achievement)
204+
await skip_loop()
204205
response = json.loads(write.call_args[0][0])
205206

206207
assert response == {

Diff for: tests/test_authenticate.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
UnknownError, InvalidCredentials, NetworkError, LoggedInElsewhere, ProtocolError,
66
BackendNotAvailable, BackendTimeout, BackendError, TemporaryBlocked, Banned, AccessDenied
77
)
8-
from galaxy.unittest.mock import async_return_value
8+
from galaxy.unittest.mock import async_return_value, skip_loop
99

1010
from tests import create_message, get_messages
1111

@@ -97,6 +97,7 @@ async def test_store_credentials(plugin, write):
9797
"token": "ABC"
9898
}
9999
plugin.store_credentials(credentials)
100+
await skip_loop()
100101

101102
assert get_messages(write) == [
102103
{
@@ -110,6 +111,7 @@ async def test_store_credentials(plugin, write):
110111
@pytest.mark.asyncio
111112
async def test_lost_authentication(plugin, write):
112113
plugin.lost_authentication()
114+
await skip_loop()
113115

114116
assert get_messages(write) == [
115117
{

Diff for: tests/test_friends.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from galaxy.api.types import FriendInfo
22
from galaxy.api.errors import UnknownError
3-
from galaxy.unittest.mock import async_return_value
3+
from galaxy.unittest.mock import async_return_value, skip_loop
44

55
import pytest
66

@@ -67,6 +67,7 @@ async def test_add_friend(plugin, write):
6767
friend = FriendInfo("7", "Kuba")
6868

6969
plugin.add_friend(friend)
70+
await skip_loop()
7071

7172
assert get_messages(write) == [
7273
{
@@ -82,6 +83,7 @@ async def test_add_friend(plugin, write):
8283
@pytest.mark.asyncio
8384
async def test_remove_friend(plugin, write):
8485
plugin.remove_friend("5")
86+
await skip_loop()
8587

8688
assert get_messages(write) == [
8789
{

Diff for: tests/test_game_times.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from galaxy.api.types import GameTime
55
from galaxy.api.errors import BackendError
6-
from galaxy.unittest.mock import async_return_value
6+
from galaxy.unittest.mock import async_return_value, skip_loop
77

88
from tests import create_message, get_messages
99

@@ -199,6 +199,7 @@ async def test_import_in_progress(plugin, read, write):
199199
async def test_update_game(plugin, write):
200200
game_time = GameTime("3", 60, 1549550504)
201201
plugin.update_game_time(game_time)
202+
await skip_loop()
202203

203204
assert get_messages(write) == [
204205
{

Diff for: tests/test_local_games.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from galaxy.api.types import LocalGame
44
from galaxy.api.consts import LocalGameState
55
from galaxy.api.errors import UnknownError, FailedParsingManifest
6-
from galaxy.unittest.mock import async_return_value
6+
from galaxy.unittest.mock import async_return_value, skip_loop
77

88
from tests import create_message, get_messages
99

@@ -83,6 +83,7 @@ async def test_failure(plugin, read, write, error, code, message):
8383
async def test_local_game_state_update(plugin, write):
8484
game = LocalGame("1", LocalGameState.Running)
8585
plugin.update_local_game_status(game)
86+
await skip_loop()
8687

8788
assert get_messages(write) == [
8889
{

Diff for: tests/test_owned_games.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from galaxy.api.types import Game, Dlc, LicenseInfo
44
from galaxy.api.consts import LicenseType
55
from galaxy.api.errors import UnknownError
6-
from galaxy.unittest.mock import async_return_value
6+
from galaxy.unittest.mock import async_return_value, skip_loop
77

88
from tests import create_message, get_messages
99

@@ -100,6 +100,7 @@ async def test_failure(plugin, read, write):
100100
async def test_add_game(plugin, write):
101101
game = Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None))
102102
plugin.add_game(game)
103+
await skip_loop()
103104
assert get_messages(write) == [
104105
{
105106
"jsonrpc": "2.0",
@@ -120,6 +121,7 @@ async def test_add_game(plugin, write):
120121
@pytest.mark.asyncio
121122
async def test_remove_game(plugin, write):
122123
plugin.remove_game("5")
124+
await skip_loop()
123125
assert get_messages(write) == [
124126
{
125127
"jsonrpc": "2.0",
@@ -135,6 +137,7 @@ async def test_remove_game(plugin, write):
135137
async def test_update_game(plugin, write):
136138
game = Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None))
137139
plugin.update_game(game)
140+
await skip_loop()
138141
assert get_messages(write) == [
139142
{
140143
"jsonrpc": "2.0",

Diff for: tests/test_persistent_cache.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from galaxy.unittest.mock import async_return_value
3+
from galaxy.unittest.mock import async_return_value, skip_loop
44

55
from tests import create_message, get_messages
66

@@ -57,6 +57,7 @@ async def test_set_cache(plugin, write, cache_data):
5757

5858
plugin.persistent_cache.update(cache_data)
5959
plugin.push_cache()
60+
await skip_loop()
6061

6162
assert_rpc_request(write, "push_cache", cache_data)
6263
assert cache_data == plugin.persistent_cache
@@ -68,6 +69,7 @@ async def test_clear_cache(plugin, write, cache_data):
6869

6970
plugin.persistent_cache.clear()
7071
plugin.push_cache()
72+
await skip_loop()
7173

7274
assert_rpc_request(write, "push_cache", {})
7375
assert {} == plugin.persistent_cache

0 commit comments

Comments
 (0)