Skip to content

feat: mcp support, openai update, refactor #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 99 additions & 23 deletions anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
openai "github.com/sashabaranov/go-openai"
"github.com/mark3labs/mcp-go/mcp"
"github.com/openai/openai-go"
)

// AnthropicClientConfig represents the configuration for the Anthropic API client.
Expand Down Expand Up @@ -40,7 +42,7 @@ func NewAnthropicClientWithConfig(config AnthropicClientConfig) *AnthropicClient
option.WithHTTPClient(config.HTTPClient),
}
if config.BaseURL != "" {
opts = append(opts, option.WithBaseURL(config.BaseURL))
opts = append(opts, option.WithBaseURL(strings.TrimSuffix(config.BaseURL, "/v1")))
}
client := anthropic.NewClient(opts...)
return &AnthropicClient{
Expand All @@ -54,11 +56,16 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
request anthropic.MessageNewParams,
) *AnthropicChatCompletionStream {
return &AnthropicChatCompletionStream{
anthropicStreamReader: &anthropicStreamReader{
Stream: c.Messages.NewStreaming(ctx, request),
},
s := &AnthropicChatCompletionStream{
stream: c.Messages.NewStreaming(ctx, request),
request: request,
}

s.factory = func() *ssestream.Stream[anthropic.MessageStreamEventUnion] {
return c.Messages.NewStreaming(ctx, s.request)
}

return s
}

func makeAnthropicSystem(system string) []anthropic.TextBlockParam {
Expand All @@ -74,45 +81,114 @@ func makeAnthropicSystem(system string) []anthropic.TextBlockParam {

// AnthropicChatCompletionStream represents a stream for chat completion.
type AnthropicChatCompletionStream struct {
*anthropicStreamReader
}

type anthropicStreamReader struct {
*ssestream.Stream[anthropic.MessageStreamEventUnion]
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
request anthropic.MessageNewParams
factory func() *ssestream.Stream[anthropic.MessageStreamEventUnion]
message anthropic.Message
}

// Recv reads the next response from the stream.
func (r *anthropicStreamReader) Recv() (response openai.ChatCompletionStreamResponse, err error) {
if err := r.Err(); err != nil {
return openai.ChatCompletionStreamResponse{}, fmt.Errorf("anthropic: %w", err)
func (r *AnthropicChatCompletionStream) Recv() (response openai.ChatCompletionChunk, err error) {
if r.stream == nil {
r.stream = r.factory()
r.message = anthropic.Message{}
}
for r.Next() {
event := r.Current()

if r.stream.Next() {
event := r.stream.Current()
if err := r.message.Accumulate(event); err != nil {
return openai.ChatCompletionChunk{}, fmt.Errorf("anthropic: %w", err)
}
switch eventVariant := event.AsAny().(type) {
case anthropic.ContentBlockDeltaEvent:
switch deltaVariant := eventVariant.Delta.AsAny().(type) {
case anthropic.TextDelta:
return openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
return openai.ChatCompletionChunk{
Choices: []openai.ChatCompletionChunkChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Delta: openai.ChatCompletionChunkChoiceDelta{
Content: deltaVariant.Text,
Role: "assistant",
Role: roleAssistant,
},
},
},
}, nil
}
}
return openai.ChatCompletionChunk{}, errNoContent
}
if err := r.stream.Err(); err != nil {
return openai.ChatCompletionChunk{}, fmt.Errorf("anthropic: %w", err)
}
if err := r.stream.Close(); err != nil {
return openai.ChatCompletionChunk{}, fmt.Errorf("anthropic: %w", err)
}
r.request.Messages = append(r.request.Messages, r.message.ToParam())
r.stream = nil

toolResults := []anthropic.ContentBlockParamUnion{}
var sb strings.Builder
for _, block := range r.message.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
content, err := toolCall(variant.Name, []byte(variant.JSON.Input.Raw()))
toolResults = append(toolResults, anthropic.NewToolResultBlock(block.ID, content, err != nil))
_, _ = sb.WriteString("\n> Ran: `" + variant.Name + "`")
if err != nil {
_, _ = sb.WriteString(" (failed: `" + err.Error() + "`)")
}
_, _ = sb.WriteString("\n")
}
}
_, _ = sb.WriteString("\n")

if len(toolResults) == 0 {
return openai.ChatCompletionChunk{}, io.EOF
}
return openai.ChatCompletionStreamResponse{}, io.EOF

msg := anthropic.NewUserMessage(toolResults...)
r.request.Messages = append(r.request.Messages, msg)

return openai.ChatCompletionChunk{
Choices: []openai.ChatCompletionChunkChoice{
{
Index: 0,
Delta: openai.ChatCompletionChunkChoiceDelta{
Content: sb.String(),
Role: roleTool,
},
},
},
}, nil
}

// Close closes the stream.
func (r *anthropicStreamReader) Close() error {
if err := r.Stream.Close(); err != nil {
func (r *AnthropicChatCompletionStream) Close() error {
if r.stream == nil {
return nil
}
if err := r.stream.Close(); err != nil {
return fmt.Errorf("anthropic: %w", err)
}
r.stream = nil
return nil
}

func makeAnthropicMCPTools(mcps map[string][]mcp.Tool) []anthropic.ToolUnionParam {
var tools []anthropic.ToolUnionParam
for name, serverTools := range mcps {
for _, tool := range serverTools {
tools = append(tools, anthropic.ToolUnionParam{
OfTool: &anthropic.ToolParam{
InputSchema: anthropic.ToolInputSchemaParam{
Properties: tool.InputSchema.Properties,
},
Name: fmt.Sprintf("%s_%s", name, tool.Name),
Description: anthropic.String(tool.Description),
},
})
}
}
return tools
}
32 changes: 16 additions & 16 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"sync"
"time"

openai "github.com/sashabaranov/go-openai"
"github.com/openai/openai-go"
)

// CacheType represents the type of cache being used.
Expand Down Expand Up @@ -95,11 +95,11 @@
}

type convoCache struct {
cache *Cache[[]openai.ChatCompletionMessage]
cache *Cache[[]modsMessage]
}

func newCache(dir string) *convoCache {
cache, err := NewCache[[]openai.ChatCompletionMessage](dir, ConversationCache)
cache, err := NewCache[[]modsMessage](dir, ConversationCache)
if err != nil {
return nil
}
Expand All @@ -108,13 +108,13 @@
}
}

func (c *convoCache) read(id string, messages *[]openai.ChatCompletionMessage) error {
func (c *convoCache) read(id string, messages *[]modsMessage) error {
return c.cache.Read(id, func(r io.Reader) error {
return decode(r, messages)
})
}

func (c *convoCache) write(id string, messages *[]openai.ChatCompletionMessage) error {
func (c *convoCache) write(id string, messages *[]modsMessage) error {
return c.cache.Write(id, func(w io.Writer) error {
return encode(w, messages)
})
Expand All @@ -134,38 +134,38 @@

func (c *cachedCompletionStream) Close() error { return nil }

func (c *cachedCompletionStream) Recv() (openai.ChatCompletionStreamResponse, error) {
func (c *cachedCompletionStream) Recv() (openai.ChatCompletionChunk, error) {
c.m.Lock()
defer c.m.Unlock()

if c.read == len(c.messages) {
return openai.ChatCompletionStreamResponse{}, io.EOF
return openai.ChatCompletionChunk{}, io.EOF
}

msg := c.messages[c.read]
prefix := ""

switch msg.Role {
case openai.ChatMessageRoleSystem:
case "system":

Check failure on line 149 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `system` has 6 occurrences, but such constant `roleSystem` already exists (goconst)

Check failure on line 149 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `system` has 6 occurrences, but such constant `roleSystem` already exists (goconst)
prefix += "\n**System**: "
case openai.ChatMessageRoleUser:
case "user":

Check failure on line 151 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `user` has 6 occurrences, but such constant `roleUser` already exists (goconst)

Check failure on line 151 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `user` has 6 occurrences, but such constant `roleUser` already exists (goconst)
prefix += "\n**Prompt**: "
case openai.ChatMessageRoleAssistant:
case "assistant":

Check failure on line 153 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `assistant` has 4 occurrences, but such constant `roleAssistant` already exists (goconst)

Check failure on line 153 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `assistant` has 4 occurrences, but such constant `roleAssistant` already exists (goconst)
prefix += "\n**Assistant**: "
case openai.ChatMessageRoleFunction:
case "function":
prefix += "\n**Function**: "
case openai.ChatMessageRoleTool:
case "tool":

Check failure on line 157 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `tool` has 3 occurrences, but such constant `roleTool` already exists (goconst)

Check failure on line 157 in cache.go

View workflow job for this annotation

GitHub Actions / lint / lint (macos-latest)

string `tool` has 3 occurrences, but such constant `roleTool` already exists (goconst)
prefix += "\n**Tool**: "
}

c.read++

return openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
return openai.ChatCompletionChunk{
Choices: []openai.ChatCompletionChunkChoice{
{
Delta: openai.ChatCompletionStreamChoiceDelta{
Delta: openai.ChatCompletionChunkChoiceDelta{
Content: prefix + msg.Content + "\n",
Role: msg.Role,
Role: string(msg.Role),
},
},
},
Expand Down
28 changes: 14 additions & 14 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"testing"
"time"

"github.com/sashabaranov/go-openai"
"github.com/openai/openai-go"
"github.com/stretchr/testify/require"
)

Expand All @@ -20,33 +20,33 @@ var update = flag.Bool("update", false, "update .golden files")
func TestCache(t *testing.T) {
t.Run("read non-existent", func(t *testing.T) {
cache := newCache(t.TempDir())
err := cache.read("super-fake", &[]openai.ChatCompletionMessage{})
err := cache.read("super-fake", &[]modsMessage{})
require.ErrorIs(t, err, os.ErrNotExist)
})

t.Run("write", func(t *testing.T) {
cache := newCache(t.TempDir())
messages := []openai.ChatCompletionMessage{
messages := []modsMessage{
{
Role: openai.ChatMessageRoleUser,
Role: roleUser,
Content: "first 4 natural numbers",
},
{
Role: openai.ChatMessageRoleAssistant,
Role: roleAssistant,
Content: "1, 2, 3, 4",
},
}
require.NoError(t, cache.write("fake", &messages))

result := []openai.ChatCompletionMessage{}
result := []modsMessage{}
require.NoError(t, cache.read("fake", &result))

require.ElementsMatch(t, messages, result)
})

t.Run("delete", func(t *testing.T) {
cache := newCache(t.TempDir())
cache.write("fake", &[]openai.ChatCompletionMessage{})
cache.write("fake", &[]modsMessage{})
require.NoError(t, cache.delete("fake"))
require.ErrorIs(t, cache.read("fake", nil), os.ErrNotExist)
})
Expand All @@ -71,32 +71,32 @@ func TestCachedCompletionStream(t *testing.T) {
stream := cachedCompletionStream{
messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Role: roleSystem,
Content: "you are a medieval king",
},
{
Role: openai.ChatMessageRoleUser,
Role: roleUser,
Content: "first 4 natural numbers",
},
{
Role: openai.ChatMessageRoleAssistant,
Role: roleAssistant,
Content: "1, 2, 3, 4",
},

{
Role: openai.ChatMessageRoleUser,
Role: roleUser,
Content: "as a json array",
},
{
Role: openai.ChatMessageRoleAssistant,
Role: roleAssistant,
Content: "[ 1, 2, 3, 4 ]",
},
{
Role: openai.ChatMessageRoleAssistant,
Role: roleAssistant,
Content: "something from an assistant",
},
{
Role: openai.ChatMessageRoleFunction,
Role: roleFunction,
Content: "something from a function",
},
},
Expand Down
Loading
Loading