Skip to content

Commit 8badbf8

Browse files
gqvzmlodicCopilot
authored
Speed up unit tests. Closes #2958 (#3289)
* Optimize tests * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/test_crons.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Conditionally apply network related patches --------- Co-authored-by: Matteo Lodi <30625432+mlodic@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 00e92a9 commit 8badbf8

File tree

6 files changed

+134
-32
lines changed

6 files changed

+134
-32
lines changed

tests/api_app/analyzers_manager/test_views.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# This file is a part of IntelOwl https://github.com/intelowlproject/IntelOwl
22
# See the file 'LICENSE' for copying permission.
3+
from contextlib import nullcontext
34
from typing import Type
45
from unittest.mock import patch
56

7+
from django.conf import settings
8+
69
from api_app.analyzables_manager.models import Analyzable
710
from api_app.analyzers_manager.models import AnalyzerConfig, AnalyzerReport
811
from api_app.choices import Classification, PythonModuleBasePaths
@@ -29,17 +32,24 @@ def test_pull(self):
2932
from api_app.analyzers_manager.file_analyzers.yara_scan import YaraScan
3033

3134
analyzer = "Yara"
32-
response = self.client.post(f"{self.URL}/{analyzer}/pull")
33-
self.assertEqual(response.status_code, 200)
35+
ctx = (
36+
patch.object(YaraScan, "update", return_value=True)
37+
if settings.MOCK_CONNECTIONS
38+
else nullcontext()
39+
)
40+
with ctx as mock_update:
41+
response = self.client.post(f"{self.URL}/{analyzer}/pull")
42+
self.assertEqual(response.status_code, 200)
3443

35-
self.client.force_authenticate(self.superuser)
44+
self.client.force_authenticate(self.superuser)
3645

37-
with patch.object(YaraScan, "update", return_value=True):
3846
response = self.client.post(f"{self.URL}/{analyzer}/pull")
39-
self.assertEqual(response.status_code, 200, response.json())
40-
result = response.json()
41-
self.assertIn("status", result)
42-
self.assertTrue(result["status"])
47+
self.assertEqual(response.status_code, 200, response.json())
48+
result = response.json()
49+
self.assertIn("status", result)
50+
self.assertTrue(result["status"])
51+
if mock_update is not None:
52+
self.assertEqual(mock_update.call_count, 2)
4353

4454
analyzer = "Doc_Info"
4555
response = self.client.post(f"{self.URL}/{analyzer}/pull")

tests/api_app/analyzers_manager/unit_tests/file_analyzers/test_capa_info.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import subprocess
22
from unittest.mock import MagicMock, patch
33

4+
from django.conf import settings
5+
46
from api_app.analyzers_manager.file_analyzers.capa_info import CapaInfo
57

68
from .base_test_class import BaseFileAnalyzerTest
@@ -29,7 +31,7 @@ def get_mocked_response(self):
2931
mock_requests_get = MagicMock()
3032
mock_requests_get.json.return_value = {"tag_name": "v1.0.0"}
3133

32-
return [
34+
patches = [
3335
patch.object(CapaInfo, "update", return_value=True),
3436
patch("subprocess.run", return_value=response_from_command),
3537
patch(
@@ -39,6 +41,10 @@ def get_mocked_response(self):
3941
patch.object(CapaInfo, "_check_if_latest_version", return_value=True),
4042
]
4143

44+
if settings.MOCK_CONNECTIONS:
45+
patches.insert(1, patch.object(CapaInfo, "_download_signatures", return_value=None))
46+
return patches
47+
4248
def get_extra_config(self):
4349
return {
4450
"shellcode": False,

tests/api_app/analyzers_manager/unit_tests/file_analyzers/test_virushee.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import patch
22

33
import requests
4+
from django.conf import settings
45

56
from api_app.analyzers_manager.file_analyzers.virushee import VirusheeFileUpload
67

@@ -18,7 +19,7 @@ def get_extra_config(self):
1819
}
1920

2021
def get_mocked_response(self):
21-
return [
22+
patches = [
2223
patch(
2324
"requests.Session.get",
2425
side_effect=[
@@ -36,6 +37,10 @@ def get_mocked_response(self):
3637
),
3738
]
3839

40+
if settings.MOCK_CONNECTIONS:
41+
patches.append(patch("time.sleep", return_value=None))
42+
return patches
43+
3944
class MockUpResponse:
4045
"""Simple mock response class to simulate requests.Response"""
4146

tests/api_app/analyzers_manager/unit_tests/observable_analyzers/test_pulsedive.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from unittest.mock import patch
22

3+
from django.conf import settings
4+
35
from api_app.analyzers_manager.observable_analyzers.pulsedive import Pulsedive
46
from tests.api_app.analyzers_manager.unit_tests.observable_analyzers.base_test_class import (
57
BaseAnalyzerTest,
@@ -12,7 +14,7 @@ class PulsediveTestCase(BaseAnalyzerTest):
1214

1315
@staticmethod
1416
def get_mocked_response():
15-
return [
17+
patches = [
1618
patch(
1719
"requests.get",
1820
side_effect=[
@@ -25,6 +27,10 @@ def get_mocked_response():
2527
patch("requests.post", return_value=MockUpResponse({"qid": 1}, 200)),
2628
]
2729

30+
if settings.MOCK_CONNECTIONS:
31+
patches.append(patch("time.sleep", return_value=None))
32+
return patches
33+
2834
@classmethod
2935
def get_extra_config(cls) -> dict:
3036
return {"scan_mode": "active", "_api_key_name": "test_api_key", "probe": 1}

tests/api_app/test_api.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import datetime
55
import hashlib
66
import os
7+
from contextlib import nullcontext
78
from typing import Tuple
9+
from unittest.mock import patch
810

911
from django.conf import settings
1012
from django.contrib.auth import get_user_model
@@ -16,6 +18,7 @@
1618
from api_app.choices import Classification
1719
from api_app.connectors_manager.models import ConnectorConfig
1820
from api_app.playbooks_manager.models import PlaybookConfig
21+
from tests.mock_utils import if_mock_connections
1922

2023
from .. import CustomViewSetTestCase
2124

@@ -101,7 +104,8 @@ def test_ask_analysis_availability__run_all_analyzers(self):
101104
response = self.client.post("/api/ask_analysis_availability", data, format="json")
102105
self.assertEqual(response.status_code, 200)
103106

104-
def test_analyze_file__pcap(self):
107+
@if_mock_connections(patch("intel_owl.tasks.job_pipeline.apply_async"))
108+
def test_analyze_file__pcap(self, mock_apply_async=None):
105109
# set a fake API key or YARAify_File_Scan will be skipped as not configured
106110
models.PluginConfig.objects.create(
107111
owner=self.user,
@@ -144,6 +148,10 @@ def test_analyze_file__pcap(self):
144148
list(job.analyzers_to_execute.all().values_list("name", flat=True)),
145149
)
146150

151+
if mock_apply_async is not None:
152+
mock_apply_async.assert_called_once()
153+
self.assertEqual(mock_apply_async.call_args.kwargs["args"], [job_id])
154+
147155
def test_analyze_file__exe(self):
148156
data = self.analyze_file_data.copy()
149157
response = self.client.post("/api/analyze_file", data, format="multipart")
@@ -232,7 +240,8 @@ def test_analyze_observable__ip(self):
232240
self.assertEqual(data["observable_classification"], job.analyzable.classification, msg=msg)
233241
self.assertEqual(self.observable_md5, job.analyzable.md5, msg=msg)
234242

235-
def test_analyze_observable__guess_optional(self):
243+
@if_mock_connections(patch("intel_owl.tasks.job_pipeline.apply_async"))
244+
def test_analyze_observable__guess_optional(self, mock_apply_async=None):
236245
data = self.analyze_observable_ip_data.copy()
237246
observable_classification = data.pop("observable_classification") # let the server calc it
238247

@@ -252,16 +261,22 @@ def test_analyze_observable__guess_optional(self):
252261
self.assertEqual(observable_classification, job.analyzable.classification, msg=msg)
253262
self.assertEqual(self.observable_md5, job.analyzable.md5, msg=msg)
254263

255-
def test_analyze_multiple_observables(self):
264+
if mock_apply_async is not None:
265+
mock_apply_async.assert_called_once()
266+
self.assertEqual(mock_apply_async.call_args.kwargs["args"], [job_id])
267+
268+
@if_mock_connections(patch("intel_owl.tasks.job_pipeline.apply_async"))
269+
def test_analyze_multiple_observables(self, mock_apply_async=None):
256270
data = self.mixed_observable_data.copy()
257271

258272
response = self.client.post("/api/analyze_multiple_observables", data, format="json")
259273
contents = response.json()
260274
msg = (response.status_code, contents)
261275
self.assertEqual(response.status_code, 200, msg=msg)
276+
if mock_apply_async is not None:
277+
self.assertEqual(mock_apply_async.call_count, len(data["observables"]))
262278

263279
content = contents["results"][0]
264-
265280
job_id = int(content["job_id"])
266281
job = models.Job.objects.get(pk=job_id)
267282
self.assertEqual(data["observables"][0][1], job.analyzable.name, msg=msg)
@@ -275,9 +290,10 @@ def test_analyze_multiple_observables(self):
275290
list(job.analyzers_to_execute.all().values_list("name", flat=True)),
276291
msg=msg,
277292
)
293+
if mock_apply_async is not None:
294+
self.assertEqual(mock_apply_async.call_args_list[0].kwargs["args"], [job_id])
278295

279296
content = contents["results"][1]
280-
281297
job_id = int(content["job_id"])
282298
job = models.Job.objects.get(pk=job_id)
283299
self.assertEqual(data["observables"][1][1], job.analyzable.name, msg=msg)
@@ -286,6 +302,8 @@ def test_analyze_multiple_observables(self):
286302
list(job.analyzers_to_execute.all().values_list("name", flat=True)),
287303
msg=msg,
288304
)
305+
if mock_apply_async is not None:
306+
self.assertEqual(mock_apply_async.call_args_list[1].kwargs["args"], [job_id])
289307
job.delete()
290308

291309
def test_observable_no_analyzers_only_connector(self):
@@ -527,7 +545,12 @@ def test_job_rescan__observable_playbook(self):
527545
"visualizers": {},
528546
},
529547
)
530-
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
548+
# dont actually run the analyzers when mocking connections, they are tested in unit tests
549+
ctx = (
550+
patch("intel_owl.tasks.job_pipeline.apply_async") if settings.MOCK_CONNECTIONS else nullcontext()
551+
)
552+
with ctx as mock_apply_async:
553+
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
531554
contents = response.json()
532555
self.assertEqual(response.status_code, 202, contents)
533556
new_job_id = int(contents["id"])
@@ -543,6 +566,9 @@ def test_job_rescan__observable_playbook(self):
543566
"visualizers": {},
544567
},
545568
)
569+
if mock_apply_async is not None:
570+
mock_apply_async.assert_called_once()
571+
self.assertEqual(mock_apply_async.call_args.kwargs["args"], [new_job_id])
546572
an.delete()
547573

548574
def test_job_rescan__sample_analyzers(self):
@@ -567,7 +593,11 @@ def test_job_rescan__sample_analyzers(self):
567593
)
568594
job.analyzers_requested.set([AnalyzerConfig.objects.get(name="Strings_Info")])
569595

570-
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
596+
ctx = (
597+
patch("intel_owl.tasks.job_pipeline.apply_async") if settings.MOCK_CONNECTIONS else nullcontext()
598+
)
599+
with ctx as mock_apply_async:
600+
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
571601
contents = response.json()
572602
self.assertEqual(response.status_code, 202, contents)
573603
new_job_id = int(contents["id"])
@@ -579,6 +609,9 @@ def test_job_rescan__sample_analyzers(self):
579609
list(new_job.analyzers_requested.all()),
580610
[AnalyzerConfig.objects.get(name="Strings_Info")],
581611
)
612+
if mock_apply_async is not None:
613+
mock_apply_async.assert_called_once()
614+
self.assertEqual(mock_apply_async.call_args.kwargs["args"], [new_job_id])
582615
self.assertEqual(
583616
new_job.runtime_configuration,
584617
{
@@ -617,7 +650,11 @@ def test_job_rescan__sample_playbook(self):
617650
},
618651
)
619652

620-
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
653+
ctx = (
654+
patch("intel_owl.tasks.job_pipeline.apply_async") if settings.MOCK_CONNECTIONS else nullcontext()
655+
)
656+
with ctx as mock_apply_async:
657+
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
621658
contents = response.json()
622659
self.assertEqual(response.status_code, 202, contents)
623660
new_job_id = int(contents["id"])
@@ -642,6 +679,9 @@ def test_job_rescan__sample_playbook(self):
642679
"visualizers": {},
643680
},
644681
)
682+
if mock_apply_async is not None:
683+
mock_apply_async.assert_called_once()
684+
self.assertEqual(mock_apply_async.call_args.kwargs["args"], [new_job_id])
645685
job.delete()
646686
an.delete()
647687

@@ -665,14 +705,24 @@ def test_job_rescan__permission(self):
665705
"visualizers": {},
666706
},
667707
)
668-
# same user
669-
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
670-
contents = response.json()
671-
self.assertEqual(response.status_code, 202, contents)
672-
# another user
673-
self.client.logout()
674-
self.client.force_login(self.guest)
675-
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
676-
contents = response.json()
677-
self.assertEqual(response.status_code, 403, contents)
708+
ctx = (
709+
patch("intel_owl.tasks.job_pipeline.apply_async") if settings.MOCK_CONNECTIONS else nullcontext()
710+
)
711+
with ctx as mock_apply_async:
712+
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
713+
contents = response.json()
714+
self.assertEqual(response.status_code, 202, contents)
715+
new_job_id = int(contents["id"])
716+
if mock_apply_async is not None:
717+
mock_apply_async.assert_called_once()
718+
self.assertEqual(mock_apply_async.call_args.kwargs["args"], [new_job_id])
719+
mock_apply_async.reset_mock()
720+
721+
self.client.logout()
722+
self.client.force_login(self.guest)
723+
response = self.client.post(f"/api/jobs/{job.pk}/rescan", format="json")
724+
contents = response.json()
725+
self.assertEqual(response.status_code, 403, contents)
726+
if mock_apply_async is not None:
727+
mock_apply_async.assert_not_called()
678728
an.delete()

tests/test_crons.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,34 @@ def test_quark_updater(self):
242242
quark_engine.QuarkEngine.update()
243243
self.assertTrue(os.path.exists(DIR_PATH))
244244

245-
def test_yara_updater(self):
246-
yara_scan.YaraScan.update()
247-
self.assertTrue(len(os.listdir(settings.YARA_RULES_PATH)))
245+
@if_mock_connections(
246+
patch("git.Repo"),
247+
patch("requests.get", return_value=MockUpResponse({}, 200)),
248+
patch("zipfile.ZipFile"),
249+
)
250+
def test_yara_updater(self, mock_zipfile=None, mock_get=None, mock_repo=None):
251+
if mock_zipfile is None or mock_get is None or mock_repo is None:
252+
yara_scan.YaraScan.update()
253+
self.assertTrue(os.path.isdir(settings.YARA_RULES_PATH))
254+
else:
255+
256+
def create_yara_file(path):
257+
os.makedirs(path, exist_ok=True)
258+
yara_file = os.path.join(path, "test_rule.yar")
259+
with open(yara_file, "w") as f:
260+
f.write(
261+
"rule TestRule {\n"
262+
" strings:\n"
263+
' $test = "test"\n'
264+
" condition:\n"
265+
" $test\n"
266+
"}\n"
267+
)
268+
269+
mock_repo.clone_from.side_effect = lambda url, path, **kwargs: create_yara_file(path)
270+
mock_zipfile.return_value.extractall.side_effect = create_yara_file
271+
result = yara_scan.YaraScan.update()
272+
self.assertTrue(result)
248273

249274
@if_mock_connections(
250275
patch(

0 commit comments

Comments
 (0)