3535from dataclasses import dataclass
3636from logHandler import log
3737from queue import Queue
38- from typing import Any , Literal , Optional , Self
38+ from typing import Any , Literal , Optional , Self , cast
3939
4040import wx
4141from extensionPoints import Action , HandlerRegistrar
@@ -263,13 +263,16 @@ def __init__(
263263 address : tuple [str , int ],
264264 timeout : int = 0 ,
265265 insecure : bool = False ,
266+ * ,
267+ trustedFingerprint : str | None = None ,
266268 ) -> None :
267269 """Initialize the TCP transport.
268270
269271 :param serializer: Message serializer instance
270272 :param address: Remote address to connect to, as (host, port) tuple
271273 :param timeout: Connection timeout in seconds, defaults to 0
272274 :param insecure: Skip certificate verification, defaults to False
275+ :param trustedFingerprint: Certificate fingerprint to trust this session if connecting insecurely.
273276 """
274277 super ().__init__ (serializer = serializer )
275278 self .closed : bool = False
@@ -304,6 +307,8 @@ def __init__(
304307 self .insecure : bool = insecure
305308 """Whether to skip certificate verification"""
306309
310+ self ._trustedFingerprint = trustedFingerprint
311+
307312 def run (self ) -> None :
308313 """
309314 Establishes a connection to the server and manages the transport lifecycle.
@@ -319,10 +324,9 @@ def run(self) -> None:
319324 thread, and enters the read loop. Upon disconnection, it clears the connected
320325 event, notifies about the transport disconnection, and performs cleanup.
321326
322- Raises:
323- ssl.SSLCertVerificationError: If SSL certificate verification fails and
324- the fingerprint is not trusted.
325- Exception: For any other exceptions during the connection process.
327+ :raises ssl.SSLCertVerificationError: If SSL certificate verification fails and
328+ the fingerprint is not trusted.
329+ :raises Exception: For any other exceptions during the connection process.
326330 """
327331 self .closed = False
328332 try :
@@ -338,6 +342,7 @@ def run(self) -> None:
338342 except Exception :
339343 pass
340344 if self .isFingerprintTrusted (fingerprint ):
345+ self ._trustedFingerprint = fingerprint
341346 self .insecure = True
342347 return self .run ()
343348 self .lastFailFingerprint = fingerprint
@@ -346,6 +351,20 @@ def run(self) -> None:
346351 except Exception :
347352 self .transportConnectionFailed .notify ()
348353 raise
354+ # If connecting without certificate verification and we were given a fingerprint to trust,
355+ # check that the server's certificate matches it.
356+ if (
357+ self .insecure
358+ and self ._trustedFingerprint is not None
359+ # Since this is a client-side socket, if connection was successful it will always return a certificate.
360+ and (fingerprint := self ._derCert2fingerprint (cast (bytes , self .serverSock .getpeercert (True ))))
361+ != self ._trustedFingerprint
362+ ):
363+ self ._disconnect ()
364+ self .lastFailFingerprint = fingerprint
365+ self .transportCertificateAuthenticationFailed .notify ()
366+ self .transportConnectionFailed .notify ()
367+ return
349368 self .onTransportConnected ()
350369 self .startQueueThread ()
351370 self ._readLoop ()
@@ -366,13 +385,18 @@ def isFingerprintTrusted(self, fingerprint: str) -> bool:
366385 and config ["trustedCertificates" ][hostPortToAddress (self .address )] == fingerprint
367386 )
368387
388+ @staticmethod
389+ def _derCert2fingerprint (cert : bytes ) -> str :
390+ """Convert a DER-encoded certificate to a certificate fingerprint."""
391+ return hashlib .sha256 (cert ).hexdigest ().lower ()
392+
369393 def getHostFingerprint (self ) -> str :
370394 tempConnection = self .createOutboundSocket (* self .address , insecure = True )
371395 tempConnection .connect (self .address )
372396 certBin = tempConnection .getpeercert (True )
373397 tempConnection .close ()
374- fingerprint = hashlib . sha256 ( certBin ). hexdigest (). lower ()
375- return fingerprint
398+ # Since this is a client-side socket, if connection was successful it will always return a certificate.
399+ return self . _derCert2fingerprint ( cast ( bytes , certBin ))
376400
377401 def startQueueThread (self ) -> None :
378402 """Start the outbound message queue thread."""
@@ -609,6 +633,8 @@ def __init__(
609633 connectionType : str | None = None ,
610634 protocolVersion : int = PROTOCOL_VERSION ,
611635 insecure : bool = False ,
636+ * ,
637+ trustedFingerprint : str | None = None ,
612638 ) -> None :
613639 """Initialize a new RelayTransport instance.
614640
@@ -619,12 +645,14 @@ def __init__(
619645 :param connectionType: Connection type identifier, defaults to ``None``
620646 :param protocolVersion: Protocol version to use, defaults to :const:`PROTOCOL_VERSION`
621647 :param insecure: Whether to skip certificate verification, defaults to ``False``
648+ :param trustedFingerprint: Certificate fingerprint to trust this session if connecting insecurely.
622649 """
623650 super ().__init__ (
624651 address = address ,
625652 serializer = serializer ,
626653 timeout = timeout ,
627654 insecure = insecure ,
655+ trustedFingerprint = trustedFingerprint ,
628656 )
629657 log .info (f"Connecting to { address } channel { channel } " )
630658 self .channel : str | None = channel
@@ -652,6 +680,7 @@ def create(cls, connectionInfo: ConnectionInfo, serializer: Serializer) -> Self:
652680 channel = connectionInfo .key ,
653681 connectionType = connectionInfo .mode ,
654682 insecure = connectionInfo .insecure ,
683+ trustedFingerprint = connectionInfo .trustedFingerprint ,
655684 )
656685
657686 def onConnected (self ) -> None :
0 commit comments