Skip to content

Commit 3668618

Browse files
committed
fix(server): reject requests before initialization
1 parent 22595b9 commit 3668618

2 files changed

Lines changed: 85 additions & 4 deletions

File tree

packages/server/src/server/server.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export class Server extends Protocol<ServerContext> {
102102
private _instructions?: string;
103103
private _jsonSchemaValidator: jsonSchemaValidator;
104104
private _experimental?: { tasks: ExperimentalServerTasks };
105+
private _initialized = false;
105106

106107
/**
107108
* Callback for when initialization has fully completed (i.e., the client has sent an `notifications/initialized` notification).
@@ -132,7 +133,13 @@ export class Server extends Protocol<ServerContext> {
132133
}
133134

134135
this.setRequestHandler('initialize', request => this._oninitialize(request));
135-
this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.());
136+
this.setNotificationHandler('notifications/initialized', () => {
137+
if (!this._clientCapabilities) {
138+
throw new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Server not initialized');
139+
}
140+
this._initialized = true;
141+
this.oninitialized?.();
142+
});
136143

137144
if (this._capabilities.logging) {
138145
this._registerLoggingHandler();
@@ -226,8 +233,15 @@ export class Server extends Protocol<ServerContext> {
226233
method: string,
227234
handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result>
228235
): (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result> {
236+
const lifecycleHandler: (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result> = async (request, ctx) => {
237+
if (!ctx.http && !this._initialized && method !== 'initialize' && method !== 'ping') {
238+
throw new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Server not initialized');
239+
}
240+
return handler(request, ctx);
241+
};
242+
229243
if (method !== 'tools/call') {
230-
return handler;
244+
return lifecycleHandler;
231245
}
232246
return async (request, ctx) => {
233247
const validatedRequest = parseSchema(CallToolRequestSchema, request);
@@ -239,7 +253,7 @@ export class Server extends Protocol<ServerContext> {
239253

240254
const { params } = validatedRequest.data;
241255

242-
const result = await handler(request, ctx);
256+
const result = await lifecycleHandler(request, ctx);
243257

244258
// When task creation is requested, validate and return CreateTaskResult
245259
if (params.task) {

packages/server/test/server/server.test.ts

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { JSONRPCMessage } from '@modelcontextprotocol/core';
2-
import { InMemoryTransport, LATEST_PROTOCOL_VERSION } from '@modelcontextprotocol/core';
2+
import { InMemoryTransport, LATEST_PROTOCOL_VERSION, ProtocolErrorCode } from '@modelcontextprotocol/core';
33
import { Server } from '../../src/server/server.js';
44

55
describe('Server', () => {
@@ -38,5 +38,72 @@ describe('Server', () => {
3838

3939
await server.close();
4040
});
41+
42+
it('rejects requests before the initialized notification', async () => {
43+
const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: { tools: {} } });
44+
45+
server.setRequestHandler('tools/list', async () => ({ tools: [] }));
46+
47+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
48+
await server.connect(serverTransport);
49+
50+
const responses: JSONRPCMessage[] = [];
51+
clientTransport.onmessage = message => responses.push(message);
52+
await clientTransport.start();
53+
54+
await clientTransport.send({
55+
jsonrpc: '2.0',
56+
method: 'notifications/initialized'
57+
} as JSONRPCMessage);
58+
59+
await clientTransport.send({
60+
jsonrpc: '2.0',
61+
id: 1,
62+
method: 'tools/list',
63+
params: {}
64+
} as JSONRPCMessage);
65+
66+
await vi.waitFor(() => expect(responses.some(message => 'id' in message && message.id === 1)).toBe(true));
67+
68+
const rejected = responses.find(message => 'id' in message && message.id === 1);
69+
expect(rejected).toMatchObject({
70+
error: {
71+
code: ProtocolErrorCode.InvalidRequest,
72+
message: 'Server not initialized'
73+
}
74+
});
75+
76+
await clientTransport.send({
77+
jsonrpc: '2.0',
78+
id: 2,
79+
method: 'initialize',
80+
params: {
81+
protocolVersion: LATEST_PROTOCOL_VERSION,
82+
capabilities: {},
83+
clientInfo: { name: 'test-client', version: '1.0.0' }
84+
}
85+
} as JSONRPCMessage);
86+
await vi.waitFor(() => expect(responses.some(message => 'id' in message && message.id === 2)).toBe(true));
87+
88+
await clientTransport.send({
89+
jsonrpc: '2.0',
90+
method: 'notifications/initialized'
91+
} as JSONRPCMessage);
92+
93+
await clientTransport.send({
94+
jsonrpc: '2.0',
95+
id: 3,
96+
method: 'tools/list',
97+
params: {}
98+
} as JSONRPCMessage);
99+
100+
await vi.waitFor(() => expect(responses.some(message => 'id' in message && message.id === 3)).toBe(true));
101+
102+
expect(responses.find(message => 'id' in message && message.id === 3)).toMatchObject({
103+
result: { tools: [] }
104+
});
105+
106+
await server.close();
107+
});
41108
});
42109
});

0 commit comments

Comments
 (0)