diff --git a/.changeset/quick-teachers-listen.md b/.changeset/quick-teachers-listen.md new file mode 100644 index 000000000000..470894c92079 --- /dev/null +++ b/.changeset/quick-teachers-listen.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +Add LlamaIndexWorkflow adapter diff --git a/content/providers/04-adapters/02-llamaindex.mdx b/content/providers/04-adapters/02-llamaindex.mdx index d9b40201598f..443b6f1c7407 100644 --- a/content/providers/04-adapters/02-llamaindex.mdx +++ b/content/providers/04-adapters/02-llamaindex.mdx @@ -7,11 +7,45 @@ description: Learn how to use LlamaIndex with the AI SDK. [LlamaIndex](https://ts.llamaindex.ai/) is a framework for building LLM-powered applications. LlamaIndex helps you ingest, structure, and access private or domain-specific data. LlamaIndex.TS offers the core features of LlamaIndex for Python for popular runtimes like Node.js (official support), Vercel Edge Functions (experimental), and Deno (experimental). -## Example: Completion +## Example: -Here is a basic example that uses both AI SDK and LlamaIndex together with the [Next.js](https://nextjs.org/docs) App Router. +Here is a basic example that uses both the AI SDK and LlamaIndex together with the [Next.js](https://nextjs.org/docs) App Router. -The AI SDK [`LlamaIndexAdapter`](/docs/reference/stream-helpers/llamaindex-adapter) uses the stream result from calling the `chat` method on a [LlamaIndex ChatEngine](https://ts.llamaindex.ai/modules/chat_engine) or the `query` method on a [LlamaIndex QueryEngine](https://ts.llamaindex.ai/modules/query_engines) to pipe text to the client. +You can use the AI SDK's [`LlamaIndexWorkflowAdapter`](/docs/reference/stream-helpers/llamaindex-workflow-adapter) to stream events from a LlamaIndex [AgentWorkflow or Workflow](https://ts.llamaindex.ai/docs/llamaindex/modules/agents/agent_workflow) to the client. + +```tsx filename="app/api/completion/route.ts" highlight="17" +import { tool, openai } from 'llamaindex'; +import { agent, startAgentEvent, run } from '@llamaindex/workflow'; +import { LlamaIndexWorkflowAdapter } from 'ai'; +import { z } from 'zod'; + +const calculatorAgent = agent({ + tools: [ + tool({ + name: 'add', + description: 'Adds two numbers', + parameters: z.object({ x: z.number(), y: z.number() }), + execute: ({ x, y }) => x + y, + }), + ], + llm: openai({ model: 'gpt-4o' }), +}); + +export async function POST(req: Request) { + const { prompt } = await req.json(); + + const eventStream = await run( + calculatorAgent, + startAgentEvent.with({ + userInput: prompt, + }), + ); + + return LlamaIndexWorkflowAdapter.toDataStreamResponse(eventStream); +} +``` + +Alternatively, you can use the AI SDK's [`LlamaIndexAdapter`](/docs/reference/stream-helpers/llamaindex-adapter) to stream the result from calling the `chat` method on a [LlamaIndex ChatEngine](https://ts.llamaindex.ai/modules/chat_engine) or the `query` method on a [LlamaIndex QueryEngine](https://ts.llamaindex.ai/modules/query_engines), piping text directly to the client. ```tsx filename="app/api/completion/route.ts" highlight="17" import { OpenAI, SimpleChatEngine } from 'llamaindex'; @@ -34,7 +68,7 @@ export async function POST(req: Request) { } ``` -Then, we use the AI SDK's [`useCompletion`](/docs/ai-sdk-ui/completion) method in the page component to handle the completion: +Finally, use the AI SDK's [`useCompletion`](/docs/ai-sdk-ui/completion) hook in your page component to handle completions: ```tsx filename="app/page.tsx" 'use client'; @@ -58,4 +92,4 @@ export default function Chat() { ## More Examples -[create-llama](https://github.com/run-llama/create-llama) is the easiest way to get started with LlamaIndex. It uses the AI SDK to connect to LlamaIndex in all its generated code. +[create-llama](https://github.com/run-llama/create-llama) is the easiest way to get started with LlamaIndex. It uses the AI SDK to connect to LlamaIndex in all of its generated code. diff --git a/packages/ai/streams/index.ts b/packages/ai/streams/index.ts index 91eb3b5233bf..84d9c6f97a6f 100644 --- a/packages/ai/streams/index.ts +++ b/packages/ai/streams/index.ts @@ -4,4 +4,5 @@ export * from '../errors/index'; export * from './assistant-response'; export * as LangChainAdapter from './langchain-adapter'; export * as LlamaIndexAdapter from './llamaindex-adapter'; +export * as LlamaIndexWorkflowAdapter from './llamaindex-workflow-adapter'; export * from './stream-data'; diff --git a/packages/ai/streams/llamaindex-workflow-adapter.test.ts b/packages/ai/streams/llamaindex-workflow-adapter.test.ts new file mode 100644 index 000000000000..0cc1e850de88 --- /dev/null +++ b/packages/ai/streams/llamaindex-workflow-adapter.test.ts @@ -0,0 +1,106 @@ +import { + convertReadableStreamToArray, + convertArrayToAsyncIterable, + convertResponseStreamToArray, +} from '@ai-sdk/provider-utils/test'; +import { + toDataStreamResponse, + toDataStream, +} from './llamaindex-workflow-adapter'; + +describe('toWorkflowDataStream', () => { + it('should convert AsyncIterable', async () => { + const inputStream = convertArrayToAsyncIterable([ + { data: { delta: 'Hello' } }, + { data: { foo: 'bar' } }, + { data: { delta: 'World' } }, + ]); + + assert.deepStrictEqual( + await convertReadableStreamToArray(toDataStream(inputStream)), + ['0:"Hello"\n', '8:[{"foo":"bar"}]\n', '0:"World"\n'], + ); + }); + + it('should support callbacks', async () => { + const inputStream = convertArrayToAsyncIterable([ + { data: { delta: 'Hello' } }, + { data: { delta: 'World' } }, + ]); + + const callbacks = { + onText: vi.fn(), + }; + + const dataStream = toDataStream(inputStream, callbacks); + + await convertReadableStreamToArray(dataStream); + + expect(callbacks.onText).toHaveBeenCalledTimes(2); + expect(callbacks.onText).toHaveBeenNthCalledWith( + 1, + 'Hello', + expect.anything(), + ); + expect(callbacks.onText).toHaveBeenNthCalledWith( + 2, + 'World', + expect.anything(), + ); + }); +}); + +describe('toDataStreamResponse', () => { + it('should convert AsyncIterable to a Response', async () => { + const inputStream = convertArrayToAsyncIterable([ + { data: { delta: 'Hello' } }, + { data: { delta: 'World' } }, + ]); + + const response = toDataStreamResponse(inputStream); + + assert.strictEqual(response.status, 200); + assert.deepStrictEqual(Object.fromEntries(response.headers.entries()), { + 'content-type': 'text/plain; charset=utf-8', + 'x-vercel-ai-data-stream': 'v1', + }); + + assert.deepStrictEqual(await convertResponseStreamToArray(response), [ + '0:"Hello"\n', + '0:"World"\n', + ]); + }); + + it('should support callbacks', async () => { + const inputStream = convertArrayToAsyncIterable([ + { data: { delta: 'Hello' } }, + { data: { delta: 'World' } }, + ]); + + const callbacks = { + onText: vi.fn(), + }; + + const response = toDataStreamResponse(inputStream, { + callbacks, + }); + + assert.strictEqual(response.status, 200); + assert.deepStrictEqual(Object.fromEntries(response.headers.entries()), { + 'content-type': 'text/plain; charset=utf-8', + 'x-vercel-ai-data-stream': 'v1', + }); + + assert.deepStrictEqual(await convertResponseStreamToArray(response), [ + '0:"Hello"\n', + '0:"World"\n', + ]); + + expect(callbacks.onText).toHaveBeenCalledTimes(2); + expect(callbacks.onText).toHaveBeenNthCalledWith( + 1, + 'Hello', + expect.anything(), + ); + }); +}); diff --git a/packages/ai/streams/llamaindex-workflow-adapter.ts b/packages/ai/streams/llamaindex-workflow-adapter.ts new file mode 100644 index 000000000000..f3b969684d27 --- /dev/null +++ b/packages/ai/streams/llamaindex-workflow-adapter.ts @@ -0,0 +1,106 @@ +import { formatDataStreamPart, JSONValue } from '@ai-sdk/ui-utils'; +import { DataStreamWriter } from '../core/data-stream/data-stream-writer'; +import { createDataStream } from '../core'; +import { prepareResponseHeaders } from '../core/util/prepare-response-headers'; + +type WorkflowEventData = { + data: T; +}; + +interface StreamCallbacks { + /** `onStart`: Called once when the stream is initialized. */ + onStart?: (dataStreamWriter: DataStreamWriter) => Promise | void; + + /** `onFinal`: Called once when the stream is closed with the final completion message. */ + onFinal?: ( + completion: string, + dataStreamWriter: DataStreamWriter, + ) => Promise | void; + + /** `onToken`: Called for each tokenized message. */ + onToken?: ( + token: string, + dataStreamWriter: DataStreamWriter, + ) => Promise | void; + + /** `onText`: Called for each text chunk. */ + onText?: ( + text: string, + dataStreamWriter: DataStreamWriter, + ) => Promise | void; +} + +export function toDataStream( + stream: AsyncIterable>, + callbacks?: StreamCallbacks, +) { + let completionText = ''; + let hasStarted = false; + + return createDataStream({ + execute: async (dataStreamWriter: DataStreamWriter) => { + if (!hasStarted && callbacks?.onStart) { + await callbacks.onStart(dataStreamWriter); + hasStarted = true; + } + + for await (const event of stream) { + const data = event.data; + + if (isTextStream(data)) { + const content = data.delta; + if (content) { + completionText += content; + dataStreamWriter.write(formatDataStreamPart('text', content)); + + if (callbacks?.onText) { + await callbacks.onText(content, dataStreamWriter); + } + } + } else { + dataStreamWriter.writeMessageAnnotation(data as JSONValue); + } + } + + if (callbacks?.onFinal) { + await callbacks.onFinal(completionText, dataStreamWriter); + } + }, + onError: (error: unknown) => { + return error instanceof Error + ? error.message + : 'An unknown error occurred during stream finalization'; + }, + }); +} + +export function toDataStreamResponse( + stream: AsyncIterable>, + options: { + init?: ResponseInit; + callbacks?: StreamCallbacks; + } = {}, +) { + const { init, callbacks } = options; + const dataStream = toDataStream(stream, callbacks).pipeThrough( + new TextEncoderStream(), + ); + + return new Response(dataStream, { + status: init?.status ?? 200, + statusText: init?.statusText, + headers: prepareResponseHeaders(init?.headers, { + contentType: 'text/plain; charset=utf-8', + dataStreamVersion: 'v1', + }), + }); +} + +function isTextStream(data: unknown): data is { delta: string } { + return ( + typeof data === 'object' && + data !== null && + 'delta' in data && + typeof (data as any).delta === 'string' + ); +}