|
5 | 5 | * 2.0. |
6 | 6 | */ |
7 | 7 |
|
8 | | -import type { JSONValue } from '../types'; |
| 8 | +import type { UIMessageChunk } from 'ai'; |
9 | 9 |
|
10 | | -export interface StreamPart<CODE extends string, NAME extends string, TYPE> { |
11 | | - code: CODE; |
12 | | - name: NAME; |
13 | | - parse: (value: JSONValue) => { type: NAME; value: TYPE }; |
| 10 | +interface ReadDataStreamOptions { |
| 11 | + isAborted?: () => boolean; |
14 | 12 | } |
15 | 13 |
|
16 | | -type StreamParts = |
17 | | - | typeof textStreamPart |
18 | | - | typeof errorStreamPart |
19 | | - | typeof messageAnnotationsStreamPart; |
20 | | -/** |
21 | | - * Maps the type of a stream part to its value type. |
22 | | - */ |
23 | | -type StreamPartValueType = { |
24 | | - [P in StreamParts as P['name']]: ReturnType<P['parse']>['value']; |
25 | | -}; |
26 | | - |
27 | | -export type StreamPartType = |
28 | | - | ReturnType<typeof textStreamPart.parse> |
29 | | - | ReturnType<typeof errorStreamPart.parse> |
30 | | - | ReturnType<typeof messageAnnotationsStreamPart.parse> |
31 | | - | ReturnType<typeof bufferStreamPart.parse>; |
32 | | - |
33 | | -const NEWLINE = '\n'.charCodeAt(0); |
34 | | - |
35 | | -const concatChunks = (chunks: Uint8Array[], totalLength: number) => { |
36 | | - const concatenatedChunks = new Uint8Array(totalLength); |
37 | | - |
38 | | - let offset = 0; |
39 | | - for (const chunk of chunks) { |
40 | | - concatenatedChunks.set(chunk, offset); |
41 | | - offset += chunk.length; |
42 | | - } |
43 | | - chunks.length = 0; |
44 | | - |
45 | | - return concatenatedChunks; |
46 | | -}; |
| 14 | +const EVENT_SEPARATOR = '\n\n'; |
| 15 | +const EVENT_DATA_PREFIX = 'data: '; |
| 16 | +const STREAM_END_PAYLOAD = '[DONE]'; |
47 | 17 |
|
48 | 18 | export async function* readDataStream( |
49 | 19 | reader: ReadableStreamDefaultReader<Uint8Array>, |
50 | | - { isAborted }: { isAborted?: () => boolean } = {} |
51 | | -): AsyncGenerator<StreamPartType> { |
| 20 | + { isAborted }: ReadDataStreamOptions = {} |
| 21 | +): AsyncGenerator<UIMessageChunk> { |
52 | 22 | const decoder = new TextDecoder(); |
53 | | - const chunks: Uint8Array[] = []; |
54 | | - let totalLength = 0; |
| 23 | + let buffer = ''; |
55 | 24 |
|
56 | 25 | while (true) { |
57 | | - const { value } = await reader.read(); |
| 26 | + const { done, value } = await reader.read(); |
58 | 27 |
|
59 | 28 | if (value) { |
60 | | - chunks.push(value); |
61 | | - totalLength += value.length; |
62 | | - if (value[value.length - 1] !== NEWLINE) { |
63 | | - continue; |
| 29 | + buffer += decoder.decode(value, { stream: true }); |
| 30 | + |
| 31 | + let separatorIndex = buffer.indexOf(EVENT_SEPARATOR); |
| 32 | + while (separatorIndex !== -1) { |
| 33 | + const event = buffer.slice(0, separatorIndex); |
| 34 | + buffer = buffer.slice(separatorIndex + EVENT_SEPARATOR.length); |
| 35 | + separatorIndex = buffer.indexOf(EVENT_SEPARATOR); |
| 36 | + |
| 37 | + if (!event.startsWith(EVENT_DATA_PREFIX)) { |
| 38 | + continue; |
| 39 | + } |
| 40 | + |
| 41 | + const payload = event.slice(EVENT_DATA_PREFIX.length); |
| 42 | + |
| 43 | + if (payload === STREAM_END_PAYLOAD) { |
| 44 | + return; |
| 45 | + } |
| 46 | + |
| 47 | + const message = JSON.parse(payload); |
| 48 | + if (isUIMessageChunk(message)) { |
| 49 | + yield message; |
| 50 | + } else { |
| 51 | + throw new Error(`Unsupported stream event: ${payload}`); |
| 52 | + } |
64 | 53 | } |
65 | 54 | } |
66 | 55 |
|
67 | | - if (chunks.length === 0) { |
| 56 | + if (done) { |
| 57 | + if (buffer.length) { |
| 58 | + let payload = buffer.trim(); |
| 59 | + if (payload.startsWith(EVENT_DATA_PREFIX)) { |
| 60 | + payload = payload.slice(EVENT_DATA_PREFIX.length).trim(); |
| 61 | + } |
| 62 | + if (payload && payload !== STREAM_END_PAYLOAD) { |
| 63 | + const message = JSON.parse(payload); |
| 64 | + if (isUIMessageChunk(message)) { |
| 65 | + yield message; |
| 66 | + } else { |
| 67 | + throw new Error(`Unsupported stream event: ${payload}`); |
| 68 | + } |
| 69 | + } |
| 70 | + } |
68 | 71 | break; |
69 | 72 | } |
70 | 73 |
|
71 | | - const concatenatedChunks = concatChunks(chunks, totalLength); |
72 | | - totalLength = 0; |
73 | | - |
74 | | - const streamParts = decoder |
75 | | - .decode(concatenatedChunks, { stream: true }) |
76 | | - .split('\n') |
77 | | - .filter((line) => line !== '') |
78 | | - .map(parseStreamPart); |
79 | | - |
80 | | - for (const streamPart of streamParts) { |
81 | | - yield streamPart; |
82 | | - } |
83 | | - |
84 | 74 | if (isAborted?.()) { |
85 | | - reader.cancel(); |
| 75 | + await reader.cancel(); |
86 | 76 | break; |
87 | 77 | } |
88 | 78 | } |
89 | 79 | } |
90 | 80 |
|
91 | | -const createStreamPart = <CODE extends string, NAME extends string, TYPE>( |
92 | | - code: CODE, |
93 | | - name: NAME, |
94 | | - parse: (value: JSONValue) => { type: NAME; value: TYPE } |
95 | | -): StreamPart<CODE, NAME, TYPE> => { |
96 | | - return { |
97 | | - code, |
98 | | - name, |
99 | | - parse, |
100 | | - }; |
101 | | -}; |
102 | | - |
103 | | -const textStreamPart = createStreamPart('0', 'text', (value) => { |
104 | | - if (typeof value !== 'string') { |
105 | | - throw new Error('"text" parts expect a string value.'); |
106 | | - } |
107 | | - return { type: 'text', value }; |
108 | | -}); |
109 | | - |
110 | | -const errorStreamPart = createStreamPart('3', 'error', (value) => { |
111 | | - if (typeof value !== 'string') { |
112 | | - throw new Error('"error" parts expect a string value.'); |
113 | | - } |
114 | | - return { type: 'error', value }; |
115 | | -}); |
116 | | - |
117 | | -const messageAnnotationsStreamPart = createStreamPart('8', 'message_annotations', (value) => { |
118 | | - if (!Array.isArray(value)) { |
119 | | - throw new Error('"message_annotations" parts expect an array value.'); |
120 | | - } |
121 | | - |
122 | | - return { type: 'message_annotations', value }; |
123 | | -}); |
124 | | - |
125 | | -const bufferStreamPart = createStreamPart('10', 'buffer', (value) => { |
126 | | - if (typeof value !== 'string') { |
127 | | - throw new Error('"buffer" parts expect a string value.'); |
128 | | - } |
129 | | - |
130 | | - return { type: 'buffer', value }; |
131 | | -}); |
132 | | - |
133 | | -const streamParts = [ |
134 | | - textStreamPart, |
135 | | - errorStreamPart, |
136 | | - bufferStreamPart, |
137 | | - messageAnnotationsStreamPart, |
138 | | -] as const; |
139 | | - |
140 | | -type StreamPartMap = { |
141 | | - [P in StreamParts as P['code']]: P; |
142 | | -}; |
143 | | - |
144 | | -const streamPartsByCode: StreamPartMap = streamParts.reduce( |
145 | | - (acc, part) => ({ |
146 | | - ...acc, |
147 | | - [part.code]: part, |
148 | | - }), |
149 | | - {} as StreamPartMap |
150 | | -); |
151 | | - |
152 | | -const validCodes = streamParts.map((part) => part.code); |
153 | | - |
154 | | -export const parseStreamPart = (line: string): StreamPartType => { |
155 | | - const firstSeparatorIndex = line.indexOf(':'); |
156 | | - |
157 | | - if (firstSeparatorIndex === -1) { |
158 | | - throw new Error('Failed to parse stream string. No separator found.'); |
159 | | - } |
160 | | - |
161 | | - const prefix = line.slice(0, firstSeparatorIndex) as keyof StreamPartMap; |
162 | | - |
163 | | - if (!validCodes.includes(prefix)) { |
164 | | - throw new Error(`Failed to parse stream string. Invalid code ${prefix}.`); |
165 | | - } |
166 | | - |
167 | | - const code = prefix as keyof StreamPartMap; |
168 | | - |
169 | | - const textValue = line.slice(firstSeparatorIndex + 1); |
170 | | - const jsonValue: JSONValue = JSON.parse(textValue); |
171 | | - |
172 | | - return streamPartsByCode[code].parse(jsonValue); |
173 | | -}; |
174 | | - |
175 | | -export const formatStreamPart = <T extends keyof StreamPartValueType>( |
176 | | - type: T, |
177 | | - value: StreamPartValueType[T] |
178 | | -): string => { |
179 | | - const streamPart = streamParts.find((part) => part.name === type); |
180 | | - |
181 | | - if (!streamPart) { |
182 | | - throw new Error(`Invalid stream part type: ${type as string}`); |
183 | | - } |
184 | | - |
185 | | - return `${streamPart.code}:${JSON.stringify(value)}\n`; |
186 | | -}; |
| 81 | +function isUIMessageChunk(message: unknown): message is UIMessageChunk { |
| 82 | + return Boolean(message && typeof message === 'object' && 'type' in message); |
| 83 | +} |
0 commit comments