@@ -66,16 +66,21 @@ def __init__(self, trans):
6666 DEBUG_OUTPUT ("AsyncStatement::__init__()" )
6767 self .trans = trans
6868
69+ self ._is_open = False
70+ self .stmt_type = None
71+ self .handle = - 1
72+
73+ @classmethod
74+ async def create (cls , trans ):
75+ self = cls (trans )
6976 self .trans .connection ._op_allocate_statement ()
7077 if (self .trans .connection .accept_type & ptype_MASK ) == ptype_lazy_send :
7178 self .trans .connection .lazy_response_count += 1
7279 self .handle = - 1
7380 else :
74- (h , oid , buf ) = self .trans .connection ._op_response ()
81+ (h , oid , buf ) = await self .trans .connection ._async_op_response ()
7582 self .handle = h
76-
77- self ._is_open = False
78- self .stmt_type = None
83+ return self
7984
8085 async def fetch_generator (self , rows , more_data ):
8186 DEBUG_OUTPUT ("AsyncStatement::_fetch_generator()" , self .handle , self .trans ._trans_handle , self .trans .connection .db_handle )
@@ -172,7 +177,7 @@ class AsyncPreparedStatement(PreparedStatement):
172177 async def __init__ (self , cur , sql , explain_plan = False ):
173178 DEBUG_OUTPUT ("AsyncPreparedStatement::__init__()" )
174179 await cur .transaction .check_trans_handle ()
175- self .stmt = await AsyncStatement (cur .transaction )
180+ self .stmt = await AsyncStatement . create (cur .transaction )
176181 await self .stmt .prepare (sql , explain_plan )
177182 self .sql = sql
178183
@@ -217,7 +222,7 @@ async def _get_stmt(self, query):
217222 await self .stmt .drop ()
218223 self .stmt = None
219224 if self .stmt is None :
220- self .stmt = AsyncStatement (self .transaction )
225+ self .stmt = await AsyncStatement . create (self .transaction )
221226 stmt = self .stmt
222227 await stmt .prepare (query )
223228 return stmt
@@ -508,6 +513,17 @@ async def check_trans_handle(self):
508513 if self ._trans_handle is None :
509514 await self ._begin ()
510515
516+ async def close (self ):
517+ if self ._trans_handle is None :
518+ return
519+ if not self .is_dirty :
520+ return
521+ DEBUG_OUTPUT ("AsyncTransaction::close()" , self ._trans_handle , self .connection .db_handle )
522+ self .connection ._op_rollback (self ._trans_handle )
523+ (h , oid , buf ) = await self .connection ._async_op_response ()
524+ self ._trans_handle = None
525+ self .is_dirty = False
526+
511527
512528class AsyncConnectionResponseMixin (ConnectionResponseMixin ):
513529 async def _async_recv_channel (self , nbytes , word_alignment = False ):
@@ -516,9 +532,10 @@ async def _async_recv_channel(self, nbytes, word_alignment=False):
516532 n += 4 - nbytes % 4 # 4 bytes word alignment
517533 r = bytes ([])
518534 while n :
519- if (self .timeout is not None and select .select ([self .sock ._sock ], [], [], self .timeout )[0 ] == []):
520- break
521- b = await self .sock .async_recv (n )
535+ if self .timeout is not None :
536+ b = await asyncio .wait_for (self .sock .async_recv (n ), timeout = self .timeout )
537+ else :
538+ b = await self .sock .async_recv (n )
522539 if not b :
523540 break
524541 r += b
@@ -713,7 +730,7 @@ async def _async_parse_connect_response(self):
713730 raise OperationalError (
714731 'Unknown wirecrypt plugin %s' % (enc_plugin .encode ("utf-8" ))
715732 )
716- (h , oid , buf ) = self ._op_response ()
733+ (h , oid , buf ) = await self ._async_op_response ()
717734 else :
718735 # no matched wire encription plugin
719736 # self.auth_data use _op_attach() and _op_create()
@@ -983,6 +1000,25 @@ async def drop_database(self):
9831000 self .sock = None
9841001 self .db_handle = None
9851002
1003+ async def close (self ):
1004+ DEBUG_OUTPUT ("AsyncConnection::close()" , id (self ), self .db_handle )
1005+ if self .sock is None :
1006+ return
1007+ if self .db_handle is not None :
1008+ # cleanup transaction
1009+ for trans in list (self ._cursors .keys ()):
1010+ await trans .close ()
1011+ if self .is_services :
1012+ self ._op_service_detach ()
1013+ else :
1014+ self ._op_detach ()
1015+ (h , oid , buf ) = await self ._async_op_response ()
1016+ self .sock .close ()
1017+ self .sock = None
1018+ self .db_handle = None
1019+
9861020 def __del__ (self ):
9871021 if self .sock :
988- self .close ()
1022+ # Async close cannot be called from __del__
1023+ # self.close()
1024+ pass
0 commit comments