Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conversation/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error

a.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.LLM.Model)
if m.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, a.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
9 changes: 5 additions & 4 deletions conversation/aws/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ type AWSBedrockMetadata struct {
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Model string `json:"model"`
CacheTTL string `json:"cacheTTL"`

Model string `json:"model"`
ResponseCacheTTL string `json:"responseCacheTTL" mapstructure:"responseCacheTTL" mapstructurealiases:"cacheTTL"`
}

func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
Expand Down Expand Up @@ -83,8 +84,8 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error

b.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.LLM.Model)
if m.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, b.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
5 changes: 4 additions & 1 deletion conversation/aws/bedrock/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ metadata:
- name: model
required: false
description: |
The LLM to use. Defaults to Bedrock's default provider model from Amazon.
The model identifier or inference profile ARN to use. Defaults to Bedrock's default provider model from Amazon.
You can specify either:
- A model ID (e.g., "amazon.titan-text-express-v1") that supports on-demand throughput
- An inference profile ARN for models that require it (found in the AWS Bedrock console under "Cross-Region Inference")
type: string
example: 'amazon.titan-text-express-v1'
- name: cacheTTL
Expand Down
36 changes: 21 additions & 15 deletions conversation/converse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ package conversation
import (
"context"
"io"
"time"

"github.com/tmc/langchaingo/llms"
"google.golang.org/protobuf/types/known/anypb"

"github.com/dapr/components-contrib/metadata"
)
Expand All @@ -36,23 +36,29 @@ type Conversation interface {

type Request struct {
// Message can be user input prompt/instructions and/or tool call responses.
Message *[]llms.MessageContent
Tools *[]llms.Tool
ToolChoice *string
Parameters map[string]*anypb.Any `json:"parameters"`
ConversationContext string `json:"conversationContext"`
Temperature float64 `json:"temperature"`

// from metadata
Key string `json:"key"`
Model string `json:"model"`
Endpoints []string `json:"endpoints"`
Policy string `json:"loadBalancingPolicy"`
Comment on lines -42 to -50
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The “from metadata” fields were bubbling the Metadata field from the Conversations API (see https://github.com/dapr/dapr/blob/master/dapr/proto/runtime/v1/ai.proto#L43
), which allowed all of these fields to be passed through the request. This is an incorrect use of metadata. Components already expose their own metadata field, so there’s no need to bubble this information up from the API—contrib already handles it.

As a result, we can remove the following fields: Key, Endpoints, and Policy.

I also removed Parameters and ConversationContext, since these were surfaced but never actually used. All of this traces back to the original implementation, which I’m now cleaning up.

Message *[]llms.MessageContent
Tools *[]llms.Tool
ToolChoice *string
Temperature float64 `json:"temperature"`

// Metadata fields that are separate from the actual component metadata fields
// that get passed to the LLM through the conversation.
// https://github.com/openai/openai-go/blob/main/chatcompletion.go#L3010
Metadata map[string]string `json:"metadata"`

ResponseFormatAsJSONSchema map[string]any `json:"responseFormatAsJsonSchema"`
PromptCacheRetention time.Duration `json:"promptCacheRetention"`
Model *string `json:"model"`

// LlmTimeout specifies the max duration to wait for the LLM to complete a conversation request.
// Langchaingo timeout is respected, so if this is set, it will override the provider's timeout.
LlmTimeout time.Duration `json:"llmTimeout"`
}

type Response struct {
ConversationContext string `json:"conversationContext"`
Outputs []Result `json:"outputs"`
Outputs []Result `json:"outputs"`
Usage *Usage `json:"usage,omitempty"`
Model string `json:"model"`
}

type Result struct {
Expand Down
7 changes: 1 addition & 6 deletions conversation/deepseek/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,7 @@ func (d *Deepseek) Init(ctx context.Context, meta conversation.Metadata) error {
md.Endpoint = defaultEndpoint
}

options := []openai.Option{
openai.WithModel(model),
openai.WithToken(md.Key),
openai.WithBaseURL(md.Endpoint),
}

options := conversation.BuildOpenAIClientOptions(model, md.Key, md.Endpoint)
llm, err := openai.New(options...)
if err != nil {
return err
Expand Down
9 changes: 5 additions & 4 deletions conversation/deepseek/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ limitations under the License.

package deepseek

import "github.com/dapr/components-contrib/conversation"

type DeepseekMetadata struct {
Key string `json:"key"`
MaxTokens int `json:"maxTokens"`
Model string `json:"model"`
Endpoint string `json:"endpoint"`
conversation.LangchainMetadata `json:",inline" mapstructure:",squash"`
Key string `json:"key"`
MaxTokens int `json:"maxTokens"`
}
47 changes: 41 additions & 6 deletions conversation/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
return err
}

e.model = r.Model
if r.Model != nil {
e.model = *r.Model
}

return nil
}
Expand All @@ -62,12 +64,23 @@
return
}

// approximateTokensFromWords estimates the number of tokens based on word count.
// ref: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
func approximateTokensFromWords(text string) int64 {
if text == "" {
return 0
}

// split on whitespace to count words
wordCount := len(strings.Fields(text))
return int64(wordCount)
}

// Converse returns one output per input message.
func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) {
if r == nil || r.Message == nil {
return &conversation.Response{
ConversationContext: r.ConversationContext,
Outputs: []conversation.Result{},
Outputs: []conversation.Result{},
}, nil
}

Expand Down Expand Up @@ -121,24 +134,33 @@

// iterate over each message in the request to echo back the content in the response. We respond with the acummulated content of the message parts and tool responses
contentFromMessaged := make([]string, 0, len(*r.Message))
var promptTextLength int
for _, message := range *r.Message {
for _, part := range message.Parts {
switch p := part.(type) {
case llms.TextContent:
// append to slice that we'll join later with new line separators
contentFromMessaged = append(contentFromMessaged, p.Text)
promptTextLength += len(p.Text)
case llms.ToolCall:
// in case we added explicit tool calls on the request like on multi-turn conversations. We still append tool calls for each tool defined in the request.
toolCalls = append(toolCalls, p)
case llms.ToolCallResponse:
// show tool responses on the request like on multi-turn conversations
contentFromMessaged = append(contentFromMessaged, fmt.Sprintf("Tool Response for tool ID '%s' with name '%s': %s", p.ToolCallID, p.Name, p.Content))
promptTextLength += len(p.Content)
default:
return nil, fmt.Errorf("found invalid content type as input for %v", p)
}
}
}

responseContent := strings.Join(contentFromMessaged, "\n")

promptTokens := approximateTokensFromWords(promptTextLength)

Check failure on line 160 in conversation/echo/echo.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

cannot use promptTextLength (variable of type int) as string value in argument to approximateTokensFromWords

Check failure on line 160 in conversation/echo/echo.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

cannot use promptTextLength (variable of type int) as string value in argument to approximateTokensFromWords
completionTokens := approximateTokensFromWords(len(responseContent))

Check failure on line 161 in conversation/echo/echo.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

cannot use len(responseContent) (value of type int) as string value in argument to approximateTokensFromWords)

Check failure on line 161 in conversation/echo/echo.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

cannot use len(responseContent) (value of type int) as string value in argument to approximateTokensFromWords
totalTokens := promptTokens + completionTokens

stopReason := "stop"
if len(toolCalls) > 0 {
stopReason = "tool_calls"
Expand All @@ -148,7 +170,7 @@
FinishReason: stopReason,
Index: 0,
Message: conversation.Message{
Content: strings.Join(contentFromMessaged, "\n"),
Content: responseContent,
},
}

Expand All @@ -161,9 +183,22 @@
Choices: []conversation.Choice{choice},
}

// allow per request model overrides
modelName := e.model
if r.Model != nil && *r.Model != "" {
modelName = *r.Model
}

usage := &conversation.Usage{
CompletionTokens: completionTokens,
PromptTokens: promptTokens,
TotalTokens: totalTokens,
}

res = &conversation.Response{
ConversationContext: r.ConversationContext,
Outputs: []conversation.Result{output},
Outputs: []conversation.Result{output},
Usage: usage,
Model: modelName,
}

return res, nil
Expand Down
14 changes: 6 additions & 8 deletions conversation/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {

// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetGoogleAIModel(md.Model)
key, _ := meta.GetProperty("key")

opts := []openai.Option{
openai.WithModel(model),
openai.WithToken(md.Key),
// endpoint from https://ai.google.dev/gemini-api/docs/openai
openai.WithBaseURL("https://generativelanguage.googleapis.com/v1beta/openai/"),
}
// endpoint from https://ai.google.dev/gemini-api/docs/openai
const endpoint = "https://generativelanguage.googleapis.com/v1beta/openai/"
opts := conversation.BuildOpenAIClientOptions(model, key, endpoint)
llm, err := openai.New(
opts...,
)
Expand All @@ -66,8 +64,8 @@ func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {

g.LLM.Model = llm

if md.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, g.LLM.Model)
if md.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, md.ResponseCacheTTL, g.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
12 changes: 3 additions & 9 deletions conversation/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,23 @@ func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) erro

// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetHuggingFaceModel(m.Model)

endpoint := strings.Replace(defaultEndpoint, "{{model}}", model, 1)
if m.Endpoint != "" {
endpoint = m.Endpoint
}

// Create options for OpenAI client using HuggingFace's OpenAI-compatible API
// This is a workaround for issues with the native HuggingFace langchaingo implementation
options := []openai.Option{
openai.WithModel(model),
openai.WithToken(m.Key),
openai.WithBaseURL(endpoint),
}

options := conversation.BuildOpenAIClientOptions(model, m.Key, endpoint)
llm, err := openai.New(options...)
if err != nil {
return err
}

h.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, h.LLM.Model)
if m.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, h.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
Loading
Loading