Skip to content

Commit 35bc636

Browse files
authored
Merge pull request #128 from attid/master
✨ feat(aio): implement non-blocking socket stream and async connectio…
2 parents 53df8ce + a7dc480 commit 35bc636

File tree

2 files changed

+88
-15
lines changed

2 files changed

+88
-15
lines changed

firebirdsql/aio/fbcore.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,21 @@ def __init__(self, trans):
6666
DEBUG_OUTPUT("AsyncStatement::__init__()")
6767
self.trans = trans
6868

69+
self._is_open = False
70+
self.stmt_type = None
71+
self.handle = -1
72+
73+
@classmethod
74+
async def create(cls, trans):
75+
self = cls(trans)
6976
self.trans.connection._op_allocate_statement()
7077
if (self.trans.connection.accept_type & ptype_MASK) == ptype_lazy_send:
7178
self.trans.connection.lazy_response_count += 1
7279
self.handle = -1
7380
else:
74-
(h, oid, buf) = self.trans.connection._op_response()
81+
(h, oid, buf) = await self.trans.connection._async_op_response()
7582
self.handle = h
76-
77-
self._is_open = False
78-
self.stmt_type = None
83+
return self
7984

8085
async def fetch_generator(self, rows, more_data):
8186
DEBUG_OUTPUT("AsyncStatement::_fetch_generator()", self.handle, self.trans._trans_handle, self.trans.connection.db_handle)
@@ -172,7 +177,7 @@ class AsyncPreparedStatement(PreparedStatement):
172177
async def __init__(self, cur, sql, explain_plan=False):
173178
DEBUG_OUTPUT("AsyncPreparedStatement::__init__()")
174179
await cur.transaction.check_trans_handle()
175-
self.stmt = await AsyncStatement(cur.transaction)
180+
self.stmt = await AsyncStatement.create(cur.transaction)
176181
await self.stmt.prepare(sql, explain_plan)
177182
self.sql = sql
178183

@@ -217,7 +222,7 @@ async def _get_stmt(self, query):
217222
await self.stmt.drop()
218223
self.stmt = None
219224
if self.stmt is None:
220-
self.stmt = AsyncStatement(self.transaction)
225+
self.stmt = await AsyncStatement.create(self.transaction)
221226
stmt = self.stmt
222227
await stmt.prepare(query)
223228
return stmt
@@ -508,6 +513,17 @@ async def check_trans_handle(self):
508513
if self._trans_handle is None:
509514
await self._begin()
510515

516+
async def close(self):
517+
if self._trans_handle is None:
518+
return
519+
if not self.is_dirty:
520+
return
521+
DEBUG_OUTPUT("AsyncTransaction::close()", self._trans_handle, self.connection.db_handle)
522+
self.connection._op_rollback(self._trans_handle)
523+
(h, oid, buf) = await self.connection._async_op_response()
524+
self._trans_handle = None
525+
self.is_dirty = False
526+
511527

512528
class AsyncConnectionResponseMixin(ConnectionResponseMixin):
513529
async def _async_recv_channel(self, nbytes, word_alignment=False):
@@ -516,9 +532,10 @@ async def _async_recv_channel(self, nbytes, word_alignment=False):
516532
n += 4 - nbytes % 4 # 4 bytes word alignment
517533
r = bytes([])
518534
while n:
519-
if (self.timeout is not None and select.select([self.sock._sock], [], [], self.timeout)[0] == []):
520-
break
521-
b = await self.sock.async_recv(n)
535+
if self.timeout is not None:
536+
b = await asyncio.wait_for(self.sock.async_recv(n), timeout=self.timeout)
537+
else:
538+
b = await self.sock.async_recv(n)
522539
if not b:
523540
break
524541
r += b
@@ -713,7 +730,7 @@ async def _async_parse_connect_response(self):
713730
raise OperationalError(
714731
'Unknown wirecrypt plugin %s' % (enc_plugin.encode("utf-8"))
715732
)
716-
(h, oid, buf) = self._op_response()
733+
(h, oid, buf) = await self._async_op_response()
717734
else:
718735
# no matched wire encription plugin
719736
# self.auth_data use _op_attach() and _op_create()
@@ -983,6 +1000,25 @@ async def drop_database(self):
9831000
self.sock = None
9841001
self.db_handle = None
9851002

1003+
async def close(self):
1004+
DEBUG_OUTPUT("AsyncConnection::close()", id(self), self.db_handle)
1005+
if self.sock is None:
1006+
return
1007+
if self.db_handle is not None:
1008+
# cleanup transaction
1009+
for trans in list(self._cursors.keys()):
1010+
await trans.close()
1011+
if self.is_services:
1012+
self._op_service_detach()
1013+
else:
1014+
self._op_detach()
1015+
(h, oid, buf) = await self._async_op_response()
1016+
self.sock.close()
1017+
self.sock = None
1018+
self.db_handle = None
1019+
9861020
def __del__(self):
9871021
if self.sock:
988-
self.close()
1022+
# Async close cannot be called from __del__
1023+
# self.close()
1024+
pass

firebirdsql/aio/stream.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,53 @@
2525
#
2626
# Python DB-API 2.0 module for Firebird.
2727
##############################################################################
28+
import asyncio
29+
2830
from firebirdsql.stream import SocketStream
31+
from firebirdsql.utils import bytes_to_bint
2932

3033

3134
class AsyncSocketStream(SocketStream):
3235
def __init__(self, host, port, loop, timeout, cloexec):
3336
super().__init__(host, port, timeout, cloexec)
3437
self.loop = loop
38+
self._send_lock = asyncio.Lock()
39+
self._last_send_task = None
40+
self._sock.setblocking(False)
41+
self._buf = b''
42+
43+
async def _await_pending_send(self):
44+
task = self._last_send_task
45+
if task is not None:
46+
await task
3547

3648
async def async_recv(self, nbytes):
37-
b = await self.loop.sock_recv(self._sock, nbytes)
38-
if self.read_translator:
39-
b = self.read_translator.decrypt(b)
40-
return b
49+
await self._await_pending_send()
50+
51+
if len(self._buf) < nbytes:
52+
read_size = max(8192, nbytes - len(self._buf))
53+
chunk = await self.loop.sock_recv(self._sock, read_size)
54+
if self.read_translator:
55+
chunk = self.read_translator.decrypt(chunk)
56+
self._buf += chunk
57+
58+
ret = self._buf[:nbytes]
59+
self._buf = self._buf[nbytes:]
60+
return ret
61+
62+
def send(self, b):
63+
if not self.loop.is_running():
64+
return super().send(b)
65+
if self.write_translator:
66+
b = self.write_translator.encrypt(b)
67+
68+
previous_task = self._last_send_task
69+
70+
async def _send_all(payload):
71+
if previous_task is not None:
72+
await previous_task
73+
async with self._send_lock:
74+
await self.loop.sock_sendall(self._sock, payload)
75+
76+
self._last_send_task = self.loop.create_task(_send_all(b))
77+
return None

0 commit comments

Comments
 (0)