1
1
import datetime
2
2
from typing import Iterable
3
3
4
+ from core .schemas import graph
5
+ from core .schemas .observable import TYPE_MAPPING , Observable , ObservableType
4
6
from fastapi import APIRouter , HTTPException
5
7
from pydantic import BaseModel , ConfigDict , field_validator
6
8
7
- from core .schemas import graph
8
- from core .schemas .observable import Observable , ObservableType
9
+ ObservableTypes = ()
9
10
11
+ for key in TYPE_MAPPING :
12
+ if key in ["observable" , "observables" ]:
13
+ continue
14
+ cls = TYPE_MAPPING [key ]
15
+ if not ObservableTypes :
16
+ ObservableTypes = cls
17
+ else :
18
+ ObservableTypes |= cls
10
19
11
- # Request schemas
12
- class NewObservableRequest (BaseModel ):
13
- model_config = ConfigDict (extra = 'forbid' )
14
20
15
- value : str
21
+ class TagRequestMixin (BaseModel ):
22
+
16
23
tags : list [str ] = []
17
- type : ObservableType
18
24
19
25
@field_validator ("tags" )
20
26
@classmethod
@@ -25,17 +31,30 @@ def validate_tags(cls, value) -> list[str]:
25
31
return value
26
32
27
33
34
+ # Request schemas
35
+ class NewObservableRequest (TagRequestMixin ):
36
+ model_config = ConfigDict (extra = 'forbid' )
37
+
38
+ value : str
39
+ type : ObservableType
40
+
41
+
42
+ class NewExtendedObservableRequest (TagRequestMixin ):
43
+ model_config = ConfigDict (extra = 'forbid' )
44
+
45
+ observable : ObservableTypes
46
+
47
+
28
48
class NewBulkObservableAddRequest (BaseModel ):
29
49
model_config = ConfigDict (extra = 'forbid' )
30
50
31
51
observables : list [NewObservableRequest ]
32
52
33
53
34
- class AddTextRequest (BaseModel ):
54
+ class AddTextRequest (TagRequestMixin ):
35
55
model_config = ConfigDict (extra = 'forbid' )
36
56
37
57
text : str
38
- tags : list [str ] = []
39
58
40
59
41
60
class AddContextRequest (BaseModel ):
@@ -62,15 +81,14 @@ class ObservableSearchRequest(BaseModel):
62
81
class ObservableSearchResponse (BaseModel ):
63
82
model_config = ConfigDict (extra = 'forbid' )
64
83
65
- observables : list [Observable ]
84
+ observables : list [ObservableTypes ]
66
85
total : int
67
86
68
87
69
- class ObservableTagRequest (BaseModel ):
88
+ class ObservableTagRequest (TagRequestMixin ):
70
89
model_config = ConfigDict (extra = 'forbid' )
71
90
72
91
ids : list [str ]
73
- tags : list [str ]
74
92
strict : bool = False
75
93
76
94
class ObservableTagResponse (BaseModel ):
@@ -90,7 +108,7 @@ async def observables_root() -> Iterable[Observable]:
90
108
91
109
92
110
@router .post ("/" )
93
- async def new (request : NewObservableRequest ) -> Observable :
111
+ async def new (request : NewObservableRequest ) -> ObservableTypes :
94
112
"""Creates a new observable in the database.
95
113
96
114
Raises:
@@ -113,8 +131,28 @@ async def new(request: NewObservableRequest) -> Observable:
113
131
return new
114
132
115
133
134
+ @router .post ("/extended" )
135
+ async def new (request : NewExtendedObservableRequest ) -> ObservableTypes :
136
+ """Creates a new observable in the database with extended properties.
137
+
138
+ Raises:
139
+ HTTPException(400) if observable already exists.
140
+ """
141
+ observable = Observable .find (value = request .observable .value , type = request .observable .type )
142
+ if observable :
143
+ raise HTTPException (
144
+ status_code = 400 ,
145
+ detail = f"Observable with value { request .observable .value } already exists" ,
146
+ )
147
+ cls = TYPE_MAPPING [request .observable .type ]
148
+ new = cls (** request .observable .model_dump ()).save ()
149
+ if request .tags :
150
+ new .tag (request .tags )
151
+ return new
152
+
153
+
116
154
@router .post ("/bulk" )
117
- async def bulk_add (request : NewBulkObservableAddRequest ) -> list [Observable ]:
155
+ async def bulk_add (request : NewBulkObservableAddRequest ) -> list [ObservableTypes ]:
118
156
"""Bulk-creates new observables in the database."""
119
157
added = []
120
158
for new_observable in request .observables :
@@ -123,19 +161,16 @@ async def bulk_add(request: NewBulkObservableAddRequest) -> list[Observable]:
123
161
new_observable .value , tags = new_observable .tags
124
162
)
125
163
else :
126
- observable = Observable (
127
- value = new_observable .value ,
128
- type = new_observable .type ,
129
- created = datetime .datetime .now (datetime .timezone .utc ),
130
- ).save ()
164
+ cls = TYPE_MAPPING [new_observable .type ]
165
+ observable = cls (value = new_observable .value ).save ()
131
166
if new_observable .tags :
132
167
observable = observable .tag (new_observable .tags )
133
168
added .append (observable )
134
169
return added
135
170
136
171
137
172
@router .get ("/{observable_id}" )
138
- async def details (observable_id ) -> Observable :
173
+ async def details (observable_id ) -> ObservableTypes :
139
174
"""Returns details about an observable."""
140
175
observable = Observable .get (observable_id )
141
176
if not observable :
@@ -145,7 +180,7 @@ async def details(observable_id) -> Observable:
145
180
146
181
147
182
@router .post ("/{observable_id}/context" )
148
- async def add_context (observable_id , request : AddContextRequest ) -> Observable :
183
+ async def add_context (observable_id , request : AddContextRequest ) -> ObservableTypes :
149
184
"""Adds context to an observable."""
150
185
observable = Observable .get (observable_id )
151
186
if not observable :
@@ -160,7 +195,7 @@ async def add_context(observable_id, request: AddContextRequest) -> Observable:
160
195
161
196
162
197
@router .post ("/{observable_id}/context/delete" )
163
- async def delete_context (observable_id , request : DeleteContextRequest ) -> Observable :
198
+ async def delete_context (observable_id , request : DeleteContextRequest ) -> ObservableTypes :
164
199
"""Removes context to an observable."""
165
200
observable = Observable .get (observable_id )
166
201
if not observable :
@@ -193,7 +228,7 @@ async def search(request: ObservableSearchRequest) -> ObservableSearchResponse:
193
228
194
229
195
230
@router .post ("/add_text" )
196
- async def add_text (request : AddTextRequest ) -> Observable :
231
+ async def add_text (request : AddTextRequest ) -> ObservableTypes :
197
232
"""Adds and returns an observable for a given string, attempting to guess
198
233
its type."""
199
234
try :
0 commit comments