11# test that socket.connect() on a non-blocking socket raises EINPROGRESS
22# and that an immediate write/send/read/recv does the right thing
33
4+ import unittest
45import errno
56import select
67import socket
78import ssl
89
910# only mbedTLS supports non-blocking mode
10- if not hasattr (ssl , "MBEDTLS_VERSION" ):
11- print ("SKIP" )
12- raise SystemExit
11+ ssl_supports_nonblocking = hasattr (ssl , "MBEDTLS_VERSION" )
1312
1413
1514# get the name of an errno error code
@@ -24,34 +23,43 @@ def errno_name(er):
2423# do_connect establishes the socket and wraps it if tls is True.
2524# If handshake is true, the initial connect (and TLS handshake) is
2625# allowed to be performed before returning.
27- def do_connect (peer_addr , tls , handshake ):
26+ def do_connect (self , peer_addr , tls , handshake ):
2827 s = socket .socket ()
2928 s .setblocking (False )
3029 try :
31- # print("Connecting to", peer_addr)
30+ print ("Connecting to" , peer_addr )
3231 s .connect (peer_addr )
32+ self .fail ()
3333 except OSError as er :
3434 print ("connect:" , errno_name (er .errno ))
35+ self .assertEqual (er .errno , errno .EINPROGRESS )
36+
3537 # wrap with ssl/tls if desired
3638 if tls :
39+ print ("wrap socket" )
3740 ssl_context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
38- try :
39- s = ssl_context .wrap_socket (s , do_handshake_on_connect = handshake )
40- print ("wrap ok: True" )
41- except Exception as e :
42- print ("wrap er:" , e )
41+ s = ssl_context .wrap_socket (s , do_handshake_on_connect = handshake )
42+
4343 return s
4444
4545
46- # poll a socket and print out the result
47- def poll (s ):
46+ # poll a socket and check the result
47+ def poll (self , s , expect_writable ):
4848 poller = select .poll ()
4949 poller .register (s )
50- print ("poll: " , poller .poll (0 ))
50+ result = poller .poll (0 )
51+ print ("poll:" , result )
52+ if expect_writable :
53+ self .assertEqual (len (result ), 1 )
54+ self .assertEqual (result [0 ][1 ], select .POLLOUT )
55+ else :
56+ self .assertEqual (result , [])
57+
5158
59+ # do_test runs the test against a specific peer address.
60+ def do_test (self , peer_addr , tls , handshake ):
61+ print ()
5262
53- # test runs the test against a specific peer address.
54- def test (peer_addr , tls , handshake ):
5563 # MicroPython plain and TLS sockets have read/write
5664 hasRW = True
5765
@@ -62,54 +70,66 @@ def test(peer_addr, tls, handshake):
6270 # connect + send
6371 # non-blocking send should raise EAGAIN
6472 if hasSR :
65- s = do_connect (peer_addr , tls , handshake )
66- poll (s )
67- try :
73+ s = do_connect (self , peer_addr , tls , handshake )
74+ poll (self , s , False )
75+ with self . assertRaises ( OSError ) as ctx :
6876 ret = s .send (b"1234" )
69- print ("send ok:" , ret ) # shouldn't get here
70- except OSError as er :
71- print ("send er:" , errno_name (er .errno ))
77+ print ("send error:" , errno_name (ctx .exception .errno ))
78+ self .assertEqual (ctx .exception .errno , errno .EAGAIN )
7279 s .close ()
7380
7481 # connect + write
7582 # non-blocking write should return None
7683 if hasRW :
77- s = do_connect (peer_addr , tls , handshake )
78- poll (s )
84+ s = do_connect (self , peer_addr , tls , handshake )
85+ poll (self , s , tls and handshake )
7986 ret = s .write (b"1234" )
80- print ("write: " , ret )
87+ print ("write:" , ret )
88+ if tls and handshake :
89+ self .assertEqual (ret , 4 )
90+ else :
91+ self .assertIsNone (ret )
8192 s .close ()
8293
8394 # connect + recv
8495 # non-blocking recv should raise EAGAIN
8596 if hasSR :
86- s = do_connect (peer_addr , tls , handshake )
87- poll (s )
88- try :
97+ s = do_connect (self , peer_addr , tls , handshake )
98+ poll (self , s , False )
99+ with self . assertRaises ( OSError ) as ctx :
89100 ret = s .recv (10 )
90- print ("recv ok:" , ret ) # shouldn't get here
91- except OSError as er :
92- print ("recv er:" , errno_name (er .errno ))
101+ print ("recv error:" , errno_name (ctx .exception .errno ))
102+ self .assertEqual (ctx .exception .errno , errno .EAGAIN )
93103 s .close ()
94104
95105 # connect + read
96106 # non-blocking read should return None
97107 if hasRW :
98- s = do_connect (peer_addr , tls , handshake )
99- poll (s )
108+ s = do_connect (self , peer_addr , tls , handshake )
109+ poll (self , s , tls and handshake )
100110 ret = s .read (10 )
101- print ("read: " , ret )
111+ print ("read:" , ret )
112+ self .assertIsNone (ret )
102113 s .close ()
103114
104115
105- if __name__ == "__main__" :
116+ class Test ( unittest . TestCase ) :
106117 # these tests use a non-existent test IP address, this way the connect takes forever and
107118 # we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
108- print ("--- Plain sockets to nowhere ---" )
109- test (socket .getaddrinfo ("192.0.2.1" , 80 )[0 ][- 1 ], False , False )
110- print ("--- SSL sockets to nowhere ---" )
111- test (socket .getaddrinfo ("192.0.2.1" , 443 )[0 ][- 1 ], True , False )
112- print ("--- Plain sockets ---" )
113- test (socket .getaddrinfo ("micropython.org" , 80 )[0 ][- 1 ], False , False )
114- print ("--- SSL sockets ---" )
115- test (socket .getaddrinfo ("micropython.org" , 443 )[0 ][- 1 ], True , True )
119+ def test_plain_sockets_to_nowhere (self ):
120+ do_test (self , socket .getaddrinfo ("192.0.2.1" , 80 )[0 ][- 1 ], False , False )
121+
122+ @unittest .skipIf (not ssl_supports_nonblocking , "SSL doesn't support non-blocking" )
123+ def test_ssl_sockets_to_nowhere (self ):
124+ do_test (self , socket .getaddrinfo ("192.0.2.1" , 443 )[0 ][- 1 ], True , False )
125+
126+ def test_plain_sockets (self ):
127+ do_test (self , socket .getaddrinfo ("micropython.org" , 80 )[0 ][- 1 ], False , False )
128+
129+ @unittest .skipIf (not ssl_supports_nonblocking , "SSL doesn't support non-blocking" )
130+ def test_ssl_sockets (self ):
131+ do_test (self , socket .getaddrinfo ("micropython.org" , 443 )[0 ][- 1 ], True , True )
132+
133+
134+ if __name__ == "__main__" :
135+ unittest .main ()
0 commit comments