|
10 | 10 | import http from "node:http"; |
11 | 11 | import { createServer, type ServerInstance } from "../../server.js"; |
12 | 12 | import type { Fixture } from "../../types.js"; |
| 13 | +import type { WSTestClient } from "../ws-test-client.js"; |
| 14 | +import { extractShape, type SSEEventShape } from "./schema.js"; |
| 15 | + |
| 16 | +import { classifyGeminiMessage } from "./ws-providers.js"; |
| 17 | + |
| 18 | +export { classifyGeminiMessage }; |
13 | 19 |
|
14 | 20 | // --------------------------------------------------------------------------- |
15 | 21 | // HTTP helpers |
@@ -101,3 +107,77 @@ export async function startDriftServer(): Promise<ServerInstance> { |
101 | 107 | export async function stopDriftServer(instance: ServerInstance): Promise<void> { |
102 | 108 | await new Promise<void>((r) => instance.server.close(() => r())); |
103 | 109 | } |
| 110 | + |
| 111 | +// --------------------------------------------------------------------------- |
| 112 | +// WebSocket helpers |
| 113 | +// --------------------------------------------------------------------------- |
| 114 | + |
| 115 | +export const GEMINI_WS_PATH = |
| 116 | + "/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"; |
| 117 | + |
| 118 | +/** |
| 119 | + * Collect mock WS messages until a terminal predicate fires. |
| 120 | + * |
| 121 | + * Uses a polling loop on waitForMessages() since ws-test-client doesn't |
| 122 | + * support predicate-based collection. The `skip` parameter tells us how |
| 123 | + * many messages have already been consumed so we don't re-read them. |
| 124 | + * |
| 125 | + * Throws if the terminal predicate never fires before the timeout expires. |
| 126 | + */ |
| 127 | +export async function collectMockWSMessages( |
| 128 | + client: WSTestClient, |
| 129 | + terminal: (msg: unknown) => boolean, |
| 130 | + timeoutMs = 15000, |
| 131 | + skip = 0, |
| 132 | +): Promise<{ events: SSEEventShape[]; rawMessages: unknown[] }> { |
| 133 | + const rawMessages: unknown[] = []; |
| 134 | + const deadline = Date.now() + timeoutMs; |
| 135 | + let count = skip; |
| 136 | + let terminated = false; |
| 137 | + |
| 138 | + while (Date.now() < deadline) { |
| 139 | + const nextCount = count + 1; |
| 140 | + let msgs: string[]; |
| 141 | + try { |
| 142 | + msgs = await client.waitForMessages(nextCount, Math.min(2000, deadline - Date.now())); |
| 143 | + } catch (e: unknown) { |
| 144 | + // Only suppress waitForMessages timeout — rethrow anything else |
| 145 | + if (e instanceof Error && e.message.includes("Timeout waiting for")) { |
| 146 | + if (Date.now() >= deadline) break; |
| 147 | + continue; |
| 148 | + } |
| 149 | + throw e; |
| 150 | + } |
| 151 | + // Only increment count after successful receipt |
| 152 | + count = nextCount; |
| 153 | + const latest = msgs[count - 1]; |
| 154 | + let parsed: unknown; |
| 155 | + try { |
| 156 | + parsed = typeof latest === "string" ? JSON.parse(latest) : latest; |
| 157 | + } catch { |
| 158 | + throw new Error( |
| 159 | + `collectMockWSMessages: failed to parse message ${count}: ${String(latest).slice(0, 200)}`, |
| 160 | + ); |
| 161 | + } |
| 162 | + rawMessages.push(parsed); |
| 163 | + if (terminal(parsed)) { |
| 164 | + terminated = true; |
| 165 | + break; |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + if (!terminated) { |
| 170 | + throw new Error( |
| 171 | + `collectMockWSMessages timed out after ${timeoutMs}ms without terminal message. ` + |
| 172 | + `Collected ${rawMessages.length} messages.`, |
| 173 | + ); |
| 174 | + } |
| 175 | + |
| 176 | + const events: SSEEventShape[] = rawMessages.map((msg) => { |
| 177 | + const m = msg as Record<string, any>; |
| 178 | + const type = m.type ?? classifyGeminiMessage(m as Record<string, unknown>); |
| 179 | + return { type, dataShape: extractShape(msg) }; |
| 180 | + }); |
| 181 | + |
| 182 | + return { events, rawMessages }; |
| 183 | +} |
0 commit comments