-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathcommon_utils_security_test.py
More file actions
141 lines (130 loc) · 4.48 KB
/
common_utils_security_test.py
File metadata and controls
141 lines (130 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-
"""Security-focused tests for common URL utility helpers."""
import base64
import socket
from unittest import TestCase
from unittest.mock import Mock, patch
from agentscope._utils._common import _get_bytes_from_web_url
class CommonUtilsSecurityTest(TestCase):
"""Test URL fetch hardening against SSRF and unsafe redirects."""
@patch("agentscope._utils._common.requests.get")
def test_reject_literal_loopback_ip(self, mock_get: Mock) -> None:
"""Loopback IP URLs should be blocked before request."""
with self.assertRaises(RuntimeError):
_get_bytes_from_web_url(
"http://127.0.0.1/internal",
max_retries=1,
)
mock_get.assert_not_called()
@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_reject_private_ip_resolution(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Hostnames resolving to private IP addresses should be blocked."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("10.0.0.2", 0),
),
]
with self.assertRaises(RuntimeError):
_get_bytes_from_web_url(
"http://example.internal/resource",
max_retries=1,
)
mock_get.assert_not_called()
@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_allow_public_ip_resolution(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Publicly routable hostnames should be fetched successfully."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", 0),
),
]
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"hello"
mock_response.headers = {}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
result = _get_bytes_from_web_url(
"http://example.com/resource",
max_retries=1,
)
self.assertEqual(result, "hello")
mock_get.assert_called_once_with(
"http://example.com/resource",
allow_redirects=False,
timeout=(5, 10),
)
@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_block_redirect_to_loopback(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Redirect targets should be validated before follow-up requests."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", 0),
),
]
mock_response = Mock()
mock_response.status_code = 302
mock_response.headers = {"Location": "http://127.0.0.1/internal"}
mock_get.return_value = mock_response
with self.assertRaises(RuntimeError):
_get_bytes_from_web_url(
"http://example.com/redirect",
max_retries=1,
)
self.assertEqual(mock_get.call_count, 1)
@patch("agentscope._utils._common.requests.get")
@patch("agentscope._utils._common.socket.getaddrinfo")
def test_binary_content_falls_back_to_base64(
self,
mock_getaddrinfo: Mock,
mock_get: Mock,
) -> None:
"""Non-UTF8 payloads should return base64-encoded content."""
mock_getaddrinfo.return_value = [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", 0),
),
]
payload = b"\xff\x00"
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = payload
mock_response.headers = {}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
result = _get_bytes_from_web_url(
"http://example.com/binary",
max_retries=1,
)
self.assertEqual(result, base64.b64encode(payload).decode("ascii"))