diff --git a/authcaptureproxy/auth_capture_proxy.py b/authcaptureproxy/auth_capture_proxy.py index b38de57..8ae6836 100644 --- a/authcaptureproxy/auth_capture_proxy.py +++ b/authcaptureproxy/auth_capture_proxy.py @@ -3,16 +3,15 @@ import asyncio import logging import re +from json import JSONDecodeError from functools import partial from ssl import SSLContext, create_default_context -from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Text, Tuple, Union import httpx from aiohttp import ( - ClientConnectionError, MultipartReader, MultipartWriter, - TooManyRedirects, hdrs, web, ) @@ -47,8 +46,7 @@ class AuthCaptureProxy: This class relies on tests to be provided to indicate the proxy has completed. At proxy completion all data can be found in self.session, self.data, and self.query. """ - def __init__( - self, + def __init__(self, proxy_url: URL, host_url: URL, session: Optional[httpx.AsyncClient] = None, @@ -58,7 +56,10 @@ def __init__( """Initialize proxy object. Args: - proxy_url (URL): url for proxy location. e.g., http://192.168.1.1/. If there is any path, the path is considered part of the base url. If no explicit port is specified, a random port will be generated. If https is passed in, ssl_context must be provided at start_proxy() or the url will be downgraded to http. + proxy_url (URL): url for proxy location. e.g., http://192.168.1.1/. + If there is any path, the path is considered part of the base url. + If no explicit port is specified, a random port will be generated. + If https is passed in, ssl_context must be provided at start_proxy() or the url will be downgraded to http. host_url (URL): original url for login, e.g., http://amazon.com session (httpx.AsyncClient): httpx client to make queries. Optional session_factory (lambda: httpx.AsyncClient): factory to create the aforementioned httpx client if having one fixed session is insufficient. @@ -102,6 +103,7 @@ def __init__( self.redirect_filters: Dict[Text, List[Text]] = { "url": [] } # dictionary of lists of regex strings to filter against + self._background_tasks: Set[asyncio.Task] = set() @property def active(self) -> bool: @@ -146,7 +148,7 @@ def tests(self, value: Dict[Text, Callable]) -> None: def modifiers(self) -> Dict[Text, Union[Callable, Dict[Text, Callable]]]: """Return modifiers setting. - :setter: value (Dict[Text, Dict[Text, Callable]): A nested dictionary of modifiers. The key shoud be a MIME type and the value should be a dictionary of modifiers for that MIME type where the key should be the name of the modifier and the value should be a function or couroutine that takes a string and returns a modified string. If parameters are necessary, functools.partial should be used. See :mod:`authcaptureproxy.examples.modifiers` for examples. + :setter: value (Dict[Text, Dict[Text, Callable]): A nested dictionary of modifiers. The key should be a MIME type and the value should be a dictionary of modifiers for that MIME type where the key should be the name of the modifier and the value should be a function or coroutine that takes a string and returns a modified string. If parameters are necessary, functools.partial should be used. See :mod:`authcaptureproxy.examples.modifiers` for examples. """ return self._modifiers @@ -284,7 +286,7 @@ async def _build_response( async def all_handler(self, request: web.Request, **kwargs) -> web.Response: """Handle all requests. - This handler will exit on succesful test found in self.tests or if a /stop url is seen. This handler can be used with any aiohttp webserver and disabled after registered using self.all_handler_active. + This handler will exit on successful test found in self.tests or if a /stop url is seen. This handler can be used with any aiohttp webserver and disabled after registered using self.all_haandler_active. Args request (web.Request): The request to process @@ -324,16 +326,27 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - break if isinstance(part, MultipartReader): await _process_multipart(part, writer) - elif part.headers.get("hdrs.CONTENT_TYPE"): - if part.headers[hdrs.CONTENT_TYPE] == "application/json": - part_data: Optional[ - Union[Text, Dict[Text, Any], List[Tuple[Text, Text]], bytes] - ] = await part.json() - writer.append_json(part_data) - elif part.headers[hdrs.CONTENT_TYPE].startswith("text"): + elif hdrs.CONTENT_TYPE in part.headers: + content_type = part.headers.get(hdrs.CONTENT_TYPE, "") + mime_type = content_type.split(";", 1)[0].strip() + if mime_type == "application/json": + try: + part_data: Optional[ + Union[Text, Dict[Text, Any], List[Tuple[Text, Text]], bytes] + ] = await part.json() + writer.append_json(part_data) + except Exception: + # Best-effort fallback: text, then bytes + try: + part_text = await part.text() + writer.append(part_text) + except Exception: + part_data = await part.read() + writer.append(part_data) + elif mime_type.startswith("text"): part_data = await part.text() writer.append(part_data) - elif part.headers[hdrs.CONTENT_TYPE] == "application/www-urlform-encode": + elif mime_type == "application/x-www-form-urlencoded": part_data = await part.form() writer.append_form(part_data) else: @@ -390,8 +403,15 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - else: data = convert_multidict_to_dict(await request.post()) json_data = None - if request.has_body: - json_data = await request.json() + # Only attempt JSON decoding for JSON requests; avoid raising for form posts. + if request.has_body and ( + request.content_type == "application/json" + or request.content_type.endswith("+json") + ): + try: + json_data = await request.json() + except (JSONDecodeError, ValueError): + json_data = None if data: self.data.update(data) _LOGGER.debug("Storing data %s", data) @@ -403,7 +423,9 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - ): self.all_handler_active = False if self.active: - asyncio.create_task(self.stop_proxy(3)) + task = asyncio.create_task(self.stop_proxy(3)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) return await self._build_response(text="Proxy stopped.") elif ( URL(str(request.url)).path @@ -434,41 +456,43 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - if skip_auto_headers: _LOGGER.debug("Discovered skip_auto_headers %s", skip_auto_headers) headers.pop(SKIP_AUTO_HEADERS) + # Avoid accidental header mutation across branches/calls + req_headers: dict[str, Any] = dict(headers) _LOGGER.debug( "Attempting %s to %s\nheaders: %s \ncookies: %s", method, site, - headers, + req_headers, self.session.cookies.jar, ) try: if mpwriter: resp = await getattr(self.session, method)( - site, data=mpwriter, headers=headers, follow_redirects=True + site, data=mpwriter, headers=req_headers, follow_redirects=True ) elif data: resp = await getattr(self.session, method)( - site, data=data, headers=headers, follow_redirects=True + site, data=data, headers=req_headers, follow_redirects=True ) elif json_data: for item in ["Host", "Origin", "User-Agent", "dnt", "Accept-Encoding"]: # remove proxy headers - if headers.get(item): - headers.pop(item) + if req_headers.get(item): + req_headers.pop(item) resp = await getattr(self.session, method)( - site, json=json_data, headers=headers, follow_redirects=True + site, json=json_data, headers=req_headers, follow_redirects=True ) else: resp = await getattr(self.session, method)( - site, headers=headers, follow_redirects=True + site, headers=req_headers, follow_redirects=True ) - except ClientConnectionError as ex: + except httpx.ConnectError as ex: return await self._build_response( text=f"Error connecting to {site}; please retry: {ex}" ) - except TooManyRedirects as ex: + except httpx.TooManyRedirects as ex: return await self._build_response( - text=f"Error connecting to {site}; too may redirects: {ex}" + text=f"Error connecting to {site}; too many redirects: {ex}" ) except httpx.TimeoutException as ex: _LOGGER.warning( @@ -484,6 +508,10 @@ async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) - "and that the service endpoint is reachable from this host." ) ) + except httpx.HTTPError as ex: + return await self._build_response( + text=f"Error connecting to {site}: {ex}" + ) if resp is None: return await self._build_response(text=f"Error connecting to {site}; please retry") self.last_resp = resp @@ -621,8 +649,7 @@ def _swap_proxy_and_host(self, text: Text, domain_only: bool = False) -> Text: """ host_string: Text = str(self._host_url.with_path("/")) proxy_string: Text = str( - self.access_url() if not domain_only else self.access_url().with_path("/") - ) + self.access_url() if not domain_only else self.access_url().with_path("/")) if str(self.access_url().with_path("/")).replace("https", "http") in text: _LOGGER.debug( "Replacing %s with %s", diff --git a/tests/test_regression_headers_and_json_parsing.py b/tests/test_regression_headers_and_json_parsing.py new file mode 100644 index 0000000..de4c213 --- /dev/null +++ b/tests/test_regression_headers_and_json_parsing.py @@ -0,0 +1,272 @@ +import asyncio + +from typing import Any +import pytest +import httpx + +from aiohttp.streams import StreamReader +from aiohttp.test_utils import make_mocked_request +from multidict import CIMultiDict +from yarl import URL + + +class DummyAsyncClient: + """Capture outbound requests without real network I/O.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + # match attribute access used in logging + self.cookies = type("Cookies", (), {"jar": {}})() + + async def aclose(self) -> None: + return + + async def post(self, url: str, **kwargs): + self.calls.append( + { + "method": "POST", + "url": url, + "headers": dict(kwargs.get("headers") or {}), + "json": kwargs.get("json"), + "data": kwargs.get("data"), + } + ) + req = httpx.Request("POST", url) + return httpx.Response( + 200, request=req, text="ok", headers={"Content-Type": "text/plain"} + ) + + +async def _make_request( + *, + method: str, + path: str, + content_type: str, + headers=None, + body: bytes = b"", +): + """ + Build a mocked aiohttp Request with a real StreamReader payload. + Request.has_body works (it calls request._payload.at_eof()). + + CI uses aiohttp 3.9.x where StreamReader requires a `limit` argument. + """ + hdrs = CIMultiDict(headers or {}) + hdrs["Content-Type"] = content_type + hdrs.setdefault("Content-Length", str(len(body))) + + loop = asyncio.get_running_loop() + + # aiohttp 3.9: StreamReader(protocol, limit, loop) + # newer aiohttp: signature varies; keep this compatible. + try: + payload = StreamReader(None, 2**16, loop=loop) # type: ignore[arg-type] + except TypeError: + payload = StreamReader(protocol=None, limit=2**16, loop=loop) # type: ignore[arg-type] + + if body: + payload.feed_data(body) + payload.feed_eof() + + return make_mocked_request(method, path, headers=hdrs, payload=payload) + + +@pytest.fixture +def proxy(monkeypatch): + """ + Regression note. + + These tests cover cross-request header contamination caused by in-place mutation + of the headers mapping inside AuthCaptureProxy.all_handler(). + + Specifically, the JSON request path removes proxy-related headers before sending + the upstream request: + + for item in ["Host", "Origin", "User-Agent", "dnt", "Accept-Encoding"]: + if req_headers.get(item): + req_headers.pop(item) + + Prior to the fix, this mutation could occur on a shared headers dict returned + from modify_headers(), leaking into subsequent requests. The fix copies the + headers mapping (req_headers = dict(headers)) before mutation. + + These tests fail on the pre-fix behavior and pass once the copy is introduced. + """ + from authcaptureproxy.auth_capture_proxy import AuthCaptureProxy + + class Proxy(AuthCaptureProxy): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shared_headers = { + "Host": "example.com", + "Origin": "https://example.com", + "User-Agent": "ua", + "dnt": "1", + "Accept-Encoding": "gzip", + "X-Custom": "keep", + } + + async def modify_headers(self, site: URL, request): # type: ignore[override] + # Return the same dict instance every time to expose in-place mutation leaks. + return self.shared_headers + + p = Proxy( + proxy_url=URL("http://127.0.0.1:12345"), + host_url=URL("https://example.com"), + session=DummyAsyncClient(), + ) + + # Keep output quiet and avoid side-effects not relevant to regression + monkeypatch.setattr("authcaptureproxy.auth_capture_proxy.print_resp", lambda *_: None) + + # Keep behavior focused (tests/modifiers are unrelated to the regression) + p._tests = {} + p._modifiers = {} + + return p + + +@pytest.mark.asyncio +async def test_cross_request_header_contamination_across_json_posts(proxy): + # JSON request #1 + req1 = await _make_request( + method="POST", + path="/login", + content_type="application/json", + body=b'{"a": 1}', + ) + + async def _json1(): + return {"a": 1} + + req1.json = _json1 # type: ignore[attr-defined] + await proxy.all_handler(req1) + + # Shared dict must remain intact after request #1 (core regression assertion) + shared = proxy.shared_headers + assert "Host" in shared + assert "Origin" in shared + assert "User-Agent" in shared + assert "dnt" in shared + assert "Accept-Encoding" in shared + assert shared["X-Custom"] == "keep" + + # JSON request #2 + req2 = await _make_request( + method="POST", + path="/login", + content_type="application/json", + body=b'{"b": 2}', + ) + + async def _json2(): + return {"b": 2} + + req2.json = _json2 # type: ignore[attr-defined] + await proxy.all_handler(req2) + + # Both outbound requests must have proxy headers stripped + calls = proxy.session.calls # type: ignore[attr-defined] + assert len(calls) >= 2 + for call in calls[-2:]: + out = call["headers"] + assert "Host" not in out + assert "Origin" not in out + assert "User-Agent" not in out + assert "dnt" not in out + assert "Accept-Encoding" not in out + assert out.get("X-Custom") == "keep" + + +@pytest.mark.asyncio +async def test_cross_request_header_contamination_between_request_types(proxy): + # First JSON request + req_json = await _make_request( + method="POST", + path="/login", + content_type="application/json", + body=b'{"a": 1}', + ) + + async def _json(): + return {"a": 1} + + req_json.json = _json # type: ignore[attr-defined] + await proxy.all_handler(req_json) + + # Then a form post; provide post() to keep it on the form path. + req_form = await _make_request( + method="POST", + path="/login", + content_type="application/x-www-form-urlencoded", + body=b"field=value", + ) + + async def _post(): + return {"field": "value"} + + req_form.post = _post # type: ignore[attr-defined] + await proxy.all_handler(req_form) + + form_out = proxy.session.calls[-1]["headers"] # type: ignore[attr-defined] + assert form_out.get("User-Agent") == "ua" + assert form_out.get("X-Custom") == "keep" + + +@pytest.mark.asyncio +async def test_json_parsing_guards_on_non_json_content(proxy): + req_form = await _make_request( + method="POST", + path="/login", + content_type="application/x-www-form-urlencoded", + body=b"field=value", + ) + + async def _json_raises(): + raise RuntimeError("json() must not be called for form posts") + + async def _post(): + return {"field": "value"} + + req_form.json = _json_raises # type: ignore[attr-defined] + req_form.post = _post # type: ignore[attr-defined] + + await proxy.all_handler(req_form) + + +@pytest.mark.asyncio +async def test_json_parsing_for_json_content_types(proxy): + req_json = await _make_request( + method="POST", + path="/login", + content_type="application/json", + body=b'{"ok": true}', + ) + called = {"count": 0} + + async def _json(): + called["count"] += 1 + return {"ok": True} + + req_json.json = _json # type: ignore[attr-defined] + await proxy.all_handler(req_json) + assert called["count"] == 1 + + +@pytest.mark.asyncio +async def test_json_parsing_for_json_plus_suffix_content_types(proxy): + req_json = await _make_request( + method="POST", + path="/login", + content_type="application/vnd.api+json", + body=b'{"v": 1}', + ) + called = {"count": 0} + + async def _json(): + called["count"] += 1 + return {"v": 1} + + req_json.json = _json # type: ignore[attr-defined] + await proxy.all_handler(req_json) + assert called["count"] == 1