@@ -49,17 +49,17 @@ type Config[M adk.MessageType] struct {
4949
5050 MemoryBackend Backend
5151
52- // Model is the default tool-calling model used by topic selection and memory extraction.
52+ // Model is the default model used by topic selection and memory extraction.
5353 // Per-read/per-write overrides can be configured in Read.Model / Write.Model.
54- Model model.ToolCallingChatModel
54+ Model model.BaseModel [ M ]
5555
5656 // Read controls how memories are loaded and injected.
5757 // Optional. Defaults to Sync load with topic selection enabled (if Model is set).
58- Read * ReadConfig
58+ Read * ReadConfig [ M ]
5959
6060 // Write controls post-run memory extraction and persistence.
6161 // Optional. Default: disabled.
62- Write * WriteConfig
62+ Write * WriteConfig [ M ]
6363
6464 // Coordination controls session identity and distributed async extraction coordination.
6565 // Optional. Defaults to a local in-process coordinator.
@@ -78,11 +78,11 @@ const (
7878 ReadModeAsync ReadMode = "async"
7979)
8080
81- type ReadConfig struct {
81+ type ReadConfig [ M adk. MessageType ] struct {
8282 Mode ReadMode
8383
8484 // Model is used for topic selection. Defaults to Config.Model.
85- Model model.ToolCallingChatModel
85+ Model model.BaseModel [ M ]
8686
8787 // Instruction overrides the default auto memory instruction block appended to system prompt.
8888 // Optional.
@@ -126,11 +126,11 @@ const (
126126 WriteModeSync WriteMode = "sync"
127127)
128128
129- type WriteConfig struct {
129+ type WriteConfig [ M adk. MessageType ] struct {
130130 Mode WriteMode
131131
132132 // Model is used for memory extraction. Defaults to Config.Model.
133- Model model.ToolCallingChatModel
133+ Model model.BaseModel [ M ]
134134
135135 // MaxTurns caps the extractor's tool-call loop.
136136 MaxTurns int
@@ -144,7 +144,7 @@ type WriteConfig struct {
144144 //
145145 // If nil, automemory uses the default drain behavior: ignore all events and
146146 // return the first ev.Err encountered (if any).
147- HandleExtractionIterator func (ctx context.Context , iter * adk.AsyncIterator [* adk.AgentEvent ]) error
147+ HandleExtractionIterator func (ctx context.Context , iter * adk.AsyncIterator [* adk.TypedAgentEvent [ M ] ]) error
148148}
149149
150150type middleware [M adk.MessageType ] struct {
@@ -154,8 +154,8 @@ type middleware[M adk.MessageType] struct {
154154
155155 resolvedMemoryDirectory string
156156
157- topicSelectionModel model.ToolCallingChatModel
158- extractionHandler adk.ChatModelAgentMiddleware
157+ topicSelectionModel model.BaseModel [ M ]
158+ extractionHandler adk.TypedChatModelAgentMiddleware [ M ]
159159 topicSelectionTool * schema.ToolInfo
160160 coordination * CoordinationConfig [M ]
161161}
@@ -201,7 +201,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
201201 return nil , fmt .Errorf ("auto memory config: resolve memory directory: %w" , err )
202202 }
203203 if cfg .Read == nil {
204- cfg .Read = & ReadConfig {}
204+ cfg .Read = & ReadConfig [ M ] {}
205205 }
206206 applyReadDefaults (cfg )
207207
@@ -214,11 +214,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
214214
215215 m .topicSelectionTool = topicSelectionToolInfo ()
216216 if cfg .Read .TopicSelection != nil && cfg .Read .Model != nil {
217- bound , err := cfg .Read .Model .WithTools ([]* schema.ToolInfo {m .topicSelectionTool })
218- if err != nil {
219- return nil , fmt .Errorf ("auto memory topic selection model init failed: %w" , err )
220- }
221- m .topicSelectionModel = bound
217+ m .topicSelectionModel = cfg .Read .Model
222218 }
223219
224220 if cfg .Write .Mode != WriteModeDisabled && cfg .Write .Model != nil {
@@ -230,7 +226,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
230226 if err != nil {
231227 return nil , err
232228 }
233- fileSystemMiddleware , err := fsmw .New (ctx , & fsmw.MiddlewareConfig {
229+ fileSystemMiddleware , err := fsmw .NewTyped [ M ] (ctx , & fsmw.MiddlewareConfig {
234230 Backend : writeFSBackend ,
235231 LsToolConfig : & fsmw.ToolConfig {Disable : true },
236232 GrepToolConfig : & fsmw.ToolConfig {Disable : true },
@@ -412,7 +408,7 @@ func applyReadDefaults[M adk.MessageType](cfg *Config[M]) {
412408 }
413409
414410 if cfg .Write == nil {
415- cfg .Write = & WriteConfig {Mode : WriteModeDisabled }
411+ cfg .Write = & WriteConfig [ M ] {Mode : WriteModeDisabled }
416412 }
417413 if cfg .Write .Mode == "" {
418414 cfg .Write .Mode = WriteModeDisabled
@@ -764,11 +760,12 @@ func (m *middleware[M]) selectTopicCandidates(
764760 toolInfo := topicSelectionToolInfo ()
765761 resp , err := m .topicSelectionModel .Generate (
766762 ctx ,
767- []* schema. Message {
768- schema . SystemMessage (getTopicSelectionSystemPrompt ()),
769- schema . UserMessage (userMsg ),
763+ []M {
764+ makeSystemMsg [ M ] (getTopicSelectionSystemPrompt ()),
765+ makeUserMsg [ M ] (userMsg ),
770766 },
771- model .WithToolChoice (schema .ToolChoiceForced , toolInfo .Name ),
767+ model .WithTools ([]* schema.ToolInfo {toolInfo }),
768+ makeToolChoiceForced [M ](toolInfo .Name ),
772769 )
773770 if err != nil {
774771 return nil , err
@@ -864,11 +861,12 @@ func topicSelectionToolInfo() *schema.ToolInfo {
864861 }
865862}
866863
867- func parseTopicSelectionFromToolCall (msg * schema.Message , valid map [string ]struct {}) ([]string , error ) {
868- if msg == nil || len (msg .ToolCalls ) == 0 {
864+ func parseTopicSelectionFromToolCall [M adk.MessageType ](msg M , valid map [string ]struct {}) ([]string , error ) {
865+ toolCalls := messageToolCalls (msg )
866+ if len (toolCalls ) == 0 {
869867 return nil , fmt .Errorf ("no tool calls" )
870868 }
871- tc := msg . ToolCalls [0 ]
869+ tc := toolCalls [0 ]
872870 if tc .Function .Name != topicSelectionToolName {
873871 return nil , fmt .Errorf ("unexpected tool call: %s" , tc .Function .Name )
874872 }
@@ -1002,6 +1000,31 @@ func setMsgExtra[M adk.MessageType](msg M, key string, value any) {
10021000 }
10031001}
10041002
1003+ func copyMsgExtra [M adk.MessageType ](dst , src M ) {
1004+ srcExtra := getMsgExtra (src )
1005+ if len (srcExtra ) == 0 {
1006+ return
1007+ }
1008+ switch d := any (dst ).(type ) {
1009+ case * schema.Message :
1010+ if d .Extra == nil {
1011+ d .Extra = make (map [string ]any , len (srcExtra ))
1012+ }
1013+ for k , v := range srcExtra {
1014+ d .Extra [k ] = v
1015+ }
1016+ case * schema.AgenticMessage :
1017+ if d .Extra == nil {
1018+ d .Extra = make (map [string ]any , len (srcExtra ))
1019+ }
1020+ for k , v := range srcExtra {
1021+ d .Extra [k ] = v
1022+ }
1023+ default :
1024+ panic ("unreachable" )
1025+ }
1026+ }
1027+
10051028func makeUserMsg [M adk.MessageType ](text string ) M {
10061029 var zero M
10071030 switch any (zero ).(type ) {
@@ -1014,6 +1037,35 @@ func makeUserMsg[M adk.MessageType](text string) M {
10141037 }
10151038}
10161039
1040+ func makeSystemMsg [M adk.MessageType ](text string ) M {
1041+ var zero M
1042+ switch any (zero ).(type ) {
1043+ case * schema.Message :
1044+ return any (schema .SystemMessage (text )).(M )
1045+ case * schema.AgenticMessage :
1046+ return any (schema .SystemAgenticMessage (text )).(M )
1047+ default :
1048+ panic ("unreachable" )
1049+ }
1050+ }
1051+
1052+ func makeToolChoiceForced [M adk.MessageType ](name string ) model.Option {
1053+ var zero M
1054+ switch any (zero ).(type ) {
1055+ case * schema.Message :
1056+ return model .WithToolChoice (schema .ToolChoiceForced , name )
1057+ case * schema.AgenticMessage :
1058+ return model .WithAgenticToolChoice (& schema.AgenticToolChoice {
1059+ Type : schema .ToolChoiceForced ,
1060+ Forced : & schema.AgenticForcedToolChoice {
1061+ Tools : []* schema.AllowedTool {{FunctionName : name }},
1062+ },
1063+ })
1064+ default :
1065+ panic ("unreachable" )
1066+ }
1067+ }
1068+
10171069func messageToolCalls [M adk.MessageType ](msg M ) []schema.ToolCall {
10181070 switch m := any (msg ).(type ) {
10191071 case * schema.Message :
@@ -1069,40 +1121,6 @@ func messageToolNames[M adk.MessageType](msg M) []string {
10691121 }
10701122}
10711123
1072- func projectMessagesToSchema [M adk.MessageType ](msgs []M ) []adk.Message {
1073- out := make ([]adk.Message , 0 , len (msgs ))
1074- for _ , msg := range msgs {
1075- if projected := projectMessageToSchema (msg ); projected != nil {
1076- out = append (out , projected )
1077- }
1078- }
1079- return out
1080- }
1081-
1082- func projectMessageToSchema [M adk.MessageType ](msg M ) adk.Message {
1083- switch m := any (msg ).(type ) {
1084- case * schema.Message :
1085- return m
1086- case * schema.AgenticMessage :
1087- if m == nil {
1088- return nil
1089- }
1090- text := m .String ()
1091- switch m .Role {
1092- case schema .AgenticRoleTypeSystem :
1093- return schema .SystemMessage (text )
1094- case schema .AgenticRoleTypeAssistant :
1095- return schema .AssistantMessage (text , messageToolCalls (msg ))
1096- case schema .AgenticRoleTypeUser :
1097- return schema .UserMessage (text )
1098- default :
1099- return schema .UserMessage (text )
1100- }
1101- default :
1102- panic ("unreachable" )
1103- }
1104- }
1105-
11061124func alreadyInjected [M adk.MessageType ](msgs []M ) bool {
11071125 for _ , m := range msgs {
11081126 if isMemoryMessage (m ) {
@@ -1464,21 +1482,20 @@ func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int
14641482 return countModelVisibleMessages (msgs [cursor :])
14651483}
14661484
1467- func (m * middleware [M ]) newExtractionAgent (ctx context.Context , toolInfos []* schema.ToolInfo ) (* adk.ChatModelAgent , error ) {
1485+ func (m * middleware [M ]) newExtractionAgent (ctx context.Context , toolInfos []* schema.ToolInfo ) (* adk.TypedChatModelAgent [ M ] , error ) {
14681486 if m .cfg == nil || m .cfg .Write == nil || m .cfg .Write .Model == nil {
14691487 return nil , fmt .Errorf ("auto memory extraction agent init failed: missing write model" )
14701488 }
14711489 if m .extractionHandler == nil {
14721490 return nil , fmt .Errorf ("auto memory extraction agent init failed: missing extraction handler" )
14731491 }
14741492
1475- agent , err := adk .NewChatModelAgent (ctx , & adk.ChatModelAgentConfig {
1476- Name : "automemory_extractor" ,
1477- Description : "Internal auto memory extraction subagent" ,
1478- Model : m .cfg .Write .Model ,
1479- Handlers : []adk.ChatModelAgentMiddleware {
1493+ agent , err := adk .NewTypedChatModelAgent [M ](ctx , & adk.TypedChatModelAgentConfig [M ]{
1494+ Name : "automemory_extractor" ,
1495+ Model : m .cfg .Write .Model ,
1496+ Handlers : []adk.TypedChatModelAgentMiddleware [M ]{
14801497 m .extractionHandler , // fs middleware
1481- & toolInfoOverrideMiddleware {toolInfos : toolInfos }, // tool info override, for prefix cache
1498+ & toolInfoOverrideMiddleware [ M ] {toolInfos : toolInfos }, // tool info override, for prefix cache
14821499 },
14831500 ToolsConfig : adk.ToolsConfig {
14841501 ToolsNodeConfig : compose.ToolsNodeConfig {
@@ -1506,13 +1523,13 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [
15061523 }
15071524 newMessageCount := countModelVisibleMessagesSince (snapshot , cursor )
15081525 userPrompt := buildExtractAutoOnlyPrompt (m .resolvedMemoryDirectory , newMessageCount , manifest , m .cfg .Write .SkipIndex )
1509- msgs := append (projectMessagesToSchema ( snapshot ), schema . UserMessage (userPrompt ))
1526+ msgs := append (append ([] M {}, snapshot ... ), makeUserMsg [ M ] (userPrompt ))
15101527 extractionAgent , err := m .newExtractionAgent (ctx , toolInfos )
15111528 if err != nil {
15121529 return err
15131530 }
15141531
1515- iter := extractionAgent .Run (ctx , & adk.AgentInput {
1532+ iter := extractionAgent .Run (ctx , & adk.TypedAgentInput [ M ] {
15161533 Messages : msgs ,
15171534 EnableStreaming : true ,
15181535 })
@@ -1583,30 +1600,27 @@ func parseRFC3339NanoBestEffort(s string) time.Time {
15831600 return time.Time {}
15841601}
15851602
1586- type toolInfoOverrideMiddleware struct {
1587- adk.BaseChatModelAgentMiddleware
1603+ type toolInfoOverrideMiddleware [ M adk. MessageType ] struct {
1604+ adk.TypedBaseChatModelAgentMiddleware [ M ]
15881605
1589- once sync.Once
15901606 toolInfos []* schema.ToolInfo
15911607}
15921608
1593- func (t * toolInfoOverrideMiddleware ) BeforeModelRewriteState (ctx context.Context , state * adk.TypedChatModelAgentState [* schema. Message ], _ * adk.TypedModelContext [* schema. Message ]) (
1594- context.Context , * adk.TypedChatModelAgentState [* schema. Message ], error ) {
1609+ func (t * toolInfoOverrideMiddleware [ M ] ) BeforeModelRewriteState (ctx context.Context , state * adk.TypedChatModelAgentState [M ], _ * adk.TypedModelContext [M ]) (
1610+ context.Context , * adk.TypedChatModelAgentState [M ], error ) {
15951611
1596- t .once .Do (func () {
1597- toolNameMapping := make (map [string ]struct {}, len (t .toolInfos ))
1598- for _ , tool := range t .toolInfos {
1599- toolNameMapping [tool .Name ] = struct {}{}
1600- }
1612+ toolNameMapping := make (map [string ]struct {}, len (t .toolInfos ))
1613+ for _ , tool := range t .toolInfos {
1614+ toolNameMapping [tool .Name ] = struct {}{}
1615+ }
16011616
1602- overrideTools := append ([]* schema.ToolInfo {}, t .toolInfos ... )
1603- for _ , tool := range state .ToolInfos {
1604- if _ , ok := toolNameMapping [tool .Name ]; ! ok { // add fs tools if not exists
1605- overrideTools = append (overrideTools , tool )
1606- }
1617+ overrideTools := append ([]* schema.ToolInfo {}, t .toolInfos ... )
1618+ for _ , tool := range state .ToolInfos {
1619+ if _ , ok := toolNameMapping [tool .Name ]; ! ok {
1620+ overrideTools = append (overrideTools , tool )
16071621 }
1608- state . ToolInfos = overrideTools
1609- })
1622+ }
1623+ state . ToolInfos = overrideTools
16101624
16111625 return ctx , state , nil
16121626}
0 commit comments