5757 AuthenticationWrongNumberOfArgsError ,
5858 ConnectionError ,
5959 DataError ,
60+ MaxConnectionsError ,
6061 RedisError ,
6162 ResponseError ,
6263 TimeoutError ,
@@ -295,7 +296,14 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
295296
296297 async def connect (self ):
297298 """Connects to the Redis server if not already connected"""
298- await self .connect_check_health (check_health = True )
299+ # try once the socket connect with the handshake, retry the whole
300+ # connect/handshake flow based on retry policy
301+ await self .retry .call_with_retry (
302+ lambda : self .connect_check_health (
303+ check_health = True , retry_socket_connect = False
304+ ),
305+ lambda error : self .disconnect (),
306+ )
299307
300308 async def connect_check_health (
301309 self , check_health : bool = True , retry_socket_connect : bool = True
@@ -805,9 +813,11 @@ def __init__(
805813 ssl_exclude_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
806814 ssl_ca_certs : Optional [str ] = None ,
807815 ssl_ca_data : Optional [str ] = None ,
816+ ssl_ca_path : Optional [str ] = None ,
808817 ssl_check_hostname : bool = True ,
809818 ssl_min_version : Optional [TLSVersion ] = None ,
810819 ssl_ciphers : Optional [str ] = None ,
820+ ssl_password : Optional [str ] = None ,
811821 ** kwargs ,
812822 ):
813823 if not SSL_AVAILABLE :
@@ -821,9 +831,11 @@ def __init__(
821831 exclude_verify_flags = ssl_exclude_verify_flags ,
822832 ca_certs = ssl_ca_certs ,
823833 ca_data = ssl_ca_data ,
834+ ca_path = ssl_ca_path ,
824835 check_hostname = ssl_check_hostname ,
825836 min_version = ssl_min_version ,
826837 ciphers = ssl_ciphers ,
838+ password = ssl_password ,
827839 )
828840 super ().__init__ (** kwargs )
829841
@@ -878,10 +890,12 @@ class RedisSSLContext:
878890 "exclude_verify_flags" ,
879891 "ca_certs" ,
880892 "ca_data" ,
893+ "ca_path" ,
881894 "context" ,
882895 "check_hostname" ,
883896 "min_version" ,
884897 "ciphers" ,
898+ "password" ,
885899 )
886900
887901 def __init__ (
@@ -893,9 +907,11 @@ def __init__(
893907 exclude_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
894908 ca_certs : Optional [str ] = None ,
895909 ca_data : Optional [str ] = None ,
910+ ca_path : Optional [str ] = None ,
896911 check_hostname : bool = False ,
897912 min_version : Optional [TLSVersion ] = None ,
898913 ciphers : Optional [str ] = None ,
914+ password : Optional [str ] = None ,
899915 ):
900916 if not SSL_AVAILABLE :
901917 raise RedisError ("Python wasn't built with SSL support" )
@@ -920,11 +936,13 @@ def __init__(
920936 self .exclude_verify_flags = exclude_verify_flags
921937 self .ca_certs = ca_certs
922938 self .ca_data = ca_data
939+ self .ca_path = ca_path
923940 self .check_hostname = (
924941 check_hostname if self .cert_reqs != ssl .CERT_NONE else False
925942 )
926943 self .min_version = min_version
927944 self .ciphers = ciphers
945+ self .password = password
928946 self .context : Optional [SSLContext ] = None
929947
930948 def get (self ) -> SSLContext :
@@ -938,10 +956,16 @@ def get(self) -> SSLContext:
938956 if self .exclude_verify_flags :
939957 for flag in self .exclude_verify_flags :
940958 context .verify_flags &= ~ flag
941- if self .certfile and self .keyfile :
942- context .load_cert_chain (certfile = self .certfile , keyfile = self .keyfile )
943- if self .ca_certs or self .ca_data :
944- context .load_verify_locations (cafile = self .ca_certs , cadata = self .ca_data )
959+ if self .certfile or self .keyfile :
960+ context .load_cert_chain (
961+ certfile = self .certfile ,
962+ keyfile = self .keyfile ,
963+ password = self .password ,
964+ )
965+ if self .ca_certs or self .ca_data or self .ca_path :
966+ context .load_verify_locations (
967+ cafile = self .ca_certs , capath = self .ca_path , cadata = self .ca_data
968+ )
945969 if self .min_version is not None :
946970 context .minimum_version = self .min_version
947971 if self .ciphers is not None :
@@ -1208,7 +1232,7 @@ def get_available_connection(self):
12081232 connection = self ._available_connections .pop ()
12091233 except IndexError :
12101234 if len (self ._in_use_connections ) >= self .max_connections :
1211- raise ConnectionError ("Too many connections" ) from None
1235+ raise MaxConnectionsError ("Too many connections" ) from None
12121236 connection = self .make_connection ()
12131237 self ._in_use_connections .add (connection )
12141238 return connection
0 commit comments