diff --git a/backend/src/xfd_django/xfd_api/api_methods/blocklist.py b/backend/src/xfd_django/xfd_api/api_methods/blocklist.py index b8ceb0d93..798c7d320 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/blocklist.py +++ b/backend/src/xfd_django/xfd_api/api_methods/blocklist.py @@ -2,6 +2,7 @@ # Standard Python Libraries import ipaddress +import logging # Third-Party Libraries from fastapi import HTTPException @@ -9,39 +10,52 @@ from ..auth import is_global_view_admin +LOGGER = logging.getLogger(__name__) -async def handle_check_ip(ip_address: str, current_user): + +async def handle_bulk_check_ips(ip_addresses: list[str], current_user): """ - Determine if an IP exists in our blocklist table. + Determine if multiple IP's exist within our blocklist table. - Returns: - { reports: int, attacks: int } + Returns: { + ip_address: { + reports: int, + attacks: int + } + } """ - try: - # Validate the IP address format - ipaddress.ip_address(ip_address) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid IP address") - attacks = 0 - reports = 0 - try: - if not is_global_view_admin(current_user): - raise HTTPException(status_code=403, detail="Unauthorized") - record = Blocklist.objects.get(ip=ip_address) - if isinstance(record.attacks, int) and record.attacks > 0: - attacks = record.attacks - if isinstance(record.reports, int) and record.reports > 0: - reports = record.reports - return { + if not is_global_view_admin(current_user): + raise HTTPException(status_code=403, detail="Unauthorized") + + # Validate all IPs first + for ip in ip_addresses: + try: + ipaddress.ip_address(ip) + except ValueError: + LOGGER.error("Invalid IP address provided: %s", ip) + raise HTTPException(status_code=400, detail="Invalid IP address") + + # Initialize results with defaults for all IPs + results = {str(ip): {"attacks": 0, "reports": 0} for ip in ip_addresses} + + # Single query to fetch all matching records + records = Blocklist.objects.filter(ip__in=ip_addresses) + for record in records: + LOGGER.info("Processing blocklist record for IP: %s", record.ip) + attacks = ( + record.attacks + if isinstance(record.attacks, int) and record.attacks > 0 + else 0 + ) + reports = ( + record.reports + if isinstance(record.reports, int) and record.reports > 0 + else 0 + ) + ip_str = str(ipaddress.ip_interface(record.ip).ip) + results[ip_str] = { "attacks": attacks, "reports": reports, } - except HTTPException as http_exc: - raise http_exc - except Blocklist.DoesNotExist: - return { - "attacks": 0, - "reports": 0, - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + + return results diff --git a/backend/src/xfd_django/xfd_api/api_methods/dns_twist_sync.py b/backend/src/xfd_django/xfd_api/api_methods/dns_twist_sync.py index 71ce86fa7..d34fa277d 100644 --- a/backend/src/xfd_django/xfd_api/api_methods/dns_twist_sync.py +++ b/backend/src/xfd_django/xfd_api/api_methods/dns_twist_sync.py @@ -43,7 +43,7 @@ async def dns_twist_sync_post(sync_body, request: Request, current_user): description="DNSTwist is a domain name permutation engine.", last_run=datetime.datetime.now(), ) - orgs_with_dps = json.loads(sync_body.data) + orgs_with_dps = sync_body.data LOGGER.info("DATA: %s", orgs_with_dps) for org in orgs_with_dps: domain_permutations = org.get("domain_permutations", []) @@ -86,8 +86,10 @@ async def dns_twist_sync_post(sync_body, request: Request, current_user): def create_checksum(data): """Validate the checksum from an API response.""" try: - # Recompute the checksum - calculated_checksum = hashlib.sha256((SALT + data).encode()).hexdigest() + # Recompute the checksum - serialize the same way the sender does + payload = {"data": data} + serialized = json.dumps(payload, default=str, sort_keys=True) + calculated_checksum = hashlib.sha256((SALT + serialized).encode()).hexdigest() return calculated_checksum except Exception as e: diff --git a/backend/src/xfd_django/xfd_api/schema_models/blocklist.py b/backend/src/xfd_django/xfd_api/schema_models/blocklist.py index a5ca96a37..d7a78109e 100644 --- a/backend/src/xfd_django/xfd_api/schema_models/blocklist.py +++ b/backend/src/xfd_django/xfd_api/schema_models/blocklist.py @@ -1,6 +1,6 @@ """Blocklist Schemas.""" # Third-Party Libraries -from pydantic import BaseModel +from pydantic import BaseModel, RootModel class BlocklistCheckResponse(BaseModel): @@ -8,3 +8,15 @@ class BlocklistCheckResponse(BaseModel): attacks: int reports: int + + +class BulkBlocklistCheckResponse(RootModel[dict[str, BlocklistCheckResponse]]): + """BulkBlocklistCheckResponse schema.""" + + root: dict[str, BlocklistCheckResponse] + + +class BulkBlocklistCheckRequest(BaseModel): + """BulkBlocklistCheckRequest schema.""" + + ip_addresses: list[str] diff --git a/backend/src/xfd_django/xfd_api/tasks/dns_twist.py b/backend/src/xfd_django/xfd_api/tasks/dns_twist.py index 0fc5cc049..861d59fb4 100644 --- a/backend/src/xfd_django/xfd_api/tasks/dns_twist.py +++ b/backend/src/xfd_django/xfd_api/tasks/dns_twist.py @@ -1,6 +1,6 @@ """Use DNS twist to fuzz domain names and cross check with a blacklist.""" # Standard Python Libraries -import contextlib +from concurrent.futures import ThreadPoolExecutor, as_completed import datetime import json import logging @@ -22,59 +22,96 @@ DMZ_API_KEY = os.getenv("DMZ_API_KEY", "local") -# Update this function to use the new homebrew blocklist checking system -def checkBlocklist(dom, data_source, org, perm_list): - """Cross reference the dnstwist results with DShield Blocklist.""" +def _is_invalid_dom(dom): + if "original" in dom.get("fuzzer", ""): + return True + if "dns_a" not in dom: + return True + if str(dom["dns_a"][0]) == "!ServFail": + return True + return False + + +def _check_blocklist(ip, blocklist_results): + if not blocklist_results or not ip or ip not in blocklist_results: + return False, 0, 0 + + result = blocklist_results[ip] + attacks = result.get("attacks", 0) + reports = result.get("reports", 0) + + malicious = attacks > 0 or reports > 0 + return malicious, attacks, reports + + +def _check_ipv6(dom, blocklist_results, attacks, reports): + dom.setdefault("dns_aaaa", [""]) + + ipv6 = str(dom["dns_aaaa"][0]) + if not ipv6 or ipv6 == "!ServFail": + dom["dns_aaaa"] = [""] + return False, attacks, reports + + malicious, v6_attacks, v6_reports = _check_blocklist(ipv6, blocklist_results) + + return ( + malicious, + max(attacks, v6_attacks), + max(reports, v6_reports), + ) + + +def _query_dshield(ip): + try: + result = dshield.ip(ip, return_format=dshield.JSON) + parsed = json.loads(result) + ip_info = parsed.get("ip", {}) + + attacks = int(ip_info.get("attacks") or 0) + feeds = len(ip_info.get("threatfeeds", [])) + + return attacks, feeds + except Exception as exc: + LOGGER.info("Error querying DShield API: %s", exc) + return 0, 0 + + +def _cleanup_dom(dom): + dom.setdefault("ssdeep_score", "") + dom.setdefault("dns_mx", [""]) + dom.setdefault("dns_ns", [""]) + + +def checkDshield(dom, data_source, org, perm_list, blocklist_results=None): + """Check DShield for the given domain and return a domain dict.""" malicious = False - attacks = 0 - reports = 0 - dshield_attacks = 0 - dshield_count = 0 - if "original" in dom["fuzzer"]: - return None, perm_list - elif "dns_a" not in dom: + attacks = reports = 0 + dshield_attacks = dshield_count = 0 + + if _is_invalid_dom(dom): return None, perm_list - else: - if str(dom["dns_a"][0]) == "!ServFail": - return None, perm_list - # Check IP in Blocklist API - check_domain_in_blocklist( - dom, malicious, attacks, reports, dshield_attacks, dshield_count - ) + ipv4 = str(dom["dns_a"][0]) + v4_malicious, attacks, reports = _check_blocklist(ipv4, blocklist_results) + malicious |= v4_malicious - # Check IPv6 - if "dns_aaaa" not in dom: - dom["dns_aaaa"] = [""] - elif str(dom["dns_aaaa"][0]) == "!ServFail": - dom["dns_aaaa"] = [""] - else: - # Check IP in Blocklist API - # To-Do: Update this function to use the new homebrew blocklist checking system - dom["use_check_ipv6"] = True - check_domain_in_blocklist( - dom, - malicious, - attacks, - reports, - dshield_attacks, - dshield_count, - ) + v6_malicious, attacks, reports = _check_ipv6( + dom, blocklist_results, attacks, reports + ) + malicious |= v6_malicious + + if ipv4: + dshield_attacks, dshield_count = _query_dshield(ipv4) + if dshield_attacks > 0 or dshield_count > 0: + malicious = True - # Clean-up other fields - if "ssdeep_score" not in dom: - dom["ssdeep_score"] = "" - if "dns_mx" not in dom: - dom["dns_mx"] = [""] - if "dns_ns" not in dom: - dom["dns_ns"] = [""] + _cleanup_dom(dom) - # Ignore duplicates permutation = dom["domain"] if permutation in perm_list: return None, perm_list - else: - perm_list.append(permutation) + + perm_list.append(permutation) domain_dict = { "organization": org, @@ -93,19 +130,31 @@ def checkBlocklist(dom, data_source, org, perm_list): "dshield_record_count": dshield_count, "dshield_attack_count": dshield_attacks, } + return domain_dict, perm_list -def execute_dnstwist(root_domain, test=0): - """Run dnstwist on each root domain.""" +def execute_dnstwist(root_domain, test=0, threads=2): + """Run dnstwist on each root domain. + + Args: + root_domain: The domain to run dnstwist on + test: If 1, return early without secondary .gov processing + threads: Number of internal threads for dnstwist (default 2 to allow for + concurrent execution at the caller level without overloading) + """ pathtoDict = str(pathlib.Path(__file__).parent.resolve()) + "/data/common_tlds.dict" dnstwist_result = dnstwist.run( registered=True, tld=pathtoDict, format="json", - threads=8, + threads=threads, domain=root_domain, ) + LOGGER.info( + "DNSTwist found %d permutations for %s", len(dnstwist_result), root_domain + ) + LOGGER.info("DNSTwist data: %s", dnstwist_result) if test == 1: return dnstwist_result finalorglist = dnstwist_result + [] @@ -117,7 +166,7 @@ def execute_dnstwist(root_domain, test=0): registered=True, tld=pathtoDict, format="json", - threads=8, + threads=threads, domain=dom["domain"], ) finalorglist += secondlist @@ -157,58 +206,60 @@ def get_data_source(data_source_name: str) -> Optional[str]: return None -def check_domain_in_blocklist( - dom, malicious, attacks, reports, dshield_attacks, dshield_count -): - """Cross reference the dnstwist results with internal and DShield blocklists.""" - dns_key = "dns_aaaa" if dom.get("use_check_ipv6") else "dns_a" +def check_domains_in_blocklist(domains: list) -> dict: + """ + Check multiple domains against the blocklist API in bulk. - try: - ip_address = str(dom[dns_key][0]) - except (KeyError, IndexError): - return malicious, attacks, reports, dshield_attacks, dshield_count + Args: + domains: List of domain objects from dnstwist results - # Query internal blocklist API + Returns: + Dictionary mapping IP addresses to their blocklist results + Format: { + "ip_address": { + "attacks": int, + "reports": int + } + } + """ + # Collect all unique IPs from the domains (both IPv4 and IPv6) + ip_addresses = set() + + for dom in domains: + # Collect IPv4 addresses + if "dns_a" in dom and dom["dns_a"]: + ip = str(dom["dns_a"][0]) + if ip != "!ServFail" and ip.strip(): + ip_addresses.add(ip) + + # Collect IPv6 addresses + if "dns_aaaa" in dom and dom["dns_aaaa"]: + ip = str(dom["dns_aaaa"][0]) + if ip != "!ServFail" and ip.strip(): + ip_addresses.add(ip) + + # Convert set to list for JSON serialization + ip_list = list(ip_addresses) + + if not ip_list: + LOGGER.info("No IP addresses to check in blocklist") + return {} + + # Make bulk API call try: - response = requests.get( - f"{BACKEND_DOMAIN}?ip_address={ip_address}", + response = requests.post( + "http://backend:3000/blocklist/check", + json={"ip_addresses": ip_list}, timeout=60, headers={"Authorization": DMZ_API_KEY}, ) response.raise_for_status() - data = response.json() - LOGGER.info("Blocklist API response: %s", data) - if isinstance(data, dict) and ( - data.get("attacks", 0) > 0 or data.get("reports", 0) > 0 - ): - malicious = True - attacks = int(data.get("attacks", 0)) - reports = int(data.get("reports", 0)) - + results = response.json() + LOGGER.info("Bulk blocklist API response for %d IPs", len(results)) + return results except Exception as e: - # Optionally log the error - LOGGER.info("Error querying internal blocklist API: %s", str(e)) - attacks = 0 - reports = 0 - - # Query DShield API - try: - dshield_result = dshield.ip(ip_address, return_format=dshield.JSON) - parsed = json.loads(dshield_result) - ip_info = parsed.get("ip", {}) - - dshield_attacks = int(ip_info.get("attacks") or 0) - dshield_count = len(ip_info.get("threatfeeds", [])) - - if dshield_attacks > 0 or dshield_count > 0: - malicious = True - - except Exception as e: - LOGGER.info("Error querying DShield API: %s", str(e)) - dshield_attacks = 0 - dshield_count = 0 - - return malicious, attacks, reports, dshield_attacks, dshield_count + LOGGER.error("Error querying bulk blocklist API: %s", str(e)) + return {} def get_org_root_domains(org_id): @@ -234,30 +285,56 @@ def get_orgs() -> list: return [] -def execute_dnstwist_data(domain_dict): - """Insert the domain permutation into the database.""" - try: - DomainPermutations.objects.update_or_create( - suspected_domain_uid=uuid4(), - organization=domain_dict["organization"], - domain_permutation=domain_dict["domain_permutation"], - ipv4=domain_dict["ipv4"], - ipv6=domain_dict["ipv6"], - mail_server=domain_dict["mail_server"], - name_server=domain_dict["name_server"], - fuzzer=domain_dict["fuzzer"], +def bulk_upsert_domain_permutations(domain_dicts): + """Bulk insert/update domain permutations into the database.""" + if not domain_dicts: + return + + # Build model instances from dicts + instances = [ + DomainPermutations( + organization=d["organization"], + domain_permutation=d["domain_permutation"], + ipv4=d["ipv4"], + ipv6=d["ipv6"], + mail_server=d["mail_server"], + name_server=d["name_server"], + fuzzer=d["fuzzer"], date_observed=datetime.datetime.now(datetime.timezone.utc), - date_active=domain_dict["date_active"], - ssdeep_score=domain_dict["ssdeep_score"], - malicious=domain_dict["malicious"], - blocklist_attack_count=domain_dict["blocklist_attack_count"], - blocklist_report_count=domain_dict["blocklist_report_count"], - data_source=domain_dict["data_source"], - dshield_record_count=domain_dict["dshield_record_count"], - dshield_attack_count=domain_dict["dshield_attack_count"], + date_active=d["date_active"], + ssdeep_score=d["ssdeep_score"], + malicious=d["malicious"], + blocklist_attack_count=d["blocklist_attack_count"], + blocklist_report_count=d["blocklist_report_count"], + data_source=d["data_source"], + dshield_record_count=d["dshield_record_count"], + dshield_attack_count=d["dshield_attack_count"], ) - except Exception as e: - LOGGER.error("Error adding domain permutation to data lake: %s", str(e)) + for d in domain_dicts + ] + + DomainPermutations.objects.bulk_create( + instances, + update_conflicts=True, + unique_fields=["organization", "domain_permutation"], + update_fields=[ + "ipv4", + "ipv6", + "mail_server", + "name_server", + "fuzzer", + "date_observed", + "date_active", + "ssdeep_score", + "malicious", + "blocklist_attack_count", + "blocklist_report_count", + "data_source", + "dshield_record_count", + "dshield_attack_count", + ], + ) + LOGGER.info("Bulk upserted %d domain permutations", len(instances)) def process_org(org, orgs_list, data_source, failures): @@ -272,29 +349,69 @@ def process_org(org, orgs_list, data_source, failures): root_dict = get_org_root_domains(org_id) domain_list = [] perm_list = [] + all_domains = [] - for root in root_dict: + # First pass: collect all domains from all root domains using thread pool + def run_dnstwist_for_root(root): + """Run dnstwist for a single root domain.""" root_domain = root.sub_domain LOGGER.info("\tRunning on root domain: %s", root_domain) - with open("dnstwist_output.txt", "w") as f, contextlib.redirect_stdout( - f - ): - finalorglist = execute_dnstwist(root_domain) - # Get subdomain uid - # Check Blocklist - for dom in finalorglist: - LOGGER.info("Checking Blocklist: %s", dom) - domain_dict, perm_list = checkBlocklist( - dom, data_source, org, perm_list - ) - if domain_dict is not None: - domain_list.append(domain_dict) + return execute_dnstwist(root_domain) + + with ThreadPoolExecutor(max_workers=20) as executor: + futures = { + executor.submit(run_dnstwist_for_root, root): root + for root in root_dict + } + for future in as_completed(futures): + try: + finalorglist = future.result() + all_domains.extend(finalorglist) + except Exception as e: + root = futures[future] + LOGGER.error( + "Error running dnstwist for %s: %s", root.sub_domain, str(e) + ) + + # Perform bulk blocklist check for all domains + LOGGER.info( + "Performing bulk blocklist check for %d domains", len(all_domains) + ) + blocklist_results = check_domains_in_blocklist(all_domains) + + # Second pass: process each domain with pre-fetched blocklist results + # Use thread pool for concurrent DShield checks + seen_permutations = set(perm_list) + + def process_domain(dom): + """Process a single domain through DShield check.""" + # Check for duplicates using domain permutation + permutation = dom.get("domain") + if permutation and permutation in seen_permutations: + return None + LOGGER.info("Checking D Shield: %s", dom) + # Pass empty perm_list since we handle deduplication here + domain_dict, _ = checkDshield( + dom, data_source, org, [], blocklist_results + ) + return domain_dict + + with ThreadPoolExecutor(max_workers=20) as executor: + futures = { + executor.submit(process_domain, dom): dom for dom in all_domains + } + for future in as_completed(futures): + try: + domain_dict = future.result() + if domain_dict is not None: + # Track seen permutations to avoid duplicates + seen_permutations.add(domain_dict["domain_permutation"]) + domain_list.append(domain_dict) + except Exception as e: + LOGGER.error("Error processing domain: %s", str(e)) + try: - for domain in domain_list: - execute_dnstwist_data(domain) - LOGGER.info( - "Inserted %s into database", domain["domain_permutation"] - ) + bulk_upsert_domain_permutations(domain_list) except Exception: # TODO: Create custom exceptions. # Issue 265: https://github.com/cisagov/pe-reports/issues/265 diff --git a/backend/src/xfd_django/xfd_api/tasks/dns_twist_sync.py b/backend/src/xfd_django/xfd_api/tasks/dns_twist_sync.py index c5ac234be..35e27a1b1 100644 --- a/backend/src/xfd_django/xfd_api/tasks/dns_twist_sync.py +++ b/backend/src/xfd_django/xfd_api/tasks/dns_twist_sync.py @@ -78,7 +78,8 @@ def main(event): # continue data = chunk["chunk"] - serialized = json.dumps(data, default=str, sort_keys=True) + payload = {"data": data} + serialized = json.dumps(payload, default=str, sort_keys=True) salted_checksum = hashlib.sha256( (SALT + serialized).encode() ).hexdigest() @@ -89,12 +90,13 @@ def main(event): "Authorization": os.getenv("DMZ_API_KEY", ""), } - requests.post( + res = requests.post( f"{os.getenv('DMZ_SYNC_ENDPOINT')}/dns_twist_sync", headers=headers, - json=serialized, + data=serialized, timeout=60, ) + LOGGER.info("Response status code: %s", res.status_code) # response = requests.post("http://backend:3000/dns_twist_sync", headers=headers, json={"data": serialized}) LOGGER.info( "Sent %s domain permutations to sync endpoint", diff --git a/backend/src/xfd_django/xfd_api/tasks/update_blocklist.py b/backend/src/xfd_django/xfd_api/tasks/update_blocklist.py index 8a86480ac..f8ac43bb1 100644 --- a/backend/src/xfd_django/xfd_api/tasks/update_blocklist.py +++ b/backend/src/xfd_django/xfd_api/tasks/update_blocklist.py @@ -1,6 +1,6 @@ """Update the blocklist with the latest data from blocklist.de.""" # Standard Python Libraries -from datetime import timedelta +from concurrent.futures import ThreadPoolExecutor, as_completed import ipaddress import logging @@ -32,13 +32,14 @@ def query_blocklist_api(ip_str): response = requests.get( "http://api.blocklist.de/api.php?ip=" + ip_str, timeout=60, - ).content - response = str(response) + ) + response.raise_for_status() # Raise exception for HTTP errors (including 429 rate limits) + response_text = str(response.content) # LOGGER.info("Queried blocklist API for IP: %s", ip_str) - # LOGGER.info("Blocklist API response: %s", response) + # LOGGER.info("Blocklist API response: %s", response_text) malicious = False - attacks = int(str(response).split("attacks: ")[1].split("<")[0]) - reports = int(str(response).split("reports: ")[1].split("<")[0]) + attacks = int(response_text.split("attacks: ")[1].split("<")[0]) + reports = int(response_text.split("reports: ")[1].split("<")[0]) if attacks > 0 or reports > 0: malicious = True return malicious, attacks, reports @@ -64,49 +65,102 @@ def create_new_blocklist_records(blocklist, created_count): def main(): - """Download blocklist data and query the blocklist API.""" + """Handle the blocklist update process.""" blocklist = download_blocklist_as_dict() if len(blocklist) == 0: LOGGER.warning("No blocklist data downloaded.") return - LOGGER.info("Blocklist downloaded successfully with %d entries.", len(blocklist)) + blocklist_ips = list(blocklist.keys()) blocklist_records = Blocklist.objects.all() - # Prune blocklist records that are not in the downloaded blocklist data - updated_count = 0 - for ip_record in blocklist_records: - ip_str = str(ipaddress.ip_interface(ip_record.ip).ip) - if ip_str in blocklist: - LOGGER.info("Updating blocklist record for IP: %s", ip_str) - # If the IP is in the blocklist, update the record + blocklist_record_ips = [ + str(ipaddress.ip_interface(x.ip).ip) for x in blocklist_records + ] + ips_to_create = list(filter(lambda x: x not in blocklist_record_ips, blocklist_ips)) + ips_to_update = list(filter(lambda x: x in blocklist_ips, blocklist_record_ips)) + ips_to_delete = list(filter(lambda x: x not in blocklist_ips, blocklist_record_ips)) + + # Create new blocklist records with API queries for accurate counts + # First, collect all API data for the IPs using parallel requests + def fetch_ip_data(ip_str): + """Fetch blocklist data for a single IP.""" + try: malicious, attacks, reports = query_blocklist_api(ip_str) - updated = False - if attacks != ip_record.attacks: - # Update the attacks count - ip_record.attacks = attacks - updated = True - if reports != ip_record.reports: - # Update the reports count - ip_record.reports = reports - updated = True - if malicious != ip_record.malicious: - ip_record.malicious = malicious - updated = True - if updated: - ip_record.updated_at = timezone.now() - - ip_record.save() - updated_count += 1 - # Remove the IP from blocklist to improve performance - del blocklist[ip_str] - # Add new blocklist records based on the downloaded data - LOGGER.info("Updated %d blocklist records.", updated_count) - created_count = 0 - create_new_blocklist_records(blocklist, created_count) - LOGGER.info("Created %d new blocklist records.", created_count) - # Delete all records that have not been updated in the last 30 days - threshold_date = timezone.now() - timedelta(days=30) - deleted_count, _ = Blocklist.objects.filter(updated_at__lt=threshold_date).delete() - LOGGER.info("Deleted %d old blocklist records.", deleted_count) + return ip_str, { + "malicious": malicious, + "attacks": attacks, + "reports": reports, + } + except requests.HTTPError as e: + if e.response.status_code == 429: + LOGGER.warning("Rate limit hit for IP %s: %s", ip_str, e) + else: + LOGGER.warning( + "HTTP error querying blocklist API for IP %s (status %d): %s", + ip_str, + e.response.status_code, + e, + ) + return ip_str, None + except Exception as e: + LOGGER.warning("Failed to query blocklist API for IP %s: %s", ip_str, e) + return ip_str, None + + ip_data = {} + total_ips = len(ips_to_create) + completed = 0 + + # Use ThreadPoolExecutor for parallel API requests + max_workers = 50 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_ip = { + executor.submit(fetch_ip_data, ip_str): ip_str for ip_str in ips_to_create + } + + # Process results as they complete + for future in as_completed(future_to_ip): + completed += 1 + ip_str, data = future.result() + if data is not None: + ip_data[ip_str] = data + if completed % max_workers == 0 or completed == total_ips: + LOGGER.info("API query progress: %d/%d completed", completed, total_ips) + + # Batch create all records in a single database transaction + now = timezone.now() + records_to_create = [ + Blocklist( + ip=ip_str, + created_at=now, + updated_at=now, + malicious=data["malicious"], + attacks=data["attacks"], + reports=data["reports"], + ) + for ip_str, data in ip_data.items() + ] + + if records_to_create: + try: + Blocklist.objects.bulk_create(records_to_create) + except Exception as e: + LOGGER.warning("Failed to bulk create blocklist records: %s", e) + + # Update existing records - just update timestamp and malicious flag + try: + updated_count = Blocklist.objects.filter(ip__in=ips_to_update).update( + updated_at=timezone.now(), malicious=True + ) + LOGGER.info("Updated %d existing blocklist records.", updated_count) + except Exception as e: + LOGGER.warning("Failed to update blocklist records: %s", e) + + # Delete records no longer in the blocklist + try: + deleted_count, _ = Blocklist.objects.filter(ip__in=ips_to_delete).delete() + LOGGER.info("Deleted %d blocklist records no longer in source.", deleted_count) + except Exception as e: + LOGGER.warning("Failed to delete old blocklist records: %s", e) def handler(_): diff --git a/backend/src/xfd_django/xfd_api/tests/test_blocklist.py b/backend/src/xfd_django/xfd_api/tests/test_blocklist.py index f5e9c27e5..435e1b941 100644 --- a/backend/src/xfd_django/xfd_api/tests/test_blocklist.py +++ b/backend/src/xfd_django/xfd_api/tests/test_blocklist.py @@ -17,51 +17,58 @@ @pytest.mark.django_db(transaction=True, databases=["default", "mini_data_lake"]) def test_blocklist_check_blocked(): - """Test blocklist check.""" + """Test blocklist check (blocked IP).""" user = User.objects.create( first_name="first", last_name="last", email="{}@crossfeed.cisa.gov".format(secrets.token_hex(4)), user_type=UserType.GLOBAL_ADMIN, ) - random_ip_address = "111.111.111.111" + blocked_ip = "111.111.111.111" + Blocklist.objects.create( - ip=random_ip_address, + ip=blocked_ip, created_at=datetime.now(timezone.utc), reports=1, attacks=1, ) - response = client.get( + response = client.post( "/blocklist/check/", - params={"ip_address": random_ip_address}, + json={"ip_addresses": [blocked_ip]}, headers={"Authorization": "Bearer {}".format(create_jwt_token(user))}, ) assert response.status_code == 200 assert response.json() == { - "attacks": 1, - "reports": 1, + "111.111.111.111": { + "attacks": 1, + "reports": 1, + } } @pytest.mark.django_db(transaction=True, databases=["default", "mini_data_lake"]) def test_blocklist_check_unblocked(): - """Test blocklist check.""" + """Test blocklist check (unblocked IP).""" user = User.objects.create( first_name="first", last_name="last", email="{}@crossfeed.cisa.gov".format(secrets.token_hex(4)), user_type=UserType.GLOBAL_ADMIN, ) - random_ip_address = "222.222.222.222" - response = client.get( + unblocked_ip = "222.222.222.222" + + response = client.post( "/blocklist/check/", - params={"ip_address": random_ip_address}, + json={"ip_addresses": [unblocked_ip]}, headers={"Authorization": "Bearer {}".format(create_jwt_token(user))}, ) + assert response.status_code == 200 assert response.json() == { - "attacks": 0, - "reports": 0, + unblocked_ip: { + "attacks": 0, + "reports": 0, + } } diff --git a/backend/src/xfd_django/xfd_api/views.py b/backend/src/xfd_django/xfd_api/views.py index e2d872a34..c9ace379e 100644 --- a/backend/src/xfd_django/xfd_api/views.py +++ b/backend/src/xfd_django/xfd_api/views.py @@ -35,7 +35,7 @@ from .api_methods import matomo_proxy_handler from .api_methods import notification as notification_methods from .api_methods import organization, proxy, scan, scan_tasks, user -from .api_methods.blocklist import handle_check_ip +from .api_methods.blocklist import handle_bulk_check_ips from .api_methods.cpe import get_cpes_by_id from .api_methods.cve import get_all_cves, get_cves_by_id, get_cves_by_name from .api_methods.dmz_sync import CybersixSyncParams @@ -112,7 +112,10 @@ from .schema_models import scan_tasks as scanTaskSchema from .schema_models import stat_schema from .schema_models.api_key import ApiKey as ApiKeySchema -from .schema_models.blocklist import BlocklistCheckResponse +from .schema_models.blocklist import ( + BulkBlocklistCheckRequest, + BulkBlocklistCheckResponse, +) from .schema_models.cpe import Cpe as CpeSchema from .schema_models.cve import Cve as CveSchema from .schema_models.cve import GetAllCvesResponse @@ -1823,19 +1826,19 @@ async def get_vulnerability_by_source_id_route( # ======================================== -@api_router.get( +@api_router.post( "/blocklist/check", dependencies=[Depends(get_current_active_user)], - response_model=BlocklistCheckResponse, + response_model=BulkBlocklistCheckResponse, tags=["Blocklist"], ) -async def get_blocklist( +async def post_blocklist_bulk_check( request: Request, - ip_address: str = Query(..., description="IP address to check"), + payload: BulkBlocklistCheckRequest, current_user: User = Depends(get_current_active_user), ): - """Determine if IP is on the blocklist.""" - return await handle_check_ip(ip_address, current_user) + """Determine if multiple IPs are on the blocklist.""" + return await handle_bulk_check_ips(payload.ip_addresses, current_user) # ========================================