Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion src/agentscope/_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import asyncio
import base64
import functools
import ipaddress
import inspect
import json
import os
import socket
import tempfile
import types
import typing
import uuid
from datetime import datetime
from typing import Any, Callable, Type, Dict
from urllib.parse import urljoin, urlparse

import numpy as np
import requests
Expand Down Expand Up @@ -170,7 +173,7 @@ def _get_bytes_from_web_url(
"""
for _ in range(max_retries):
try:
response = requests.get(url)
response = _request_url_with_validated_redirects(url)
response.raise_for_status()
return response.content.decode("utf-8")

Expand All @@ -189,6 +192,87 @@ def _get_bytes_from_web_url(
)


def _is_public_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return whether the given IP address is globally routable."""
return ip.is_global


def _validate_external_url(url: str) -> None:
"""Validate URL to prevent fetching local or private network resources."""
parsed = urlparse(url)

if parsed.scheme not in {"http", "https"}:
raise ValueError(
f"Unsupported URL scheme: {parsed.scheme}. "
"Only http/https are allowed.",
)

host = parsed.hostname
if not host:
raise ValueError(f"Invalid URL without hostname: {url}")

# Fast-path for literal IP addresses
try:
ip = ipaddress.ip_address(host)
except ValueError:
ip = None

if ip is not None:
if not _is_public_ip(ip):
raise ValueError(
f"Blocked non-public URL host: {host}",
)
return

if host.lower() == "localhost":
raise ValueError("Blocked localhost URL host.")

try:
addresses = socket.getaddrinfo(host, None)
except socket.gaierror as e:
raise ValueError(f"Failed to resolve URL host: {host}") from e

if not addresses:
raise ValueError(f"Failed to resolve URL host: {host}")

for addr_info in addresses:
resolved_ip = ipaddress.ip_address(addr_info[4][0])
if not _is_public_ip(resolved_ip):
raise ValueError(
f"Blocked non-public URL host {host} "
f"(resolved to {resolved_ip}).",
)


def _request_url_with_validated_redirects(
url: str,
max_redirects: int = 5,
) -> requests.Response:
"""Request URL while validating each redirect target."""
current_url = url
for _ in range(max_redirects + 1):
_validate_external_url(current_url)

response = requests.get(
current_url,
allow_redirects=False,
timeout=(5, 10),
)

if (
response.status_code in {301, 302, 303, 307, 308}
and "Location" in response.headers
):
current_url = urljoin(current_url, response.headers["Location"])
continue

return response

raise RuntimeError(
f"Exceeded maximum redirects ({max_redirects}) for URL `{url}`.",
)


def _save_base64_data(
media_type: str,
base64_data: str,
Expand Down
141 changes: 141 additions & 0 deletions tests/common_utils_security_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
"""Security-focused tests for common URL utility helpers."""
import base64
import socket
from unittest import TestCase
from unittest.mock import Mock, patch

from agentscope._utils._common import _get_bytes_from_web_url


class CommonUtilsSecurityTest(TestCase):
"""Test URL fetch hardening against SSRF and unsafe redirects."""

@patch("agentscope._utils._common.requests.get")
def test_reject_literal_loopback_ip(self, mock_get: Mock) -> None:
"""Loopback IP URLs should be blocked before request."""
with self.assertRaises(RuntimeError):
_get_bytes_from_web_url(
"http://127.0.0.1/internal",
max_retries=1,
)
mock_get.assert_not_called()

@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_reject_private_ip_resolution(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Hostnames resolving to private IP addresses should be blocked."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("10.0.0.2", 0),
),
]
with self.assertRaises(RuntimeError):
_get_bytes_from_web_url(
"http://example.internal/resource",
max_retries=1,
)
mock_get.assert_not_called()

@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_allow_public_ip_resolution(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Publicly routable hostnames should be fetched successfully."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", 0),
),
]
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"hello"
mock_response.headers = {}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response

result = _get_bytes_from_web_url(
"http://example.com/resource",
max_retries=1,
)
self.assertEqual(result, "hello")
mock_get.assert_called_once_with(
"http://example.com/resource",
allow_redirects=False,
timeout=(5, 10),
)

@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_block_redirect_to_loopback(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Redirect targets should be validated before follow-up requests."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", 0),
),
]
mock_response = Mock()
mock_response.status_code = 302
mock_response.headers = {"Location": "http://127.0.0.1/internal"}
mock_get.return_value = mock_response

with self.assertRaises(RuntimeError):
_get_bytes_from_web_url(
"http://example.com/redirect",
max_retries=1,
)
self.assertEqual(mock_get.call_count, 1)

@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_binary_content_falls_back_to_base64(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Non-UTF8 payloads should return base64-encoded content."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", 0),
),
]
payload = b"\xff\x00"
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = payload
mock_response.headers = {}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response

result = _get_bytes_from_web_url(
"http://example.com/binary",
max_retries=1,
)
self.assertEqual(result, base64.b64encode(payload).decode("ascii"))