Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions backend/src/xfd_django/xfd_api/api_methods/blocklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,60 @@

# Standard Python Libraries
import ipaddress
import logging

# Third-Party Libraries
from fastapi import HTTPException
from xfd_mini_dl.models import Blocklist

from ..auth import is_global_view_admin

LOGGER = logging.getLogger(__name__)

async def handle_check_ip(ip_address: str, current_user):

async def handle_bulk_check_ips(ip_addresses: list[str], current_user):
"""
Determine if an IP exists in our blocklist table.
Determine if multiple IP's exist within our blocklist table.

Returns:
{ reports: int, attacks: int }
Returns: {
ip_address: {
reports: int,
attacks: int
}
}
"""
try:
# Validate the IP address format
ipaddress.ip_address(ip_address)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid IP address")
attacks = 0
reports = 0
try:
if not is_global_view_admin(current_user):
raise HTTPException(status_code=403, detail="Unauthorized")
record = Blocklist.objects.get(ip=ip_address)
if isinstance(record.attacks, int) and record.attacks > 0:
attacks = record.attacks
if isinstance(record.reports, int) and record.reports > 0:
reports = record.reports
return {
if not is_global_view_admin(current_user):
raise HTTPException(status_code=403, detail="Unauthorized")

# Validate all IPs first
for ip in ip_addresses:
try:
ipaddress.ip_address(ip)
except ValueError:
LOGGER.error("Invalid IP address provided: %s", ip)
raise HTTPException(status_code=400, detail="Invalid IP address")

# Initialize results with defaults for all IPs
results = {str(ip): {"attacks": 0, "reports": 0} for ip in ip_addresses}

# Single query to fetch all matching records
records = Blocklist.objects.filter(ip__in=ip_addresses)
for record in records:
LOGGER.info("Processing blocklist record for IP: %s", record.ip)
attacks = (
record.attacks
if isinstance(record.attacks, int) and record.attacks > 0
else 0
)
reports = (
record.reports
if isinstance(record.reports, int) and record.reports > 0
else 0
)
ip_str = str(ipaddress.ip_interface(record.ip).ip)
results[ip_str] = {
"attacks": attacks,
"reports": reports,
}
except HTTPException as http_exc:
raise http_exc
except Blocklist.DoesNotExist:
return {
"attacks": 0,
"reports": 0,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return results
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def dns_twist_sync_post(sync_body, request: Request, current_user):
description="DNSTwist is a domain name permutation engine.",
last_run=datetime.datetime.now(),
)
orgs_with_dps = json.loads(sync_body.data)
orgs_with_dps = sync_body.data
LOGGER.info("DATA: %s", orgs_with_dps)
for org in orgs_with_dps:
domain_permutations = org.get("domain_permutations", [])
Expand Down Expand Up @@ -86,8 +86,10 @@ async def dns_twist_sync_post(sync_body, request: Request, current_user):
def create_checksum(data):
"""Validate the checksum from an API response."""
try:
# Recompute the checksum
calculated_checksum = hashlib.sha256((SALT + data).encode()).hexdigest()
# Recompute the checksum - serialize the same way the sender does
payload = {"data": data}
serialized = json.dumps(payload, default=str, sort_keys=True)
calculated_checksum = hashlib.sha256((SALT + serialized).encode()).hexdigest()

return calculated_checksum
except Exception as e:
Expand Down
14 changes: 13 additions & 1 deletion backend/src/xfd_django/xfd_api/schema_models/blocklist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
"""Blocklist Schemas."""
# Third-Party Libraries
from pydantic import BaseModel
from pydantic import BaseModel, RootModel


class BlocklistCheckResponse(BaseModel):
"""BlocklistCheckResponse schema."""

attacks: int
reports: int


class BulkBlocklistCheckResponse(RootModel[dict[str, BlocklistCheckResponse]]):
"""BulkBlocklistCheckResponse schema."""

root: dict[str, BlocklistCheckResponse]


class BulkBlocklistCheckRequest(BaseModel):
"""BulkBlocklistCheckRequest schema."""

ip_addresses: list[str]
Loading
Loading