Skip to content

Commit bcfa569

Browse files
authored
Merge pull request #87 from NatLabRockies/pp/safer_redirect
Safer redirect
2 parents 982ff8e + 598d5ec commit bcfa569

6 files changed

Lines changed: 91 additions & 5 deletions

File tree

elm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
ELM version number
33
"""
44

5-
__version__ = "0.0.43"
5+
__version__ = "0.0.44"

elm/web/file_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def fetch(self, source):
110110
raise
111111
except Exception as e:
112112
msg = ("Encountered error of type %r while fetching document from "
113-
"%s:")
113+
"%s :")
114114
err_type = type(e)
115115
logger.exception(msg, err_type, source)
116116
return HTMLDocument(pages=[])

elm/web/search/google.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
logger = logging.getLogger(__name__)
21+
_SERPER_SEMAPHORE = asyncio.Semaphore(50) # default limit is 50 q per second
2122

2223

2324
class PlaywrightGoogleLinkSearch(PlaywrightSearchEngineLinkSearch):
@@ -309,11 +310,13 @@ async def _search(self, query, num_results=10, raw=False):
309310

310311
payload = {"q": query, "num": num_results}
311312
headers = {"X-API-KEY": self.api_key}
313+
c_kwargs = {"verify": self.verify, "timeout": 120}
312314

313-
async with httpx.AsyncClient(verify=self.verify) as client:
315+
async with httpx.AsyncClient(**c_kwargs) as client, _SERPER_SEMAPHORE:
314316
response = await client.post(self._URL, headers=headers,
315317
json=payload)
316318
response.raise_for_status()
319+
await asyncio.sleep(1)
317320

318321
results = response.json().get("organic", [])
319322
return format_search_results(self._SE_NAME, query, results,

elm/web/search/run.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,9 @@ async def search_all_se(queries, search_engines=_DEFAULT_SE,
404404
ELMInputError
405405
If `search_engines` input is empty.
406406
"""
407+
logger.debug("Running %d queries over these search engines: %r\n"
408+
"Queries:\n\t- %s", len(queries), search_engines,
409+
"\n\t- ".join(queries))
407410
num_urls = num_urls or 3 * len(queries)
408411
if len(search_engines) < 1:
409412
msg = f"Must provide at least one search engine! Got {search_engines=}"
@@ -441,7 +444,7 @@ async def load_docs(sources, file_loader):
441444
docs = await file_loader.fetch_all(*sources)
442445
logger.debug("Loaded %d docs from %d sources", len(docs), len(sources))
443446
docs = [doc for doc in docs if not doc.empty]
444-
if len(docs)== 1:
447+
if len(docs) == 1:
445448
logger.debug("%d doc is not empty", len(docs))
446449
else:
447450
logger.debug("%d docs are not empty", len(docs))
@@ -625,4 +628,4 @@ def _handle_old_ignore_key(url_ignore_substrings, kwargs):
625628
else:
626629
url_ignore_substrings = old_ignore_key
627630

628-
return url_ignore_substrings, kwargs
631+
return url_ignore_substrings, kwargs

elm/web/utilities.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import hashlib
55
import logging
66
import asyncio
7+
import socket
8+
import ipaddress
79
from pathlib import Path
810
from copy import deepcopy
911
from random import randint, choice
1012
from contextlib import asynccontextmanager
13+
from urllib.parse import urlparse, urljoin
1114

1215
import httpx
1316
from slugify import slugify
@@ -81,7 +84,13 @@ async def get_redirected_url(url, **kwargs):
8184
The final URL after following redirects, or the original URL if
8285
no redirects are found or an error occurs.
8386
"""
87+
8488
kwargs["follow_redirects"] = True
89+
event_hooks = kwargs.setdefault("event_hooks", {})
90+
response_hooks = list(event_hooks.get("response") or [])
91+
response_hooks.append(_check_redirect_safety)
92+
event_hooks["response"] = response_hooks
93+
8594
try:
8695
async with httpx.AsyncClient(**kwargs) as client:
8796
response = await client.head(url)
@@ -90,6 +99,53 @@ async def get_redirected_url(url, **kwargs):
9099
return url
91100

92101

102+
async def _check_redirect_safety(response):
103+
"""Validate each redirect target before following it."""
104+
if not response.is_redirect:
105+
return
106+
107+
redirect_url = response.headers.get("location")
108+
if not redirect_url:
109+
return
110+
111+
if not redirect_url.startswith(("http://", "https://")):
112+
redirect_url = urljoin(str(response.url), redirect_url)
113+
114+
if not _is_safe_url(redirect_url):
115+
raise ValueError(f"Redirect target is not allowed: {redirect_url}")
116+
117+
118+
def _is_safe_url(url):
119+
"""Return whether a URL resolves to a globally routable address."""
120+
try:
121+
parsed = urlparse(url)
122+
hostname = parsed.hostname
123+
124+
if not hostname:
125+
return False
126+
127+
try:
128+
ip = ipaddress.ip_address(hostname)
129+
except ValueError:
130+
131+
try:
132+
ip_str = socket.gethostbyname(hostname)
133+
ip = ipaddress.ip_address(ip_str)
134+
except (socket.gaierror, socket.herror):
135+
return False
136+
137+
return ip.is_global and not (
138+
ip.is_private
139+
or ip.is_loopback
140+
or ip.is_link_local
141+
or ip.is_reserved
142+
or ip.is_multicast
143+
or ip.is_unspecified
144+
)
145+
except Exception:
146+
return False
147+
148+
93149
def clean_search_query(query):
94150
"""Check if the first character is a digit and remove it if so.
95151

tests/web/test_web_utilities.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from pathlib import Path
44

55
import pytest
6+
import httpx
67

78
from elm.web.document import HTMLDocument
89
from elm.web.utilities import (
910
clean_search_query,
1011
compute_fn_from_url,
12+
get_redirected_url,
1113
write_url_doc_to_file,
1214
)
1315

@@ -90,5 +92,27 @@ def test_write_url_doc_to_file(tmp_path):
9092
assert out_fp.name == "examplecom20test.txt"
9193

9294

95+
@pytest.mark.asyncio
96+
async def test_get_redirected_url_preserves_response_hooks():
97+
"""Test redirect lookup with an existing response hook."""
98+
99+
seen_urls = []
100+
101+
async def response_hook(response):
102+
seen_urls.append(str(response.url))
103+
104+
def handler(request):
105+
return httpx.Response(200, request=request)
106+
107+
out = await get_redirected_url(
108+
"https://www.example.com",
109+
transport=httpx.MockTransport(handler),
110+
event_hooks={"response": [response_hook]},
111+
)
112+
113+
assert out == "https://www.example.com"
114+
assert seen_urls == ["https://www.example.com"]
115+
116+
93117
if __name__ == "__main__":
94118
pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"])

0 commit comments

Comments
 (0)