Skip to content

Add logger to context for serverInit #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions modelcontextprotocol/mcpserver/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ func (m *McpServer) handleIncomingMessage(

func (m *McpServer) EventMcpRequestInitialize(params *mcp.JsonRpcRequestInitializeParams, reqId *jsonrpc.JsonRpcRequestId) {
// store client information
if params.ProtocolVersion != mcp.ProtocolVersion {
if params.ProtocolVersion < mcp.ProtocolVersion {
m.logger.Error("protocol version mismatch", types.LogArg{
"expected": mcp.ProtocolVersion,
"received": params.ProtocolVersion,
})
m.jsonRpcTransport.SendError(jsonrpc.RpcInvalidRequest, "protocol version mismatch", reqId)

s := fmt.Sprintf("protocol version mismatch: expected %s, received %s", mcp.ProtocolVersion, params.ProtocolVersion)
m.jsonRpcTransport.SendError(jsonrpc.RpcInvalidRequest, "protocol version mismatch: "+s, reqId)
return
}
// we store the client information
Expand All @@ -165,7 +167,8 @@ func (m *McpServer) EventMcpRequestInitialize(params *mcp.JsonRpcRequestInitiali
ListChanged: jsonrpc.BoolPtr(true),
},
},
ServerInfo: mcp.ServerInfo{Name: m.serverName, Version: m.serverVersion},
ServerInfo: mcp.ServerInfo{Name: m.serverName, Version: m.serverVersion},
Instructions: m.instructions,
}
m.jsonRpcTransport.SendJsonRpcResponse(&response, reqId)
}
Expand Down
6 changes: 6 additions & 0 deletions modelcontextprotocol/mcpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type McpServer struct {
logger types.Logger
serverName string
serverVersion string
instructions *string
handler modelcontextprotocol.McpServerEventHandler
// used by protocol
clientName string
Expand Down Expand Up @@ -55,12 +56,17 @@ func NewMcpSdkServer(serverDefinition types.McpSdkServerDefinition, debug bool)
logger: logger,
serverName: sdkServerDefinition.ServerName(),
serverVersion: sdkServerDefinition.ServerVersion(),
instructions: sdkServerDefinition.Instructions(),
handler: mcpServerNotifications,
lastRequestId: 0,
}, nil

}

func (mcp *McpServer) Logger() types.Logger {
return mcp.logger
}

func (mcp *McpServer) StdioTransport() types.Transport {
// we create the transport
transport := transport.NewStdioTransport(
Expand Down
1 change: 1 addition & 0 deletions protocol/mcp/rpcReqInitialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type JsonRpcRequestInitializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ClientCapabilities `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
Instructions *string `json:"instructions,omitempty"`
}

type ClientCapabilities struct {
Expand Down
14 changes: 12 additions & 2 deletions protocol/mcp/rpcRespInitialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ type JsonRpcResponseInitializeResult struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
Instructions *string `json:"instructions"`
}

type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Name string `json:"name"`
Version string `json:"version"`
Instructions *string `json:"instructions,omitempty"`
}

type ServerCapabilities struct {
Expand Down Expand Up @@ -74,6 +76,14 @@ func ParseJsonRpcResponseInitialize(response *jsonrpc.JsonRpcResponse) (*JsonRpc
}
resp.ServerInfo.Version = version

// read instructions are present
instructions := protocol.GetOptionalStringField(serverInfo, "instructions")
if instructions != nil {
resp.Instructions = instructions
} else {
resp.Instructions = nil
}

// read capabilities
capabilities, err := protocol.CheckIsObject(result, "capabilities")
if err != nil {
Expand Down
13 changes: 11 additions & 2 deletions providers/sdk/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type SdkServerDefinition struct {
serverVersion string
debugLevel string
debugFile string
instructions *string
toolConfigurationData interface{}
toolsInitFunction interface{}
toolDefinitions []*SdkToolDefinition
Expand Down Expand Up @@ -65,6 +66,10 @@ func (s *SdkServerDefinition) ServerVersion() string {
return s.serverVersion
}

func (s *SdkServerDefinition) Instructions() *string {
return s.instructions
}

func (s *SdkServerDefinition) SetDebugLevel(debugLevel string, debugFile string) {
s.debugLevel = debugLevel
s.debugFile = debugFile
Expand All @@ -85,12 +90,16 @@ func (s *SdkServerDefinition) DebugFile() string {
return s.debugFile
}

func (s *SdkServerDefinition) WithTools(toolConfigurationDate interface{}, toolsInitFunction interface{}) types.ToolsDefinition {
s.toolConfigurationData = toolConfigurationDate
func (s *SdkServerDefinition) WithTools(toolConfigurationData interface{}, toolsInitFunction interface{}) types.ToolsDefinition {
s.toolConfigurationData = toolConfigurationData
s.toolsInitFunction = toolsInitFunction
return s
}

func (s *SdkServerDefinition) SetInstructions(instructions string) {
s.instructions = &instructions
}

func (s *SdkServerDefinition) AddTool(toolName string, description string, toolHandler interface{}) error {
s.toolDefinitions = append(s.toolDefinitions, &SdkToolDefinition{
ToolName: toolName,
Expand Down
7 changes: 5 additions & 2 deletions providers/sdk/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ func (s *SdkServerDefinition) serverInitFunction(ctx context.Context, logger typ
var result interface{}
var callErr, err error

// create a new context with the logger
goCtx := types.ContextWithLogger(ctx, logger)

// check if we have a tool configuration data
if s.toolConfigurationData != nil {
result, callErr, err = callFunction(s.toolsInitFunction, ctx, s.toolConfigurationData)
result, callErr, err = callFunction(s.toolsInitFunction, goCtx, s.toolConfigurationData)
} else {
result, callErr, err = callFunction(s.toolsInitFunction, ctx)
result, callErr, err = callFunction(s.toolsInitFunction, goCtx)
}
if err != nil {
return err
Expand Down