Skip to content

Commit dc72a3e

Browse files
committed
feat: adapt agentic message (#1071)
1 parent d1947f2 commit dc72a3e

4 files changed

Lines changed: 156 additions & 109 deletions

File tree

adk/middlewares/automemory/automemory.go

Lines changed: 120 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ type Config[M adk.MessageType] struct {
4949

5050
MemoryBackend Backend
5151

52-
// Model is the default tool-calling model used by topic selection and memory extraction.
52+
// Model is the default model used by topic selection and memory extraction.
5353
// Per-read/per-write overrides can be configured in Read.Model / Write.Model.
54-
Model model.ToolCallingChatModel
54+
Model model.BaseModel[M]
5555

5656
// Read controls how memories are loaded and injected.
5757
// Optional. Defaults to Sync load with topic selection enabled (if Model is set).
58-
Read *ReadConfig
58+
Read *ReadConfig[M]
5959

6060
// Write controls post-run memory extraction and persistence.
6161
// Optional. Default: disabled.
62-
Write *WriteConfig
62+
Write *WriteConfig[M]
6363

6464
// Coordination controls session identity and distributed async extraction coordination.
6565
// Optional. Defaults to a local in-process coordinator.
@@ -78,11 +78,11 @@ const (
7878
ReadModeAsync ReadMode = "async"
7979
)
8080

81-
type ReadConfig struct {
81+
type ReadConfig[M adk.MessageType] struct {
8282
Mode ReadMode
8383

8484
// Model is used for topic selection. Defaults to Config.Model.
85-
Model model.ToolCallingChatModel
85+
Model model.BaseModel[M]
8686

8787
// Instruction overrides the default auto memory instruction block appended to system prompt.
8888
// Optional.
@@ -126,11 +126,11 @@ const (
126126
WriteModeSync WriteMode = "sync"
127127
)
128128

129-
type WriteConfig struct {
129+
type WriteConfig[M adk.MessageType] struct {
130130
Mode WriteMode
131131

132132
// Model is used for memory extraction. Defaults to Config.Model.
133-
Model model.ToolCallingChatModel
133+
Model model.BaseModel[M]
134134

135135
// MaxTurns caps the extractor's tool-call loop.
136136
MaxTurns int
@@ -144,7 +144,7 @@ type WriteConfig struct {
144144
//
145145
// If nil, automemory uses the default drain behavior: ignore all events and
146146
// return the first ev.Err encountered (if any).
147-
HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error
147+
HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.TypedAgentEvent[M]]) error
148148
}
149149

150150
type middleware[M adk.MessageType] struct {
@@ -154,8 +154,8 @@ type middleware[M adk.MessageType] struct {
154154

155155
resolvedMemoryDirectory string
156156

157-
topicSelectionModel model.ToolCallingChatModel
158-
extractionHandler adk.ChatModelAgentMiddleware
157+
topicSelectionModel model.BaseModel[M]
158+
extractionHandler adk.TypedChatModelAgentMiddleware[M]
159159
topicSelectionTool *schema.ToolInfo
160160
coordination *CoordinationConfig[M]
161161
}
@@ -201,7 +201,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
201201
return nil, fmt.Errorf("auto memory config: resolve memory directory: %w", err)
202202
}
203203
if cfg.Read == nil {
204-
cfg.Read = &ReadConfig{}
204+
cfg.Read = &ReadConfig[M]{}
205205
}
206206
applyReadDefaults(cfg)
207207

@@ -214,11 +214,10 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
214214

215215
m.topicSelectionTool = topicSelectionToolInfo()
216216
if cfg.Read.TopicSelection != nil && cfg.Read.Model != nil {
217-
bound, err := cfg.Read.Model.WithTools([]*schema.ToolInfo{m.topicSelectionTool})
218-
if err != nil {
219-
return nil, fmt.Errorf("auto memory topic selection model init failed: %w", err)
217+
m.topicSelectionModel = &modelWithTools[M]{
218+
base: cfg.Read.Model,
219+
tools: []*schema.ToolInfo{m.topicSelectionTool},
220220
}
221-
m.topicSelectionModel = bound
222221
}
223222

224223
if cfg.Write.Mode != WriteModeDisabled && cfg.Write.Model != nil {
@@ -230,7 +229,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
230229
if err != nil {
231230
return nil, err
232231
}
233-
fileSystemMiddleware, err := fsmw.New(ctx, &fsmw.MiddlewareConfig{
232+
fileSystemMiddleware, err := fsmw.NewTyped[M](ctx, &fsmw.MiddlewareConfig{
234233
Backend: writeFSBackend,
235234
LsToolConfig: &fsmw.ToolConfig{Disable: true},
236235
GrepToolConfig: &fsmw.ToolConfig{Disable: true},
@@ -412,7 +411,7 @@ func applyReadDefaults[M adk.MessageType](cfg *Config[M]) {
412411
}
413412

414413
if cfg.Write == nil {
415-
cfg.Write = &WriteConfig{Mode: WriteModeDisabled}
414+
cfg.Write = &WriteConfig[M]{Mode: WriteModeDisabled}
416415
}
417416
if cfg.Write.Mode == "" {
418417
cfg.Write.Mode = WriteModeDisabled
@@ -764,11 +763,11 @@ func (m *middleware[M]) selectTopicCandidates(
764763
toolInfo := topicSelectionToolInfo()
765764
resp, err := m.topicSelectionModel.Generate(
766765
ctx,
767-
[]*schema.Message{
768-
schema.SystemMessage(getTopicSelectionSystemPrompt()),
769-
schema.UserMessage(userMsg),
766+
[]M{
767+
makeSystemMsg[M](getTopicSelectionSystemPrompt()),
768+
makeUserMsg[M](userMsg),
770769
},
771-
model.WithToolChoice(schema.ToolChoiceForced, toolInfo.Name),
770+
makeToolChoiceForced[M](toolInfo.Name),
772771
)
773772
if err != nil {
774773
return nil, err
@@ -864,11 +863,12 @@ func topicSelectionToolInfo() *schema.ToolInfo {
864863
}
865864
}
866865

867-
func parseTopicSelectionFromToolCall(msg *schema.Message, valid map[string]struct{}) ([]string, error) {
868-
if msg == nil || len(msg.ToolCalls) == 0 {
866+
func parseTopicSelectionFromToolCall[M adk.MessageType](msg M, valid map[string]struct{}) ([]string, error) {
867+
toolCalls := messageToolCalls(msg)
868+
if len(toolCalls) == 0 {
869869
return nil, fmt.Errorf("no tool calls")
870870
}
871-
tc := msg.ToolCalls[0]
871+
tc := toolCalls[0]
872872
if tc.Function.Name != topicSelectionToolName {
873873
return nil, fmt.Errorf("unexpected tool call: %s", tc.Function.Name)
874874
}
@@ -1002,6 +1002,31 @@ func setMsgExtra[M adk.MessageType](msg M, key string, value any) {
10021002
}
10031003
}
10041004

1005+
func copyMsgExtra[M adk.MessageType](dst, src M) {
1006+
srcExtra := getMsgExtra(src)
1007+
if len(srcExtra) == 0 {
1008+
return
1009+
}
1010+
switch d := any(dst).(type) {
1011+
case *schema.Message:
1012+
if d.Extra == nil {
1013+
d.Extra = make(map[string]any, len(srcExtra))
1014+
}
1015+
for k, v := range srcExtra {
1016+
d.Extra[k] = v
1017+
}
1018+
case *schema.AgenticMessage:
1019+
if d.Extra == nil {
1020+
d.Extra = make(map[string]any, len(srcExtra))
1021+
}
1022+
for k, v := range srcExtra {
1023+
d.Extra[k] = v
1024+
}
1025+
default:
1026+
panic("unreachable")
1027+
}
1028+
}
1029+
10051030
func makeUserMsg[M adk.MessageType](text string) M {
10061031
var zero M
10071032
switch any(zero).(type) {
@@ -1014,6 +1039,35 @@ func makeUserMsg[M adk.MessageType](text string) M {
10141039
}
10151040
}
10161041

1042+
func makeSystemMsg[M adk.MessageType](text string) M {
1043+
var zero M
1044+
switch any(zero).(type) {
1045+
case *schema.Message:
1046+
return any(schema.SystemMessage(text)).(M)
1047+
case *schema.AgenticMessage:
1048+
return any(schema.SystemAgenticMessage(text)).(M)
1049+
default:
1050+
panic("unreachable")
1051+
}
1052+
}
1053+
1054+
func makeToolChoiceForced[M adk.MessageType](name string) model.Option {
1055+
var zero M
1056+
switch any(zero).(type) {
1057+
case *schema.Message:
1058+
return model.WithToolChoice(schema.ToolChoiceForced, name)
1059+
case *schema.AgenticMessage:
1060+
return model.WithAgenticToolChoice(&schema.AgenticToolChoice{
1061+
Type: schema.ToolChoiceForced,
1062+
Forced: &schema.AgenticForcedToolChoice{
1063+
Tools: []*schema.AllowedTool{{FunctionName: name}},
1064+
},
1065+
})
1066+
default:
1067+
panic("unreachable")
1068+
}
1069+
}
1070+
10171071
func messageToolCalls[M adk.MessageType](msg M) []schema.ToolCall {
10181072
switch m := any(msg).(type) {
10191073
case *schema.Message:
@@ -1069,40 +1123,6 @@ func messageToolNames[M adk.MessageType](msg M) []string {
10691123
}
10701124
}
10711125

1072-
func projectMessagesToSchema[M adk.MessageType](msgs []M) []adk.Message {
1073-
out := make([]adk.Message, 0, len(msgs))
1074-
for _, msg := range msgs {
1075-
if projected := projectMessageToSchema(msg); projected != nil {
1076-
out = append(out, projected)
1077-
}
1078-
}
1079-
return out
1080-
}
1081-
1082-
func projectMessageToSchema[M adk.MessageType](msg M) adk.Message {
1083-
switch m := any(msg).(type) {
1084-
case *schema.Message:
1085-
return m
1086-
case *schema.AgenticMessage:
1087-
if m == nil {
1088-
return nil
1089-
}
1090-
text := m.String()
1091-
switch m.Role {
1092-
case schema.AgenticRoleTypeSystem:
1093-
return schema.SystemMessage(text)
1094-
case schema.AgenticRoleTypeAssistant:
1095-
return schema.AssistantMessage(text, messageToolCalls(msg))
1096-
case schema.AgenticRoleTypeUser:
1097-
return schema.UserMessage(text)
1098-
default:
1099-
return schema.UserMessage(text)
1100-
}
1101-
default:
1102-
panic("unreachable")
1103-
}
1104-
}
1105-
11061126
func alreadyInjected[M adk.MessageType](msgs []M) bool {
11071127
for _, m := range msgs {
11081128
if isMemoryMessage(m) {
@@ -1464,21 +1484,20 @@ func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int
14641484
return countModelVisibleMessages(msgs[cursor:])
14651485
}
14661486

1467-
func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.ChatModelAgent, error) {
1487+
func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.TypedChatModelAgent[M], error) {
14681488
if m.cfg == nil || m.cfg.Write == nil || m.cfg.Write.Model == nil {
14691489
return nil, fmt.Errorf("auto memory extraction agent init failed: missing write model")
14701490
}
14711491
if m.extractionHandler == nil {
14721492
return nil, fmt.Errorf("auto memory extraction agent init failed: missing extraction handler")
14731493
}
14741494

1475-
agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
1476-
Name: "automemory_extractor",
1477-
Description: "Internal auto memory extraction subagent",
1478-
Model: m.cfg.Write.Model,
1479-
Handlers: []adk.ChatModelAgentMiddleware{
1495+
agent, err := adk.NewTypedChatModelAgent[M](ctx, &adk.TypedChatModelAgentConfig[M]{
1496+
Name: "automemory_extractor",
1497+
Model: m.cfg.Write.Model,
1498+
Handlers: []adk.TypedChatModelAgentMiddleware[M]{
14801499
m.extractionHandler, // fs middleware
1481-
&toolInfoOverrideMiddleware{toolInfos: toolInfos}, // tool info override, for prefix cache
1500+
&toolInfoOverrideMiddleware[M]{toolInfos: toolInfos}, // tool info override, for prefix cache
14821501
},
14831502
ToolsConfig: adk.ToolsConfig{
14841503
ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -1506,13 +1525,13 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [
15061525
}
15071526
newMessageCount := countModelVisibleMessagesSince(snapshot, cursor)
15081527
userPrompt := buildExtractAutoOnlyPrompt(m.resolvedMemoryDirectory, newMessageCount, manifest, m.cfg.Write.SkipIndex)
1509-
msgs := append(projectMessagesToSchema(snapshot), schema.UserMessage(userPrompt))
1528+
msgs := append(append([]M{}, snapshot...), makeUserMsg[M](userPrompt))
15101529
extractionAgent, err := m.newExtractionAgent(ctx, toolInfos)
15111530
if err != nil {
15121531
return err
15131532
}
15141533

1515-
iter := extractionAgent.Run(ctx, &adk.AgentInput{
1534+
iter := extractionAgent.Run(ctx, &adk.TypedAgentInput[M]{
15161535
Messages: msgs,
15171536
EnableStreaming: true,
15181537
})
@@ -1583,30 +1602,46 @@ func parseRFC3339NanoBestEffort(s string) time.Time {
15831602
return time.Time{}
15841603
}
15851604

1586-
type toolInfoOverrideMiddleware struct {
1587-
adk.BaseChatModelAgentMiddleware
1605+
type toolInfoOverrideMiddleware[M adk.MessageType] struct {
1606+
adk.TypedBaseChatModelAgentMiddleware[M]
15881607

1589-
once sync.Once
15901608
toolInfos []*schema.ToolInfo
15911609
}
15921610

1593-
func (t *toolInfoOverrideMiddleware) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[*schema.Message], _ *adk.TypedModelContext[*schema.Message]) (
1594-
context.Context, *adk.TypedChatModelAgentState[*schema.Message], error) {
1611+
func (t *toolInfoOverrideMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[M], _ *adk.TypedModelContext[M]) (
1612+
context.Context, *adk.TypedChatModelAgentState[M], error) {
15951613

1596-
t.once.Do(func() {
1597-
toolNameMapping := make(map[string]struct{}, len(t.toolInfos))
1598-
for _, tool := range t.toolInfos {
1599-
toolNameMapping[tool.Name] = struct{}{}
1600-
}
1614+
toolNameMapping := make(map[string]struct{}, len(t.toolInfos))
1615+
for _, tool := range t.toolInfos {
1616+
toolNameMapping[tool.Name] = struct{}{}
1617+
}
16011618

1602-
overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...)
1603-
for _, tool := range state.ToolInfos {
1604-
if _, ok := toolNameMapping[tool.Name]; !ok { // add fs tools if not exists
1605-
overrideTools = append(overrideTools, tool)
1606-
}
1619+
overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...)
1620+
for _, tool := range state.ToolInfos {
1621+
if _, ok := toolNameMapping[tool.Name]; !ok {
1622+
overrideTools = append(overrideTools, tool)
16071623
}
1608-
state.ToolInfos = overrideTools
1609-
})
1624+
}
1625+
state.ToolInfos = overrideTools
16101626

16111627
return ctx, state, nil
16121628
}
1629+
1630+
type modelWithTools[M adk.MessageType] struct {
1631+
base model.BaseModel[M]
1632+
tools []*schema.ToolInfo
1633+
}
1634+
1635+
func (m *modelWithTools[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) {
1636+
newOpts := make([]model.Option, len(opts)+1)
1637+
copy(newOpts, opts)
1638+
newOpts[len(opts)] = model.WithTools(m.tools)
1639+
return m.base.Generate(ctx, input, newOpts...)
1640+
}
1641+
1642+
func (m *modelWithTools[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) {
1643+
newOpts := make([]model.Option, len(opts)+1)
1644+
copy(newOpts, opts)
1645+
newOpts[len(opts)] = model.WithTools(m.tools)
1646+
return m.base.Stream(ctx, input, newOpts...)
1647+
}

0 commit comments

Comments
 (0)