diff --git a/modelcontextprotocol/mcpserver/protocol.go b/modelcontextprotocol/mcpserver/protocol.go index 0921216..0548407 100644 --- a/modelcontextprotocol/mcpserver/protocol.go +++ b/modelcontextprotocol/mcpserver/protocol.go @@ -142,12 +142,14 @@ func (m *McpServer) handleIncomingMessage( func (m *McpServer) EventMcpRequestInitialize(params *mcp.JsonRpcRequestInitializeParams, reqId *jsonrpc.JsonRpcRequestId) { // store client information - if params.ProtocolVersion != mcp.ProtocolVersion { + if params.ProtocolVersion < mcp.ProtocolVersion { m.logger.Error("protocol version mismatch", types.LogArg{ "expected": mcp.ProtocolVersion, "received": params.ProtocolVersion, }) - m.jsonRpcTransport.SendError(jsonrpc.RpcInvalidRequest, "protocol version mismatch", reqId) + + s := fmt.Sprintf("protocol version mismatch: expected %s, received %s", mcp.ProtocolVersion, params.ProtocolVersion) + m.jsonRpcTransport.SendError(jsonrpc.RpcInvalidRequest, "protocol version mismatch: "+s, reqId) return } // we store the client information @@ -165,7 +167,8 @@ func (m *McpServer) EventMcpRequestInitialize(params *mcp.JsonRpcRequestInitiali ListChanged: jsonrpc.BoolPtr(true), }, }, - ServerInfo: mcp.ServerInfo{Name: m.serverName, Version: m.serverVersion}, + ServerInfo: mcp.ServerInfo{Name: m.serverName, Version: m.serverVersion}, + Instructions: m.instructions, } m.jsonRpcTransport.SendJsonRpcResponse(&response, reqId) } diff --git a/modelcontextprotocol/mcpserver/server.go b/modelcontextprotocol/mcpserver/server.go index 452f0b9..b371700 100644 --- a/modelcontextprotocol/mcpserver/server.go +++ b/modelcontextprotocol/mcpserver/server.go @@ -15,6 +15,7 @@ type McpServer struct { logger types.Logger serverName string serverVersion string + instructions *string handler modelcontextprotocol.McpServerEventHandler // used by protocol clientName string @@ -55,12 +56,17 @@ func NewMcpSdkServer(serverDefinition types.McpSdkServerDefinition, debug bool) logger: logger, serverName: sdkServerDefinition.ServerName(), serverVersion: sdkServerDefinition.ServerVersion(), + instructions: sdkServerDefinition.Instructions(), handler: mcpServerNotifications, lastRequestId: 0, }, nil } +func (mcp *McpServer) Logger() types.Logger { + return mcp.logger +} + func (mcp *McpServer) StdioTransport() types.Transport { // we create the transport transport := transport.NewStdioTransport( diff --git a/protocol/mcp/rpcReqInitialize.go b/protocol/mcp/rpcReqInitialize.go index 5f5a9bb..b60687d 100644 --- a/protocol/mcp/rpcReqInitialize.go +++ b/protocol/mcp/rpcReqInitialize.go @@ -18,6 +18,7 @@ type JsonRpcRequestInitializeParams struct { ProtocolVersion string `json:"protocolVersion"` Capabilities ClientCapabilities `json:"capabilities"` ClientInfo ClientInfo `json:"clientInfo"` + Instructions *string `json:"instructions,omitempty"` } type ClientCapabilities struct { diff --git a/protocol/mcp/rpcRespInitialize.go b/protocol/mcp/rpcRespInitialize.go index 209d8d6..6845321 100644 --- a/protocol/mcp/rpcRespInitialize.go +++ b/protocol/mcp/rpcRespInitialize.go @@ -9,11 +9,13 @@ type JsonRpcResponseInitializeResult struct { ProtocolVersion string `json:"protocolVersion"` Capabilities ServerCapabilities `json:"capabilities"` ServerInfo ServerInfo `json:"serverInfo"` + Instructions *string `json:"instructions"` } type ServerInfo struct { - Name string `json:"name"` - Version string `json:"version"` + Name string `json:"name"` + Version string `json:"version"` + Instructions *string `json:"instructions,omitempty"` } type ServerCapabilities struct { @@ -74,6 +76,14 @@ func ParseJsonRpcResponseInitialize(response *jsonrpc.JsonRpcResponse) (*JsonRpc } resp.ServerInfo.Version = version + // read instructions are present + instructions := protocol.GetOptionalStringField(serverInfo, "instructions") + if instructions != nil { + resp.Instructions = instructions + } else { + resp.Instructions = nil + } + // read capabilities capabilities, err := protocol.CheckIsObject(result, "capabilities") if err != nil { diff --git a/providers/sdk/definition.go b/providers/sdk/definition.go index f222095..d40c4ed 100644 --- a/providers/sdk/definition.go +++ b/providers/sdk/definition.go @@ -15,6 +15,7 @@ type SdkServerDefinition struct { serverVersion string debugLevel string debugFile string + instructions *string toolConfigurationData interface{} toolsInitFunction interface{} toolDefinitions []*SdkToolDefinition @@ -65,6 +66,10 @@ func (s *SdkServerDefinition) ServerVersion() string { return s.serverVersion } +func (s *SdkServerDefinition) Instructions() *string { + return s.instructions +} + func (s *SdkServerDefinition) SetDebugLevel(debugLevel string, debugFile string) { s.debugLevel = debugLevel s.debugFile = debugFile @@ -85,12 +90,16 @@ func (s *SdkServerDefinition) DebugFile() string { return s.debugFile } -func (s *SdkServerDefinition) WithTools(toolConfigurationDate interface{}, toolsInitFunction interface{}) types.ToolsDefinition { - s.toolConfigurationData = toolConfigurationDate +func (s *SdkServerDefinition) WithTools(toolConfigurationData interface{}, toolsInitFunction interface{}) types.ToolsDefinition { + s.toolConfigurationData = toolConfigurationData s.toolsInitFunction = toolsInitFunction return s } +func (s *SdkServerDefinition) SetInstructions(instructions string) { + s.instructions = &instructions +} + func (s *SdkServerDefinition) AddTool(toolName string, description string, toolHandler interface{}) error { s.toolDefinitions = append(s.toolDefinitions, &SdkToolDefinition{ ToolName: toolName, diff --git a/providers/sdk/lifecycle.go b/providers/sdk/lifecycle.go index ce14b25..d1522d0 100644 --- a/providers/sdk/lifecycle.go +++ b/providers/sdk/lifecycle.go @@ -12,11 +12,14 @@ func (s *SdkServerDefinition) serverInitFunction(ctx context.Context, logger typ var result interface{} var callErr, err error + // create a new context with the logger + goCtx := types.ContextWithLogger(ctx, logger) + // check if we have a tool configuration data if s.toolConfigurationData != nil { - result, callErr, err = callFunction(s.toolsInitFunction, ctx, s.toolConfigurationData) + result, callErr, err = callFunction(s.toolsInitFunction, goCtx, s.toolConfigurationData) } else { - result, callErr, err = callFunction(s.toolsInitFunction, ctx) + result, callErr, err = callFunction(s.toolsInitFunction, goCtx) } if err != nil { return err