Skip to content

Commit 5bd8607

Browse files
udgovertomchop
andauthored
Extended observables api (#949)
Co-authored-by: Thomas Chopitea <[email protected]>
1 parent 699d79d commit 5bd8607

File tree

3 files changed

+170
-46
lines changed

3 files changed

+170
-46
lines changed

core/schemas/observable.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
import re
55
from enum import Enum
6-
from typing import ClassVar, Literal
6+
from typing import ClassVar, Literal, Type
77

88
import validators
99
from core import database_arango
@@ -53,16 +53,19 @@ class Observable(BaseModel, database_arango.ArangoYetiConnector):
5353
context: list[dict] = []
5454
last_analysis: dict[str, datetime.datetime] = {}
5555

56+
5657
@classmethod
57-
def load(cls, object: dict) -> "Observable":
58-
return cls(**object)
58+
def load(cls, object: dict) -> "ObservableTypes":
59+
if object["type"] in TYPE_MAPPING:
60+
return TYPE_MAPPING[object["type"]](**object)
61+
raise ValueError("Attempted to instantiate an undefined observable type.")
5962

6063
@classmethod
6164
def is_valid(cls, object: dict) -> bool:
6265
return validate_observable(object)
6366

6467
@classmethod
65-
def add_text(cls, text: str, tags: list[str] = []) -> "Observable":
68+
def add_text(cls, text: str, tags: list[str] = []) -> "ObservableTypes":
6669
"""Adds and returns an observable for a given string.
6770
6871
Args:
@@ -79,9 +82,8 @@ def add_text(cls, text: str, tags: list[str] = []) -> "Observable":
7982

8083
observable = Observable.find(value=refanged)
8184
if not observable:
82-
observable = Observable(
85+
observable = TYPE_MAPPING[observable_type](
8386
value=refanged,
84-
type=observable_type,
8587
created=datetime.datetime.now(datetime.timezone.utc),
8688
).save()
8789
if tags:
@@ -90,7 +92,7 @@ def add_text(cls, text: str, tags: list[str] = []) -> "Observable":
9092

9193
def add_context(
9294
self, source: str, context: dict, skip_compare: set = set()
93-
) -> "Observable":
95+
) -> "ObservableTypes":
9496
"""Adds context to an observable."""
9597
compare_fields = set(context.keys()) - skip_compare - {"source"}
9698
for idx, db_context in enumerate(list(self.context)):
@@ -111,7 +113,7 @@ def add_context(
111113

112114
def delete_context(
113115
self, source: str, context: dict, skip_compare: set = set()
114-
) -> "Observable":
116+
) -> "ObservableTypes":
115117
"""Deletes context from an observable."""
116118
compare_fields = set(context.keys()) - skip_compare - {"source"}
117119
for idx, db_context in enumerate(list(self.context)):
@@ -140,9 +142,9 @@ def delete_context(
140142

141143
REGEXES_OBSERVABLES = {
142144
# Unix
143-
ObservableType.path : [
145+
ObservableType.path: [
144146
re.compile(r"^(\/[^\/\0]+)+$"),
145-
re.compile(r"^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+")
147+
re.compile(r"^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+"),
146148
]
147149
}
148150

@@ -170,16 +172,14 @@ def find_type(value: str) -> ObservableType | None:
170172
return None
171173

172174

173-
TYPE_MAPPING = {
174-
'observable': Observable,
175-
'observables': Observable
176-
}
175+
TYPE_MAPPING = {"observable": Observable, "observables": Observable}
176+
177177

178178
# Import all observable types, as these register themselves in the TYPE_MAPPING
179179
# disable: pylint=wrong-import-position
180180
from core.schemas.observables import (asn, bitcoin_wallet, certificate, cidr,
181181
command_line, docker_image, email, file,
182-
generic_observable, hostname, imphash, ipv4, ipv6,
183-
mac_address, md5, path, registry_key,
184-
sha1, sha256, ssdeep, tlsh, url,
185-
user_agent)
182+
generic_observable, hostname, imphash,
183+
ipv4, ipv6, mac_address, md5, path,
184+
registry_key, sha1, sha256, ssdeep, tlsh,
185+
url, user_agent)

core/web/apiv2/observables.py

+58-23
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
import datetime
22
from typing import Iterable
33

4+
from core.schemas import graph
5+
from core.schemas.observable import TYPE_MAPPING, Observable, ObservableType
46
from fastapi import APIRouter, HTTPException
57
from pydantic import BaseModel, ConfigDict, field_validator
68

7-
from core.schemas import graph
8-
from core.schemas.observable import Observable, ObservableType
9+
ObservableTypes = ()
910

11+
for key in TYPE_MAPPING:
12+
if key in ["observable", "observables"]:
13+
continue
14+
cls = TYPE_MAPPING[key]
15+
if not ObservableTypes:
16+
ObservableTypes = cls
17+
else:
18+
ObservableTypes |= cls
1019

11-
# Request schemas
12-
class NewObservableRequest(BaseModel):
13-
model_config = ConfigDict(extra='forbid')
1420

15-
value: str
21+
class TagRequestMixin(BaseModel):
22+
1623
tags: list[str] = []
17-
type: ObservableType
1824

1925
@field_validator("tags")
2026
@classmethod
@@ -25,17 +31,30 @@ def validate_tags(cls, value) -> list[str]:
2531
return value
2632

2733

34+
# Request schemas
35+
class NewObservableRequest(TagRequestMixin):
36+
model_config = ConfigDict(extra='forbid')
37+
38+
value: str
39+
type: ObservableType
40+
41+
42+
class NewExtendedObservableRequest(TagRequestMixin):
43+
model_config = ConfigDict(extra='forbid')
44+
45+
observable: ObservableTypes
46+
47+
2848
class NewBulkObservableAddRequest(BaseModel):
2949
model_config = ConfigDict(extra='forbid')
3050

3151
observables: list[NewObservableRequest]
3252

3353

34-
class AddTextRequest(BaseModel):
54+
class AddTextRequest(TagRequestMixin):
3555
model_config = ConfigDict(extra='forbid')
3656

3757
text: str
38-
tags: list[str] = []
3958

4059

4160
class AddContextRequest(BaseModel):
@@ -62,15 +81,14 @@ class ObservableSearchRequest(BaseModel):
6281
class ObservableSearchResponse(BaseModel):
6382
model_config = ConfigDict(extra='forbid')
6483

65-
observables: list[Observable]
84+
observables: list[ObservableTypes]
6685
total: int
6786

6887

69-
class ObservableTagRequest(BaseModel):
88+
class ObservableTagRequest(TagRequestMixin):
7089
model_config = ConfigDict(extra='forbid')
7190

7291
ids: list[str]
73-
tags: list[str]
7492
strict: bool = False
7593

7694
class ObservableTagResponse(BaseModel):
@@ -90,7 +108,7 @@ async def observables_root() -> Iterable[Observable]:
90108

91109

92110
@router.post("/")
93-
async def new(request: NewObservableRequest) -> Observable:
111+
async def new(request: NewObservableRequest) -> ObservableTypes:
94112
"""Creates a new observable in the database.
95113
96114
Raises:
@@ -113,8 +131,28 @@ async def new(request: NewObservableRequest) -> Observable:
113131
return new
114132

115133

134+
@router.post("/extended")
135+
async def new(request: NewExtendedObservableRequest) -> ObservableTypes:
136+
"""Creates a new observable in the database with extended properties.
137+
138+
Raises:
139+
HTTPException(400) if observable already exists.
140+
"""
141+
observable = Observable.find(value=request.observable.value, type=request.observable.type)
142+
if observable:
143+
raise HTTPException(
144+
status_code=400,
145+
detail=f"Observable with value {request.observable.value} already exists",
146+
)
147+
cls = TYPE_MAPPING[request.observable.type]
148+
new = cls(**request.observable.model_dump()).save()
149+
if request.tags:
150+
new.tag(request.tags)
151+
return new
152+
153+
116154
@router.post("/bulk")
117-
async def bulk_add(request: NewBulkObservableAddRequest) -> list[Observable]:
155+
async def bulk_add(request: NewBulkObservableAddRequest) -> list[ObservableTypes]:
118156
"""Bulk-creates new observables in the database."""
119157
added = []
120158
for new_observable in request.observables:
@@ -123,19 +161,16 @@ async def bulk_add(request: NewBulkObservableAddRequest) -> list[Observable]:
123161
new_observable.value, tags=new_observable.tags
124162
)
125163
else:
126-
observable = Observable(
127-
value=new_observable.value,
128-
type=new_observable.type,
129-
created=datetime.datetime.now(datetime.timezone.utc),
130-
).save()
164+
cls = TYPE_MAPPING[new_observable.type]
165+
observable = cls(value=new_observable.value).save()
131166
if new_observable.tags:
132167
observable = observable.tag(new_observable.tags)
133168
added.append(observable)
134169
return added
135170

136171

137172
@router.get("/{observable_id}")
138-
async def details(observable_id) -> Observable:
173+
async def details(observable_id) -> ObservableTypes:
139174
"""Returns details about an observable."""
140175
observable = Observable.get(observable_id)
141176
if not observable:
@@ -145,7 +180,7 @@ async def details(observable_id) -> Observable:
145180

146181

147182
@router.post("/{observable_id}/context")
148-
async def add_context(observable_id, request: AddContextRequest) -> Observable:
183+
async def add_context(observable_id, request: AddContextRequest) -> ObservableTypes:
149184
"""Adds context to an observable."""
150185
observable = Observable.get(observable_id)
151186
if not observable:
@@ -160,7 +195,7 @@ async def add_context(observable_id, request: AddContextRequest) -> Observable:
160195

161196

162197
@router.post("/{observable_id}/context/delete")
163-
async def delete_context(observable_id, request: DeleteContextRequest) -> Observable:
198+
async def delete_context(observable_id, request: DeleteContextRequest) -> ObservableTypes:
164199
"""Removes context to an observable."""
165200
observable = Observable.get(observable_id)
166201
if not observable:
@@ -193,7 +228,7 @@ async def search(request: ObservableSearchRequest) -> ObservableSearchResponse:
193228

194229

195230
@router.post("/add_text")
196-
async def add_text(request: AddTextRequest) -> Observable:
231+
async def add_text(request: AddTextRequest) -> ObservableTypes:
197232
"""Adds and returns an observable for a given string, attempting to guess
198233
its type."""
199234
try:

0 commit comments

Comments
 (0)