Skip to content

Commit 15d5344

Browse files
authored
Graph search can now select target vertices based on root_type (previously only leaf types) (#1065)
1 parent f0df5c4 commit 15d5344

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

core/database_arango.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,8 @@ def neighbors(
546546
count: int = 0,
547547
) -> tuple[
548548
dict[
549-
str, "observable.Observable | entity.Entity | indicator.Indicator | tag.Tag"
549+
str,
550+
"observable.ObservableTypes | entity.EntityTypes | indicator.IndicatorTypes | tag.Tag",
550551
],
551552
List[List["Relationship | TagRelationship"]],
552553
int,
@@ -582,7 +583,9 @@ def neighbors(
582583
query_filter = "FILTER e.type IN @link_types"
583584
if target_types:
584585
args["target_types"] = target_types
585-
query_filter = "FILTER v.type IN @target_types"
586+
query_filter = (
587+
"FILTER (v.type IN @target_types OR v.root_type IN @target_types)"
588+
)
586589

587590
limit = ""
588591
if count != 0:

tests/apiv2/graph.py

+45
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,51 @@ def test_neighbors_strongly_typed(self):
215215
self.assertEqual(neighbor["query_type"], "opensearch")
216216
self.assertEqual(neighbor["target_systems"], ["system1"])
217217

218+
def test_neighbors_target_types(self):
219+
self.entity1.link_to(self.observable1, "uses", "asd")
220+
self.entity1.link_to(self.observable2, "uses", "asd")
221+
response = client.post(
222+
"/api/v2/graph/search",
223+
json={
224+
"source": self.entity1.extended_id,
225+
"hops": 1,
226+
"graph": "links",
227+
"direction": "any",
228+
"target_types": ["hostname"],
229+
"include_original": False,
230+
},
231+
)
232+
data = response.json()
233+
self.assertEqual(response.status_code, 200, data)
234+
self.assertEqual(len(data["vertices"]), 1)
235+
self.assertEqual(
236+
data["vertices"][self.observable1.extended_id]["value"], "tomchop.me"
237+
)
238+
239+
def test_neighbors_target_types_root_type(self):
240+
self.entity1.link_to(self.observable1, "uses", "asd")
241+
self.entity1.link_to(self.observable2, "uses", "asd")
242+
response = client.post(
243+
"/api/v2/graph/search",
244+
json={
245+
"source": self.entity1.extended_id,
246+
"hops": 1,
247+
"graph": "links",
248+
"direction": "any",
249+
"target_types": ["observable"],
250+
"include_original": False,
251+
},
252+
)
253+
data = response.json()
254+
self.assertEqual(response.status_code, 200, data)
255+
self.assertEqual(len(data["vertices"]), 2)
256+
self.assertEqual(
257+
data["vertices"][self.observable1.extended_id]["value"], "tomchop.me"
258+
)
259+
self.assertEqual(
260+
data["vertices"][self.observable2.extended_id]["value"], "127.0.0.1"
261+
)
262+
218263
def test_add_link(self):
219264
response = client.post(
220265
"/api/v2/graph/add",

0 commit comments

Comments
 (0)