|
| 1 | +from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_AS_STRING, PyByteArray_GET_SIZE, PyByteArray_Resize |
| 2 | +from cpython.bytes cimport PyBytes_FromStringAndSize |
| 3 | + |
| 4 | + |
1 | 5 | cdef _create_transport_context(server_side, server_hostname): |
2 | 6 | if server_side: |
3 | 7 | raise ValueError('Server side SSL needs a valid SSLContext') |
@@ -199,22 +203,12 @@ cdef class SSLProtocol: |
199 | 203 | buffers which are ssl.MemoryBIO objects. |
200 | 204 | """ |
201 | 205 |
|
202 | | - def __cinit__(self, *args, **kwargs): |
203 | | - self._ssl_buffer_len = SSL_READ_MAX_SIZE |
204 | | - self._ssl_buffer = <char*>PyMem_RawMalloc(self._ssl_buffer_len) |
205 | | - if not self._ssl_buffer: |
206 | | - raise MemoryError() |
207 | | - |
208 | | - def __dealloc__(self): |
209 | | - PyMem_RawFree(self._ssl_buffer) |
210 | | - self._ssl_buffer = NULL |
211 | | - self._ssl_buffer_len = 0 |
212 | | - |
213 | 206 | def __init__(self, loop, app_protocol, sslcontext, waiter, |
214 | 207 | server_side=False, server_hostname=None, |
215 | 208 | call_connection_made=True, |
216 | 209 | ssl_handshake_timeout=None, |
217 | 210 | ssl_shutdown_timeout=None): |
| 211 | + |
218 | 212 | if ssl_handshake_timeout is None: |
219 | 213 | ssl_handshake_timeout = SSL_HANDSHAKE_TIMEOUT |
220 | 214 | elif ssl_handshake_timeout <= 0: |
@@ -261,6 +255,11 @@ cdef class SSLProtocol: |
261 | 255 | self._incoming_write = self._incoming.write |
262 | 256 | self._outgoing = ssl_MemoryBIO() |
263 | 257 | self._outgoing_read = self._outgoing.read |
| 258 | + |
| 259 | + self._plain_read_buffer = PyByteArray_FromStringAndSize( |
| 260 | + NULL, SSL_READ_DEFAULT_SIZE) |
| 261 | + self._ssl_read_max_size_obj = SSL_READ_MAX_SIZE |
| 262 | + |
264 | 263 | self._state = UNWRAPPED |
265 | 264 | self._conn_lost = 0 # Set when connection_lost called |
266 | 265 | if call_connection_made: |
@@ -291,8 +290,12 @@ cdef class SSLProtocol: |
291 | 290 | self._app_protocol_get_buffer = app_protocol.get_buffer |
292 | 291 | self._app_protocol_buffer_updated = app_protocol.buffer_updated |
293 | 292 | self._app_protocol_is_buffer = True |
| 293 | + self._ssl_read_buffer = None |
294 | 294 | else: |
295 | 295 | self._app_protocol_is_buffer = False |
| 296 | + if self._ssl_read_buffer is None: |
| 297 | + self._ssl_read_buffer = PyByteArray_FromStringAndSize( |
| 298 | + NULL, SSL_READ_MAX_SIZE) |
296 | 299 |
|
297 | 300 | cdef _wakeup_waiter(self, exc=None): |
298 | 301 | if self._waiter is None: |
@@ -356,21 +359,24 @@ cdef class SSLProtocol: |
356 | 359 | self._handshake_timeout_handle = None |
357 | 360 |
|
358 | 361 | cdef get_buffer_impl(self, size_t n, char** buf, size_t* buf_size): |
359 | | - cdef size_t want = n |
360 | | - if want > SSL_READ_MAX_SIZE: |
361 | | - want = SSL_READ_MAX_SIZE |
362 | | - if self._ssl_buffer_len < want: |
363 | | - self._ssl_buffer = <char*>PyMem_RawRealloc(self._ssl_buffer, want) |
364 | | - if not self._ssl_buffer: |
365 | | - raise MemoryError() |
366 | | - self._ssl_buffer_len = want |
367 | | - |
368 | | - buf[0] = self._ssl_buffer |
369 | | - buf_size[0] = self._ssl_buffer_len |
| 362 | + cdef Py_ssize_t want = min(<Py_ssize_t>n, SSL_READ_MAX_SIZE) |
| 363 | + |
| 364 | + if len(self._plain_read_buffer) < want: |
| 365 | + PyByteArray_Resize(self._plain_read_buffer, want) |
| 366 | + if self._ssl_read_buffer is not None: |
| 367 | + PyByteArray_Resize(self._ssl_read_buffer, want) |
| 368 | + |
| 369 | + buf[0] = PyByteArray_AS_STRING(self._plain_read_buffer) |
| 370 | + buf_size[0] = PyByteArray_GET_SIZE(self._plain_read_buffer) |
370 | 371 |
|
371 | 372 | cdef buffer_updated_impl(self, size_t nbytes): |
372 | | - self._incoming_write(PyMemoryView_FromMemory( |
373 | | - self._ssl_buffer, nbytes, PyBUF_WRITE)) |
| 373 | + mv = PyMemoryView_FromMemory( |
| 374 | + self._plain_read_buffer, |
| 375 | + nbytes, |
| 376 | + PyBUF_WRITE |
| 377 | + ) |
| 378 | + |
| 379 | + self._incoming_write(mv) |
374 | 380 |
|
375 | 381 | if self._state == DO_HANDSHAKE: |
376 | 382 | self._do_handshake() |
@@ -597,7 +603,7 @@ cdef class SSLProtocol: |
597 | 603 | bint close_notify = False |
598 | 604 | try: |
599 | 605 | while True: |
600 | | - if not self._sslobj_read(SSL_READ_MAX_SIZE): |
| 606 | + if not self._sslobj_read(self._ssl_read_max_size_obj): |
601 | 607 | close_notify = True |
602 | 608 | break |
603 | 609 | except ssl_SSLAgainErrors as exc: |
@@ -787,7 +793,7 @@ cdef class SSLProtocol: |
787 | 793 | PyBUF_WRITE) |
788 | 794 |
|
789 | 795 | last_bytes_read = <Py_ssize_t>self._sslobj_read( |
790 | | - app_buffer_size - total_bytes_read, app_buffer) |
| 796 | + self._ssl_read_max_size_obj, app_buffer) |
791 | 797 | total_bytes_read += last_bytes_read |
792 | 798 |
|
793 | 799 | if last_bytes_read == 0: |
@@ -823,32 +829,36 @@ cdef class SSLProtocol: |
823 | 829 |
|
824 | 830 | cdef _do_read__copied(self): |
825 | 831 | cdef: |
826 | | - list data |
827 | | - bytes first, chunk = b'1' |
828 | | - bint zero = True, one = False |
| 832 | + Py_ssize_t bytes_read = -1 |
| 833 | + list data = None |
| 834 | + bytes first_chunk = None, curr_chunk |
829 | 835 |
|
830 | 836 | try: |
831 | 837 | while (<Py_ssize_t>self._incoming.pending > 0 or |
832 | 838 | <Py_ssize_t>self._sslobj_pending() > 0): |
833 | | - chunk = self._sslobj_read(SSL_READ_MAX_SIZE) |
834 | | - if not chunk: |
| 839 | + bytes_read = self._sslobj_read( |
| 840 | + self._ssl_read_max_size_obj, |
| 841 | + self._ssl_read_buffer) |
| 842 | + if bytes_read == 0: |
835 | 843 | break |
836 | | - if zero: |
837 | | - zero = False |
838 | | - one = True |
839 | | - first = chunk |
840 | | - elif one: |
841 | | - one = False |
842 | | - data = [first, chunk] |
| 844 | + |
| 845 | + curr_chunk = <bytes> PyBytes_FromStringAndSize( |
| 846 | + PyByteArray_AS_STRING(self._ssl_read_buffer), bytes_read) |
| 847 | + |
| 848 | + if first_chunk is None: |
| 849 | + first_chunk = curr_chunk |
| 850 | + elif data is None: |
| 851 | + data = [first_chunk, curr_chunk] |
843 | 852 | else: |
844 | | - data.append(chunk) |
| 853 | + data.append(curr_chunk) |
845 | 854 | except ssl_SSLAgainErrors as exc: |
846 | 855 | pass |
847 | | - if one: |
848 | | - self._app_protocol.data_received(first) |
849 | | - elif not zero: |
| 856 | + |
| 857 | + if data is not None: |
850 | 858 | self._app_protocol.data_received(b''.join(data)) |
851 | | - if not chunk: |
| 859 | + elif first_chunk is not None: |
| 860 | + self._app_protocol.data_received(first_chunk) |
| 861 | + elif bytes_read == 0: |
852 | 862 | # close_notify |
853 | 863 | self._call_eof_received() |
854 | 864 | self._start_shutdown() |
|
0 commit comments