Skip to content
This repository was archived by the owner on Nov 1, 2023. It is now read-only.

Commit a92c84d

Browse files
authored
work around issue with discriminated typed unions (#939)
We're experiencing a bug where Unions of sub-models are getting downcast, which causes a loss of information. As an example, EventScalesetCreated was getting downcast to EventScalesetDeleted. I have not figured out why, nor can I replicate it locally to minimize the bug send upstream, but I was able to reliably replicate it on the service. While working through this issue, I noticed that deserialization of SignalR events was frequently wrong, leaving things like tasks as "init" in `status top`. Both of these issues are Unions of models with a type field, so it's likely these are related.
1 parent 3d191c3 commit a92c84d

File tree

7 files changed

+82
-40
lines changed

7 files changed

+82
-40
lines changed

src/api-service/__app__/onefuzzlib/events.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_events() -> Optional[str]:
2424
for _ in range(5):
2525
try:
2626
event = EVENTS.get(block=False)
27-
events.append(json.loads(event.json(exclude_none=True)))
27+
events.append(json.loads(event))
2828
EVENTS.task_done()
2929
except Empty:
3030
break
@@ -36,13 +36,13 @@ def get_events() -> Optional[str]:
3636

3737

3838
def log_event(event: Event, event_type: EventType) -> None:
39-
scrubbed_event = filter_event(event, event_type)
39+
scrubbed_event = filter_event(event)
4040
logging.info(
4141
"sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True)
4242
)
4343

4444

45-
def filter_event(event: Event, event_type: EventType) -> BaseModel:
45+
def filter_event(event: Event) -> BaseModel:
4646
clone_event = event.copy(deep=True)
4747
filtered_event = filter_event_recurse(clone_event)
4848
return filtered_event
@@ -73,12 +73,18 @@ def filter_event_recurse(entry: BaseModel) -> BaseModel:
7373

7474
def send_event(event: Event) -> None:
7575
event_type = get_event_type(event)
76-
log_event(event, event_type)
76+
7777
event_message = EventMessage(
7878
event_type=event_type,
79-
event=event,
79+
event=event.copy(deep=True),
8080
instance_id=get_instance_id(),
8181
instance_name=get_instance_name(),
8282
)
83-
EVENTS.put(event_message)
83+
84+
# work around odd bug with Event Message creation. See PR 939
85+
if event_message.event != event:
86+
event_message.event = event.copy(deep=True)
87+
88+
EVENTS.put(event_message.json())
8489
Webhook.send_event(event_message)
90+
log_event(event, event_type)

src/api-service/__app__/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ azure-storage-queue==12.1.6
2323
jinja2~=2.11.3
2424
msrestazure~=0.6.3
2525
opencensus-ext-azure~=1.0.2
26-
pydantic~=1.8.1 --no-binary=pydantic
26+
pydantic==1.8.2 --no-binary=pydantic
2727
PyJWT~=1.7.1
2828
requests~=2.25.1
2929
memoization~=0.3.1

src/api-service/tests/test_userinfo_filter.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from uuid import uuid4
99

1010
from onefuzztypes.enums import ContainerType, TaskType
11-
from onefuzztypes.events import EventTaskCreated, get_event_type
11+
from onefuzztypes.events import EventTaskCreated
1212
from onefuzztypes.models import (
1313
TaskConfig,
1414
TaskContainers,
@@ -65,9 +65,7 @@ def test_user_info_filter(self) -> None:
6565
user_info=None,
6666
)
6767

68-
test_event_type = get_event_type(test_event)
69-
70-
scrubbed_test_event = filter_event(test_event, test_event_type)
68+
scrubbed_test_event = filter_event(test_event)
7169

7270
self.assertEqual(scrubbed_test_event, control_test_event)
7371

src/cli/onefuzz/status/cache.py

+44-22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright (c) Microsoft Corporation.
44
# Licensed under the MIT License.
55

6+
import json
67
import logging
78
from datetime import datetime
89
from enum import Enum
@@ -15,7 +16,6 @@
1516
EventFileAdded,
1617
EventJobCreated,
1718
EventJobStopped,
18-
EventMessage,
1919
EventNodeCreated,
2020
EventNodeDeleted,
2121
EventNodeStateUpdated,
@@ -26,6 +26,7 @@
2626
EventTaskStateUpdated,
2727
EventTaskStopped,
2828
EventType,
29+
parse_event_message,
2930
)
3031
from onefuzztypes.models import (
3132
Job,
@@ -152,31 +153,43 @@ def add_container(self, name: Container) -> None:
152153

153154
self.add_files_set(name, set(files.files))
154155

155-
def add_message(self, message: EventMessage) -> None:
156-
events = {
157-
EventPoolCreated: lambda x: self.pool_created(x),
158-
EventPoolDeleted: lambda x: self.pool_deleted(x),
159-
EventTaskCreated: lambda x: self.task_created(x),
160-
EventTaskStopped: lambda x: self.task_stopped(x),
161-
EventTaskFailed: lambda x: self.task_stopped(x),
162-
EventTaskStateUpdated: lambda x: self.task_state_updated(x),
163-
EventJobCreated: lambda x: self.job_created(x),
164-
EventJobStopped: lambda x: self.job_stopped(x),
165-
EventNodeStateUpdated: lambda x: self.node_state_updated(x),
166-
EventNodeCreated: lambda x: self.node_created(x),
167-
EventNodeDeleted: lambda x: self.node_deleted(x),
168-
EventCrashReported: lambda x: self.file_added(x),
169-
EventFileAdded: lambda x: self.file_added(x),
170-
}
171-
172-
for event_cls in events:
173-
if isinstance(message.event, event_cls):
174-
events[event_cls](message.event)
156+
def add_message(self, message_obj: Any) -> None:
157+
message = parse_event_message(message_obj)
158+
159+
event = message.event
160+
if isinstance(event, EventPoolCreated):
161+
self.pool_created(event)
162+
elif isinstance(event, EventPoolDeleted):
163+
self.pool_deleted(event)
164+
elif isinstance(event, EventTaskCreated):
165+
self.task_created(event)
166+
elif isinstance(event, EventTaskStopped):
167+
self.task_stopped(event)
168+
elif isinstance(event, EventTaskFailed):
169+
self.task_failed(event)
170+
elif isinstance(event, EventTaskStateUpdated):
171+
self.task_state_updated(event)
172+
elif isinstance(event, EventJobCreated):
173+
self.job_created(event)
174+
elif isinstance(event, EventJobStopped):
175+
self.job_stopped(event)
176+
elif isinstance(event, EventNodeStateUpdated):
177+
self.node_state_updated(event)
178+
elif isinstance(event, EventNodeCreated):
179+
self.node_created(event)
180+
elif isinstance(event, EventNodeDeleted):
181+
self.node_deleted(event)
182+
elif isinstance(event, (EventCrashReported, EventFileAdded)):
183+
self.file_added(event)
175184

176185
self.last_update = datetime.now()
177186
messages = [x for x in self.messages][-99:]
178187
messages += [
179-
(datetime.now(), message.event_type, message.event.json(exclude_none=True))
188+
(
189+
datetime.now(),
190+
message.event_type,
191+
json.dumps(message_obj, sort_keys=True),
192+
)
180193
]
181194
self.messages = messages
182195

@@ -301,6 +314,10 @@ def task_stopped(self, event: EventTaskStopped) -> None:
301314
if event.task_id in self.tasks:
302315
del self.tasks[event.task_id]
303316

317+
def task_failed(self, event: EventTaskFailed) -> None:
318+
if event.task_id in self.tasks:
319+
del self.tasks[event.task_id]
320+
304321
def render_tasks(self) -> List:
305322
results = []
306323
for task in self.tasks.values():
@@ -352,6 +369,11 @@ def job_stopped(self, event: EventJobStopped) -> None:
352369
if event.job_id in self.jobs:
353370
del self.jobs[event.job_id]
354371

372+
to_remove = [x.task_id for x in self.tasks.values() if x.job_id == event.job_id]
373+
374+
for task_id in to_remove:
375+
del self.tasks[task_id]
376+
355377
def render_jobs(self) -> List[Tuple]:
356378
results: List[Tuple] = []
357379

src/cli/onefuzz/status/top.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from threading import Thread
1010
from typing import Any, Optional
1111

12-
from onefuzztypes.events import EventMessage
13-
1412
from .cache import JobFilter, TopCache
1513
from .signalr import Stream
1614
from .top_view import render
@@ -50,8 +48,7 @@ def add_container(self, name: str) -> None:
5048

5149
def handler(self, message: Any) -> None:
5250
for event_raw in message:
53-
message = EventMessage.parse_obj(event_raw)
54-
self.cache.add_message(message)
51+
self.cache.add_message(event_raw)
5552

5653
def setup(self) -> Stream:
5754
client = Stream(self.onefuzz, self.logger)

src/pytypes/onefuzztypes/events.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from datetime import datetime
77
from enum import Enum
8-
from typing import List, Optional, Union
8+
from typing import Any, Dict, List, Optional, Union
99
from uuid import UUID, uuid4
1010

1111
from pydantic import BaseModel, Field
@@ -91,7 +91,7 @@ class EventTaskHeartbeat(BaseEvent):
9191
config: TaskConfig
9292

9393

94-
class EventPing(BaseResponse):
94+
class EventPing(BaseEvent, BaseResponse):
9595
ping_id: UUID
9696

9797

@@ -300,3 +300,22 @@ class EventMessage(BaseEvent):
300300
event: Event
301301
instance_id: UUID
302302
instance_name: str
303+
304+
305+
# because Pydantic does not yet have discriminated union types yet, parse events
306+
# by hand. https://github.com/samuelcolvin/pydantic/issues/619
307+
def parse_event_message(data: Dict[str, Any]) -> EventMessage:
308+
instance_id = UUID(data["instance_id"])
309+
instance_name = data["instance_name"]
310+
event_id = UUID(data["event_id"])
311+
event_type = EventType[data["event_type"]]
312+
# mypy incorrectly identifies this as having not supported parse_obj yet
313+
event = EventTypeMap[event_type].parse_obj(data["event"]) # type: ignore
314+
315+
return EventMessage(
316+
event_id=event_id,
317+
event_type=event_type,
318+
event=event,
319+
instance_id=instance_id,
320+
instance_name=instance_name,
321+
)

src/pytypes/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
pydantic~=1.8.1 --no-binary=pydantic
1+
pydantic==1.8.2 --no-binary=pydantic

0 commit comments

Comments
 (0)