Skip to content

Commit 6d8ed6c

Browse files
authored
Merge pull request #2117 from dan-acsc/feat/add-full-tree-inspection-option
feat: add full tree inspection option
2 parents 51ff3b2 + db32c2a commit 6d8ed6c

2 files changed

Lines changed: 175 additions & 6 deletions

File tree

assemblyline/datastore/helper.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def get_signature_last_modified(self, sig_type=None):
598598

599599
@elasticapm.capture_span(span_type='datastore')
600600
def get_or_create_file_tree(self, submission, max_depth, cl_engine=forge.get_classification(),
601-
user_classification=None):
601+
user_classification=None, get_full_tree: bool = False):
602602
# Generate cache key
603603
if user_classification is not None:
604604
user_classification = cl_engine.normalize_classification(user_classification, long_format=False)
@@ -616,8 +616,12 @@ def get_or_create_file_tree(self, submission, max_depth, cl_engine=forge.get_cla
616616
num_files = len(list({x[:64] for x in submission['results']}))
617617
max_score = submission['max_score']
618618

619-
# Load / Validate cache tree if exist
620-
cached_tree = self.submission_tree.get_if_exists(cache_key, as_obj=False)
619+
# bypass cache if the full tree is requested as the cache may hold a non-full version
620+
if get_full_tree:
621+
cached_tree = None
622+
else:
623+
# Load / Validate cache tree if exist
624+
cached_tree = self.submission_tree.get_if_exists(cache_key, as_obj=False)
621625
if cached_tree:
622626
tree = json.loads(cached_tree['tree'])
623627
if self._is_valid_tree(tree, num_files, max_score):
@@ -704,8 +708,8 @@ def process_file(current_file, tree_branch, partial, lvl=0):
704708
file_sha256 = current_file['sha256']
705709
file_name = current_file['name']
706710

707-
# Check if the file not already in the tree and if its allowed to be processed
708-
if file_sha256 not in tree_branch \
711+
# Check if the file not already in the tree (unless full tree is requested) and if its allowed to be processed
712+
if (get_full_tree or file_sha256 not in tree_branch) \
709713
and file_sha256 not in forbidden_files \
710714
and file_sha256 not in missing_files:
711715

@@ -726,7 +730,7 @@ def process_file(current_file, tree_branch, partial, lvl=0):
726730
# Process each children of the file
727731
for new_child in files.get(file_sha256, []):
728732
# Check if the file has already been processed elsewhere in the tree
729-
if new_child['sha256'] in tree_cache:
733+
if not get_full_tree and new_child['sha256'] in tree_cache:
730734
truncated = True
731735
else:
732736
# Process file children

test/test_datastore_helper.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
import hashlib
3+
from typing import Iterator
34
from assemblyline.common.isotime import now_as_iso
45
from assemblyline.odm.models.file import File
56
import pytest
@@ -16,6 +17,7 @@
1617
from assemblyline.odm.models.result import File as ResultFile
1718
from assemblyline.odm.models.service import Service
1819
from assemblyline.odm.models.submission import Submission
20+
from assemblyline.odm.models.submission import File as SubmissionFile
1921
from assemblyline.remote.datatypes.hash import Hash
2022
from assemblyline.odm.randomizer import SERVICES, random_minimal_obj
2123
from assemblyline.odm.random_data import create_signatures, create_submission, create_heuristics, create_services
@@ -69,6 +71,115 @@ def ds(request, config):
6971

7072
return pytest.skip("Connection to the Elasticsearch server failed. This test cannot be performed...")
7173

74+
@pytest.fixture
75+
def submission_with_duplicate_extracted_files(ds: AssemblylineDatastore) -> Iterator[Submission]:
76+
# Create a submission with two root sha256's that have a common extracted file.
77+
# This file should appear in both extractions in the full tree case and not if full tree isn't requested.
78+
fs = forge.get_filestore()
79+
cl_engine = forge.get_classification()
80+
81+
# Root file creation
82+
root_file = random_minimal_obj(File)
83+
root_file.expiry_ts = now_as_iso(60 * 60 * 24 * 14)
84+
root_data = b"root file content for extracted test"
85+
root_sha256 = hashlib.sha256(root_data).hexdigest()
86+
root_file.sha256 = root_sha256
87+
ds.file.save(root_sha256, root_file)
88+
fs.put(root_sha256, root_data)
89+
90+
# secondary root file creation
91+
root_file_2 = random_minimal_obj(File)
92+
root_file_2.expiry_ts = now_as_iso(60 * 60 * 24 * 14)
93+
root_data_2 = b"root file content for extracted test number 2"
94+
root_sha256_2 = hashlib.sha256(root_data_2).hexdigest()
95+
root_file_2.sha256 = root_sha256_2
96+
ds.file.save(root_sha256_2, root_file_2)
97+
fs.put(root_sha256_2, root_data_2)
98+
99+
# file extracted from first and second root files
100+
extracted_file = random_minimal_obj(File)
101+
extracted_file.expiry_ts = now_as_iso(60 * 60 * 24 * 14)
102+
extracted_data = b"extracted file content"
103+
extracted_sha256 = hashlib.sha256(extracted_data).hexdigest()
104+
extracted_file.sha256 = extracted_sha256
105+
ds.file.save(extracted_sha256, extracted_file)
106+
fs.put(extracted_sha256, extracted_data)
107+
108+
# file extracted only from first root file
109+
extracted_file_2 = random_minimal_obj(File)
110+
extracted_file_2.expiry_ts = now_as_iso(60 * 60 * 24 * 14)
111+
extracted_data_2 = b"extracted file content 2"
112+
extracted_sha256_2 = hashlib.sha256(extracted_data_2).hexdigest()
113+
extracted_file_2.sha256 = extracted_sha256_2
114+
ds.file.save(extracted_sha256_2, extracted_file_2)
115+
fs.put(extracted_sha256_2, extracted_data_2)
116+
117+
submission = random_minimal_obj(Submission)
118+
submission.expiry_ts = now_as_iso(60 * 60 * 24 * 14)
119+
submission.files[0].sha256 = root_sha256
120+
submission.files.append(SubmissionFile({"sha256": root_file_2.sha256, "name": "root_2_file_name", "size": root_file_2.size}))
121+
ds.submission.save(submission.sid, submission)
122+
123+
# First extraction from root sha256
124+
result = random_minimal_obj(Result)
125+
result.sha256 = root_sha256
126+
result.classification = cl_engine.UNRESTRICTED
127+
result.response.extracted = [
128+
ResultFile({
129+
"sha256": extracted_sha256,
130+
"name": "extracted.bin",
131+
"description": "Test extracted file",
132+
"classification": cl_engine.UNRESTRICTED
133+
})
134+
]
135+
result_key = result.build_key()
136+
ds.result.save(result_key, result)
137+
ds.result.commit()
138+
139+
# Second extraction from root sha256
140+
result_2 = random_minimal_obj(Result)
141+
result_2.sha256 = root_sha256
142+
result_2.classification = cl_engine.UNRESTRICTED
143+
result_2.response.extracted = [
144+
ResultFile({
145+
"sha256": extracted_sha256_2,
146+
"name": "extracted_2.bin",
147+
"description": "Test extracted file 2",
148+
"classification": cl_engine.UNRESTRICTED
149+
})
150+
]
151+
result_key_2 = result_2.build_key()
152+
ds.result.save(result_key_2, result_2)
153+
ds.result.commit()
154+
155+
# First extraction from root 2 sha256 - same file extracted as first extraction from root sha256
156+
result_3 = random_minimal_obj(Result)
157+
result_3.sha256 = root_sha256_2
158+
result_3.classification = cl_engine.UNRESTRICTED
159+
result_3.response.extracted = [
160+
ResultFile({
161+
"sha256": extracted_sha256,
162+
"name": "extracted.bin",
163+
"description": "Test extracted file 2",
164+
"classification": cl_engine.UNRESTRICTED
165+
})
166+
]
167+
result_key_3 = result_3.build_key()
168+
ds.result.save(result_key_3, result_3)
169+
ds.result.commit()
170+
171+
# save all results to the submission
172+
submission.results = [result_key, result_key_2, result_key_3]
173+
ds.submission.save(submission.sid, submission)
174+
ds.submission.commit()
175+
yield submission
176+
# Ensure submission tree is deleted
177+
ds.delete_submission_tree(submission.sid, transport=fs)
178+
ds.submission.commit()
179+
ds.error.commit()
180+
ds.emptyresult.commit()
181+
ds.result.commit()
182+
ds.file.commit()
72183

73184
def test_index_archive_status(ds: AssemblylineDatastore, config: Config):
74185
"""Save a new document atomically, then try to save it again and detect the failure."""
@@ -351,6 +462,60 @@ def test_get_or_create_file_tree(ds: AssemblylineDatastore, config: Config):
351462
for f in submission.files:
352463
assert f.sha256 in tree['tree']
353464

465+
def test_get_or_create_full_file_tree(ds: AssemblylineDatastore, config: Config):
466+
# Get a random submission
467+
submission: Submission = ds.submission.search("id:*", rows=1, fl="*")['items'][0]
468+
469+
# Get file tree
470+
tree_normal = ds.get_or_create_file_tree(submission, config.submission.max_extraction_depth, get_full_tree=False)
471+
full_tree = ds.get_or_create_file_tree(submission, config.submission.max_extraction_depth, get_full_tree=True)
472+
473+
# Check if all files that are obvious from the results are there
474+
for x in ['tree', 'classification', 'filtered', 'partial', 'supplementary']:
475+
assert x in full_tree
476+
477+
for f in submission.files:
478+
assert f.sha256 in full_tree['tree']
479+
480+
# result should never be truncated
481+
for _, res_val in full_tree['tree'].items():
482+
truncated = res_val.get("truncated", True)
483+
assert truncated is False
484+
485+
# result for normal tree should always be truncated at the first level.
486+
for _, res_val in tree_normal['tree'].items():
487+
truncated = res_val.get("truncated", False)
488+
assert truncated is True
489+
490+
491+
def test_get_or_create_full_file_tree_guaranteed_tree(submission_with_duplicate_extracted_files: Submission , ds: AssemblylineDatastore, config: Config):
492+
tree_normal = ds.get_or_create_file_tree(submission_with_duplicate_extracted_files, config.submission.max_extraction_depth, get_full_tree=False)
493+
normal_found_sha256s = {}
494+
for root_sha256, val in tree_normal['tree'].items():
495+
normal_found_sha256s[root_sha256] = []
496+
for child_sha256, val in val.get("children", {}).items():
497+
normal_found_sha256s[root_sha256].append(child_sha256)
498+
499+
full_tree = ds.get_or_create_file_tree(submission_with_duplicate_extracted_files, config.submission.max_extraction_depth, get_full_tree=True)
500+
full_found_sha256s = {}
501+
for root_sha256, val in full_tree['tree'].items():
502+
full_found_sha256s[root_sha256] = []
503+
for child_sha256, val in val.get("children", {}).items():
504+
assert val.get("truncated", True) is False
505+
full_found_sha256s[root_sha256].append(child_sha256)
506+
507+
assert len(normal_found_sha256s) == 2
508+
assert len(full_found_sha256s) == 2
509+
510+
more_found_in_full_result = False
511+
for root_sha in full_found_sha256s:
512+
assert len(full_found_sha256s[root_sha]) >= len(normal_found_sha256s[root_sha])
513+
if len(full_found_sha256s[root_sha]) == len(normal_found_sha256s[root_sha]):
514+
assert sorted(full_found_sha256s[root_sha]) == sorted(normal_found_sha256s[root_sha])
515+
elif len(full_found_sha256s[root_sha]) > len(normal_found_sha256s[root_sha]):
516+
more_found_in_full_result = True
517+
assert more_found_in_full_result is True
518+
354519

355520
def test_get_summary_from_keys(ds: AssemblylineDatastore):
356521
# Get a random submission

0 commit comments

Comments
 (0)