Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions libs/core/langchain_core/_security/_ssrf_protection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,21 @@
]

# Cloud provider metadata endpoints
CLOUD_METADATA_RANGES = [
ipaddress.ip_network(
"169.254.0.0/16"
), # IPv4 link-local (used by metadata services)
]

CLOUD_METADATA_IPS = [
"169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud
"169.254.170.2", # AWS ECS task metadata
"169.254.170.23", # AWS EKS Pod Identity Agent
"100.100.100.200", # Alibaba Cloud metadata
"fd00:ec2::254", # AWS EC2 IMDSv2 over IPv6 (Nitro instances)
"fd00:ec2::23", # AWS EKS Pod Identity Agent (IPv6)
"fe80::a9fe:a9fe", # OpenStack Nova metadata (IPv6 link-local equiv of
# 169.254.169.254)
]

CLOUD_METADATA_HOSTNAMES = [
Expand All @@ -68,6 +79,21 @@
]


def _normalize_ip(ip_str: str) -> str:
"""Normalize IP strings for consistent SSRF checks.

Args:
ip_str: IP address as a string.

Returns:
Canonical string form, converting IPv6-mapped IPv4 to plain IPv4.
"""
ip = ipaddress.ip_address(ip_str)
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
return str(ip.ipv4_mapped)
return str(ip)


def is_private_ip(ip_str: str) -> bool:
"""Check if an IP address is in a private range.

Expand All @@ -78,7 +104,7 @@ def is_private_ip(ip_str: str) -> bool:
True if IP is in a private range, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
ip = ipaddress.ip_address(_normalize_ip(ip_str))
return any(ip in range_ for range_ in PRIVATE_IP_RANGES)
except ValueError:
return False
Expand All @@ -99,8 +125,17 @@ def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool:
return True

# Check IP
if ip_str and ip_str in CLOUD_METADATA_IPS: # noqa: SIM103
return True
if ip_str:
try:
normalized_ip = _normalize_ip(ip_str)
if normalized_ip in CLOUD_METADATA_IPS:
return True

ip = ipaddress.ip_address(normalized_ip)
if any(ip in range_ for range_ in CLOUD_METADATA_RANGES):
return True
except ValueError:
pass

return False

Expand All @@ -122,12 +157,13 @@ def is_localhost(hostname: str, ip_str: str | None = None) -> bool:
# Check IP
if ip_str:
try:
ip = ipaddress.ip_address(ip_str)
normalized_ip = _normalize_ip(ip_str)
ip = ipaddress.ip_address(normalized_ip)
# Check if loopback
if ip.is_loopback:
return True
# Also check common localhost IPs
if ip_str in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
if normalized_ip in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
return True
except ValueError:
pass
Expand Down Expand Up @@ -225,20 +261,21 @@ def validate_safe_url(

for result in addr_info:
ip_str: str = result[4][0] # type: ignore[assignment]
normalized_ip = _normalize_ip(ip_str)

# ALWAYS block cloud metadata IPs
if is_cloud_metadata(hostname, ip_str):
msg = f"URL resolves to cloud metadata IP: {ip_str}"
if is_cloud_metadata(hostname, normalized_ip):
msg = f"URL resolves to cloud metadata IP: {normalized_ip}"
raise ValueError(msg)

# Check for localhost IPs
if is_localhost(hostname, ip_str) and not allow_private:
msg = f"URL resolves to localhost IP: {ip_str}"
if is_localhost(hostname, normalized_ip) and not allow_private:
msg = f"URL resolves to localhost IP: {normalized_ip}"
raise ValueError(msg)

# Check for private IPs
if not allow_private and is_private_ip(ip_str):
msg = f"URL resolves to private IP address: {ip_str}"
if not allow_private and is_private_ip(normalized_ip):
msg = f"URL resolves to private IP address: {normalized_ip}"
raise ValueError(msg)

except socket.gaierror as e:
Expand Down
19 changes: 19 additions & 0 deletions libs/core/tests/unit_tests/test_ssrf_protection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,16 @@ def test_is_cloud_metadata_ips(self) -> None:
"""Test cloud metadata IP detection."""
assert is_cloud_metadata("example.com", "169.254.169.254") is True
assert is_cloud_metadata("example.com", "169.254.170.2") is True
assert is_cloud_metadata("example.com", "169.254.170.23") is True
assert is_cloud_metadata("example.com", "100.100.100.200") is True
assert is_cloud_metadata("example.com", "fd00:ec2::254") is True
assert is_cloud_metadata("example.com", "fd00:ec2::23") is True
assert is_cloud_metadata("example.com", "fe80::a9fe:a9fe") is True

def test_is_cloud_metadata_link_local_range(self) -> None:
"""Test that IPv4 link-local is flagged as cloud metadata."""
assert is_cloud_metadata("example.com", "169.254.1.2") is True
assert is_cloud_metadata("example.com", "169.254.255.254") is True

def test_is_cloud_metadata_hostnames(self) -> None:
"""Test cloud metadata hostname detection."""
Expand Down Expand Up @@ -143,6 +152,16 @@ def test_cloud_metadata_always_blocked(self) -> None:
allow_private=True,
)

def test_ipv6_mapped_ipv4_localhost_blocked(self) -> None:
"""Test that IPv6-mapped IPv4 localhost is blocked."""
with pytest.raises(ValueError, match="localhost"):
validate_safe_url("http://[::ffff:127.0.0.1]:8080/webhook")

def test_ipv6_mapped_ipv4_cloud_metadata_blocked(self) -> None:
"""Test that IPv6-mapped IPv4 cloud metadata is blocked."""
with pytest.raises(ValueError, match="metadata"):
validate_safe_url("http://[::ffff:169.254.169.254]/latest/meta-data/")

def test_invalid_scheme_blocked(self) -> None:
"""Test that non-HTTP(S) schemes are blocked."""
with pytest.raises(ValueError, match="scheme"):
Expand Down
Loading