Skip to content

Commit 588aace

Browse files
authored
Merge pull request #922 from cisagov/DJ_optimize-asm-sync_WIP
optimize asm_sync to enumerate live ips
2 parents 2be8203 + 42b6154 commit 588aace

File tree

9 files changed

+213
-45
lines changed

9 files changed

+213
-45
lines changed

backend/src/xfd_django/xfd_api/api_methods/sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ async def sync_post(sync_body, request: Request, current_user):
4848

4949
except HTTPException as http_exc:
5050
raise http_exc
51-
except Exception as e:
52-
print(e)
53-
raise HTTPException(status_code=500, detail=str(e))
5451
except SyncError as sync_exc:
5552
raise HTTPException(
5653
status_code=500,
5754
detail=f"SyncError: {sync_exc.message} - {sync_exc.error_message}",
5855
)
56+
except Exception as e:
57+
print(e)
58+
raise HTTPException(status_code=500, detail=str(e))
5959

6060

6161
def process_request(headers, sync_body):

backend/src/xfd_django/xfd_api/helpers/asset_inserts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def create_or_update_ip(create_defaults, update_dict, linked_sub=None):
3939
if not created:
4040
for key, value in update_dict.items():
4141
if key == "origin_cidr":
42+
if ip_object.origin_cidr is None:
43+
ip_object.origin_cidr = value
44+
continue
4245
if value.id == ip_object.origin_cidr.id:
4346
continue
4447
if ip_object.origin_cidr.retired is True:

backend/src/xfd_django/xfd_api/helpers/link_subs_from_ips.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,51 +35,76 @@
3535
)
3636

3737

38-
def process_ips(thread_id, org, cidr, ip_gen):
39-
"""Process ips through WhoisXML and save them to DB."""
38+
def process_ips(thread_id, org, cidr, ip_list):
39+
"""Process IPs through WhoisXML and save them to DB."""
4040
count = 0
4141
failed_ips = []
4242
chunk_start = time.time()
43-
while True:
43+
for ip in ip_list:
44+
count += 1
4445
try:
45-
# Get the next IP from the generator or break if exhausted
46-
ip = str(next(ip_gen))
47-
count += 1
48-
try:
49-
domain_list, failed_ips = search_whois_for_domains(ip, failed_ips)
50-
except Exception as e:
51-
LOGGER.error("Thread %d: Error identifying domains: %s", thread_id, e)
52-
53-
failed_ips.append(ip)
54-
continue
55-
if domain_list:
56-
LOGGER.info("Found %d domains associated with %s", len(domain_list), ip)
57-
save_and_link_ip_and_subdomain(ip, cidr, org, domain_list)
58-
59-
except StopIteration:
60-
# Stop when the generator is exhausted
61-
LOGGER.info(
62-
"Thread %d has completed. Processed %d ips in %s seconds.",
63-
thread_id,
64-
count,
65-
round(time.time() - chunk_start, 2),
46+
domain_list, failed_ips = search_whois_for_domains(ip, failed_ips)
47+
except Exception as e:
48+
LOGGER.error(
49+
"Thread %d: Error identifying domains for %s: %s", thread_id, ip, e
6650
)
67-
if len(failed_ips) > 0:
68-
LOGGER.warning("%d IPs failed to process", len(failed_ips))
69-
break
51+
failed_ips.append(ip)
52+
continue
53+
54+
if domain_list:
55+
LOGGER.info("Found %d domains associated with %s", len(domain_list), ip)
56+
save_and_link_ip_and_subdomain(ip, cidr, org, domain_list)
57+
58+
LOGGER.info(
59+
"Thread %d completed. Processed %d IPs in %.2f seconds.",
60+
thread_id,
61+
count,
62+
time.time() - chunk_start,
63+
)
64+
65+
if failed_ips:
66+
LOGGER.warning(
67+
"%d IPs failed to process in thread %d", len(failed_ips), thread_id
68+
)
69+
70+
71+
def split_into_balanced_chunks(items, num_chunks):
72+
"""Split `items` into `num_chunks` parts, distributing remainder fairly."""
73+
n = len(items)
74+
base_chunk_size = n // num_chunks
75+
remainder = n % num_chunks
76+
77+
chunks = []
78+
start = 0
79+
for i in range(num_chunks):
80+
# Give one extra item to the first `remainder` chunks
81+
end = start + base_chunk_size + (1 if i < remainder else 0)
82+
chunks.append(items[start:end])
83+
start = end
84+
return chunks
7085

7186

7287
def process_cidr(cidr, org):
73-
"""Process a given cidr."""
74-
ip_gen = generate_ips(cidr.network)
88+
"""Process a given CIDR using stored live IPs."""
89+
if not cidr.live_ips:
90+
LOGGER.warning("No live IPs for CIDR: %s", cidr.network)
91+
return
92+
93+
ip_list = list(cidr.live_ips)
94+
if not ip_list:
95+
return
96+
97+
chunks = split_into_balanced_chunks(ip_list, THREAD_COUNT)
7598

7699
threads = []
77-
for i in range(THREAD_COUNT):
78-
thread = threading.Thread(target=process_ips, args=(i + 1, org, cidr, ip_gen))
79-
threads.append(thread)
80-
thread.start()
100+
for i, chunk in enumerate(chunks):
101+
if chunk: # Only create a thread if the chunk has work
102+
thread = threading.Thread(
103+
target=process_ips, args=(i + 1, org, cidr, chunk)
104+
)
105+
threads.append(thread)
106+
thread.start()
81107

82-
# Wait for all threads to complete
83108
for thread in threads:
84109
thread.join()
85110

backend/src/xfd_django/xfd_api/management/commands/syncmdl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ def handle(self, *args, **options):
9292
except Exception as e:
9393
self.stdout.write("Granting privileges failed: {}".format(e))
9494

95+
# 👉 Step 1.5: Enable btree_gist extension
96+
self.stdout.write("Enabling btree_gist extension for GiST indexing...")
97+
try:
98+
with connection.cursor() as cursor:
99+
cursor.execute("CREATE EXTENSION IF NOT EXISTS btree_gist;")
100+
self.stdout.write("btree_gist extension enabled.")
101+
except Exception as e:
102+
self.stdout.write(f"Failed to enable btree_gist extension: {e}")
103+
95104
# Step 2: Synchronize or Reset the Database
96105
self.stdout.write("Synchronizing the MDL database schema...")
97106
if dangerouslyforce:

backend/src/xfd_django/xfd_api/schema_models/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class GenericMessageResponseModel(BaseModel):
147147
cpu="1024",
148148
memory="8192",
149149
description="Enumerate and sync org assets.",
150-
max_concurrent_tasks=1,
150+
max_concurrent_tasks=3,
151151
),
152152
"censys": ScanSchema(
153153
type="fargate",

backend/src/xfd_django/xfd_api/tasks/syncdb_helpers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from xfd_api.helpers.regionStateMap import REGION_STATE_MAP
2929
from xfd_api.models import Domain, Service, Vulnerability
3030
from xfd_api.tasks.es_client import ESClient
31+
from xfd_api.utils.scan_utils.vuln_scanning_sync_utils import ( # fill_cidr_live_ips,
32+
fill_cidr_live_ips_bulk_update,
33+
)
3134
from xfd_mini_dl.models import (
3235
ApiKey,
3336
Cidr,
@@ -668,7 +671,7 @@ def create_cidrs_for_org(org, cidr_list, data_source=None, ips_per_cidr=4):
668671
last_seen_timestamp=timezone.now(),
669672
last_reverse_lookup=timezone.now(),
670673
has_shodan_results=random.choice([True, False]),
671-
current=random.choice([True, False]),
674+
current=random.choice([True, True, True, False]),
672675
conflict_alerts=[],
673676
synced_at=timezone.now(),
674677
)
@@ -741,6 +744,9 @@ def populate_sample_data():
741744
)
742745
sys.stdout.flush()
743746

747+
# fill_cidr_live_ips()
748+
fill_cidr_live_ips_bulk_update()
749+
744750
print("\n✅ Done populating all data.")
745751

746752

@@ -959,6 +965,22 @@ def synchronize(target_app_label=None, using=None):
959965
process_m2m_tables(schema_editor, ordered_models, database)
960966

961967
if target_app_label == "xfd_mini_dl":
968+
print("Ensuring GiST index exists on ip.ip...")
969+
with connections[database].cursor() as cursor:
970+
cursor.execute(
971+
"""
972+
DO $$
973+
BEGIN
974+
IF NOT EXISTS (
975+
SELECT 1 FROM pg_indexes
976+
WHERE tablename = 'ip' AND indexname = 'ip_ip_gist_idx'
977+
) THEN
978+
EXECUTE 'CREATE INDEX ip_ip_gist_idx ON ip USING gist (ip inet_ops)';
979+
END IF;
980+
END
981+
$$;
982+
"""
983+
)
962984
create_domain_view(database)
963985
create_service_view(database)
964986
create_vuln_normal_views(database)
@@ -1039,6 +1061,7 @@ def process_model(
10391061
else:
10401062
print("Creating table for model: {}".format(model.__name__))
10411063
schema_editor.create_model(model)
1064+
10421065
except Exception as e:
10431066
print("Error processing model {}: {}".format(model.__name__, e))
10441067

backend/src/xfd_django/xfd_api/tasks/vulnScanningSync.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939
ScanExecutionError,
4040
SyncError,
4141
)
42-
from xfd_api.utils.scan_utils.vuln_scanning_sync_utils import (
42+
from xfd_api.utils.scan_utils.vuln_scanning_sync_utils import ( # fill_cidr_live_ips,
4343
enforce_latest_flag_port_scan,
4444
fetch_orgs_and_relations,
45+
fill_cidr_live_ips_bulk_update,
4546
get_latest_os_type,
4647
load_test_data,
4748
save_cve_to_datalake,
@@ -138,7 +139,7 @@ def fetch_in_chunks(base_query: str, chunk_size: int = 5000):
138139
offset += chunk_size
139140

140141

141-
def main():
142+
def main(): # pylint: disable=R0915
142143
"""Execute the vulnerability scanning synchronization task."""
143144
LOGGER.info("Started VulnScanningSync scan...")
144145

@@ -150,9 +151,6 @@ def main():
150151
org_id_dict = process_orgs(request_list)
151152
LOGGER.info("Completed saving organizations to the LZ MDL.")
152153

153-
# Process Organizations & Relations
154-
send_organizations_to_dmz()
155-
156154
# Process Vulnerability Scans
157155
LOGGER.info("Started processing vulnerability scans...")
158156
vuln_scans = fetch_from_redshift(
@@ -199,6 +197,12 @@ def main():
199197
create_port_scan_service_summaries()
200198
LOGGER.info("Finished processing port scans")
201199

200+
# fill_cidr_live_ips()
201+
fill_cidr_live_ips_bulk_update()
202+
203+
# Process Organizations & Relations
204+
send_organizations_to_dmz()
205+
202206
# Process Tickets (Chunked)
203207
LOGGER.info("Started processing tickets...")
204208
base_query = (
@@ -343,7 +347,10 @@ def send_csv_to_sync(csv_data, bounds):
343347

344348
try:
345349
response = requests.post(
346-
os.getenv("DMZ_SYNC_ENDPOINT"), json=body, headers=headers, timeout=60
350+
os.getenv("DMZ_SYNC_ENDPOINT") + "/sync",
351+
json=body,
352+
headers=headers,
353+
timeout=60,
347354
)
348355
response.raise_for_status()
349356
LOGGER.info("Successfully sent chunk to sync API")

backend/src/xfd_django/xfd_api/utils/scan_utils/vuln_scanning_sync_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import logging
1212
import os
13+
import time
1314
from typing import Dict
1415
from uuid import uuid1
1516

@@ -18,6 +19,7 @@
1819
from django.db import connections, models, transaction
1920
from django.db.models import Exists, OuterRef, Prefetch
2021
from django.db.utils import IntegrityError
22+
from django.utils import timezone
2123
from xfd_mini_dl.models import (
2224
Cidr,
2325
CidrOrgs,
@@ -584,6 +586,7 @@ def save_cidr_to_mdl(cidr_dict: dict, org: Organization, db_name="mini_data_lake
584586
cidr_obj.start_ip = cidr_dict["start_ip"]
585587
cidr_obj.end_ip = cidr_dict["end_ip"]
586588
cidr_obj.retired = False
589+
cidr_obj.live_ips = cidr_dict["live_ips"]
587590
cidr_obj.save(using=db_name) # Save updates
588591

589592
else:
@@ -592,6 +595,7 @@ def save_cidr_to_mdl(cidr_dict: dict, org: Organization, db_name="mini_data_lake
592595
network=cidr_dict["network"],
593596
start_ip=cidr_dict["start_ip"],
594597
end_ip=cidr_dict["end_ip"],
598+
live_ips=cidr_dict["live_ips"],
595599
retired=False,
596600
)
597601
# cidr_obj.organizations.add(org, through_defaults={})
@@ -690,3 +694,86 @@ def map_severity(severity):
690694
if severity < 9:
691695
return "High"
692696
return "Critical"
697+
698+
699+
def fill_cidr_live_ips():
700+
"""Update live_ips field for all current CIDRs based on recent open PortScans."""
701+
start_time = time.time()
702+
703+
# Define the 90-day threshold
704+
time_threshold = timezone.now() - datetime.timedelta(days=90)
705+
706+
# Get all Cidrs with at least one related CidrOrgs marked as current
707+
current_cidrs = Cidr.objects.filter(cidrorgs__current=True).distinct()
708+
709+
for cidr in current_cidrs:
710+
if not cidr.network:
711+
continue
712+
713+
scans = (
714+
PortScan.objects.filter(
715+
state="open",
716+
time_scanned__gte=time_threshold,
717+
ip__ip__net_contained=cidr.network,
718+
)
719+
.values_list("ip__ip", flat=True)
720+
.distinct()
721+
)
722+
723+
# If live_ips is empty or not set, initialize it as an empty set
724+
current_live_ips = set(cidr.live_ips or [])
725+
726+
# Add the new IPs from the scans to the existing set (no duplicates)
727+
current_live_ips.update(scans)
728+
729+
# Convert all IP objects to strings for JSON serialization
730+
cidr.live_ips = [str(ip.ip) for ip in current_live_ips]
731+
cidr.save()
732+
733+
duration = time.time() - start_time
734+
LOGGER.info("fill_cidr_live_ips completed in %.2f seconds", duration)
735+
736+
737+
def fill_cidr_live_ips_bulk_update():
738+
"""Fill live_ips field in the cidr table based on recent port scans."""
739+
start_time = time.time()
740+
741+
with transaction.atomic(using="mini_data_lake"):
742+
with connections["mini_data_lake"].cursor() as cursor:
743+
cursor.execute(
744+
"""
745+
WITH new_ips AS (
746+
SELECT
747+
cidr.id AS cidr_id,
748+
array_agg(DISTINCT ip.ip) AS new_ip_list
749+
FROM cidr
750+
JOIN cidr_orgs ON cidr_orgs.cidr_id = cidr.id
751+
JOIN port_scan ON port_scan.state = 'open'
752+
AND port_scan.time_scanned >= NOW() - INTERVAL '90 days'
753+
JOIN ip ON port_scan.ip_id = ip.id
754+
WHERE cidr_orgs.current = TRUE
755+
AND cidr.network IS NOT NULL
756+
AND ip.ip << cidr.network
757+
GROUP BY cidr.id
758+
),
759+
merged_ips AS (
760+
SELECT
761+
cidr.id,
762+
ARRAY(
763+
SELECT DISTINCT ip_address::inet
764+
FROM jsonb_array_elements_text(
765+
COALESCE(cidr.live_ips, '[]'::jsonb) || to_jsonb(new_ips.new_ip_list)
766+
) AS ip_address
767+
) AS updated_ips
768+
FROM cidr
769+
JOIN new_ips ON cidr.id = new_ips.cidr_id
770+
)
771+
UPDATE cidr
772+
SET live_ips = to_jsonb(merged_ips.updated_ips)
773+
FROM merged_ips
774+
WHERE cidr.id = merged_ips.id;
775+
"""
776+
)
777+
778+
duration = time.time() - start_time
779+
LOGGER.info("fill_cidr_live_ips_bulk_update completed in %.2f seconds", duration)

0 commit comments

Comments
 (0)