|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import json |
6 | | -import time |
7 | 6 |
|
8 | 7 | from bernstein.core.server.sse_events import SSEEvent, SSEEventType |
9 | 8 |
|
10 | 9 |
|
11 | 10 | class TestSSEEventType: |
12 | | - """Tests for the SSEEventType enum.""" |
13 | | - |
14 | | - def test_all_14_event_types_defined(self) -> None: |
| 11 | + def test_has_14_members(self) -> None: |
15 | 12 | assert len(SSEEventType) == 14 |
16 | 13 |
|
17 | 14 | def test_event_type_values_are_dotted(self) -> None: |
18 | 15 | for member in SSEEventType: |
| 16 | + if member == SSEEventType.HEARTBEAT: |
| 17 | + continue |
19 | 18 | assert "." in member.value, f"{member.name} should have dotted value" |
20 | 19 |
|
21 | | - def test_event_type_is_str_enum(self) -> None: |
22 | | - assert isinstance(SSEEventType.TASK_CREATED, str) |
23 | | - assert SSEEventType.TASK_CREATED == "task.created" |
24 | | - |
25 | | - |
26 | | -class TestSSEEventToSSE: |
27 | | - """Tests for SSE wire format output.""" |
28 | | - |
29 | | - def test_to_sse_starts_with_event(self) -> None: |
30 | | - event = SSEEvent.task_created("t1", "do stuff", "backend", "medium") |
31 | | - wire = event.to_sse() |
| 20 | + def test_all_expected_types_exist(self) -> None: |
| 21 | + expected = { |
| 22 | + "TASK_CREATED", "TASK_CLAIMED", "TASK_COMPLETED", "TASK_FAILED", |
| 23 | + "TASK_RETRIED", "AGENT_SPAWNED", "AGENT_EXITED", "GATE_RESULT", |
| 24 | + "COST_UPDATE", "MERGE_STARTED", "MERGE_COMPLETED", |
| 25 | + "RUN_STARTED", "RUN_COMPLETED", "HEARTBEAT", |
| 26 | + } |
| 27 | + actual = {m.name for m in SSEEventType} |
| 28 | + assert expected == actual |
| 29 | + |
| 30 | + |
| 31 | +class TestSSEEvent: |
| 32 | + def test_to_sse_wire_format(self) -> None: |
| 33 | + evt = SSEEvent(event=SSEEventType.TASK_CREATED, data={"task_id": "t1"}, timestamp=1000.0) |
| 34 | + wire = evt.to_sse() |
32 | 35 | assert wire.startswith("event: task.created\n") |
33 | | - |
34 | | - def test_to_sse_has_data_line(self) -> None: |
35 | | - event = SSEEvent.task_created("t1", "do stuff", "backend", "medium") |
36 | | - wire = event.to_sse() |
37 | | - lines = wire.strip().split("\n") |
38 | | - assert lines[1].startswith("data: ") |
39 | | - |
40 | | - def test_to_sse_ends_with_double_newline(self) -> None: |
41 | | - event = SSEEvent.task_created("t1", "do stuff", "backend", "medium") |
42 | | - wire = event.to_sse() |
| 36 | + assert "data: " in wire |
43 | 37 | assert wire.endswith("\n\n") |
44 | 38 |
|
45 | | - def test_to_sse_data_is_valid_json(self) -> None: |
46 | | - event = SSEEvent.task_created("t1", "do stuff", "backend", "medium") |
47 | | - wire = event.to_sse() |
48 | | - data_line = wire.strip().split("\n")[1] |
49 | | - payload = json.loads(data_line.removeprefix("data: ")) |
50 | | - assert isinstance(payload, dict) |
51 | | - |
52 | | - def test_to_sse_payload_contains_timestamp(self) -> None: |
53 | | - event = SSEEvent.task_created("t1", "do stuff", "backend", "medium") |
54 | | - wire = event.to_sse() |
55 | | - data_line = wire.strip().split("\n")[1] |
56 | | - payload = json.loads(data_line.removeprefix("data: ")) |
57 | | - assert "timestamp" in payload |
58 | | - assert isinstance(payload["timestamp"], float) |
59 | | - |
60 | | - |
61 | | -class TestSSEEventTimestamp: |
62 | | - """Tests for auto-generated timestamps.""" |
| 39 | + def test_to_sse_json_payload_valid(self) -> None: |
| 40 | + evt = SSEEvent(event=SSEEventType.TASK_COMPLETED, data={"task_id": "t2"}, timestamp=2000.0) |
| 41 | + wire = evt.to_sse() |
| 42 | + data_line = next(line for line in wire.split("\n") if line.startswith("data: ")) |
| 43 | + payload = json.loads(data_line[6:]) |
| 44 | + assert payload["task_id"] == "t2" |
| 45 | + assert payload["timestamp"] == 2000.0 |
63 | 46 |
|
64 | 47 | def test_timestamp_auto_generated(self) -> None: |
65 | | - before = time.time() |
66 | | - event = SSEEvent.task_created("t1", "goal", "role", "low") |
67 | | - after = time.time() |
68 | | - assert before <= event.timestamp <= after |
| 48 | + evt = SSEEvent(event=SSEEventType.HEARTBEAT, data={}) |
| 49 | + assert evt.timestamp > 0 |
69 | 50 |
|
70 | | - def test_timestamp_preserved_when_provided(self) -> None: |
71 | | - event = SSEEvent(SSEEventType.TASK_CREATED, {"task_id": "t1"}, timestamp=123.0) |
72 | | - assert event.timestamp == 123.0 |
| 51 | + def test_id_field_in_wire_format(self) -> None: |
| 52 | + evt = SSEEvent(event=SSEEventType.HEARTBEAT, data={}, id="evt-42") |
| 53 | + wire = evt.to_sse() |
| 54 | + assert "id: evt-42\n" in wire |
73 | 55 |
|
| 56 | + def test_no_id_by_default(self) -> None: |
| 57 | + evt = SSEEvent(event=SSEEventType.HEARTBEAT, data={}) |
| 58 | + wire = evt.to_sse() |
| 59 | + assert "id: " not in wire |
74 | 60 |
|
75 | | -class TestSSEEventFactories: |
76 | | - """Tests for each factory method.""" |
77 | 61 |
|
| 62 | +class TestSSEEventFactories: |
78 | 63 | def test_task_created(self) -> None: |
79 | | - event = SSEEvent.task_created("t1", "build API", "backend", "high") |
80 | | - assert event.event_type == SSEEventType.TASK_CREATED |
81 | | - assert event.data["task_id"] == "t1" |
82 | | - assert event.data["goal"] == "build API" |
83 | | - assert event.data["role"] == "backend" |
84 | | - assert event.data["complexity"] == "high" |
| 64 | + evt = SSEEvent.task_created(task_id="abc", title="Fix bug") |
| 65 | + assert evt.event == SSEEventType.TASK_CREATED |
| 66 | + assert evt.data["task_id"] == "abc" |
| 67 | + assert evt.data["title"] == "Fix bug" |
85 | 68 |
|
86 | 69 | def test_task_completed(self) -> None: |
87 | | - event = SSEEvent.task_completed("t1", "agent-1", "opus", 42.567, 0.12345) |
88 | | - assert event.event_type == SSEEventType.TASK_COMPLETED |
89 | | - assert event.data["task_id"] == "t1" |
90 | | - assert event.data["agent_id"] == "agent-1" |
91 | | - assert event.data["model"] == "opus" |
92 | | - assert event.data["duration_s"] == 42.57 |
93 | | - assert event.data["cost_usd"] == 0.1235 |
| 70 | + evt = SSEEvent.task_completed(task_id="abc", cost_usd=0.12) |
| 71 | + assert evt.event == SSEEventType.TASK_COMPLETED |
| 72 | + assert evt.data["task_id"] == "abc" |
| 73 | + assert evt.data["cost_usd"] == 0.12 |
94 | 74 |
|
95 | 75 | def test_task_failed(self) -> None: |
96 | | - event = SSEEvent.task_failed("t1", "timeout", True) |
97 | | - assert event.event_type == SSEEventType.TASK_FAILED |
98 | | - assert event.data["task_id"] == "t1" |
99 | | - assert event.data["reason"] == "timeout" |
100 | | - assert event.data["will_retry"] is True |
| 76 | + evt = SSEEvent.task_failed(task_id="abc", reason="timeout") |
| 77 | + assert evt.event == SSEEventType.TASK_FAILED |
| 78 | + assert evt.data["reason"] == "timeout" |
101 | 79 |
|
102 | 80 | def test_agent_spawned(self) -> None: |
103 | | - event = SSEEvent.agent_spawned("a1", "t1", "sonnet", "claude") |
104 | | - assert event.event_type == SSEEventType.AGENT_SPAWNED |
105 | | - assert event.data["agent_id"] == "a1" |
106 | | - assert event.data["task_id"] == "t1" |
107 | | - assert event.data["model"] == "sonnet" |
108 | | - assert event.data["adapter"] == "claude" |
| 81 | + evt = SSEEvent.agent_spawned(agent_id="a1", role="backend") |
| 82 | + assert evt.event == SSEEventType.AGENT_SPAWNED |
| 83 | + assert evt.data["agent_id"] == "a1" |
| 84 | + assert evt.data["role"] == "backend" |
109 | 85 |
|
110 | 86 | def test_gate_result_passed(self) -> None: |
111 | | - event = SSEEvent.gate_result("t1", "ruff", passed=True, details="clean") |
112 | | - assert event.event_type == SSEEventType.GATE_PASSED |
113 | | - assert event.data["passed"] is True |
114 | | - assert event.data["gate"] == "ruff" |
| 87 | + evt = SSEEvent.gate_result(gate_name="lint", passed=True) |
| 88 | + assert evt.event == SSEEventType.GATE_RESULT |
| 89 | + assert evt.data["passed"] is True |
115 | 90 |
|
116 | 91 | def test_gate_result_failed(self) -> None: |
117 | | - event = SSEEvent.gate_result("t1", "pytest", passed=False, details="3 failures") |
118 | | - assert event.event_type == SSEEventType.GATE_FAILED |
119 | | - assert event.data["passed"] is False |
| 92 | + evt = SSEEvent.gate_result(gate_name="test", passed=False) |
| 93 | + assert evt.event == SSEEventType.GATE_RESULT |
| 94 | + assert evt.data["passed"] is False |
120 | 95 |
|
121 | 96 | def test_cost_update(self) -> None: |
122 | | - event = SSEEvent.cost_update(1.23456, 10.0, 12.3456) |
123 | | - assert event.event_type == SSEEventType.COST_UPDATE |
124 | | - assert event.data["total_usd"] == 1.2346 |
125 | | - assert event.data["budget_usd"] == 10.0 |
126 | | - assert event.data["budget_pct"] == 12.3 |
| 97 | + evt = SSEEvent.cost_update(total_usd=1.23) |
| 98 | + assert evt.event == SSEEventType.COST_UPDATE |
| 99 | + assert evt.data["total_usd"] == 1.23 |
127 | 100 |
|
128 | 101 | def test_merge_completed(self) -> None: |
129 | | - event = SSEEvent.merge_completed("t1", "feat/x", "abc1234") |
130 | | - assert event.event_type == SSEEventType.MERGE_COMPLETED |
131 | | - assert event.data["branch"] == "feat/x" |
132 | | - assert event.data["commit_sha"] == "abc1234" |
| 102 | + evt = SSEEvent.merge_completed(branch="feat/x", result="success") |
| 103 | + assert evt.event == SSEEventType.MERGE_COMPLETED |
| 104 | + assert evt.data["branch"] == "feat/x" |
133 | 105 |
|
134 | 106 | def test_run_completed(self) -> None: |
135 | | - event = SSEEvent.run_completed(10, 8, 2, 5.6789) |
136 | | - assert event.event_type == SSEEventType.RUN_COMPLETED |
137 | | - assert event.data["total_tasks"] == 10 |
138 | | - assert event.data["passed"] == 8 |
139 | | - assert event.data["failed"] == 2 |
140 | | - assert event.data["total_cost_usd"] == 5.6789 |
| 107 | + evt = SSEEvent.run_completed(run_id="run-1") |
| 108 | + assert evt.event == SSEEventType.RUN_COMPLETED |
| 109 | + assert evt.data["run_id"] == "run-1" |
| 110 | + |
| 111 | + def test_extra_kwargs(self) -> None: |
| 112 | + evt = SSEEvent.task_created(task_id="t1", title="X", custom_field="val") |
| 113 | + assert evt.data["custom_field"] == "val" |
0 commit comments