Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conversation/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error

a.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.LLM.Model)
if m.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, a.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
9 changes: 5 additions & 4 deletions conversation/aws/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ type AWSBedrockMetadata struct {
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Model string `json:"model"`
CacheTTL string `json:"cacheTTL"`

Model string `json:"model"`
ResponseCacheTTL string `json:"responseCacheTTL" mapstructure:"responseCacheTTL" mapstructurealiases:"cacheTTL"`
}

func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
Expand Down Expand Up @@ -83,8 +84,8 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error

b.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.LLM.Model)
if m.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, b.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
36 changes: 21 additions & 15 deletions conversation/converse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ package conversation
import (
"context"
"io"
"time"

"github.com/tmc/langchaingo/llms"
"google.golang.org/protobuf/types/known/anypb"

"github.com/dapr/components-contrib/metadata"
)
Expand All @@ -36,23 +36,29 @@ type Conversation interface {

type Request struct {
// Message can be user input prompt/instructions and/or tool call responses.
Message *[]llms.MessageContent
Tools *[]llms.Tool
ToolChoice *string
Parameters map[string]*anypb.Any `json:"parameters"`
ConversationContext string `json:"conversationContext"`
Temperature float64 `json:"temperature"`

// from metadata
Key string `json:"key"`
Model string `json:"model"`
Endpoints []string `json:"endpoints"`
Policy string `json:"loadBalancingPolicy"`
Comment on lines -42 to -50
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The “from metadata” fields were bubbling the Metadata field from the Conversations API (see https://github.com/dapr/dapr/blob/master/dapr/proto/runtime/v1/ai.proto#L43
), which allowed all of these fields to be passed through the request. This is an incorrect use of metadata. Components already expose their own metadata field, so there’s no need to bubble this information up from the API—contrib already handles it.

As a result, we can remove the following fields: Key, Endpoints, and Policy.

I also removed Parameters and ConversationContext, since these were surfaced but never actually used. All of this traces back to the original implementation, which I’m now cleaning up.

Message *[]llms.MessageContent
Tools *[]llms.Tool
ToolChoice *string
Temperature float64 `json:"temperature"`

// Metadata fields that are separate from the actual component metadata fields
// that get passed to the LLM through the conversation.
// https://github.com/openai/openai-go/blob/main/chatcompletion.go#L3010
Metadata map[string]string `json:"metadata"`

ResponseFormatAsJsonSchema map[string]any `json:"responseFormatAsJsonSchema"`
PromptCacheRetention time.Duration `json:"promptCacheRetention"`
Model *string `json:"model"`

// LlmTimeout specifies the max duration to wait for the LLM to complete a conversation request.
// Langchaingo timeout is respected, so if this is set, it will override the provider's timeout.
LlmTimeout time.Duration `json:"llmTimeout"`
}

type Response struct {
ConversationContext string `json:"conversationContext"`
Outputs []Result `json:"outputs"`
Outputs []Result `json:"outputs"`
Usage *Usage `json:"usage,omitempty"`
Model string `json:"model"`
}

type Result struct {
Expand Down
7 changes: 1 addition & 6 deletions conversation/deepseek/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,7 @@ func (d *Deepseek) Init(ctx context.Context, meta conversation.Metadata) error {
md.Endpoint = defaultEndpoint
}

options := []openai.Option{
openai.WithModel(model),
openai.WithToken(md.Key),
openai.WithBaseURL(md.Endpoint),
}

options := conversation.BuildOpenAIClientOptions(model, md.Key, md.Endpoint)
llm, err := openai.New(options...)
if err != nil {
return err
Expand Down
9 changes: 5 additions & 4 deletions conversation/deepseek/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ limitations under the License.

package deepseek

import "github.com/dapr/components-contrib/conversation"

type DeepseekMetadata struct {
Key string `json:"key"`
MaxTokens int `json:"maxTokens"`
Model string `json:"model"`
Endpoint string `json:"endpoint"`
conversation.LangchainMetadata `json:",inline" mapstructure:",squash"`
Key string `json:"key"`
MaxTokens int `json:"maxTokens"`
}
49 changes: 43 additions & 6 deletions conversation/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ func (e *Echo) Init(ctx context.Context, meta conversation.Metadata) error {
return err
}

e.model = r.Model
if r.Model != nil {
e.model = *r.Model
}

return nil
}
Expand All @@ -62,12 +64,25 @@ func (e *Echo) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
return
}

// approximateTokensFromLength estimates the number of tokens based on text length.
// This uses a rough approximation: ~1 token per 4 chars.
// Reasoning behind 4 char per token:
// - LLM tokens are subword units, not individual characters
// - Text averages ~4-5 chars per token
// ref: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
// We round up division to avoid undercounting tokens.
func approximateTokensFromLength(textLength int) int64 {
if textLength == 0 {
return 0
}
return int64((textLength + 3) / 4)
}

// Converse returns one output per input message.
func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) {
if r == nil || r.Message == nil {
return &conversation.Response{
ConversationContext: r.ConversationContext,
Outputs: []conversation.Result{},
Outputs: []conversation.Result{},
}, nil
}

Expand Down Expand Up @@ -121,24 +136,33 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv

// iterate over each message in the request to echo back the content in the response. We respond with the acummulated content of the message parts and tool responses
contentFromMessaged := make([]string, 0, len(*r.Message))
var promptTextLength int
for _, message := range *r.Message {
for _, part := range message.Parts {
switch p := part.(type) {
case llms.TextContent:
// append to slice that we'll join later with new line separators
contentFromMessaged = append(contentFromMessaged, p.Text)
promptTextLength += len(p.Text)
case llms.ToolCall:
// in case we added explicit tool calls on the request like on multi-turn conversations. We still append tool calls for each tool defined in the request.
toolCalls = append(toolCalls, p)
case llms.ToolCallResponse:
// show tool responses on the request like on multi-turn conversations
contentFromMessaged = append(contentFromMessaged, fmt.Sprintf("Tool Response for tool ID '%s' with name '%s': %s", p.ToolCallID, p.Name, p.Content))
promptTextLength += len(p.Content)
default:
return nil, fmt.Errorf("found invalid content type as input for %v", p)
}
}
}

responseContent := strings.Join(contentFromMessaged, "\n")

promptTokens := approximateTokensFromLength(promptTextLength)
completionTokens := approximateTokensFromLength(len(responseContent))
totalTokens := promptTokens + completionTokens

stopReason := "stop"
if len(toolCalls) > 0 {
stopReason = "tool_calls"
Expand All @@ -148,7 +172,7 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv
FinishReason: stopReason,
Index: 0,
Message: conversation.Message{
Content: strings.Join(contentFromMessaged, "\n"),
Content: responseContent,
},
}

Expand All @@ -161,9 +185,22 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv
Choices: []conversation.Choice{choice},
}

// allow per request model overrides
modelName := e.model
if r.Model != nil && *r.Model != "" {
modelName = *r.Model
}

usage := &conversation.Usage{
CompletionTokens: completionTokens,
PromptTokens: promptTokens,
TotalTokens: totalTokens,
}

res = &conversation.Response{
ConversationContext: r.ConversationContext,
Outputs: []conversation.Result{output},
Outputs: []conversation.Result{output},
Usage: usage,
Model: modelName,
}

return res, nil
Expand Down
14 changes: 6 additions & 8 deletions conversation/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {

// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetGoogleAIModel(md.Model)
key, _ := meta.GetProperty("key")

opts := []openai.Option{
openai.WithModel(model),
openai.WithToken(md.Key),
// endpoint from https://ai.google.dev/gemini-api/docs/openai
openai.WithBaseURL("https://generativelanguage.googleapis.com/v1beta/openai/"),
}
// endpoint from https://ai.google.dev/gemini-api/docs/openai
const endpoint = "https://generativelanguage.googleapis.com/v1beta/openai/"
opts := conversation.BuildOpenAIClientOptions(model, key, endpoint)
llm, err := openai.New(
opts...,
)
Expand All @@ -66,8 +64,8 @@ func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {

g.LLM.Model = llm

if md.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, g.LLM.Model)
if md.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, md.ResponseCacheTTL, g.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
12 changes: 3 additions & 9 deletions conversation/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,23 @@ func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) erro

// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetHuggingFaceModel(m.Model)

endpoint := strings.Replace(defaultEndpoint, "{{model}}", model, 1)
if m.Endpoint != "" {
endpoint = m.Endpoint
}

// Create options for OpenAI client using HuggingFace's OpenAI-compatible API
// This is a workaround for issues with the native HuggingFace langchaingo implementation
options := []openai.Option{
openai.WithModel(model),
openai.WithToken(m.Key),
openai.WithBaseURL(endpoint),
}

options := conversation.BuildOpenAIClientOptions(model, m.Key, endpoint)
llm, err := openai.New(options...)
if err != nil {
return err
}

h.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, h.LLM.Model)
if m.ResponseCacheTTL != "" {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, h.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
Loading