Skip to content

Commit d94d638

Browse files
fix(adk): stabilize session message ids
Change-Id: I898c2fe90d3457176f59f3641b6562f36f16b282
1 parent 82ba4df commit d94d638

4 files changed

Lines changed: 40 additions & 9 deletions

File tree

adk/chatmodel.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ func newDefaultGenModelInput[M MessageType]() TypedGenModelInput[M] {
280280
}
281281
}
282282

283+
func ensureGeneratedMessageIDs[M MessageType](messages []M) {
284+
for _, msg := range messages {
285+
EnsureMessageID(msg)
286+
}
287+
}
288+
283289
// TypedChatModelAgentState represents the state of a chat model agent during conversation.
284290
// This is the primary state type for both TypedChatModelAgentMiddleware and AgentMiddleware callbacks.
285291
type TypedChatModelAgentState[M MessageType] struct {
@@ -1121,6 +1127,9 @@ func (a *TypedChatModelAgent[M]) buildNoToolsRunFunc(_ context.Context) (typedRu
11211127
if err != nil {
11221128
return nil, err
11231129
}
1130+
if p.sessionEvents {
1131+
ensureGeneratedMessageIDs(messages)
1132+
}
11241133
if err := compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error {
11251134
st.Messages = append(st.Messages, messages...)
11261135
return nil
@@ -1289,6 +1298,9 @@ func (a *TypedChatModelAgent[M]) buildMessageReActRunFunc(_ context.Context, bc
12891298
if genErr != nil {
12901299
return nil, genErr
12911300
}
1301+
if mp.sessionEvents {
1302+
ensureGeneratedMessageIDs(messages)
1303+
}
12921304
return &reactInput{
12931305
Messages: messages,
12941306
}, nil
@@ -1443,6 +1455,9 @@ func (a *TypedChatModelAgent[M]) buildAgenticReActRunFunc(_ context.Context, bc
14431455
if genErr != nil {
14441456
return nil, genErr
14451457
}
1458+
if ap.sessionEvents {
1459+
ensureGeneratedMessageIDs(messages)
1460+
}
14461461
return &agenticReactInput{
14471462
Messages: messages,
14481463
}, nil

adk/middlewares/patchtoolcalls/patchtoolcalls.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ type normalizationPlan[M adk.MessageType] struct {
165165
}
166166

167167
func buildMessageNormalizationPlan(ctx context.Context, cfg Config, messages []*schema.Message) (*normalizationPlan[*schema.Message], error) {
168+
ensureMessageIDs(messages)
169+
168170
counts := analyzeMessages(messages)
169171
if cfg.Strict && counts.hasMismatch() {
170172
return nil, counts.strictError()
@@ -248,6 +250,12 @@ func analyzeMessages(messages []*schema.Message) mismatchCounts {
248250
return counts
249251
}
250252

253+
func ensureMessageIDs[M adk.MessageType](messages []M) {
254+
for _, msg := range messages {
255+
adk.EnsureMessageID(msg)
256+
}
257+
}
258+
251259
func keptMessages(messages []*schema.Message, cfg Config) []bool {
252260
keep := make([]bool, len(messages))
253261
previousCalls := make(map[string]struct{})
@@ -281,6 +289,8 @@ func keptMessages(messages []*schema.Message, cfg Config) []bool {
281289
}
282290

283291
func buildAgenticNormalizationPlan(ctx context.Context, cfg Config, messages []*schema.AgenticMessage) (*normalizationPlan[*schema.AgenticMessage], error) {
292+
ensureMessageIDs(messages)
293+
284294
counts := analyzeAgenticMessages(messages)
285295
if cfg.Strict && counts.hasMismatch() {
286296
return nil, counts.strictError()

adk/runner.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,13 @@ func prepareRunnerSessionRun[M MessageType]( //nolint:revive // argument-limit
319319
TurnID: state.turnID,
320320
Lifecycle: &LifecycleEvent{Scope: LifecycleScopeSession, State: SessionRunStateRunning},
321321
}
322-
if err := assignSessionEventID(ctx, runningEvent, state.sessionConfig.EventIDGenerator); err != nil {
322+
err = assignSessionEventID(ctx, runningEvent, state.sessionConfig.EventIDGenerator)
323+
if err != nil {
323324
_ = state.sessionHandle.close(ctx)
324325
return nil, err
325326
}
326-
if err := appendRunnerSessionControlEvent(ctx, state, runningEvent, ""); err != nil {
327+
err = appendRunnerSessionControlEvent(ctx, state, runningEvent, "")
328+
if err != nil {
327329
_ = state.sessionHandle.close(ctx)
328330
return nil, err
329331
}
@@ -352,7 +354,7 @@ func prepareRunnerSessionRun[M MessageType]( //nolint:revive // argument-limit
352354
return state, nil
353355
}
354356

355-
func prepareRunnerSessionResume[M MessageType](
357+
func prepareRunnerSessionResume[M MessageType]( //nolint:revive // argument-limit
356358
ctx context.Context,
357359
checkPointStore CheckPointStore,
358360
sessionID string,
@@ -568,12 +570,12 @@ func typedRunnerRunImpl[M MessageType](a TypedAgent[M], enableStreaming bool, st
568570
// Capture caller-provided messages BEFORE prepending history. These will be
569571
// emitted as session events at turn start so they appear in the event log.
570572
sessionState.inputMessages = append([]M{}, messages...)
571-
// Assign eino message IDs to input messages (needed for BeforeMessageID references
572-
// emitted by middlewares that anchor on user messages).
573-
for _, msg := range sessionState.inputMessages {
573+
messages = append(append([]M{}, sessionState.latestState.Messages...), sessionState.inputMessages...)
574+
// Assign IDs before messages can be both persisted and inspected by
575+
// middleware, avoiding concurrent lazy ID mutation during event snapshotting.
576+
for _, msg := range messages {
574577
EnsureMessageID(msg)
575578
}
576-
messages = append(append([]M{}, sessionState.latestState.Messages...), sessionState.inputMessages...)
577579
o.sessionValues = mergeSessionValues(sessionState.latestState.SessionValues, o.sessionValues)
578580
opts = append(opts, withEnableSessionEvents())
579581
opts = append(opts, withEnableInternalTimelineEvents())

adk/wrappers.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,9 +911,13 @@ func GetMessageID[M MessageType](msg M) string {
911911
func EnsureMessageID[M MessageType](msg M) {
912912
switch v := any(msg).(type) {
913913
case *schema.Message:
914-
v.Extra = internal.EnsureMessageID(v.Extra)
914+
if internal.GetMessageID(v.Extra) == "" {
915+
v.Extra = internal.EnsureMessageID(v.Extra)
916+
}
915917
case *schema.AgenticMessage:
916-
v.Extra = internal.EnsureMessageID(v.Extra)
918+
if internal.GetMessageID(v.Extra) == "" {
919+
v.Extra = internal.EnsureMessageID(v.Extra)
920+
}
917921
}
918922
}
919923

0 commit comments

Comments
 (0)