@@ -49,17 +49,17 @@ type Config[M adk.MessageType] struct {
4949
5050 MemoryBackend Backend
5151
52- // Model is the default tool-calling model used by topic selection and memory extraction.
52+ // Model is the default model used by topic selection and memory extraction.
5353 // Per-read/per-write overrides can be configured in Read.Model / Write.Model.
54- Model model.ToolCallingChatModel
54+ Model model.BaseModel [ M ]
5555
5656 // Read controls how memories are loaded and injected.
5757 // Optional. Defaults to Sync load with topic selection enabled (if Model is set).
58- Read * ReadConfig
58+ Read * ReadConfig [ M ]
5959
6060 // Write controls post-run memory extraction and persistence.
6161 // Optional. Default: disabled.
62- Write * WriteConfig
62+ Write * WriteConfig [ M ]
6363
6464 // Coordination controls session identity and distributed async extraction coordination.
6565 // Optional. Defaults to a local in-process coordinator.
@@ -78,11 +78,11 @@ const (
7878 ReadModeAsync ReadMode = "async"
7979)
8080
81- type ReadConfig struct {
81+ type ReadConfig [ M adk. MessageType ] struct {
8282 Mode ReadMode
8383
8484 // Model is used for topic selection. Defaults to Config.Model.
85- Model model.ToolCallingChatModel
85+ Model model.BaseModel [ M ]
8686
8787 // Instruction overrides the default auto memory instruction block appended to system prompt.
8888 // Optional.
@@ -126,11 +126,11 @@ const (
126126 WriteModeSync WriteMode = "sync"
127127)
128128
129- type WriteConfig struct {
129+ type WriteConfig [ M adk. MessageType ] struct {
130130 Mode WriteMode
131131
132132 // Model is used for memory extraction. Defaults to Config.Model.
133- Model model.ToolCallingChatModel
133+ Model model.BaseModel [ M ]
134134
135135 // MaxTurns caps the extractor's tool-call loop.
136136 MaxTurns int
@@ -144,7 +144,7 @@ type WriteConfig struct {
144144 //
145145 // If nil, automemory uses the default drain behavior: ignore all events and
146146 // return the first ev.Err encountered (if any).
147- HandleExtractionIterator func (ctx context.Context , iter * adk.AsyncIterator [* adk.AgentEvent ]) error
147+ HandleExtractionIterator func (ctx context.Context , iter * adk.AsyncIterator [* adk.TypedAgentEvent [ M ] ]) error
148148}
149149
150150type middleware [M adk.MessageType ] struct {
@@ -154,8 +154,8 @@ type middleware[M adk.MessageType] struct {
154154
155155 resolvedMemoryDirectory string
156156
157- topicSelectionModel model.ToolCallingChatModel
158- extractionHandler adk.ChatModelAgentMiddleware
157+ topicSelectionModel model.BaseModel [ M ]
158+ extractionHandler adk.TypedChatModelAgentMiddleware [ M ]
159159 topicSelectionTool * schema.ToolInfo
160160 coordination * CoordinationConfig [M ]
161161}
@@ -201,7 +201,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
201201 return nil , fmt .Errorf ("auto memory config: resolve memory directory: %w" , err )
202202 }
203203 if cfg .Read == nil {
204- cfg .Read = & ReadConfig {}
204+ cfg .Read = & ReadConfig [ M ] {}
205205 }
206206 applyReadDefaults (cfg )
207207
@@ -214,11 +214,10 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
214214
215215 m .topicSelectionTool = topicSelectionToolInfo ()
216216 if cfg .Read .TopicSelection != nil && cfg .Read .Model != nil {
217- bound , err := cfg . Read . Model . WithTools ([] * schema. ToolInfo { m . topicSelectionTool })
218- if err != nil {
219- return nil , fmt . Errorf ( "auto memory topic selection model init failed: %w" , err )
217+ m . topicSelectionModel = & modelWithTools [ M ]{
218+ base : cfg . Read . Model ,
219+ tools : [] * schema. ToolInfo { m . topicSelectionTool },
220220 }
221- m .topicSelectionModel = bound
222221 }
223222
224223 if cfg .Write .Mode != WriteModeDisabled && cfg .Write .Model != nil {
@@ -230,7 +229,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
230229 if err != nil {
231230 return nil , err
232231 }
233- fileSystemMiddleware , err := fsmw .New (ctx , & fsmw.MiddlewareConfig {
232+ fileSystemMiddleware , err := fsmw .NewTyped [ M ] (ctx , & fsmw.MiddlewareConfig {
234233 Backend : writeFSBackend ,
235234 LsToolConfig : & fsmw.ToolConfig {Disable : true },
236235 GrepToolConfig : & fsmw.ToolConfig {Disable : true },
@@ -412,7 +411,7 @@ func applyReadDefaults[M adk.MessageType](cfg *Config[M]) {
412411 }
413412
414413 if cfg .Write == nil {
415- cfg .Write = & WriteConfig {Mode : WriteModeDisabled }
414+ cfg .Write = & WriteConfig [ M ] {Mode : WriteModeDisabled }
416415 }
417416 if cfg .Write .Mode == "" {
418417 cfg .Write .Mode = WriteModeDisabled
@@ -764,11 +763,11 @@ func (m *middleware[M]) selectTopicCandidates(
764763 toolInfo := topicSelectionToolInfo ()
765764 resp , err := m .topicSelectionModel .Generate (
766765 ctx ,
767- []* schema. Message {
768- schema . SystemMessage (getTopicSelectionSystemPrompt ()),
769- schema . UserMessage (userMsg ),
766+ []M {
767+ makeSystemMsg [ M ] (getTopicSelectionSystemPrompt ()),
768+ makeUserMsg [ M ] (userMsg ),
770769 },
771- model . WithToolChoice ( schema . ToolChoiceForced , toolInfo .Name ),
770+ makeToolChoiceForced [ M ]( toolInfo .Name ),
772771 )
773772 if err != nil {
774773 return nil , err
@@ -864,11 +863,12 @@ func topicSelectionToolInfo() *schema.ToolInfo {
864863 }
865864}
866865
867- func parseTopicSelectionFromToolCall (msg * schema.Message , valid map [string ]struct {}) ([]string , error ) {
868- if msg == nil || len (msg .ToolCalls ) == 0 {
866+ func parseTopicSelectionFromToolCall [M adk.MessageType ](msg M , valid map [string ]struct {}) ([]string , error ) {
867+ toolCalls := messageToolCalls (msg )
868+ if len (toolCalls ) == 0 {
869869 return nil , fmt .Errorf ("no tool calls" )
870870 }
871- tc := msg . ToolCalls [0 ]
871+ tc := toolCalls [0 ]
872872 if tc .Function .Name != topicSelectionToolName {
873873 return nil , fmt .Errorf ("unexpected tool call: %s" , tc .Function .Name )
874874 }
@@ -1002,6 +1002,31 @@ func setMsgExtra[M adk.MessageType](msg M, key string, value any) {
10021002 }
10031003}
10041004
1005+ func copyMsgExtra [M adk.MessageType ](dst , src M ) {
1006+ srcExtra := getMsgExtra (src )
1007+ if len (srcExtra ) == 0 {
1008+ return
1009+ }
1010+ switch d := any (dst ).(type ) {
1011+ case * schema.Message :
1012+ if d .Extra == nil {
1013+ d .Extra = make (map [string ]any , len (srcExtra ))
1014+ }
1015+ for k , v := range srcExtra {
1016+ d .Extra [k ] = v
1017+ }
1018+ case * schema.AgenticMessage :
1019+ if d .Extra == nil {
1020+ d .Extra = make (map [string ]any , len (srcExtra ))
1021+ }
1022+ for k , v := range srcExtra {
1023+ d .Extra [k ] = v
1024+ }
1025+ default :
1026+ panic ("unreachable" )
1027+ }
1028+ }
1029+
10051030func makeUserMsg [M adk.MessageType ](text string ) M {
10061031 var zero M
10071032 switch any (zero ).(type ) {
@@ -1014,6 +1039,35 @@ func makeUserMsg[M adk.MessageType](text string) M {
10141039 }
10151040}
10161041
1042+ func makeSystemMsg [M adk.MessageType ](text string ) M {
1043+ var zero M
1044+ switch any (zero ).(type ) {
1045+ case * schema.Message :
1046+ return any (schema .SystemMessage (text )).(M )
1047+ case * schema.AgenticMessage :
1048+ return any (schema .SystemAgenticMessage (text )).(M )
1049+ default :
1050+ panic ("unreachable" )
1051+ }
1052+ }
1053+
1054+ func makeToolChoiceForced [M adk.MessageType ](name string ) model.Option {
1055+ var zero M
1056+ switch any (zero ).(type ) {
1057+ case * schema.Message :
1058+ return model .WithToolChoice (schema .ToolChoiceForced , name )
1059+ case * schema.AgenticMessage :
1060+ return model .WithAgenticToolChoice (& schema.AgenticToolChoice {
1061+ Type : schema .ToolChoiceForced ,
1062+ Forced : & schema.AgenticForcedToolChoice {
1063+ Tools : []* schema.AllowedTool {{FunctionName : name }},
1064+ },
1065+ })
1066+ default :
1067+ panic ("unreachable" )
1068+ }
1069+ }
1070+
10171071func messageToolCalls [M adk.MessageType ](msg M ) []schema.ToolCall {
10181072 switch m := any (msg ).(type ) {
10191073 case * schema.Message :
@@ -1069,40 +1123,6 @@ func messageToolNames[M adk.MessageType](msg M) []string {
10691123 }
10701124}
10711125
1072- func projectMessagesToSchema [M adk.MessageType ](msgs []M ) []adk.Message {
1073- out := make ([]adk.Message , 0 , len (msgs ))
1074- for _ , msg := range msgs {
1075- if projected := projectMessageToSchema (msg ); projected != nil {
1076- out = append (out , projected )
1077- }
1078- }
1079- return out
1080- }
1081-
1082- func projectMessageToSchema [M adk.MessageType ](msg M ) adk.Message {
1083- switch m := any (msg ).(type ) {
1084- case * schema.Message :
1085- return m
1086- case * schema.AgenticMessage :
1087- if m == nil {
1088- return nil
1089- }
1090- text := m .String ()
1091- switch m .Role {
1092- case schema .AgenticRoleTypeSystem :
1093- return schema .SystemMessage (text )
1094- case schema .AgenticRoleTypeAssistant :
1095- return schema .AssistantMessage (text , messageToolCalls (msg ))
1096- case schema .AgenticRoleTypeUser :
1097- return schema .UserMessage (text )
1098- default :
1099- return schema .UserMessage (text )
1100- }
1101- default :
1102- panic ("unreachable" )
1103- }
1104- }
1105-
11061126func alreadyInjected [M adk.MessageType ](msgs []M ) bool {
11071127 for _ , m := range msgs {
11081128 if isMemoryMessage (m ) {
@@ -1464,21 +1484,20 @@ func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int
14641484 return countModelVisibleMessages (msgs [cursor :])
14651485}
14661486
1467- func (m * middleware [M ]) newExtractionAgent (ctx context.Context , toolInfos []* schema.ToolInfo ) (* adk.ChatModelAgent , error ) {
1487+ func (m * middleware [M ]) newExtractionAgent (ctx context.Context , toolInfos []* schema.ToolInfo ) (* adk.TypedChatModelAgent [ M ] , error ) {
14681488 if m .cfg == nil || m .cfg .Write == nil || m .cfg .Write .Model == nil {
14691489 return nil , fmt .Errorf ("auto memory extraction agent init failed: missing write model" )
14701490 }
14711491 if m .extractionHandler == nil {
14721492 return nil , fmt .Errorf ("auto memory extraction agent init failed: missing extraction handler" )
14731493 }
14741494
1475- agent , err := adk .NewChatModelAgent (ctx , & adk.ChatModelAgentConfig {
1476- Name : "automemory_extractor" ,
1477- Description : "Internal auto memory extraction subagent" ,
1478- Model : m .cfg .Write .Model ,
1479- Handlers : []adk.ChatModelAgentMiddleware {
1495+ agent , err := adk .NewTypedChatModelAgent [M ](ctx , & adk.TypedChatModelAgentConfig [M ]{
1496+ Name : "automemory_extractor" ,
1497+ Model : m .cfg .Write .Model ,
1498+ Handlers : []adk.TypedChatModelAgentMiddleware [M ]{
14801499 m .extractionHandler , // fs middleware
1481- & toolInfoOverrideMiddleware {toolInfos : toolInfos }, // tool info override, for prefix cache
1500+ & toolInfoOverrideMiddleware [ M ] {toolInfos : toolInfos }, // tool info override, for prefix cache
14821501 },
14831502 ToolsConfig : adk.ToolsConfig {
14841503 ToolsNodeConfig : compose.ToolsNodeConfig {
@@ -1506,13 +1525,13 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [
15061525 }
15071526 newMessageCount := countModelVisibleMessagesSince (snapshot , cursor )
15081527 userPrompt := buildExtractAutoOnlyPrompt (m .resolvedMemoryDirectory , newMessageCount , manifest , m .cfg .Write .SkipIndex )
1509- msgs := append (projectMessagesToSchema ( snapshot ), schema . UserMessage (userPrompt ))
1528+ msgs := append (append ([] M {}, snapshot ... ), makeUserMsg [ M ] (userPrompt ))
15101529 extractionAgent , err := m .newExtractionAgent (ctx , toolInfos )
15111530 if err != nil {
15121531 return err
15131532 }
15141533
1515- iter := extractionAgent .Run (ctx , & adk.AgentInput {
1534+ iter := extractionAgent .Run (ctx , & adk.TypedAgentInput [ M ] {
15161535 Messages : msgs ,
15171536 EnableStreaming : true ,
15181537 })
@@ -1583,30 +1602,46 @@ func parseRFC3339NanoBestEffort(s string) time.Time {
15831602 return time.Time {}
15841603}
15851604
1586- type toolInfoOverrideMiddleware struct {
1587- adk.BaseChatModelAgentMiddleware
1605+ type toolInfoOverrideMiddleware [ M adk. MessageType ] struct {
1606+ adk.TypedBaseChatModelAgentMiddleware [ M ]
15881607
1589- once sync.Once
15901608 toolInfos []* schema.ToolInfo
15911609}
15921610
1593- func (t * toolInfoOverrideMiddleware ) BeforeModelRewriteState (ctx context.Context , state * adk.TypedChatModelAgentState [* schema. Message ], _ * adk.TypedModelContext [* schema. Message ]) (
1594- context.Context , * adk.TypedChatModelAgentState [* schema. Message ], error ) {
1611+ func (t * toolInfoOverrideMiddleware [ M ] ) BeforeModelRewriteState (ctx context.Context , state * adk.TypedChatModelAgentState [M ], _ * adk.TypedModelContext [M ]) (
1612+ context.Context , * adk.TypedChatModelAgentState [M ], error ) {
15951613
1596- t .once .Do (func () {
1597- toolNameMapping := make (map [string ]struct {}, len (t .toolInfos ))
1598- for _ , tool := range t .toolInfos {
1599- toolNameMapping [tool .Name ] = struct {}{}
1600- }
1614+ toolNameMapping := make (map [string ]struct {}, len (t .toolInfos ))
1615+ for _ , tool := range t .toolInfos {
1616+ toolNameMapping [tool .Name ] = struct {}{}
1617+ }
16011618
1602- overrideTools := append ([]* schema.ToolInfo {}, t .toolInfos ... )
1603- for _ , tool := range state .ToolInfos {
1604- if _ , ok := toolNameMapping [tool .Name ]; ! ok { // add fs tools if not exists
1605- overrideTools = append (overrideTools , tool )
1606- }
1619+ overrideTools := append ([]* schema.ToolInfo {}, t .toolInfos ... )
1620+ for _ , tool := range state .ToolInfos {
1621+ if _ , ok := toolNameMapping [tool .Name ]; ! ok {
1622+ overrideTools = append (overrideTools , tool )
16071623 }
1608- state . ToolInfos = overrideTools
1609- })
1624+ }
1625+ state . ToolInfos = overrideTools
16101626
16111627 return ctx , state , nil
16121628}
1629+
1630+ type modelWithTools [M adk.MessageType ] struct {
1631+ base model.BaseModel [M ]
1632+ tools []* schema.ToolInfo
1633+ }
1634+
1635+ func (m * modelWithTools [M ]) Generate (ctx context.Context , input []M , opts ... model.Option ) (M , error ) {
1636+ newOpts := make ([]model.Option , len (opts )+ 1 )
1637+ copy (newOpts , opts )
1638+ newOpts [len (opts )] = model .WithTools (m .tools )
1639+ return m .base .Generate (ctx , input , newOpts ... )
1640+ }
1641+
1642+ func (m * modelWithTools [M ]) Stream (ctx context.Context , input []M , opts ... model.Option ) (* schema.StreamReader [M ], error ) {
1643+ newOpts := make ([]model.Option , len (opts )+ 1 )
1644+ copy (newOpts , opts )
1645+ newOpts [len (opts )] = model .WithTools (m .tools )
1646+ return m .base .Stream (ctx , input , newOpts ... )
1647+ }
0 commit comments