@@ -56,7 +56,7 @@ def connect(host="localhost", user=None, password="",
5656 client_flag = 0 , cursorclass = Cursor , init_command = None ,
5757 connect_timeout = None , read_default_group = None ,
5858 no_delay = None , autocommit = False , echo = False ,
59- local_infile = False , loop = None ):
59+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
6060 """See connections.Connection.__init__() for information about
6161 defaults."""
6262 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -68,7 +68,8 @@ def connect(host="localhost", user=None, password="",
6868 connect_timeout = connect_timeout ,
6969 read_default_group = read_default_group ,
7070 no_delay = no_delay , autocommit = autocommit , echo = echo ,
71- local_infile = local_infile , loop = loop )
71+ local_infile = local_infile , loop = loop , ssl = ssl ,
72+ auth_plugin = auth_plugin )
7273 return _ConnectionContextManager (coro )
7374
7475
@@ -93,7 +94,7 @@ def __init__(self, host="localhost", user=None, password="",
9394 client_flag = 0 , cursorclass = Cursor , init_command = None ,
9495 connect_timeout = None , read_default_group = None ,
9596 no_delay = None , autocommit = False , echo = False ,
96- local_infile = False , loop = None ):
97+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
9798 """
9899 Establish a connection to the MySQL database. Accepts several
99100 arguments:
@@ -164,6 +165,9 @@ def __init__(self, host="localhost", user=None, password="",
164165 self ._no_delay = no_delay
165166 self ._echo = echo
166167 self ._last_usage = self ._loop .time ()
168+ self ._client_auth_plugin = auth_plugin
169+ self ._server_auth_plugin = ""
170+ self ._auth_plugin_used = ""
167171
168172 self ._unix_socket = unix_socket
169173 if charset :
@@ -176,6 +180,10 @@ def __init__(self, host="localhost", user=None, password="",
176180 if use_unicode is not None :
177181 self .use_unicode = use_unicode
178182
183+ self ._ssl_context = ssl
184+ if ssl :
185+ client_flag |= CLIENT .SSL
186+
179187 self ._encoding = charset_by_name (self ._charset ).encoding
180188
181189 if local_infile :
@@ -209,8 +217,6 @@ def __init__(self, host="localhost", user=None, password="",
209217 # user
210218 self ._close_reason = None
211219
212- self ._auth_plugin_name = ""
213-
214220 @property
215221 def host (self ):
216222 """MySQL server IP address or name"""
@@ -663,6 +669,31 @@ def _request_authentication(self):
663669 if self .user is None :
664670 raise ValueError ("Did not specify a username" )
665671
672+ if self ._ssl_context :
673+ # capablities, max packet, charset
674+ data = struct .pack ('<IIB' , self .client_flag , 16777216 , 33 )
675+ data += b'\x00 ' * (32 - len (data ))
676+
677+ self .write_packet (data )
678+
679+ # Stop sending events to data_received
680+ self ._writer .transport .pause_reading ()
681+
682+ # Get the raw socket from the transport
683+ raw_sock = self ._writer .transport .get_extra_info ('socket' ,
684+ default = None )
685+ if raw_sock is None :
686+ raise RuntimeError ("Transport does not expose socket instance" )
687+
688+ # MySQL expects TLS negotiation to happen in the middle of a
689+ # TCP connection not at start. Passing in a socket to
690+ # open_connection will cause it to negotiate TLS on an existing
691+ # connection not initiate a new one.
692+ self ._reader , self ._writer = yield from asyncio .open_connection (
693+ sock = raw_sock , ssl = self ._ssl_context , loop = self ._loop ,
694+ server_hostname = self ._host
695+ )
696+
666697 charset_id = charset_by_name (self .charset ).id
667698 if isinstance (self .user , str ):
668699 _user = self .user .encode (self .encoding )
@@ -673,8 +704,16 @@ def _request_authentication(self):
673704 data = data_init + _user + b'\0 '
674705
675706 authresp = b''
676- if self ._auth_plugin_name in ('' , 'mysql_native_password' ):
707+
708+ auth_plugin = self ._client_auth_plugin
709+ if not self ._client_auth_plugin :
710+ # Contains the auth plugin from handshake
711+ auth_plugin = self ._server_auth_plugin
712+
713+ if auth_plugin in ('' , 'mysql_native_password' ):
677714 authresp = _scramble (self ._password .encode ('latin1' ), self .salt )
715+ elif auth_plugin in ('' , 'mysql_clear_password' ):
716+ authresp = self ._password .encode ('latin1' ) + b'\0 '
678717
679718 if self .server_capabilities & CLIENT .PLUGIN_AUTH_LENENC_CLIENT_DATA :
680719 data += lenenc_int (len (authresp )) + authresp
@@ -693,11 +732,13 @@ def _request_authentication(self):
693732 data += db + b'\0 '
694733
695734 if self .server_capabilities & CLIENT .PLUGIN_AUTH :
696- name = self . _auth_plugin_name
735+ name = auth_plugin
697736 if isinstance (name , str ):
698737 name = name .encode ('ascii' )
699738 data += name + b'\0 '
700739
740+ self ._auth_plugin_used = auth_plugin
741+
701742 self .write_packet (data )
702743 auth_packet = yield from self ._read_packet ()
703744
@@ -710,14 +751,45 @@ def _request_authentication(self):
710751 plugin_name = auth_packet .read_string ()
711752 if (self .server_capabilities & CLIENT .PLUGIN_AUTH and
712753 plugin_name is not None ):
713- auth_packet = self ._process_auth (plugin_name , auth_packet )
754+ auth_packet = yield from self ._process_auth (
755+ plugin_name , auth_packet )
714756 else :
715757 # send legacy handshake
716758 data = _scramble_323 (self ._password .encode ('latin1' ),
717759 self .salt ) + b'\0 '
718760 self .write_packet (data )
719761 auth_packet = yield from self ._read_packet ()
720762
763+ @asyncio .coroutine
764+ def _process_auth (self , plugin_name , auth_packet ):
765+ if plugin_name == b"mysql_native_password" :
766+ # https://dev.mysql.com/doc/internals/en/
767+ # secure-password-authentication.html#packet-Authentication::
768+ # Native41
769+ data = _scramble (self ._password .encode ('latin1' ),
770+ auth_packet .read_all ())
771+ elif plugin_name == b"mysql_old_password" :
772+ # https://dev.mysql.com/doc/internals/en/
773+ # old-password-authentication.html
774+ data = _scramble_323 (self ._password .encode ('latin1' ),
775+ auth_packet .read_all ()) + b'\0 '
776+ elif plugin_name == b"mysql_clear_password" :
777+ # https://dev.mysql.com/doc/internals/en/
778+ # clear-text-authentication.html
779+ data = self ._password .encode ('latin1' ) + b'\0 '
780+ else :
781+ raise OperationalError (
782+ 2059 , "Authentication plugin '%s' not configured" % plugin_name
783+ )
784+
785+ self .write_packet (data )
786+ pkt = yield from self ._read_packet ()
787+ pkt .check_error ()
788+
789+ self ._auth_plugin_used = plugin_name
790+
791+ return pkt
792+
721793 # _mysql support
722794 def thread_id (self ):
723795 return self .server_thread_id [0 ]
@@ -786,9 +858,9 @@ def _get_server_information(self):
786858 server_end = data .find (b'\0 ' , i )
787859 if server_end < 0 : # pragma: no cover - very specific upstream bug
788860 # not found \0 and last field so take it all
789- self ._auth_plugin_name = data [i :].decode ('latin1' )
861+ self ._server_auth_plugin = data [i :].decode ('latin1' )
790862 else :
791- self ._auth_plugin_name = data [i :server_end ].decode ('latin1' )
863+ self ._server_auth_plugin = data [i :server_end ].decode ('latin1' )
792864
793865 def get_transaction_status (self ):
794866 return bool (self .server_status & SERVER_STATUS .SERVER_STATUS_IN_TRANS )
0 commit comments