Skip to content

Commit eb78f52

Browse files
committed
feat: adapt agentic message
1 parent abed3ee commit eb78f52

4 files changed

Lines changed: 136 additions & 110 deletions

File tree

adk/middlewares/automemory/automemory.go

Lines changed: 100 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ type Config[M adk.MessageType] struct {
4949

5050
MemoryBackend Backend
5151

52-
// Model is the default tool-calling model used by topic selection and memory extraction.
52+
// Model is the default model used by topic selection and memory extraction.
5353
// Per-read/per-write overrides can be configured in Read.Model / Write.Model.
54-
Model model.ToolCallingChatModel
54+
Model model.BaseModel[M]
5555

5656
// Read controls how memories are loaded and injected.
5757
// Optional. Defaults to Sync load with topic selection enabled (if Model is set).
58-
Read *ReadConfig
58+
Read *ReadConfig[M]
5959

6060
// Write controls post-run memory extraction and persistence.
6161
// Optional. Default: disabled.
62-
Write *WriteConfig
62+
Write *WriteConfig[M]
6363

6464
// Coordination controls session identity and distributed async extraction coordination.
6565
// Optional. Defaults to a local in-process coordinator.
@@ -78,11 +78,11 @@ const (
7878
ReadModeAsync ReadMode = "async"
7979
)
8080

81-
type ReadConfig struct {
81+
type ReadConfig[M adk.MessageType] struct {
8282
Mode ReadMode
8383

8484
// Model is used for topic selection. Defaults to Config.Model.
85-
Model model.ToolCallingChatModel
85+
Model model.BaseModel[M]
8686

8787
// Instruction overrides the default auto memory instruction block appended to system prompt.
8888
// Optional.
@@ -126,11 +126,11 @@ const (
126126
WriteModeSync WriteMode = "sync"
127127
)
128128

129-
type WriteConfig struct {
129+
type WriteConfig[M adk.MessageType] struct {
130130
Mode WriteMode
131131

132132
// Model is used for memory extraction. Defaults to Config.Model.
133-
Model model.ToolCallingChatModel
133+
Model model.BaseModel[M]
134134

135135
// MaxTurns caps the extractor's tool-call loop.
136136
MaxTurns int
@@ -144,7 +144,7 @@ type WriteConfig struct {
144144
//
145145
// If nil, automemory uses the default drain behavior: ignore all events and
146146
// return the first ev.Err encountered (if any).
147-
HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error
147+
HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.TypedAgentEvent[M]]) error
148148
}
149149

150150
type middleware[M adk.MessageType] struct {
@@ -154,8 +154,8 @@ type middleware[M adk.MessageType] struct {
154154

155155
resolvedMemoryDirectory string
156156

157-
topicSelectionModel model.ToolCallingChatModel
158-
extractionHandler adk.ChatModelAgentMiddleware
157+
topicSelectionModel model.BaseModel[M]
158+
extractionHandler adk.TypedChatModelAgentMiddleware[M]
159159
topicSelectionTool *schema.ToolInfo
160160
coordination *CoordinationConfig[M]
161161
}
@@ -201,7 +201,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
201201
return nil, fmt.Errorf("auto memory config: resolve memory directory: %w", err)
202202
}
203203
if cfg.Read == nil {
204-
cfg.Read = &ReadConfig{}
204+
cfg.Read = &ReadConfig[M]{}
205205
}
206206
applyReadDefaults(cfg)
207207

@@ -214,11 +214,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
214214

215215
m.topicSelectionTool = topicSelectionToolInfo()
216216
if cfg.Read.TopicSelection != nil && cfg.Read.Model != nil {
217-
bound, err := cfg.Read.Model.WithTools([]*schema.ToolInfo{m.topicSelectionTool})
218-
if err != nil {
219-
return nil, fmt.Errorf("auto memory topic selection model init failed: %w", err)
220-
}
221-
m.topicSelectionModel = bound
217+
m.topicSelectionModel = cfg.Read.Model
222218
}
223219

224220
if cfg.Write.Mode != WriteModeDisabled && cfg.Write.Model != nil {
@@ -230,7 +226,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
230226
if err != nil {
231227
return nil, err
232228
}
233-
fileSystemMiddleware, err := fsmw.New(ctx, &fsmw.MiddlewareConfig{
229+
fileSystemMiddleware, err := fsmw.NewTyped[M](ctx, &fsmw.MiddlewareConfig{
234230
Backend: writeFSBackend,
235231
LsToolConfig: &fsmw.ToolConfig{Disable: true},
236232
GrepToolConfig: &fsmw.ToolConfig{Disable: true},
@@ -412,7 +408,7 @@ func applyReadDefaults[M adk.MessageType](cfg *Config[M]) {
412408
}
413409

414410
if cfg.Write == nil {
415-
cfg.Write = &WriteConfig{Mode: WriteModeDisabled}
411+
cfg.Write = &WriteConfig[M]{Mode: WriteModeDisabled}
416412
}
417413
if cfg.Write.Mode == "" {
418414
cfg.Write.Mode = WriteModeDisabled
@@ -764,11 +760,12 @@ func (m *middleware[M]) selectTopicCandidates(
764760
toolInfo := topicSelectionToolInfo()
765761
resp, err := m.topicSelectionModel.Generate(
766762
ctx,
767-
[]*schema.Message{
768-
schema.SystemMessage(getTopicSelectionSystemPrompt()),
769-
schema.UserMessage(userMsg),
763+
[]M{
764+
makeSystemMsg[M](getTopicSelectionSystemPrompt()),
765+
makeUserMsg[M](userMsg),
770766
},
771-
model.WithToolChoice(schema.ToolChoiceForced, toolInfo.Name),
767+
model.WithTools([]*schema.ToolInfo{toolInfo}),
768+
makeToolChoiceForced[M](toolInfo.Name),
772769
)
773770
if err != nil {
774771
return nil, err
@@ -864,11 +861,12 @@ func topicSelectionToolInfo() *schema.ToolInfo {
864861
}
865862
}
866863

867-
func parseTopicSelectionFromToolCall(msg *schema.Message, valid map[string]struct{}) ([]string, error) {
868-
if msg == nil || len(msg.ToolCalls) == 0 {
864+
func parseTopicSelectionFromToolCall[M adk.MessageType](msg M, valid map[string]struct{}) ([]string, error) {
865+
toolCalls := messageToolCalls(msg)
866+
if len(toolCalls) == 0 {
869867
return nil, fmt.Errorf("no tool calls")
870868
}
871-
tc := msg.ToolCalls[0]
869+
tc := toolCalls[0]
872870
if tc.Function.Name != topicSelectionToolName {
873871
return nil, fmt.Errorf("unexpected tool call: %s", tc.Function.Name)
874872
}
@@ -1002,6 +1000,31 @@ func setMsgExtra[M adk.MessageType](msg M, key string, value any) {
10021000
}
10031001
}
10041002

1003+
func copyMsgExtra[M adk.MessageType](dst, src M) {
1004+
srcExtra := getMsgExtra(src)
1005+
if len(srcExtra) == 0 {
1006+
return
1007+
}
1008+
switch d := any(dst).(type) {
1009+
case *schema.Message:
1010+
if d.Extra == nil {
1011+
d.Extra = make(map[string]any, len(srcExtra))
1012+
}
1013+
for k, v := range srcExtra {
1014+
d.Extra[k] = v
1015+
}
1016+
case *schema.AgenticMessage:
1017+
if d.Extra == nil {
1018+
d.Extra = make(map[string]any, len(srcExtra))
1019+
}
1020+
for k, v := range srcExtra {
1021+
d.Extra[k] = v
1022+
}
1023+
default:
1024+
panic("unreachable")
1025+
}
1026+
}
1027+
10051028
func makeUserMsg[M adk.MessageType](text string) M {
10061029
var zero M
10071030
switch any(zero).(type) {
@@ -1014,6 +1037,35 @@ func makeUserMsg[M adk.MessageType](text string) M {
10141037
}
10151038
}
10161039

1040+
func makeSystemMsg[M adk.MessageType](text string) M {
1041+
var zero M
1042+
switch any(zero).(type) {
1043+
case *schema.Message:
1044+
return any(schema.SystemMessage(text)).(M)
1045+
case *schema.AgenticMessage:
1046+
return any(schema.SystemAgenticMessage(text)).(M)
1047+
default:
1048+
panic("unreachable")
1049+
}
1050+
}
1051+
1052+
func makeToolChoiceForced[M adk.MessageType](name string) model.Option {
1053+
var zero M
1054+
switch any(zero).(type) {
1055+
case *schema.Message:
1056+
return model.WithToolChoice(schema.ToolChoiceForced, name)
1057+
case *schema.AgenticMessage:
1058+
return model.WithAgenticToolChoice(&schema.AgenticToolChoice{
1059+
Type: schema.ToolChoiceForced,
1060+
Forced: &schema.AgenticForcedToolChoice{
1061+
Tools: []*schema.AllowedTool{{FunctionName: name}},
1062+
},
1063+
})
1064+
default:
1065+
panic("unreachable")
1066+
}
1067+
}
1068+
10171069
func messageToolCalls[M adk.MessageType](msg M) []schema.ToolCall {
10181070
switch m := any(msg).(type) {
10191071
case *schema.Message:
@@ -1069,40 +1121,6 @@ func messageToolNames[M adk.MessageType](msg M) []string {
10691121
}
10701122
}
10711123

1072-
func projectMessagesToSchema[M adk.MessageType](msgs []M) []adk.Message {
1073-
out := make([]adk.Message, 0, len(msgs))
1074-
for _, msg := range msgs {
1075-
if projected := projectMessageToSchema(msg); projected != nil {
1076-
out = append(out, projected)
1077-
}
1078-
}
1079-
return out
1080-
}
1081-
1082-
func projectMessageToSchema[M adk.MessageType](msg M) adk.Message {
1083-
switch m := any(msg).(type) {
1084-
case *schema.Message:
1085-
return m
1086-
case *schema.AgenticMessage:
1087-
if m == nil {
1088-
return nil
1089-
}
1090-
text := m.String()
1091-
switch m.Role {
1092-
case schema.AgenticRoleTypeSystem:
1093-
return schema.SystemMessage(text)
1094-
case schema.AgenticRoleTypeAssistant:
1095-
return schema.AssistantMessage(text, messageToolCalls(msg))
1096-
case schema.AgenticRoleTypeUser:
1097-
return schema.UserMessage(text)
1098-
default:
1099-
return schema.UserMessage(text)
1100-
}
1101-
default:
1102-
panic("unreachable")
1103-
}
1104-
}
1105-
11061124
func alreadyInjected[M adk.MessageType](msgs []M) bool {
11071125
for _, m := range msgs {
11081126
if isMemoryMessage(m) {
@@ -1464,21 +1482,20 @@ func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int
14641482
return countModelVisibleMessages(msgs[cursor:])
14651483
}
14661484

1467-
func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.ChatModelAgent, error) {
1485+
func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.TypedChatModelAgent[M], error) {
14681486
if m.cfg == nil || m.cfg.Write == nil || m.cfg.Write.Model == nil {
14691487
return nil, fmt.Errorf("auto memory extraction agent init failed: missing write model")
14701488
}
14711489
if m.extractionHandler == nil {
14721490
return nil, fmt.Errorf("auto memory extraction agent init failed: missing extraction handler")
14731491
}
14741492

1475-
agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
1476-
Name: "automemory_extractor",
1477-
Description: "Internal auto memory extraction subagent",
1478-
Model: m.cfg.Write.Model,
1479-
Handlers: []adk.ChatModelAgentMiddleware{
1493+
agent, err := adk.NewTypedChatModelAgent[M](ctx, &adk.TypedChatModelAgentConfig[M]{
1494+
Name: "automemory_extractor",
1495+
Model: m.cfg.Write.Model,
1496+
Handlers: []adk.TypedChatModelAgentMiddleware[M]{
14801497
m.extractionHandler, // fs middleware
1481-
&toolInfoOverrideMiddleware{toolInfos: toolInfos}, // tool info override, for prefix cache
1498+
&toolInfoOverrideMiddleware[M]{toolInfos: toolInfos}, // tool info override, for prefix cache
14821499
},
14831500
ToolsConfig: adk.ToolsConfig{
14841501
ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -1506,13 +1523,13 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [
15061523
}
15071524
newMessageCount := countModelVisibleMessagesSince(snapshot, cursor)
15081525
userPrompt := buildExtractAutoOnlyPrompt(m.resolvedMemoryDirectory, newMessageCount, manifest, m.cfg.Write.SkipIndex)
1509-
msgs := append(projectMessagesToSchema(snapshot), schema.UserMessage(userPrompt))
1526+
msgs := append(append([]M{}, snapshot...), makeUserMsg[M](userPrompt))
15101527
extractionAgent, err := m.newExtractionAgent(ctx, toolInfos)
15111528
if err != nil {
15121529
return err
15131530
}
15141531

1515-
iter := extractionAgent.Run(ctx, &adk.AgentInput{
1532+
iter := extractionAgent.Run(ctx, &adk.TypedAgentInput[M]{
15161533
Messages: msgs,
15171534
EnableStreaming: true,
15181535
})
@@ -1583,30 +1600,27 @@ func parseRFC3339NanoBestEffort(s string) time.Time {
15831600
return time.Time{}
15841601
}
15851602

1586-
type toolInfoOverrideMiddleware struct {
1587-
adk.BaseChatModelAgentMiddleware
1603+
type toolInfoOverrideMiddleware[M adk.MessageType] struct {
1604+
adk.TypedBaseChatModelAgentMiddleware[M]
15881605

1589-
once sync.Once
15901606
toolInfos []*schema.ToolInfo
15911607
}
15921608

1593-
func (t *toolInfoOverrideMiddleware) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[*schema.Message], _ *adk.TypedModelContext[*schema.Message]) (
1594-
context.Context, *adk.TypedChatModelAgentState[*schema.Message], error) {
1609+
func (t *toolInfoOverrideMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[M], _ *adk.TypedModelContext[M]) (
1610+
context.Context, *adk.TypedChatModelAgentState[M], error) {
15951611

1596-
t.once.Do(func() {
1597-
toolNameMapping := make(map[string]struct{}, len(t.toolInfos))
1598-
for _, tool := range t.toolInfos {
1599-
toolNameMapping[tool.Name] = struct{}{}
1600-
}
1612+
toolNameMapping := make(map[string]struct{}, len(t.toolInfos))
1613+
for _, tool := range t.toolInfos {
1614+
toolNameMapping[tool.Name] = struct{}{}
1615+
}
16011616

1602-
overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...)
1603-
for _, tool := range state.ToolInfos {
1604-
if _, ok := toolNameMapping[tool.Name]; !ok { // add fs tools if not exists
1605-
overrideTools = append(overrideTools, tool)
1606-
}
1617+
overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...)
1618+
for _, tool := range state.ToolInfos {
1619+
if _, ok := toolNameMapping[tool.Name]; !ok {
1620+
overrideTools = append(overrideTools, tool)
16071621
}
1608-
state.ToolInfos = overrideTools
1609-
})
1622+
}
1623+
state.ToolInfos = overrideTools
16101624

16111625
return ctx, state, nil
16121626
}

0 commit comments

Comments
 (0)