Skip to content

Commit 855e957

Browse files
authored
Merge pull request #1416 from cisagov/Domain/IP-Filter-Refactor-CRASM-2704
Domain/IP Filter Refactor (CRASM-3594)
2 parents e098dd6 + 28cf509 commit 855e957

File tree

17 files changed

+1387
-332
lines changed

17 files changed

+1387
-332
lines changed

backend/src/xfd_django/xfd_api/api_methods/domain.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import csv
55
import io
66
import logging
7-
from typing import Optional
7+
from typing import Any, Dict, Optional
88

99
# Third-Party Libraries
1010
from django.core.paginator import Paginator
@@ -13,12 +13,15 @@
1313
from django.db.models.fields import GenericIPAddressField
1414
from django.db.models.functions import Cast
1515
from fastapi import HTTPException, status
16-
from xfd_mini_dl.models import Domain, DomainSearchView, Organization, Service
16+
from xfd_mini_dl.models import Domain, DomainSearchView, Organization, Service, UserType
1717

18+
from ..api_methods.organization import escape_special_characters
19+
from ..api_methods.search import is_valid_org, is_valid_region
1820
from ..auth import get_org_memberships, is_global_view_admin
1921
from ..helpers.filter_helpers import apply_domain_filters
2022
from ..helpers.s3_client import S3Client
21-
from ..schema_models.domain import DomainSearch
23+
from ..schema_models.domain import DomainNameSearch, DomainSearch
24+
from ..tasks.es_client import ESClient
2225

2326
LOGGER = logging.getLogger(__name__)
2427

@@ -316,3 +319,99 @@ def export_domains(domain_search: DomainSearch, current_user):
316319
# Log the exception for debugging (optional)
317320
LOGGER.error("Error exporting domains: %s", e)
318321
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
322+
323+
324+
def search_domains_name(search_body: DomainNameSearch, current_user):
325+
"""Handle the logic for searching organizations in Elasticsearch."""
326+
try:
327+
if search_body.regions is not None and len(search_body.regions) > 0:
328+
# Validate regions
329+
for region in search_body.regions:
330+
if not is_valid_region(region, current_user):
331+
raise HTTPException(
332+
status_code=403,
333+
detail="Unauthorized",
334+
)
335+
if search_body.organizations is not None and len(search_body.organizations) > 0:
336+
# Validate organizations
337+
for org in search_body.organizations:
338+
if not is_valid_org(org, current_user):
339+
raise HTTPException(
340+
status_code=403,
341+
detail="Unauthorized",
342+
)
343+
# Check if user is GlobalViewAdmin or has memberships
344+
if not is_global_view_admin(current_user) and not get_org_memberships(
345+
current_user
346+
):
347+
return []
348+
349+
# Initialize Elasticsearch client
350+
client = ESClient()
351+
352+
# Construct the Elasticsearch query
353+
354+
query_body: Dict[str, Any] = {
355+
# "_source": ["id", "name"],
356+
"query": {"bool": {"must": [], "filter": []}},
357+
}
358+
359+
validated_search_field = (
360+
search_body.search_field
361+
if search_body.search_field in ["name", "ip"]
362+
else "name"
363+
)
364+
365+
# Use match_all if searchTerm is empty
366+
if search_body.search_term.strip():
367+
sanitized_search_term = escape_special_characters(search_body.search_term)
368+
query_body["query"]["bool"]["must"].append(
369+
{
370+
"query_string": {
371+
"query": "*{}*".format(sanitized_search_term),
372+
"fields": [validated_search_field],
373+
"fuzziness": "AUTO",
374+
"analyze_wildcard": True,
375+
}
376+
}
377+
)
378+
else:
379+
query_body["query"]["bool"]["must"].append({"match_all": {}})
380+
381+
# Apply region filters if provided
382+
if search_body.regions:
383+
query_body["query"]["bool"]["filter"].append(
384+
{"terms": {"organization.region_id": search_body.regions}}
385+
)
386+
if search_body.organizations:
387+
query_body["query"]["bool"]["filter"].append(
388+
{"terms": {"organization.id.keyword": search_body.organizations}}
389+
)
390+
391+
if current_user.user_type == UserType.STANDARD:
392+
if search_body.regions == [] and search_body.organizations == []:
393+
orgs = get_org_memberships(current_user)
394+
if not orgs:
395+
return []
396+
query_body["query"]["bool"]["filter"].append(
397+
{"terms": {"organization.id.keyword": orgs}}
398+
)
399+
400+
if current_user.user_type == UserType.REGIONAL_ADMIN:
401+
if search_body.regions == [] and search_body.organizations == []:
402+
query_body["query"]["bool"]["filter"].append(
403+
{"terms": {"organization.region_id": [current_user.region_id]}}
404+
)
405+
406+
# Log the query for debugging
407+
LOGGER.debug("Query body: %s", query_body)
408+
409+
# Execute the search
410+
search_results = client.search_domains(query_body)
411+
412+
return {"body": search_results}
413+
except HTTPException as http_exc:
414+
raise http_exc
415+
except Exception as e:
416+
LOGGER.exception("Error occurred while searching organizations: %s", e)
417+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)

backend/src/xfd_django/xfd_api/helpers/link_ips_from_subs.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import socket
66

77
# Third-Party Libraries
8+
from django.conf import settings
89
from django.core.exceptions import ObjectDoesNotExist
910
import dns.resolver
1011
from xfd_api.helpers.asset_inserts import create_or_update_ip
@@ -16,6 +17,8 @@
1617

1718
def get_matching_cidr(ip, org):
1819
"""Return cidr that contains the ip owned by the org."""
20+
if settings.IS_LOCAL:
21+
return Cidr.objects.filter().first()
1922
try:
2023
# Use .get() to find a single CIDR network that contains the IP
2124
matching_cidr = Cidr.objects.get(
@@ -35,6 +38,12 @@ def get_matching_cidr(ip, org):
3538
def resolve_domain(domain, nameservers=None):
3639
"""Identify ips linked to a given domain."""
3740
ip_addresses = set()
41+
42+
# If local add ip address based on string value instead of lookup
43+
if settings.IS_LOCAL:
44+
ip_addresses.add((domain.ip_address, "IPv4"))
45+
return ip_addresses
46+
3847
if not nameservers:
3948
nameservers = ["8.8.8.8"]
4049
# Create a resolver instance and optionally set a custom DNS server
@@ -46,7 +55,7 @@ def resolve_domain(domain, nameservers=None):
4655

4756
try:
4857
# Resolve IPv4 addresses (A records)
49-
ipv4_answers = dns.resolver.resolve(domain, "A")
58+
ipv4_answers = dns.resolver.resolve(domain.sub_domain, "A")
5059
for rdata in ipv4_answers:
5160
ip_addresses.add((rdata.address, "IPv4"))
5261
except dns.resolver.NoAnswer:
@@ -56,7 +65,7 @@ def resolve_domain(domain, nameservers=None):
5665

5766
try:
5867
# Resolve IPv6 addresses (AAAA records)
59-
ipv6_answers = dns.resolver.resolve(domain, "AAAA")
68+
ipv6_answers = dns.resolver.resolve(domain.sub_domain, "AAAA")
6069
for rdata in ipv6_answers:
6170
ip_addresses.add((rdata.address, "IPv6"))
6271
except dns.resolver.NoAnswer:
@@ -74,9 +83,10 @@ def get_ips_and_type_dns(subdomain, org):
7483
for ip_address, version in ip_set:
7584
cidr = get_matching_cidr(ip_address, org)
7685
if cidr:
77-
LOGGER.warning(
78-
"Found matching cidr for %s: %s", str(ip_address), cidr.network
79-
)
86+
if not settings.IS_LOCAL:
87+
LOGGER.warning(
88+
"Found matching cidr for %s: %s", str(ip_address), cidr.network
89+
)
8090
ip_info.append((ip_address, version, cidr))
8191
return ip_info
8292

@@ -121,7 +131,7 @@ def get_ips_and_type_socket(subdomain, org):
121131

122132
def link_ip_from_domain(sub, org):
123133
"""Link IP from domain."""
124-
ips = get_ips_and_type_dns(sub.sub_domain, org)
134+
ips = get_ips_and_type_dns(sub, org)
125135

126136
if not ips:
127137
return 0

backend/src/xfd_django/xfd_api/schema_models/domain.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Standard Python Libraries
33
from datetime import datetime
44
import logging
5-
from typing import Any, List, Optional
5+
from typing import Any, List, Literal, Optional
66
from uuid import UUID
77

88
# Third-Party Libraries
@@ -57,6 +57,15 @@ class Config:
5757
from_attributes = True
5858

5959

60+
class DomainNameSearch(BaseModel):
61+
"""DomainNameSearch Schema."""
62+
63+
regions: Optional[List[str]]
64+
organizations: Optional[List[str]]
65+
search_term: str
66+
search_field: Literal["name", "ip"] = "name"
67+
68+
6069
class DomainSearch(BaseModel):
6170
"""DomainSearch schema."""
6271

backend/src/xfd_django/xfd_api/tasks/helpers/syncdb_helpers/create_sample_data.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
from django.db import IntegrityError, transaction
2222
from django.utils import timezone
2323
from faker import Faker
24+
from xfd_api.helpers.link_ips_from_subs import connect_ips_from_subs
2425
from xfd_api.helpers.regionStateMap import REGION_STATE_MAP, STATE_ABBR_MAP
25-
from xfd_api.models import Domain, Service, Vulnerability
26+
from xfd_api.models import Service, Vulnerability
2627
from xfd_api.schema_models.scan import SCAN_SCHEMA
2728
from xfd_api.tasks.refresh_material_views import handler as refresh_materialized_views
2829
from xfd_api.tasks.refresh_vs_summaries import handler as refresh_vs_summaries
@@ -33,6 +34,7 @@
3334
CidrOrgs,
3435
Cve,
3536
CveSsvc,
37+
DataSource,
3638
Host,
3739
Ip,
3840
LatestPortScan,
@@ -42,6 +44,7 @@
4244
Role,
4345
Scan,
4446
ScanResult,
47+
SubDomains,
4548
Ticket,
4649
TicketEvent,
4750
User,
@@ -751,6 +754,17 @@ def populate_sample_data():
751754
cidrs = generate_cidr_blocks()
752755
create_cidrs_for_org(org, cidrs)
753756

757+
data_source, _ = DataSource.objects.using("mini_data_lake").get_or_create(
758+
name="findomain",
759+
description="findomain enumerates domains into subs.",
760+
last_run=datetime.now(),
761+
)
762+
for _ in range(NUM_SAMPLE_DOMAINS):
763+
create_sample_domain(org, data_source)
764+
765+
# COnnect created sub domains with ips
766+
connect_ips_from_subs(orgs)
767+
754768
LOGGER.info("Populating vuln_scans, port_scans, tickets, and ticket_events...")
755769
for idx, org in enumerate(orgs, start=1):
756770
try:
@@ -1067,18 +1081,20 @@ def generate_random_name():
10671081
return "{} {} {}".format(adjective.capitalize(), entity, noun.capitalize())
10681082

10691083

1070-
def create_sample_domain(organization):
1084+
def create_sample_domain(organization, data_source):
10711085
"""Create a sample domain linked to an organization."""
10721086
domain_name = "{}-{}.crossfeed.local".format(
10731087
random.choice(adjectives), random.choice(nouns)
10741088
).lower()
10751089
ip = ".".join(map(str, (random.randint(0, 255) for _ in range(4))))
1076-
return Domain.objects.create(
1077-
name=domain_name,
1078-
ip=ip,
1079-
fromRootDomain="crossfeed.local",
1080-
subdomainSource="findomain",
1090+
return SubDomains.objects.create(
1091+
sub_domain=domain_name,
1092+
ip_address=ip,
1093+
from_root_domain="crossfeed.local",
1094+
subdomain_source="findomain",
10811095
organization=organization,
1096+
data_source=data_source,
1097+
current=True,
10821098
)
10831099

10841100

backend/src/xfd_django/xfd_api/tests/test_api_integrity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_endpoints_require_auth(method, route):
7979
("DELETE", "/notifications/{notification_id}"),
8080
("POST", "/v2/organizations/{organization_id}/users"),
8181
("POST", "/search/organizations"),
82+
("POST", "/search/domains"),
8283
("DELETE", "/saved-searches/{saved_search_id}"),
8384
("POST", "/scheduler/invoke"),
8485
("POST", "/scan-tasks/{scan_task_id}/kill"),

backend/src/xfd_django/xfd_api/tests/test_domain.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime
44
import logging
55
import secrets
6+
from unittest.mock import patch
67

78
# Third-Party Libraries
89
from django.db import transaction
@@ -329,3 +330,80 @@ def test_search_domains_does_not_exist(user, domain, refresh_vuln_views):
329330
assert response.status_code == 200
330331
data = response.json()
331332
assert len(data["result"]) == 0, "No result found for the given organization name"
333+
334+
335+
@pytest.mark.django_db(transaction=True, databases=["default", "mini_data_lake"])
336+
@patch("xfd_api.tasks.es_client.ESClient.search_domains")
337+
def test_domains_search_autofill_endpoint_auth_user_200(mock_search):
338+
"""Test domain search autocomplete endpoint."""
339+
user = User.objects.create(
340+
first_name="",
341+
last_name="",
342+
email="{}@example.com".format(secrets.token_hex(4)),
343+
user_type=UserType.STANDARD,
344+
created_at=datetime.now(),
345+
updated_at=datetime.now(),
346+
)
347+
response = client.post(
348+
"/search/domains",
349+
json={
350+
"search_term": "127",
351+
"search_field": "name",
352+
"regions": [],
353+
"organizations": [],
354+
},
355+
headers={"Authorization": "Bearer " + create_jwt_token(user)},
356+
)
357+
assert response.status_code == 200
358+
359+
360+
@pytest.mark.django_db(transaction=True, databases=["default", "mini_data_lake"])
361+
@patch("xfd_api.tasks.es_client.ESClient.search_domains")
362+
def test_domains_search_autofill_endpoint_regional_auth(mock_search):
363+
"""Test domain search autocomplete endpoint."""
364+
user = User.objects.create(
365+
first_name="",
366+
last_name="",
367+
email="{}@example.com".format(secrets.token_hex(4)),
368+
user_type=UserType.REGIONAL_ADMIN,
369+
created_at=datetime.now(),
370+
updated_at=datetime.now(),
371+
region_id="8",
372+
)
373+
response = client.post(
374+
"/search/domains",
375+
json={
376+
"search_term": "127",
377+
"search_field": "name",
378+
"regions": ["3"],
379+
"organizations": [],
380+
},
381+
headers={"Authorization": "Bearer " + create_jwt_token(user)},
382+
)
383+
assert response.status_code == 200
384+
385+
386+
@pytest.mark.django_db(transaction=True, databases=["default", "mini_data_lake"])
387+
@patch("xfd_api.tasks.es_client.ESClient.search_domains")
388+
def test_domains_search_autofill_endpoint_global_auth(mock_search):
389+
"""Test domain search autocomplete endpoint."""
390+
user = User.objects.create(
391+
first_name="",
392+
last_name="",
393+
email="{}@example.com".format(secrets.token_hex(4)),
394+
user_type=UserType.GLOBAL_ADMIN,
395+
created_at=datetime.now(),
396+
updated_at=datetime.now(),
397+
region_id="8",
398+
)
399+
response = client.post(
400+
"/search/domains",
401+
json={
402+
"search_term": "127",
403+
"search_field": "name",
404+
"regions": ["3"],
405+
"organizations": [],
406+
},
407+
headers={"Authorization": "Bearer " + create_jwt_token(user)},
408+
)
409+
assert response.status_code == 200

0 commit comments

Comments
 (0)