diff --git a/.changeset/fair-pianos-smile.md b/.changeset/fair-pianos-smile.md new file mode 100644 index 00000000..4a5b5a22 --- /dev/null +++ b/.changeset/fair-pianos-smile.md @@ -0,0 +1,5 @@ +--- +'@storybook/mcp': patch +--- + +Forward `sources` through `createStorybookMcpHandler()` into the per-request transport context. diff --git a/packages/mcp/src/index.test.ts b/packages/mcp/src/index.test.ts index 0a25424d..8715883e 100644 --- a/packages/mcp/src/index.test.ts +++ b/packages/mcp/src/index.test.ts @@ -2,6 +2,7 @@ import { describe, it, expect, vi, afterEach } from 'vitest'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { createStorybookMcpHandler } from './index.ts'; +import type { StorybookContext } from './types.ts'; import smallManifestFixture from '../fixtures/small-manifest.fixture.json' with { type: 'json' }; import smallDocsManifestFixture from '../fixtures/small-docs-manifest.fixture.json' with { type: 'json' }; @@ -42,12 +43,15 @@ describe('createStorybookMcpHandler', () => { /** * Helper to setup client with a mock fetch that routes to our handler */ - async function setupClient(handler: Awaited>) { + async function setupClient( + handler: Awaited>, + context?: StorybookContext, + ) { // Mock global fetch to route to our handler fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => { const url = typeof input === 'string' ? input : input instanceof URL ? input.href : input.url; const request = new Request(url, init); - return await handler(request); + return await handler(request, context); }); (global as any).fetch = fetchMock; @@ -168,6 +172,85 @@ describe('createStorybookMcpHandler', () => { }); }); + it('should forward handler-level sources into the transport context', async () => { + const manifestProvider = createManifestProviderMockWithDocs(); + const sources = [ + { id: 'local', title: 'Local' }, + { id: 'remote', title: 'Remote', url: 'https://example.com/storybook' }, + ]; + + const handler = await createStorybookMcpHandler({ + manifestProvider, + sources, + }); + await setupClient(handler); + + await client.callTool({ + name: 'list-all-documentation', + arguments: {}, + }); + + expect(manifestProvider).toHaveBeenCalledWith( + expect.any(Request), + './manifests/components.json', + sources[0], + ); + expect(manifestProvider).toHaveBeenCalledWith( + expect.any(Request), + './manifests/docs.json', + sources[0], + ); + expect(manifestProvider).toHaveBeenCalledWith( + expect.any(Request), + './manifests/components.json', + sources[1], + ); + expect(manifestProvider).toHaveBeenCalledWith( + expect.any(Request), + './manifests/docs.json', + sources[1], + ); + }); + + it('should allow per-request sources to override handler-level sources', async () => { + const manifestProvider = createManifestProviderMockWithDocs(); + const handlerSources = [ + { id: 'handler', title: 'Handler', url: 'https://handler.example.com' }, + ]; + const requestSources = [ + { id: 'request', title: 'Request', url: 'https://request.example.com' }, + ]; + + const handler = await createStorybookMcpHandler({ + manifestProvider, + sources: handlerSources, + }); + await setupClient(handler, { + sources: requestSources, + }); + + await client.callTool({ + name: 'list-all-documentation', + arguments: {}, + }); + + expect(manifestProvider).toHaveBeenCalledWith( + expect.any(Request), + './manifests/components.json', + requestSources[0], + ); + expect(manifestProvider).toHaveBeenCalledWith( + expect.any(Request), + './manifests/docs.json', + requestSources[0], + ); + expect(manifestProvider).not.toHaveBeenCalledWith( + expect.any(Request), + expect.any(String), + handlerSources[0], + ); + }); + it('should call onListAllDocumentation handler when tool is invoked', async () => { const onListAllDocumentation = vi.fn(); const manifestProvider = createManifestProviderMock(); diff --git a/packages/mcp/src/index.ts b/packages/mcp/src/index.ts index 57d55d58..ffe9f378 100644 --- a/packages/mcp/src/index.ts +++ b/packages/mcp/src/index.ts @@ -76,6 +76,7 @@ type Handler = (req: Request, context?: StorybookContext) => Promise; export const createStorybookMcpHandler = async ( options: StorybookMcpHandlerOptions = {}, ): Promise => { + const { onSessionInitialize, ...defaultContext } = options; const adapter = new ValibotJsonSchemaAdapter(); const server = new McpServer( { @@ -92,8 +93,8 @@ export const createStorybookMcpHandler = async ( }, ).withContext(); - if (options.onSessionInitialize) { - server.on('initialize', options.onSessionInitialize); + if (onSessionInitialize) { + server.on('initialize', onSessionInitialize); } await addListAllDocumentationTool(server); @@ -104,10 +105,9 @@ export const createStorybookMcpHandler = async ( return (async (req, context) => { return await transport.respond(req, { + ...defaultContext, + ...context, request: req, - manifestProvider: context?.manifestProvider ?? options.manifestProvider, - onListAllDocumentation: context?.onListAllDocumentation ?? options.onListAllDocumentation, - onGetDocumentation: context?.onGetDocumentation ?? options.onGetDocumentation, }); }) as Handler; };