4141# from aiomysql.utils import _convert_to_str
4242from .cursors import Cursor
4343from .utils import _ConnectionContextManager , _ContextManager
44- # from .log import logger
44+ from .log import logger
4545
4646
4747DEFAULT_USER = getpass .getuser ()
@@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
5555 connect_timeout = None , read_default_group = None ,
5656 no_delay = None , autocommit = False , echo = False ,
5757 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
58- program_name = '' ):
58+ program_name = '' , server_public_key = None ):
5959 """See connections.Connection.__init__() for information about
6060 defaults."""
6161 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
9393 connect_timeout = None , read_default_group = None ,
9494 no_delay = None , autocommit = False , echo = False ,
9595 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
96- program_name = '' ):
96+ program_name = '' , server_public_key = None ):
9797 """
9898 Establish a connection to the MySQL database. Accepts several
9999 arguments:
@@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
134134 (default: Server Default)
135135 :param program_name: Program name string to provide when
136136 handshaking with MySQL. (default: sys.argv[0])
137+ :param server_public_key: SHA256 authentication plugin public
138+ key value.
137139 :param loop: asyncio loop
138140 """
139141 self ._loop = loop or asyncio .get_event_loop ()
@@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
174176 self ._client_auth_plugin = auth_plugin
175177 self ._server_auth_plugin = ""
176178 self ._auth_plugin_used = ""
179+ self .server_public_key = server_public_key
180+ self .salt = None
177181
178182 # TODO somehow import version from __init__.py
179183 self ._connect_attrs = {
@@ -712,6 +716,20 @@ async def _request_authentication(self):
712716 if auth_plugin in ('' , 'mysql_native_password' ):
713717 authresp = _auth .scramble_native_password (
714718 self ._password .encode ('latin1' ), self .salt )
719+ elif auth_plugin == 'caching_sha2_password' :
720+ if self ._password :
721+ authresp = _auth .scramble_caching_sha2 (
722+ self ._password .encode ('latin1' ), self .salt
723+ )
724+ # Else: empty password
725+ elif auth_plugin == 'sha256_password' :
726+ if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
727+ authresp = self ._password .encode ('latin1' ) + b'\0 '
728+ elif self ._password :
729+ authresp = b'\1 ' # request public key
730+ else :
731+ authresp = b'\0 ' # empty password
732+
715733 elif auth_plugin in ('' , 'mysql_clear_password' ):
716734 authresp = self ._password .encode ('latin1' ) + b'\0 '
717735
@@ -768,35 +786,174 @@ async def _request_authentication(self):
768786 auth_packet .read_all ()) + b'\0 '
769787 self .write_packet (data )
770788 await self ._read_packet ()
789+ elif auth_packet .is_extra_auth_data ():
790+ if auth_plugin == "caching_sha2_password" :
791+ await self .caching_sha2_password_auth (auth_packet )
792+ elif auth_plugin == "sha256_password" :
793+ await self .sha256_password_auth (auth_packet )
794+ else :
795+ raise OperationalError ("Received extra packet "
796+ "for auth method %r" , auth_plugin )
771797
772798 async def _process_auth (self , plugin_name , auth_packet ):
773- if plugin_name == b"mysql_native_password" :
774- # https://dev.mysql.com/doc/internals/en/
775- # secure-password-authentication.html#packet-Authentication::
776- # Native41
777- data = _auth .scramble_native_password (
778- self ._password .encode ('latin1' ),
779- auth_packet .read_all ())
780- elif plugin_name == b"mysql_old_password" :
781- # https://dev.mysql.com/doc/internals/en/
782- # old-password-authentication.html
783- data = _auth .scramble_old_password (self ._password .encode ('latin1' ),
784- auth_packet .read_all ()) + b'\0 '
785- elif plugin_name == b"mysql_clear_password" :
786- # https://dev.mysql.com/doc/internals/en/
787- # clear-text-authentication.html
788- data = self ._password .encode ('latin1' ) + b'\0 '
799+ # These auth plugins do their own packet handling
800+ if plugin_name == b"caching_sha2_password" :
801+ await self .caching_sha2_password_auth (auth_packet )
802+ self ._auth_plugin_used = plugin_name .decode ()
803+ elif plugin_name == b"sha256_password" :
804+ await self .sha256_password_auth (auth_packet )
805+ self ._auth_plugin_used = plugin_name .decode ()
789806 else :
807+
808+ if plugin_name == b"mysql_native_password" :
809+ # https://dev.mysql.com/doc/internals/en/
810+ # secure-password-authentication.html#packet-Authentication::
811+ # Native41
812+ data = _auth .scramble_native_password (
813+ self ._password .encode ('latin1' ),
814+ auth_packet .read_all ())
815+ elif plugin_name == b"mysql_old_password" :
816+ # https://dev.mysql.com/doc/internals/en/
817+ # old-password-authentication.html
818+ data = _auth .scramble_old_password (
819+ self ._password .encode ('latin1' ),
820+ auth_packet .read_all ()
821+ ) + b'\0 '
822+ elif plugin_name == b"mysql_clear_password" :
823+ # https://dev.mysql.com/doc/internals/en/
824+ # clear-text-authentication.html
825+ data = self ._password .encode ('latin1' ) + b'\0 '
826+ else :
827+ raise OperationalError (
828+ 2059 , "Authentication plugin '{0}'"
829+ " not configured" .format (plugin_name )
830+ )
831+
832+ self .write_packet (data )
833+ pkt = await self ._read_packet ()
834+ pkt .check_error ()
835+
836+ self ._auth_plugin_used = plugin_name .decode ()
837+
838+ return pkt
839+
840+ async def caching_sha2_password_auth (self , pkt ):
841+ # No password fast path
842+ if not self ._password :
843+ self .write_packet (b'' )
844+ pkt = await self ._read_packet ()
845+ pkt .check_error ()
846+ return pkt
847+
848+ if pkt .is_auth_switch_request ():
849+ # Try from fast auth
850+ logger .debug ("caching sha2: Trying fast path" )
851+ self .salt = pkt .read_all ()
852+ scrambled = _auth .scramble_caching_sha2 (
853+ self ._password .encode ('latin1' ), self .salt
854+ )
855+
856+ self .write_packet (scrambled )
857+ pkt = await self ._read_packet ()
858+ pkt .check_error ()
859+
860+ # else: fast auth is tried in initial handshake
861+
862+ if not pkt .is_extra_auth_data ():
790863 raise OperationalError (
791- 2059 , "Authentication plugin '%s' not configured" % plugin_name
864+ "caching sha2: Unknown packet "
865+ "for fast auth: {0}" .format (pkt ._data [:1 ])
792866 )
793867
868+ # magic numbers:
869+ # 2 - request public key
870+ # 3 - fast auth succeeded
871+ # 4 - need full auth
872+
873+ pkt .advance (1 )
874+ n = pkt .read_uint8 ()
875+
876+ if n == 3 :
877+ logger .debug ("caching sha2: succeeded by fast path." )
878+ pkt = await self ._read_packet ()
879+ pkt .check_error () # pkt must be OK packet
880+ return pkt
881+
882+ if n != 4 :
883+ raise OperationalError ("caching sha2: Unknown "
884+ "result for fast auth: {0}" .format (n ))
885+
886+ logger .debug ("caching sha2: Trying full auth..." )
887+
888+ if self ._ssl_context :
889+ logger .debug ("caching sha2: Sending plain "
890+ "password via secure connection" )
891+ self .write_packet (self ._password .encode ('latin1' ) + b'\0 ' )
892+ pkt = await self ._read_packet ()
893+ pkt .check_error ()
894+ return pkt
895+
896+ if not self .server_public_key :
897+ self .write_packet (b'\x02 ' )
898+ pkt = await self ._read_packet () # Request public key
899+ pkt .check_error ()
900+
901+ if not pkt .is_extra_auth_data ():
902+ raise OperationalError (
903+ "caching sha2: Unknown packet "
904+ "for public key: {0}" .format (pkt ._data [:1 ])
905+ )
906+
907+ self .server_public_key = pkt ._data [1 :]
908+ logger .debug (self .server_public_key .decode ('ascii' ))
909+
910+ data = _auth .sha2_rsa_encrypt (
911+ self ._password .encode ('latin1' ), self .salt ,
912+ self .server_public_key
913+ )
794914 self .write_packet (data )
795915 pkt = await self ._read_packet ()
796916 pkt .check_error ()
797917
798- self ._auth_plugin_used = plugin_name
918+ async def sha256_password_auth (self , pkt ):
919+ if self ._ssl_context :
920+ logger .debug ("sha256: Sending plain password" )
921+ data = self ._password .encode ('latin1' ) + b'\0 '
922+ self .write_packet (data )
923+ pkt = await self ._read_packet ()
924+ pkt .check_error ()
925+ return pkt
926+
927+ if pkt .is_auth_switch_request ():
928+ self .salt = pkt .read_all ()
929+ if not self .server_public_key and self ._password :
930+ # Request server public key
931+ logger .debug ("sha256: Requesting server public key" )
932+ self .write_packet (b'\1 ' )
933+ pkt = await self ._read_packet ()
934+ pkt .check_error ()
935+
936+ if pkt .is_extra_auth_data ():
937+ self .server_public_key = pkt ._data [1 :]
938+ logger .debug (
939+ "Received public key:\n " ,
940+ self .server_public_key .decode ('ascii' )
941+ )
942+
943+ if self ._password :
944+ if not self .server_public_key :
945+ raise OperationalError ("Couldn't receive server's public key" )
946+
947+ data = _auth .sha2_rsa_encrypt (
948+ self ._password .encode ('latin1' ), self .salt ,
949+ self .server_public_key
950+ )
951+ else :
952+ data = b''
799953
954+ self .write_packet (data )
955+ pkt = await self ._read_packet ()
956+ pkt .check_error ()
800957 return pkt
801958
802959 # _mysql support
0 commit comments