Skip to content

Commit f98d032

Browse files
martindemellometa-codesync[bot]
authored andcommitted
Add some stubs to improve analysis for the requests package
Summary: After fixing some false positives, the only genuine import-time effect in `requests` is the CA-bundle SSL preload at `requests.adapters` module scope: ``` _preloaded_ssl_context = create_urllib3_context() _preloaded_ssl_context.load_verify_locations(extract_zipped_paths(DEFAULT_CA_BUNDLE_PATH)) ``` This is idempotent local initialization, which is safe to defer for lazy imports. This diff adds some stubs to let lifeguard fully analyse the code: - Add effects to `stdlib/ssl.pyi` - Add `urllib3/util/ssl_.pyi` with a single annotation for `create_urllib3_context()` - Add `requests/utils.pyi` with a single annotation for `extract_zipped_paths()` NOTE: We fully annotated `stdlib/ssl.pyi` since it was a standard library module with a complete .pyi file, and we eventually want a fully annotated set of stdlib stubs. The other files are targeted fixes for single methods that let us mark `requests` as safe, in keeping with the existing pattern of doing minimal stubs for third party libraries. Reviewed By: brittanyrey Differential Revision: D108188446 fbshipit-source-id: 794fae9d517e0d73d98b84d21c594cbf8d077c66
1 parent 5f4270a commit f98d032

3 files changed

Lines changed: 110 additions & 78 deletions

File tree

resources/stubs/shared/requests/utils.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
def extract_zipped_paths(path: str) -> str: no_effects()
16+
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import ssl
16+
17+
def create_urllib3_context(*args: object, **kwargs: object) -> ssl.SSLContext: no_effects()

resources/stubs/stdlib/ssl.pyi

Lines changed: 91 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,17 @@ if sys.version_info < (3, 12):
9191
do_handshake_on_connect: bool = True,
9292
suppress_ragged_eofs: bool = True,
9393
ciphers: str | None = None,
94-
) -> SSLSocket: ...
94+
) -> SSLSocket: unsafe()
9595

96+
# Builds and configures a fresh, locally-created SSLContext (reading CA files /
97+
# env). No external mutation or network, so safe to defer for lazy imports.
9698
def create_default_context(
9799
purpose: Purpose = ...,
98100
*,
99101
cafile: StrOrBytesPath | None = None,
100102
capath: StrOrBytesPath | None = None,
101103
cadata: str | ReadableBuffer | None = None,
102-
) -> SSLContext: ...
104+
) -> SSLContext: no_effects()
103105

104106
if sys.version_info >= (3, 10):
105107
def _create_unverified_context(
@@ -113,7 +115,7 @@ if sys.version_info >= (3, 10):
113115
cafile: StrOrBytesPath | None = None,
114116
capath: StrOrBytesPath | None = None,
115117
cadata: str | ReadableBuffer | None = None,
116-
) -> SSLContext: ...
118+
) -> SSLContext: no_effects()
117119

118120
else:
119121
def _create_unverified_context(
@@ -127,25 +129,26 @@ else:
127129
cafile: StrOrBytesPath | None = None,
128130
capath: StrOrBytesPath | None = None,
129131
cadata: str | ReadableBuffer | None = None,
130-
) -> SSLContext: ...
132+
) -> SSLContext: no_effects()
131133

132134
_create_default_https_context: Callable[..., SSLContext]
133135

134136
if sys.version_info < (3, 12):
135-
def match_hostname(cert: _PeerCertRetDictType, hostname: str) -> None: ...
137+
def match_hostname(cert: _PeerCertRetDictType, hostname: str) -> None: no_effects()
136138

137-
def cert_time_to_seconds(cert_time: str) -> int: ...
139+
def cert_time_to_seconds(cert_time: str) -> int: no_effects()
138140

141+
# Opens a network connection to the server to fetch its certificate.
139142
if sys.version_info >= (3, 10):
140143
def get_server_certificate(
141144
addr: tuple[str, int], ssl_version: int = ..., ca_certs: str | None = None, timeout: float = ...
142-
) -> str: ...
145+
) -> str: unsafe()
143146

144147
else:
145-
def get_server_certificate(addr: tuple[str, int], ssl_version: int = ..., ca_certs: str | None = None) -> str: ...
148+
def get_server_certificate(addr: tuple[str, int], ssl_version: int = ..., ca_certs: str | None = None) -> str: unsafe()
146149

147-
def DER_cert_to_PEM_cert(der_cert_bytes: ReadableBuffer) -> str: ...
148-
def PEM_cert_to_DER_cert(pem_cert_string: str) -> bytes: ...
150+
def DER_cert_to_PEM_cert(der_cert_bytes: ReadableBuffer) -> str: no_effects()
151+
def PEM_cert_to_DER_cert(pem_cert_string: str) -> bytes: no_effects()
149152

150153
class DefaultVerifyPaths(NamedTuple):
151154
cafile: str
@@ -155,7 +158,7 @@ class DefaultVerifyPaths(NamedTuple):
155158
openssl_capath_env: str
156159
openssl_capath: str
157160

158-
def get_default_verify_paths() -> DefaultVerifyPaths: ...
161+
def get_default_verify_paths() -> DefaultVerifyPaths: no_effects()
159162

160163
class VerifyMode(enum.IntEnum):
161164
CERT_NONE = 0
@@ -318,17 +321,17 @@ class _ASN1ObjectBase(NamedTuple):
318321
oid: str
319322

320323
class _ASN1Object(_ASN1ObjectBase):
321-
def __new__(cls, oid: str) -> Self: ...
324+
def __new__(cls, oid: str) -> Self: no_effects()
322325
@classmethod
323-
def fromnid(cls, nid: int) -> Self: ...
326+
def fromnid(cls, nid: int) -> Self: no_effects()
324327
@classmethod
325-
def fromname(cls, name: str) -> Self: ...
328+
def fromname(cls, name: str) -> Self: no_effects()
326329

327330
class Purpose(_ASN1Object, enum.Enum):
328331
# Normally this class would inherit __new__ from _ASN1Object, but
329332
# because this is an enum, the inherited __new__ is replaced at runtime with
330333
# Enum.__new__.
331-
def __new__(cls, value: object) -> Self: ...
334+
def __new__(cls, value: object) -> Self: no_effects()
332335
SERVER_AUTH = (129, "serverAuth", "TLS Web Server Authentication", "1.3.6.1.5.5.7.3.2") # pyright: ignore[reportCallIssue]
333336
CLIENT_AUTH = (130, "clientAuth", "TLS Web Client Authentication", "1.3.6.1.5.5.7.3.1") # pyright: ignore[reportCallIssue]
334337

@@ -340,53 +343,57 @@ class SSLSocket(socket.socket):
340343
@property
341344
def session_reused(self) -> bool | None: ...
342345
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
343-
def connect(self, addr: socket._Address) -> None: ...
344-
def connect_ex(self, addr: socket._Address) -> int: ...
345-
def recv(self, buflen: int = 1024, flags: int = 0) -> bytes: ...
346-
def recv_into(self, buffer: WriteableBuffer, nbytes: int | None = None, flags: int = 0) -> int: ...
347-
def recvfrom(self, buflen: int = 1024, flags: int = 0) -> tuple[bytes, socket._RetAddress]: ...
346+
# Network reads just retrieve data — their timing is irrelevant for lazy
347+
# imports, so they are safe to defer. Writes / connection setup / teardown
348+
# are externally-observable effects whose timing matters, so they are unsafe.
349+
def connect(self, addr: socket._Address) -> None: unsafe()
350+
def connect_ex(self, addr: socket._Address) -> int: unsafe()
351+
def recv(self, buflen: int = 1024, flags: int = 0) -> bytes: no_effects()
352+
def recv_into(self, buffer: WriteableBuffer, nbytes: int | None = None, flags: int = 0) -> int: no_effects()
353+
def recvfrom(self, buflen: int = 1024, flags: int = 0) -> tuple[bytes, socket._RetAddress]: no_effects()
348354
def recvfrom_into(
349355
self, buffer: WriteableBuffer, nbytes: int | None = None, flags: int = 0
350-
) -> tuple[int, socket._RetAddress]: ...
351-
def send(self, data: ReadableBuffer, flags: int = 0) -> int: ...
352-
def sendall(self, data: ReadableBuffer, flags: int = 0) -> None: ...
356+
) -> tuple[int, socket._RetAddress]: no_effects()
357+
def send(self, data: ReadableBuffer, flags: int = 0) -> int: unsafe()
358+
def sendall(self, data: ReadableBuffer, flags: int = 0) -> None: unsafe()
353359
@overload
354-
def sendto(self, data: ReadableBuffer, flags_or_addr: socket._Address, addr: None = None) -> int: ...
360+
def sendto(self, data: ReadableBuffer, flags_or_addr: socket._Address, addr: None = None) -> int: unsafe()
355361
@overload
356362
def sendto(self, data: ReadableBuffer, flags_or_addr: int, addr: socket._Address) -> int: ...
357-
def shutdown(self, how: int) -> None: ...
358-
def read(self, len: int = 1024, buffer: bytearray | None = None) -> bytes: ...
359-
def write(self, data: ReadableBuffer) -> int: ...
360-
def do_handshake(self, block: bool = False) -> None: ... # block is undocumented
363+
def shutdown(self, how: int) -> None: unsafe()
364+
def read(self, len: int = 1024, buffer: bytearray | None = None) -> bytes: no_effects()
365+
def write(self, data: ReadableBuffer) -> int: unsafe()
366+
def do_handshake(self, block: bool = False) -> None: unsafe() # block is undocumented
367+
# Read accessors for connection/cert state.
361368
@overload
362-
def getpeercert(self, binary_form: Literal[False] = False) -> _PeerCertRetDictType | None: ...
369+
def getpeercert(self, binary_form: Literal[False] = False) -> _PeerCertRetDictType | None: no_effects()
363370
@overload
364371
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: ...
365372
@overload
366373
def getpeercert(self, binary_form: bool) -> _PeerCertRetType: ...
367-
def cipher(self) -> tuple[str, str, int] | None: ...
368-
def shared_ciphers(self) -> list[tuple[str, str, int]] | None: ...
369-
def compression(self) -> str | None: ...
370-
def get_channel_binding(self, cb_type: str = "tls-unique") -> bytes | None: ...
371-
def selected_alpn_protocol(self) -> str | None: ...
374+
def cipher(self) -> tuple[str, str, int] | None: no_effects()
375+
def shared_ciphers(self) -> list[tuple[str, str, int]] | None: no_effects()
376+
def compression(self) -> str | None: no_effects()
377+
def get_channel_binding(self, cb_type: str = "tls-unique") -> bytes | None: no_effects()
378+
def selected_alpn_protocol(self) -> str | None: no_effects()
372379
if sys.version_info >= (3, 10):
373380
@deprecated("Deprecated in 3.10. Use ALPN instead.")
374-
def selected_npn_protocol(self) -> str | None: ...
381+
def selected_npn_protocol(self) -> str | None: no_effects()
375382
else:
376-
def selected_npn_protocol(self) -> str | None: ...
383+
def selected_npn_protocol(self) -> str | None: no_effects()
377384

378-
def accept(self) -> tuple[SSLSocket, socket._RetAddress]: ...
379-
def unwrap(self) -> socket.socket: ...
380-
def version(self) -> str | None: ...
381-
def pending(self) -> int: ...
382-
def verify_client_post_handshake(self) -> None: ...
385+
def accept(self) -> tuple[SSLSocket, socket._RetAddress]: unsafe()
386+
def unwrap(self) -> socket.socket: unsafe()
387+
def version(self) -> str | None: no_effects()
388+
def pending(self) -> int: no_effects()
389+
def verify_client_post_handshake(self) -> None: unsafe()
383390
# These methods always raise `NotImplementedError`:
384391
def recvmsg(self, *args: Never, **kwargs: Never) -> Never: ... # type: ignore[override]
385392
def recvmsg_into(self, *args: Never, **kwargs: Never) -> Never: ... # type: ignore[override]
386393
def sendmsg(self, *args: Never, **kwargs: Never) -> Never: ... # type: ignore[override]
387394
if sys.version_info >= (3, 13):
388-
def get_verified_chain(self) -> list[bytes]: ...
389-
def get_unverified_chain(self) -> list[bytes]: ...
395+
def get_verified_chain(self) -> list[bytes]: no_effects()
396+
def get_unverified_chain(self) -> list[bytes]: no_effects()
390397

391398
class TLSVersion(enum.IntEnum):
392399
MINIMUM_SUPPORTED = -2
@@ -418,36 +425,39 @@ class SSLContext(_SSLContext):
418425
if sys.version_info >= (3, 10):
419426
# Using the default (None) for the `protocol` parameter is deprecated,
420427
# but there isn't a good way of marking that in the stub unless/until PEP 702 is accepted
421-
def __new__(cls, protocol: int | None = None, *args: Any, **kwargs: Any) -> Self: ...
428+
def __new__(cls, protocol: int | None = None, *args: Any, **kwargs: Any) -> Self: no_effects()
422429
else:
423-
def __new__(cls, protocol: int = ..., *args: Any, **kwargs: Any) -> Self: ...
430+
def __new__(cls, protocol: int = ..., *args: Any, **kwargs: Any) -> Self: no_effects()
424431

425-
def load_default_certs(self, purpose: Purpose = ...) -> None: ...
432+
# load_* / set_* configure (mutate) the context object. Safe on a freshly
433+
# constructed / own context; flagged when mutating an imported/shared one.
434+
def load_default_certs(self, purpose: Purpose = ...) -> None: mutation()
426435
def load_verify_locations(
427436
self,
428437
cafile: StrOrBytesPath | None = None,
429438
capath: StrOrBytesPath | None = None,
430439
cadata: str | ReadableBuffer | None = None,
431-
) -> None: ...
440+
) -> None: mutation()
432441
@overload
433-
def get_ca_certs(self, binary_form: Literal[False] = False) -> list[_PeerCertRetDictType]: ...
442+
def get_ca_certs(self, binary_form: Literal[False] = False) -> list[_PeerCertRetDictType]: no_effects()
434443
@overload
435444
def get_ca_certs(self, binary_form: Literal[True]) -> list[bytes]: ...
436445
@overload
437446
def get_ca_certs(self, binary_form: bool = False) -> Any: ...
438-
def get_ciphers(self) -> list[_Cipher]: ...
439-
def set_default_verify_paths(self) -> None: ...
440-
def set_ciphers(self, cipherlist: str, /) -> None: ...
441-
def set_alpn_protocols(self, alpn_protocols: Iterable[str]) -> None: ...
447+
def get_ciphers(self) -> list[_Cipher]: no_effects()
448+
def set_default_verify_paths(self) -> None: mutation()
449+
def set_ciphers(self, cipherlist: str, /) -> None: mutation()
450+
def set_alpn_protocols(self, alpn_protocols: Iterable[str]) -> None: mutation()
442451
if sys.version_info >= (3, 10):
443452
@deprecated("Deprecated in 3.10. Use ALPN instead.")
444-
def set_npn_protocols(self, npn_protocols: Iterable[str]) -> None: ...
453+
def set_npn_protocols(self, npn_protocols: Iterable[str]) -> None: mutation()
445454
else:
446-
def set_npn_protocols(self, npn_protocols: Iterable[str]) -> None: ...
455+
def set_npn_protocols(self, npn_protocols: Iterable[str]) -> None: mutation()
447456

448-
def set_servername_callback(self, server_name_callback: _SrvnmeCbType | None) -> None: ...
449-
def load_dh_params(self, path: str, /) -> None: ...
450-
def set_ecdh_curve(self, name: str, /) -> None: ...
457+
def set_servername_callback(self, server_name_callback: _SrvnmeCbType | None) -> None: mutation()
458+
def load_dh_params(self, path: str, /) -> None: mutation()
459+
def set_ecdh_curve(self, name: str, /) -> None: mutation()
460+
# Creates an SSLSocket; may perform a TLS handshake over the network.
451461
def wrap_socket(
452462
self,
453463
sock: socket.socket,
@@ -456,16 +466,19 @@ class SSLContext(_SSLContext):
456466
suppress_ragged_eofs: bool = True,
457467
server_hostname: str | bytes | None = None,
458468
session: SSLSession | None = None,
459-
) -> SSLSocket: ...
469+
) -> SSLSocket: unsafe()
470+
# Memory-BIO wrapper construction; no socket or network involved.
460471
def wrap_bio(
461472
self,
462473
incoming: MemoryBIO,
463474
outgoing: MemoryBIO,
464475
server_side: bool = False,
465476
server_hostname: str | bytes | None = None,
466477
session: SSLSession | None = None,
467-
) -> SSLObject: ...
478+
) -> SSLObject: no_effects()
468479

480+
# SSLObject operates on in-memory BIO buffers (no socket). read/write/handshake
481+
# mutate the object's internal TLS state; the accessors are pure reads.
469482
class SSLObject:
470483
context: SSLContext
471484
@property
@@ -476,33 +489,33 @@ class SSLObject:
476489
@property
477490
def session_reused(self) -> bool: ...
478491
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
479-
def read(self, len: int = 1024, buffer: bytearray | None = None) -> bytes: ...
480-
def write(self, data: ReadableBuffer) -> int: ...
492+
def read(self, len: int = 1024, buffer: bytearray | None = None) -> bytes: mutation()
493+
def write(self, data: ReadableBuffer) -> int: mutation()
481494
@overload
482-
def getpeercert(self, binary_form: Literal[False] = False) -> _PeerCertRetDictType | None: ...
495+
def getpeercert(self, binary_form: Literal[False] = False) -> _PeerCertRetDictType | None: no_effects()
483496
@overload
484497
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: ...
485498
@overload
486499
def getpeercert(self, binary_form: bool) -> _PeerCertRetType: ...
487-
def selected_alpn_protocol(self) -> str | None: ...
500+
def selected_alpn_protocol(self) -> str | None: no_effects()
488501
if sys.version_info >= (3, 10):
489502
@deprecated("Deprecated in 3.10. Use ALPN instead.")
490-
def selected_npn_protocol(self) -> str | None: ...
503+
def selected_npn_protocol(self) -> str | None: no_effects()
491504
else:
492-
def selected_npn_protocol(self) -> str | None: ...
493-
494-
def cipher(self) -> tuple[str, str, int] | None: ...
495-
def shared_ciphers(self) -> list[tuple[str, str, int]] | None: ...
496-
def compression(self) -> str | None: ...
497-
def pending(self) -> int: ...
498-
def do_handshake(self) -> None: ...
499-
def unwrap(self) -> None: ...
500-
def version(self) -> str | None: ...
501-
def get_channel_binding(self, cb_type: str = "tls-unique") -> bytes | None: ...
502-
def verify_client_post_handshake(self) -> None: ...
505+
def selected_npn_protocol(self) -> str | None: no_effects()
506+
507+
def cipher(self) -> tuple[str, str, int] | None: no_effects()
508+
def shared_ciphers(self) -> list[tuple[str, str, int]] | None: no_effects()
509+
def compression(self) -> str | None: no_effects()
510+
def pending(self) -> int: no_effects()
511+
def do_handshake(self) -> None: mutation()
512+
def unwrap(self) -> None: mutation()
513+
def version(self) -> str | None: no_effects()
514+
def get_channel_binding(self, cb_type: str = "tls-unique") -> bytes | None: no_effects()
515+
def verify_client_post_handshake(self) -> None: mutation()
503516
if sys.version_info >= (3, 13):
504-
def get_verified_chain(self) -> list[bytes]: ...
505-
def get_unverified_chain(self) -> list[bytes]: ...
517+
def get_verified_chain(self) -> list[bytes]: no_effects()
518+
def get_unverified_chain(self) -> list[bytes]: no_effects()
506519

507520
class SSLErrorNumber(enum.IntEnum):
508521
SSL_ERROR_EOF = 8
@@ -525,7 +538,7 @@ SSL_ERROR_WANT_WRITE: SSLErrorNumber # undocumented
525538
SSL_ERROR_WANT_X509_LOOKUP: SSLErrorNumber # undocumented
526539
SSL_ERROR_ZERO_RETURN: SSLErrorNumber # undocumented
527540

528-
def get_protocol_name(protocol_code: int) -> str: ...
541+
def get_protocol_name(protocol_code: int) -> str: no_effects()
529542

530543
PEM_FOOTER: str
531544
PEM_HEADER: str

0 commit comments

Comments
 (0)