Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conversation/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error

a.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.LLM.Model)
if m.ResponseCacheTTL != nil {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, a.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
19 changes: 10 additions & 9 deletions conversation/aws/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package bedrock
import (
"context"
"reflect"
"time"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
"github.com/dapr/components-contrib/conversation"
Expand All @@ -37,13 +38,13 @@ type AWSBedrock struct {
}

type AWSBedrockMetadata struct {
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Model string `json:"model"`
CacheTTL string `json:"cacheTTL"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Model string `json:"model"`
ResponseCacheTTL *time.Duration `json:"responseCacheTTL,omitempty" mapstructure:"responseCacheTTL" mapstructurealiases:"cacheTTL"`
}

func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
Expand Down Expand Up @@ -83,8 +84,8 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error

b.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.LLM.Model)
if m.ResponseCacheTTL != nil {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, b.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
14 changes: 9 additions & 5 deletions conversation/converse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package conversation
import (
"context"
"io"
"time"

"github.com/tmc/langchaingo/llms"
"google.golang.org/protobuf/types/known/anypb"
Expand All @@ -43,16 +44,19 @@ type Request struct {
ConversationContext string `json:"conversationContext"`
Temperature float64 `json:"temperature"`

// from metadata
Key string `json:"key"`
Model string `json:"model"`
Endpoints []string `json:"endpoints"`
Policy string `json:"loadBalancingPolicy"`
// Metadata fields that are separate from the actual component metadata fields
// that get passed to the LLM through the conversation.
// https://github.com/openai/openai-go/blob/main/chatcompletion.go#L3010
Metadata map[string]string `json:"metadata"`
Model *string `json:"model"`

PromptCacheRetention *time.Duration `json:"promptCacheRetention,omitempty"`
}

type Response struct {
ConversationContext string `json:"conversationContext"`
Outputs []Result `json:"outputs"`
Usage *Usage `json:"usage,omitempty"`
}

type Result struct {
Expand Down
26 changes: 24 additions & 2 deletions conversation/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ func (e *Echo) Init(ctx context.Context, meta conversation.Metadata) error {
return err
}

e.model = r.Model
if r.Model != nil {
e.model = *r.Model
}

return nil
}
Expand All @@ -62,6 +64,17 @@ func (e *Echo) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
return
}

// approximateTokensFromWords estimates the number of tokens based on word count.
// ref: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
func approximateTokensFromWords(text string) int64 {
if text == "" {
return 0
}

// split on whitespace to count words
return int64(len(strings.Fields(text)))
}

// Converse returns one output per input message.
func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) {
if r == nil || r.Message == nil {
Expand Down Expand Up @@ -139,6 +152,7 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv
}
}

responseContent := strings.Join(contentFromMessaged, "\n")
stopReason := "stop"
if len(toolCalls) > 0 {
stopReason = "tool_calls"
Expand All @@ -148,7 +162,7 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv
FinishReason: stopReason,
Index: 0,
Message: conversation.Message{
Content: strings.Join(contentFromMessaged, "\n"),
Content: responseContent,
},
}

Expand All @@ -161,9 +175,17 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv
Choices: []conversation.Choice{choice},
}

tokenCount := approximateTokensFromWords(responseContent)
usage := &conversation.Usage{
CompletionTokens: tokenCount,
PromptTokens: tokenCount,
TotalTokens: tokenCount + tokenCount,
}

res = &conversation.Response{
ConversationContext: r.ConversationContext,
Outputs: []conversation.Result{output},
Usage: usage,
}

return res, nil
Expand Down
4 changes: 2 additions & 2 deletions conversation/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {

g.LLM.Model = llm

if md.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, g.LLM.Model)
if md.ResponseCacheTTL != nil {
cachedModel, cacheErr := conversation.CacheResponses(ctx, md.ResponseCacheTTL, g.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
4 changes: 2 additions & 2 deletions conversation/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) erro

h.LLM.Model = llm

if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, h.LLM.Model)
if m.ResponseCacheTTL != nil {
cachedModel, cacheErr := conversation.CacheResponses(ctx, m.ResponseCacheTTL, h.LLM.Model)
if cacheErr != nil {
return cacheErr
}
Expand Down
89 changes: 73 additions & 16 deletions conversation/langchaingokit/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package langchaingokit

import (
"context"
"fmt"

"github.com/tmc/langchaingo/llms"

Expand All @@ -40,37 +41,61 @@ func (a *LLM) Converse(ctx context.Context, r *conversation.Request) (res *conve
return nil, err
}

outputs := make([]conversation.Result, 0, len(resp.Choices))
for i := range resp.Choices {
outputs, usage, err := a.NormalizeConverseResult(resp.Choices)
if err != nil {
return nil, err
}

res = &conversation.Response{
// TODO: Fix this, we never used this ConversationContext field to begin with.
// This needs improvements to be useful.
ConversationContext: r.ConversationContext,
Outputs: outputs,
Usage: usage,
}

return res, nil
}

func (a *LLM) NormalizeConverseResult(choices []*llms.ContentChoice) ([]conversation.Result, *conversation.Usage, error) {
if len(choices) == 0 {
return nil, nil, nil
}

// Extract usage from the first choice's GenerationInfo (all choices share the same usage)
var usage *conversation.Usage
if len(choices) > 0 && choices[0].GenerationInfo != nil {
var err error
usage, err = extractUsageFromLangchainGenerationInfo(choices[0].GenerationInfo)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract usage metrics: %v", err)
}
}

outputs := make([]conversation.Result, 0, len(choices))
for i, c := range choices {
choice := conversation.Choice{
FinishReason: resp.Choices[i].StopReason,
FinishReason: c.StopReason,
Index: int64(i),
}

if resp.Choices[i].Content != "" {
choice.Message.Content = resp.Choices[i].Content
if choices[i].Content != "" {
choice.Message.Content = choices[i].Content
}

if resp.Choices[i].ToolCalls != nil {
choice.Message.ToolCallRequest = &resp.Choices[i].ToolCalls
if choices[i].ToolCalls != nil {
choice.Message.ToolCallRequest = &choices[i].ToolCalls
}

output := conversation.Result{
StopReason: resp.Choices[i].StopReason,
StopReason: c.StopReason,
Choices: []conversation.Choice{choice},
}

outputs = append(outputs, output)
}

res = &conversation.Response{
// TODO: Fix this, we never used this ConversationContext field to begin with.
// This needs improvements to be useful.
ConversationContext: r.ConversationContext,
Outputs: outputs,
}

return res, nil
return outputs, usage, nil
}

func getOptionsFromRequest(r *conversation.Request, opts ...llms.CallOption) []llms.CallOption {
Expand All @@ -90,5 +115,37 @@ func getOptionsFromRequest(r *conversation.Request, opts ...llms.CallOption) []l
opts = append(opts, llms.WithToolChoice(r.ToolChoice))
}

// Handle prompt cache retention for OpenAI's extended prompt caching feature
if r.PromptCacheRetention != nil {
if r.Metadata == nil {
r.Metadata = make(map[string]string)
}
// OpenAI expects this as a top-level parameter, but we are forced to pass it via metadata,
// and langchaingo should forward it to the OpenAI client.
// NOTE: This is absolutely a complete hack that I guarantee you does work.
// In langchain there is a llms.WithPromptCaching(true) option that is incompatible with Openai yielding an err bc then it tries to use a bool instead of a string,
// because openai expects this to be a time duration string but used with langchain with their llms.WithPromptCachine(true) does not translate properly.
// When Langchain fixes this then we can update accordingly :)
const metadataPromptCacheKey = "prompt_cache_retention"
r.Metadata[metadataPromptCacheKey] = r.PromptCacheRetention.String()
}

// Openai accepts this as map[string]string but langchain expects map[string]any,
// so we go with openai for our type opinion here, and therefore I convert accordingly.
if r.Metadata != nil {
opts = append(opts, llms.WithMetadata(stringMapToAny(r.Metadata)))
}

return opts
}

func stringMapToAny(m map[string]string) map[string]any {
if m == nil {
return nil
}
out := make(map[string]any, len(m))
for k, v := range m {
out[k] = v
}
return out
}
Loading
Loading