Skip to content

Commit 98e2904

Browse files
authored
Fix: calling save() doesn't overwrite all the object's contents (fixes #947) (#954)
1 parent 5bd8607 commit 98e2904

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

core/database_arango.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,19 @@ def _update(self, document_json):
204204
newdoc["id"] = newdoc.pop("_key")
205205
return newdoc
206206

207-
def save(self: TYetiObject) -> TYetiObject:
207+
def save(
208+
self: TYetiObject,
209+
exclude_overwrite: list[str] = ['created', 'tags', 'context']
210+
) -> TYetiObject:
208211
"""Inserts or updates a Yeti object into the database.
209212
210213
We need to pass the JSON representation of the object to the database
211214
because it may contain fields that are not JSON serializable by arango.
212215
216+
Args:
217+
exclude_overwrite: Exclude overwriting these fields if observable
218+
already exists in the database.
219+
213220
Returns:
214221
The created Yeti object.
215222
"""
@@ -219,8 +226,12 @@ def save(self: TYetiObject) -> TYetiObject:
219226
else:
220227
result = self._insert(self.model_dump_json())
221228
if not result:
222-
result = self._update(self.model_dump_json(exclude={"created"}))
223-
return self.__class__(**result)
229+
result = self._update(self.model_dump_json(exclude=exclude_overwrite))
230+
yeti_object = self.__class__(**result)
231+
#TODO: Override this if we decide to implement YetiTagModel
232+
if hasattr(self, 'tags'):
233+
yeti_object.get_tags()
234+
return yeti_object
224235

225236
@classmethod
226237
def list(cls: Type[TYetiObject]) -> Iterable[TYetiObject]:

tests/schemas/observable.py

+23
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,29 @@ def test_observable_create(self) -> None:
2323
self.assertIsNotNone(result.id)
2424
self.assertEqual(result.value, "toto.com")
2525

26+
def test_observable_update(self) -> None:
27+
"""Tests that calling save() on an observable treats it as PATCH."""
28+
result = registry_key.RegistryKey(
29+
key="Microsoft\\Windows\\CurrentVersion\\Run",
30+
value="persist",
31+
data=b"cmd.exe",
32+
hive=registry_key.RegistryHive.HKEY_LOCAL_MACHINE_Software).save()
33+
result.tag(['tag1'])
34+
result.add_context(source='source1', context={'some': 'info'})
35+
self.assertEqual(list(result.tags.keys()), ['tag1'])
36+
self.assertEqual(
37+
result.context[0], {'source': 'source1', 'some': 'info'})
38+
result = registry_key.RegistryKey(
39+
key="Microsoft\\Windows\\CurrentVersion\\RunOnce",
40+
value="persist",
41+
data=b"other.exe",
42+
hive=registry_key.RegistryHive.HKEY_LOCAL_MACHINE_Software).save()
43+
self.assertEqual(result.key, "Microsoft\\Windows\\CurrentVersion\\RunOnce")
44+
self.assertEqual(result.data, b"other.exe")
45+
self.assertEqual(list(result.tags.keys()), ['tag1'])
46+
self.assertEqual(
47+
result.context[0], {'source': 'source1', 'some': 'info'})
48+
2649
def test_create_generic_observable(self):
2750
result = generic_observable.GenericObservable(value="Some_String").save()
2851
self.assertIsNotNone(result.id)

0 commit comments

Comments
 (0)