diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 2b7c75ce5..abbfac8d5 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -348,10 +348,12 @@ func TestWithCancel_AgenticResumeStreamableToolTimeout_DoesNotPersistTypedNil(t }() select { case err = <-cancelDone: - assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout), "unexpected cancel wait error: %v", err) + assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout) || errors.Is(err, ErrExecutionEnded), + "unexpected cancel wait error: %v", err) case <-time.After(5 * time.Second): t.Fatal("resume cancel handle did not complete") } + executionCompletedBeforeCancel := errors.Is(err, ErrExecutionEnded) var hasCancelError bool for { @@ -371,7 +373,8 @@ func TestWithCancel_AgenticResumeStreamableToolTimeout_DoesNotPersistTypedNil(t assert.NotContains(t, errText, "cannot encode nil pointer") assert.NotContains(t, errText, "*adk.agenticReactInput(nil=true") } - assert.True(t, hasCancelError, "expected CancelError in resume event stream") + assert.True(t, hasCancelError || executionCompletedBeforeCancel, + "expected CancelError in resume event stream unless execution completed before cancel") } func TestCancelContext(t *testing.T) { @@ -2561,7 +2564,7 @@ func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { t.Run("unit_SendEvent_no_execCtx", func(t *testing.T) { err := SendEvent(context.Background(), &AgentEvent{AgentName: "test"}) - assert.Error(t, err, "SendEvent without execCtx should return error") + assert.NoError(t, err, "SendEvent without execCtx should be a no-op") }) t.Run("integration_cancel_escalation_orphans_tool", func(t *testing.T) { diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 57d8036d3..8e0cee42f 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -1674,7 +1674,6 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...))) } } - ctx = contextWithToolPermissionDecisionStore(ctx) go func() { defer func() { @@ -1803,7 +1802,6 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o return nil })) } - ctx = contextWithToolPermissionDecisionStore(ctx) go func() { defer func() { diff --git a/adk/coverage_contract_test.go b/adk/coverage_contract_test.go index 3dc9c8a88..c97dbe2c9 100644 --- a/adk/coverage_contract_test.go +++ b/adk/coverage_contract_test.go @@ -156,27 +156,6 @@ func TestCommonOptionsAndFilteringContracts(t *testing.T) { assert.Len(t, filterOptions("parent", []AgentRunOption{nonCallback.DesignateAgent("parent"), otherCallback, {}}), 2) } -func TestToolPermissionDecisionStoreContracts(t *testing.T) { - ctx := context.Background() - assert.Empty(t, GetToolPermissionDecision(ctx, "call-1")) - SetToolPermissionDecision(ctx, "call-1", "allowed") - assert.Empty(t, GetToolPermissionDecision(ctx, "call-1")) - - ctx = contextWithToolPermissionDecisionStore(ctx) - same := contextWithToolPermissionDecisionStore(ctx) - assert.Same(t, ctx, same) - - SetToolPermissionDecision(ctx, "", "allowed") - SetToolPermissionDecision(ctx, "call-1", "") - assert.Empty(t, GetToolPermissionDecision(ctx, "call-1")) - - SetToolPermissionDecision(ctx, "call-1", "allowed") - SetToolPermissionDecision(ctx, "call-2", "denied") - assert.Equal(t, "allowed", GetToolPermissionDecision(ctx, "call-1")) - assert.Equal(t, "denied", GetToolPermissionDecision(ctx, "call-2")) - assert.Empty(t, GetToolPermissionDecision(ctx, "")) -} - func TestLocalSessionServiceHandleContracts(t *testing.T) { ctx := context.Background() assert.Nil(t, NewLocalSessionService[*schema.Message](nil)) diff --git a/adk/handler.go b/adk/handler.go index 5cc55c8c7..472831da2 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -419,12 +419,12 @@ func DeleteRunLocalValue(ctx context.Context, key string) error { // via internal wrapper layers. If your middleware constructs its own messages, call // EnsureMessageID before sending to assign an ID. // -// This function can only be called from within a TypedChatModelAgentMiddleware during agent execution. -// Returns an error if called outside of an agent execution context. +// When called outside of an agent execution context, or from a path without an +// event generator, this function is a no-op. func TypedSendEvent[M MessageType](ctx context.Context, event *TypedAgentEvent[M]) error { execCtx := getTypedChatModelAgentExecCtx[M](ctx) if execCtx == nil || execCtx.generator == nil { - return fmt.Errorf("TypedSendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") + return nil } execCtx.send(ctx, event) @@ -437,8 +437,8 @@ func TypedSendEvent[M MessageType](ctx context.Context, event *TypedAgentEvent[M // For custom session timeline events during a Runner run, set AgentEvent.SessionEvent // to an extension SessionEvent with an x.* Kind and send it through this function. // -// This function can only be called from within a ChatModelAgentMiddleware during agent execution. -// Returns an error if called outside of an agent execution context. +// When called outside of an agent execution context, or from a path without an +// event generator, this function is a no-op. func SendEvent(ctx context.Context, event *AgentEvent) error { return TypedSendEvent(ctx, event) } diff --git a/adk/interrupt.go b/adk/interrupt.go index 97d424db9..9d51108bc 100644 --- a/adk/interrupt.go +++ b/adk/interrupt.go @@ -310,8 +310,15 @@ func encodeRunnerCheckPointImpl( info *InterruptInfo, is *core.InterruptSignal, ) ([]byte, error) { - runCtx := getRunCtx(ctx) + return encodeRunnerCheckPointWithRunCtx(enableStreaming, getRunCtx(ctx), info, is) +} +func encodeRunnerCheckPointWithRunCtx( + enableStreaming bool, + runCtx *runContext, + info *InterruptInfo, + is *core.InterruptSignal, +) ([]byte, error) { id2Addr, id2State := core.SignalToPersistenceMaps(is) buf := &bytes.Buffer{} diff --git a/adk/middlewares/permission/permission.go b/adk/middlewares/permission/permission.go index 833884f81..7f0950cb1 100644 --- a/adk/middlewares/permission/permission.go +++ b/adk/middlewares/permission/permission.go @@ -32,6 +32,7 @@ import ( func init() { schema.RegisterName[*AskInfo]("_eino_adk_permission_ask_info") schema.RegisterName[*AskState]("_eino_adk_permission_ask_state") + schema.RegisterName[*DecisionEvent]("_eino_adk_permission_decision_event") } // GateDecision is the result of a pre-execution permission check. @@ -47,6 +48,12 @@ const ( GateAsk GateDecision = "ask" ) +const ( + // SessionEventPermissionDecision records a valid user resume decision for a + // previously interrupted permission ask. + SessionEventPermissionDecision adk.SessionEventKind = adk.SessionEventKind(adk.SessionEventExtensionPrefix + "permission.decision") +) + // GateCheckResult determines how a tool call should proceed before execution. type GateCheckResult struct { Decision GateDecision @@ -114,6 +121,18 @@ type ResumeResponse struct { Message string } +// DecisionEvent is the typed payload for SessionEventPermissionDecision. +// It intentionally omits the original saved tool arguments; only user-provided +// UpdatedInput is carried when it is part of an approval decision. +type DecisionEvent struct { + Action ResumeAction `json:"action"` + ToolName string `json:"tool_name"` + ToolUseID string `json:"tool_use_id,omitempty"` + DecisionText string `json:"decision_text,omitempty"` + UpdatedInput string `json:"updated_input,omitempty"` + HasUpdatedInput bool `json:"has_updated_input,omitempty"` +} + // Middleware gates tool calls with a permission Checker. type Middleware[M adk.MessageType] struct { *adk.TypedBaseChatModelAgentMiddleware[M] @@ -139,6 +158,13 @@ type gateResult struct { argument *schema.ToolArgument } +type normalizedResumeDecision struct { + Action ResumeAction + UpdatedInput string + HasUpdatedInput bool + DecisionText string +} + func (m *Middleware[M]) permissionGate( ctx context.Context, tCtx *adk.ToolContext, @@ -162,6 +188,9 @@ func (m *Middleware[M]) permissionGate( if !hasState || savedState == nil { return nil, fmt.Errorf("permission: missing AskState for targeted resume of tool %q (call_id=%s)", tCtx.Name, tCtx.CallID) } + if err := emitDecisionEvent[M](ctx, tCtx, savedState, response); err != nil { + return nil, err + } return handleResumeResponse(ctx, tCtx, &schema.ToolArgument{Text: savedState.Arguments}, response) } @@ -191,16 +220,13 @@ func (m *Middleware[M]) permissionGate( switch decision.Decision { case GateAllow: - adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateAllow)) return &gateResult{ allowed: true, argument: withUpdatedInput(argument, decision.UpdatedInput, decision.HasUpdatedInput || decision.UpdatedInput != ""), }, nil case GateDeny: - adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateDeny)) return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.Message)}, nil case GateAsk: - adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateAsk)) info := &AskInfo{ ToolName: tCtx.Name, Summary: publicSummary(decision.Message, tCtx.CallID, argument.Text), @@ -250,40 +276,94 @@ func handleResumeResponse( argument *schema.ToolArgument, response *ResumeResponse, ) (*gateResult, error) { - if response == nil { - return nil, fmt.Errorf("permission: nil ResumeResponse for tool %q (call_id=%s)", tCtx.Name, tCtx.CallID) + decision, err := normalizeResumeDecision(tCtx, response) + if err != nil { + return nil, err } - switch response.Action { + switch decision.Action { case ResumeActionApprove: - adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionApprove)) return &gateResult{ allowed: true, - argument: withUpdatedInput(argument, response.UpdatedInput, response.HasUpdatedInput || response.UpdatedInput != ""), + argument: withUpdatedInput(argument, decision.UpdatedInput, decision.HasUpdatedInput), }, nil case ResumeActionReject: - adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionReject)) - message := response.Message - if message == "" { - message = "rejected by user" + return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.DecisionText)}, nil + case ResumeActionRespond: + return &gateResult{denyResult: formatRespondResult(tCtx.Name, decision.DecisionText)}, nil + default: + return nil, fmt.Errorf("permission: unknown resume action %q for tool %q (call_id=%s); expected approve, reject, or respond", + decision.Action, tCtx.Name, tCtx.CallID) + } +} + +func normalizeResumeDecision(tCtx *adk.ToolContext, response *ResumeResponse) (*normalizedResumeDecision, error) { + toolName, callID := "", "" + if tCtx != nil { + toolName = tCtx.Name + callID = tCtx.CallID + } + if response == nil { + return nil, fmt.Errorf("permission: nil ResumeResponse for tool %q (call_id=%s)", toolName, callID) + } + + decision := &normalizedResumeDecision{Action: response.Action} + switch response.Action { + case ResumeActionApprove: + decision.HasUpdatedInput = response.HasUpdatedInput || response.UpdatedInput != "" + if decision.HasUpdatedInput { + decision.UpdatedInput = response.UpdatedInput + } + return decision, nil + case ResumeActionReject: + decision.DecisionText = response.Message + if decision.DecisionText == "" { + decision.DecisionText = "rejected by user" } - return &gateResult{denyResult: formatDenyResult(tCtx.Name, message)}, nil + return decision, nil case ResumeActionRespond: - adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionRespond)) if response.Message == "" { return nil, fmt.Errorf("permission: empty response message for respond action on tool %q (call_id=%s)", - tCtx.Name, tCtx.CallID) + toolName, callID) } - return &gateResult{denyResult: formatRespondResult(tCtx.Name, response.Message)}, nil + decision.DecisionText = response.Message + return decision, nil case "": return nil, fmt.Errorf("permission: empty resume action for tool %q (call_id=%s); expected approve, reject, or respond", - tCtx.Name, tCtx.CallID) + toolName, callID) default: return nil, fmt.Errorf("permission: unknown resume action %q for tool %q (call_id=%s); expected approve, reject, or respond", - response.Action, tCtx.Name, tCtx.CallID) + response.Action, toolName, callID) } } +func emitDecisionEvent[M adk.MessageType](ctx context.Context, tCtx *adk.ToolContext, state *AskState, response *ResumeResponse) error { + if tCtx == nil { + return fmt.Errorf("permission: nil ToolContext for resume decision event") + } + if state == nil { + return fmt.Errorf("permission: nil AskState for resume decision event on tool %q (call_id=%s)", tCtx.Name, tCtx.CallID) + } + decision, err := normalizeResumeDecision(tCtx, response) + if err != nil { + return err + } + payload := &DecisionEvent{ + Action: decision.Action, + ToolName: state.ToolName, + ToolUseID: state.CallID, + DecisionText: decision.DecisionText, + UpdatedInput: decision.UpdatedInput, + HasUpdatedInput: decision.HasUpdatedInput, + } + return adk.TypedSendEvent[M](ctx, &adk.TypedAgentEvent[M]{ + SessionEvent: &adk.SessionEvent[M]{ + Kind: SessionEventPermissionDecision, + Extension: &adk.SessionExtensionEvent{Data: payload}, + }, + }) +} + func (m *Middleware[M]) WrapInvokableToolCall( _ context.Context, endpoint adk.InvokableToolCallEndpoint, diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go index 96c6f0363..80e7faf4b 100644 --- a/adk/middlewares/permission/permission_test.go +++ b/adk/middlewares/permission/permission_test.go @@ -690,6 +690,261 @@ func TestPermissionDecisionAppearsInToolUseTimeline(t *testing.T) { assert.Equal(t, `{"path":"/tmp/file"}`, captureTool.received) } +func TestPermissionDecisionEventResumeLiveAndPersisted(t *testing.T) { + tests := []struct { + name string + response *ResumeResponse + wantAction ResumeAction + wantDecisionText string + wantUpdatedInput string + wantHasUpdated bool + wantToolInput string + wantToolNotInvoked bool + }{ + { + name: "approve with updated input", + response: &ResumeResponse{ + Action: ResumeActionApprove, + UpdatedInput: `{"path":"/tmp/safe.txt"}`, + }, + wantAction: ResumeActionApprove, + wantUpdatedInput: `{"path":"/tmp/safe.txt"}`, + wantHasUpdated: true, + wantToolInput: `{"path":"/tmp/safe.txt"}`, + }, + { + name: "approve with explicit empty updated input", + response: &ResumeResponse{Action: ResumeActionApprove, HasUpdatedInput: true}, + wantAction: ResumeActionApprove, + wantHasUpdated: true, + wantToolInput: "", + wantUpdatedInput: "", + }, + { + name: "reject with default text", + response: &ResumeResponse{Action: ResumeActionReject}, + wantAction: ResumeActionReject, + wantDecisionText: "rejected by user", + wantToolNotInvoked: true, + }, + { + name: "respond with decision text", + response: &ResumeResponse{ + Action: ResumeActionRespond, + Message: "Please explain first.", + }, + wantAction: ResumeActionRespond, + wantDecisionText: "Please explain first.", + wantToolNotInvoked: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + captureTool := &permissionCaptureTool{name: "permission_tool"} + info, err := captureTool.Info(ctx) + require.NoError(t, err) + + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return schema.AssistantMessage("calling tool", []schema.ToolCall{ + {ID: "permission_call", Function: schema.FunctionCall{Name: info.Name, Arguments: `{"path":"/etc/passwd"}`}}, + }), nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "PermissionDecisionAgent", + Instruction: "use tools", + Model: cm, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + }, + }, + Handlers: []adk.ChatModelAgentMiddleware{ + New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*GateCheckResult, error) { + return &GateCheckResult{Decision: GateAsk, Message: `Approve permission_call with {"path":"/etc/passwd"}?`}, nil + }), + }, + }) + require.NoError(t, err) + + sessionStore := &permissionSessionService{} + checkpointStore := newPermissionCheckpointStore() + checkpointID := "permission-decision-" + strings.ReplaceAll(tt.name, " ", "-") + runner := adk.NewRunner(ctx, adk.RunnerConfig{ + Agent: agent, + CheckPointStore: checkpointStore, + SessionID: checkpointID, + SessionService: adk.NewLocalSessionService[*schema.Message](sessionStore), + }) + + var interruptID string + iter := runner.Query(ctx, "use the tool", adk.WithCheckPointID(checkpointID), adk.WithTimelineEvents()) + for { + event, ok := iter.Next() + if !ok { + break + } + require.NoError(t, event.Err) + if event.SessionEvent == nil || event.SessionEvent.Kind != adk.SessionEventAgentInterrupt { + continue + } + require.NotNil(t, event.SessionEvent.AgentInterrupt) + require.Len(t, event.SessionEvent.AgentInterrupt.Contexts, 1) + interruptID = event.SessionEvent.AgentInterrupt.Contexts[0].InterruptID + } + require.NotEmpty(t, interruptID) + + resumeIter, err := runner.ResumeWithParams(ctx, checkpointID, &adk.ResumeParams{ + Targets: map[string]any{interruptID: tt.response}, + }, adk.WithTimelineEvents()) + require.NoError(t, err) + + var liveDecision *adk.SessionEvent[*schema.Message] + for { + event, ok := resumeIter.Next() + if !ok { + break + } + require.NoError(t, event.Err) + if event.SessionEvent != nil && event.SessionEvent.Kind == SessionEventPermissionDecision { + liveDecision = event.SessionEvent + } + } + requireDecisionEvent(t, liveDecision, tt.wantAction, tt.wantDecisionText, tt.wantUpdatedInput, tt.wantHasUpdated) + + decisions := filterPermissionDecisionEvents(sessionStore.events) + require.Len(t, decisions, 1) + requireDecisionEvent(t, decisions[0], tt.wantAction, tt.wantDecisionText, tt.wantUpdatedInput, tt.wantHasUpdated) + assert.Equal(t, liveDecision.EventID, decisions[0].EventID) + assert.Equal(t, liveDecision.TurnID, decisions[0].TurnID) + + decisionJSON, err := json.Marshal(decisions[0].Extension.Data) + require.NoError(t, err) + assert.NotContains(t, string(decisionJSON), `{"path":"/etc/passwd"}`) + assert.NotContains(t, string(decisionJSON), "Arguments") + assert.NotContains(t, string(decisionJSON), "CallID") + + decisionIndex, idleAfterDecisionIndex := -1, -1 + for i, event := range sessionStore.events { + if event.Kind == SessionEventPermissionDecision { + decisionIndex = i + } + if decisionIndex >= 0 && i > decisionIndex && event.Kind == adk.SessionEventSessionStatusIdle { + idleAfterDecisionIndex = i + break + } + } + require.NotEqual(t, -1, decisionIndex) + require.NotEqual(t, -1, idleAfterDecisionIndex) + assert.Less(t, decisionIndex, idleAfterDecisionIndex) + + if tt.wantToolNotInvoked { + assert.Empty(t, captureTool.received) + } else { + assert.Equal(t, tt.wantToolInput, captureTool.received) + } + }) + } +} + +func TestAttack_InvalidRespondDoesNotPersistDecisionEvent(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + captureTool := &permissionCaptureTool{name: "permission_tool"} + info, err := captureTool.Info(ctx) + require.NoError(t, err) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling tool", []schema.ToolCall{ + {ID: "permission_call", Function: schema.FunctionCall{Name: info.Name, Arguments: `{"path":"/etc/passwd"}`}}, + }), nil).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "PermissionInvalidRespondAgent", + Instruction: "use tools", + Model: cm, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + }, + }, + Handlers: []adk.ChatModelAgentMiddleware{ + New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*GateCheckResult, error) { + return &GateCheckResult{Decision: GateAsk, Message: "approve?"}, nil + }), + }, + }) + require.NoError(t, err) + + sessionStore := &permissionSessionService{} + checkpointStore := newPermissionCheckpointStore() + const checkpointID = "permission-invalid-respond" + runner := adk.NewRunner(ctx, adk.RunnerConfig{ + Agent: agent, + CheckPointStore: checkpointStore, + SessionID: checkpointID, + SessionService: adk.NewLocalSessionService[*schema.Message](sessionStore), + }) + + var interruptID string + iter := runner.Query(ctx, "use the tool", adk.WithCheckPointID(checkpointID), adk.WithTimelineEvents()) + for { + event, ok := iter.Next() + if !ok { + break + } + require.NoError(t, event.Err) + if event.SessionEvent != nil && event.SessionEvent.Kind == adk.SessionEventAgentInterrupt { + require.NotNil(t, event.SessionEvent.AgentInterrupt) + require.Len(t, event.SessionEvent.AgentInterrupt.Contexts, 1) + interruptID = event.SessionEvent.AgentInterrupt.Contexts[0].InterruptID + } + } + require.NotEmpty(t, interruptID) + + resumeIter, err := runner.ResumeWithParams(ctx, checkpointID, &adk.ResumeParams{ + Targets: map[string]any{interruptID: &ResumeResponse{Action: ResumeActionRespond}}, + }, adk.WithTimelineEvents()) + require.NoError(t, err) + + var resumeErr error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErr = event.Err + continue + } + if event.SessionEvent != nil { + assert.NotEqual(t, SessionEventPermissionDecision, event.SessionEvent.Kind) + } + } + require.Error(t, resumeErr) + assert.Contains(t, resumeErr.Error(), "empty response message") + assert.Empty(t, filterPermissionDecisionEvents(sessionStore.events)) + assert.Empty(t, captureTool.received) +} + // TestToolSpan_PermissionDenyEmitsBothSpansOnSameRun verifies plan ยง4.5.1 #6: // when the permission gate denies on first invocation (no interrupt), the // tool wrapper emits a tool_call_start + tool_call_end pair on the SAME run. @@ -926,6 +1181,60 @@ func (s *permissionSessionService) LoadEvents(_ context.Context, _ *adk.LoadSess return &adk.LoadSessionEventsResult[*schema.Message]{Events: nil}, nil } +type permissionCheckpointStore struct { + data map[string][]byte +} + +func newPermissionCheckpointStore() *permissionCheckpointStore { + return &permissionCheckpointStore{data: make(map[string][]byte)} +} + +func (s *permissionCheckpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { + data, ok := s.data[key] + if !ok { + return nil, false, nil + } + return append([]byte(nil), data...), true, nil +} + +func (s *permissionCheckpointStore) Set(_ context.Context, key string, data []byte) error { + s.data[key] = append([]byte(nil), data...) + return nil +} + +func filterPermissionDecisionEvents(events []*adk.SessionEvent[*schema.Message]) []*adk.SessionEvent[*schema.Message] { + var decisions []*adk.SessionEvent[*schema.Message] + for _, event := range events { + if event.Kind == SessionEventPermissionDecision { + decisions = append(decisions, event) + } + } + return decisions +} + +func requireDecisionEvent( + t *testing.T, + event *adk.SessionEvent[*schema.Message], + action ResumeAction, + decisionText string, + updatedInput string, + hasUpdatedInput bool, +) { + t.Helper() + require.NotNil(t, event) + require.NotEmpty(t, event.EventID) + require.NotEmpty(t, event.TurnID) + require.NotNil(t, event.Extension) + payload, ok := event.Extension.Data.(*DecisionEvent) + require.True(t, ok) + assert.Equal(t, action, payload.Action) + assert.Equal(t, "permission_tool", payload.ToolName) + assert.Equal(t, "permission_call", payload.ToolUseID) + assert.Equal(t, decisionText, payload.DecisionText) + assert.Equal(t, updatedInput, payload.UpdatedInput) + assert.Equal(t, hasUpdatedInput, payload.HasUpdatedInput) +} + func requireAskInfo(t *testing.T, err error) *AskInfo { t.Helper() var signal *core.InterruptSignal diff --git a/adk/middlewares/summarization/summarization_test.go b/adk/middlewares/summarization/summarization_test.go index d70f396b0..11ae84d14 100644 --- a/adk/middlewares/summarization/summarization_test.go +++ b/adk/middlewares/summarization/summarization_test.go @@ -1425,11 +1425,10 @@ func TestPostProcessSummary(t *testing.T) { func TestEventHelpers(t *testing.T) { ctx := context.Background() - t.Run("emitEvent returns wrapped error outside execution context", func(t *testing.T) { + t.Run("emitEvent is no-op outside execution context", func(t *testing.T) { mw := &TypedMiddleware[*schema.Message]{cfg: &Config{}} err := mw.emitEvent(ctx, &CustomizedAction{Type: ActionTypeBeforeSummarize}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to send internal event") + assert.NoError(t, err) }) t.Run("emitGenerateSummaryEvent is skipped when internal events are disabled", func(t *testing.T) { @@ -1438,11 +1437,10 @@ func TestEventHelpers(t *testing.T) { assert.NoError(t, err) }) - t.Run("emitGenerateSummaryEvent returns wrapped error when enabled outside execution context", func(t *testing.T) { + t.Run("emitGenerateSummaryEvent is no-op when enabled outside execution context", func(t *testing.T) { mw := &TypedMiddleware[*schema.Message]{cfg: &Config{EmitInternalEvents: true}} err := mw.emitGenerateSummaryEvent(ctx, 1, GenerateSummaryPhasePrimary, schema.AssistantMessage("ok", nil), nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to send internal event") + assert.NoError(t, err) }) } @@ -1937,7 +1935,7 @@ func TestSummarizationGeneric(t *testing.T) { }) } -func TestEmitInternalEvents_AgenticMessage_RequiresExecContext(t *testing.T) { +func TestEmitInternalEvents_AgenticMessage_NoopOutsideExecContext(t *testing.T) { ctx := context.Background() longContent := strings.Repeat("x", 800000) @@ -1967,9 +1965,12 @@ func TestEmitInternalEvents_AgenticMessage_RequiresExecContext(t *testing.T) { require.NoError(t, err) state := &adk.TypedChatModelAgentState[*schema.AgenticMessage]{Messages: msgs} - _, _, err = mw.BeforeModelRewriteState(ctx, state, nil) - assert.Error(t, err, "should error without exec context when EmitInternalEvents is true") - assert.Contains(t, err.Error(), "send internal event") + _, gotState, err := mw.BeforeModelRewriteState(ctx, state, nil) + require.NoError(t, err) + require.NotNil(t, gotState) + require.Len(t, gotState.Messages, 2) + assert.Equal(t, schema.AgenticRoleTypeSystem, gotState.Messages[0].Role) + assert.Equal(t, schema.AgenticRoleTypeUser, gotState.Messages[1].Role) } func testSummarizationHelpers[M adk.MessageType](t *testing.T) { diff --git a/adk/runctx.go b/adk/runctx.go index 8affe4432..9cdd24efa 100644 --- a/adk/runctx.go +++ b/adk/runctx.go @@ -377,6 +377,131 @@ func (rc *runContext) deepCopy() *runContext { return copied } +func sanitizeRunContextForSessionCheckpoint[M MessageType](rc *runContext) *runContext { + if rc == nil { + return nil + } + copied := &runContext{ + RootInput: rc.RootInput, + AgenticRootInput: rc.AgenticRootInput, + RunPath: append([]RunStep(nil), rc.RunPath...), + Session: sanitizeRunSessionForSessionCheckpoint[M](rc.Session), + } + return copied +} + +func sanitizeRunSessionForSessionCheckpoint[M MessageType](rs *runSession) *runSession { + if rs == nil { + return nil + } + + copied := &runSession{ + Values: make(map[string]any), + valuesMtx: &sync.Mutex{}, + } + + if rs.valuesMtx != nil { + rs.valuesMtx.Lock() + for k, v := range rs.Values { + copied.Values[k] = v + } + rs.valuesMtx.Unlock() + } else { + for k, v := range rs.Values { + copied.Values[k] = v + } + } + + var events []*agentEventWrapper + var typedEvents any + rs.mtx.Lock() + events = append(events, rs.Events...) + typedEvents = rs.TypedEvents + rs.mtx.Unlock() + + for _, event := range events { + if sanitized := sanitizeAgentEventWrapperForSessionCheckpoint(event); sanitized != nil { + copied.Events = append(copied.Events, sanitized) + } + } + copied.LaneEvents = sanitizeLaneEventsForSessionCheckpoint(rs.LaneEvents) + + if store, ok := typedEvents.(*[]*typedAgentEventWrapper[M]); ok { + if store == nil { + copied.TypedEvents = store + return copied + } + sanitized := make([]*typedAgentEventWrapper[M], 0, len(*store)) + for _, event := range *store { + if copiedEvent := sanitizeTypedAgentEventWrapperForSessionCheckpoint(event); copiedEvent != nil { + sanitized = append(sanitized, copiedEvent) + } + } + copied.TypedEvents = &sanitized + } else { + copied.TypedEvents = typedEvents + } + + return copied +} + +func sanitizeLaneEventsForSessionCheckpoint(le *laneEvents) *laneEvents { + if le == nil { + return nil + } + copied := &laneEvents{ + Parent: sanitizeLaneEventsForSessionCheckpoint(le.Parent), + } + for _, event := range le.Events { + if sanitized := sanitizeAgentEventWrapperForSessionCheckpoint(event); sanitized != nil { + copied.Events = append(copied.Events, sanitized) + } + } + return copied +} + +func sanitizeAgentEventWrapperForSessionCheckpoint(w *agentEventWrapper) *agentEventWrapper { + if w == nil || w.AgentEvent == nil { + return nil + } + + event := *w.AgentEvent + event.RunPath = append([]RunStep(nil), w.AgentEvent.RunPath...) + event.SessionEvent = nil + if event.Output == nil && event.Action == nil && event.Err == nil { + return nil + } + + return &agentEventWrapper{ + AgentEvent: &event, + concatenatedMessage: w.concatenatedMessage, + TS: w.TS, + StreamErr: w.StreamErr, + } +} + +func sanitizeTypedAgentEventWrapperForSessionCheckpoint[M MessageType]( + w *typedAgentEventWrapper[M], +) *typedAgentEventWrapper[M] { + if w == nil || w.event == nil { + return nil + } + + event := *w.event + event.RunPath = append([]RunStep(nil), w.event.RunPath...) + event.SessionEvent = nil + if event.Output == nil && event.Action == nil && event.Err == nil { + return nil + } + + return &typedAgentEventWrapper[M]{ + event: &event, + concatenatedMessage: w.concatenatedMessage, + TS: w.TS, + StreamErr: w.StreamErr, + } +} + type runCtxKey struct{} func getRunCtx(ctx context.Context) *runContext { diff --git a/adk/runctx_test.go b/adk/runctx_test.go index bef1f44eb..292e86f2b 100644 --- a/adk/runctx_test.go +++ b/adk/runctx_test.go @@ -21,10 +21,12 @@ import ( "context" "encoding/gob" "errors" + "sync" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/schema" ) @@ -632,3 +634,219 @@ func TestGobEncodeStreamErrors(t *testing.T) { assert.NoError(t, err, "encoding runSession with WillRetryError stream should succeed") }) } + +func TestSanitizeRunContextForSessionCheckpointStripsSessionEvents(t *testing.T) { + now := time.Now().UTC() + output := &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage("kept", nil), + Role: schema.Assistant, + }, + } + kept := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "event-output", + Timestamp: now, + AgentName: "agent", + RunPath: []RunStep{{agentName: "root"}}, + Output: output, + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "event-output", + Kind: SessionEventMessage, + Message: schema.AssistantMessage("kept", nil), + }, + }, + TS: 10, + } + dropped := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "event-session-only", + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "event-session-only", + Kind: SessionEventSessionStatusRunning, + }, + }, + TS: 11, + } + interrupt := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "event-interrupt", + Action: &AgentAction{Interrupted: &InterruptInfo{Data: "pause"}}, + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "event-interrupt", + Kind: SessionEventAgentInterrupt, + }, + }, + TS: 12, + } + session := newRunSession() + session.Values["k"] = "v" + session.Events = []*agentEventWrapper{kept, dropped, interrupt} + rc := &runContext{ + RootInput: &AgentInput{Messages: []*schema.Message{schema.UserMessage("q")}}, + RunPath: []RunStep{{agentName: "root"}}, + Session: session, + } + + sanitized := sanitizeRunContextForSessionCheckpoint[*schema.Message](rc) + + require.NotNil(t, sanitized) + require.NotSame(t, rc, sanitized) + require.NotSame(t, session, sanitized.Session) + require.Len(t, sanitized.Session.Events, 2) + assert.Equal(t, "event-output", sanitized.Session.Events[0].EventID) + assert.Nil(t, sanitized.Session.Events[0].SessionEvent) + assert.Same(t, output, sanitized.Session.Events[0].Output) + assert.Equal(t, "event-interrupt", sanitized.Session.Events[1].EventID) + assert.NotNil(t, sanitized.Session.Events[1].Action.Interrupted) + assert.Nil(t, sanitized.Session.Events[1].SessionEvent) + assert.Equal(t, map[string]any{"k": "v"}, sanitized.Session.Values) + + assert.NotNil(t, kept.SessionEvent, "sanitizer must not mutate the original output event") + assert.NotNil(t, dropped.SessionEvent, "sanitizer must not mutate the original timeline event") + assert.NotNil(t, interrupt.SessionEvent, "sanitizer must not mutate the original interrupt event") +} + +func TestSanitizeRunContextForSessionCheckpointTypedEvents(t *testing.T) { + output := &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: schema.UserAgenticMessage("kept"), + AgenticRole: schema.AgenticRoleTypeUser, + }, + } + events := []*typedAgentEventWrapper[*schema.AgenticMessage]{ + { + event: &TypedAgentEvent[*schema.AgenticMessage]{ + EventID: "typed-output", + Output: output, + SessionEvent: &SessionEvent[*schema.AgenticMessage]{ + EventID: "typed-output", + Kind: SessionEventMessage, + Message: schema.UserAgenticMessage("kept"), + }, + }, + TS: 20, + }, + { + event: &TypedAgentEvent[*schema.AgenticMessage]{ + EventID: "typed-session-only", + SessionEvent: &SessionEvent[*schema.AgenticMessage]{ + EventID: "typed-session-only", + Kind: SessionEventSessionStatusRunning, + }, + }, + TS: 21, + }, + } + session := newRunSession() + session.TypedEvents = &events + rc := &runContext{Session: session} + + sanitized := sanitizeRunContextForSessionCheckpoint[*schema.AgenticMessage](rc) + + store, ok := sanitized.Session.TypedEvents.(*[]*typedAgentEventWrapper[*schema.AgenticMessage]) + require.True(t, ok) + require.Len(t, *store, 1) + assert.Equal(t, "typed-output", (*store)[0].event.EventID) + assert.Nil(t, (*store)[0].event.SessionEvent) + assert.Same(t, output, (*store)[0].event.Output) + assert.NotNil(t, events[0].event.SessionEvent, "sanitizer must not mutate the original typed event") + assert.NotNil(t, events[1].event.SessionEvent, "sanitizer must not mutate the original typed timeline event") +} + +func TestSanitizeRunContextForSessionCheckpointReducesEncodedPayload(t *testing.T) { + sessionEvent := &SessionEvent[*schema.Message]{ + EventID: "large-session-event", + Kind: SessionEventMessage, + Message: schema.AssistantMessage("large duplicated durable session payload", nil), + } + rc := &runContext{Session: newRunSession()} + rc.Session.Events = []*agentEventWrapper{ + { + AgentEvent: &AgentEvent{ + EventID: "large-session-event", + SessionEvent: sessionEvent, + }, + }, + { + AgentEvent: &AgentEvent{ + EventID: "mixed-event", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage("kept output", nil), + Role: schema.Assistant, + }, + }, + SessionEvent: sessionEvent, + }, + }, + } + + unsanitized, err := encodeRunnerCheckPointWithRunCtx(false, rc, nil, nil) + require.NoError(t, err) + sanitized, err := encodeRunnerCheckPointWithRunCtx( + false, + sanitizeRunContextForSessionCheckpoint[*schema.Message](rc), + nil, + nil, + ) + require.NoError(t, err) + assert.Less(t, len(sanitized), len(unsanitized)) + + _, decoded, _, err := runnerLoadCheckPointBytes(context.Background(), sanitized) + require.NoError(t, err) + require.Len(t, decoded.Session.Events, 1) + assert.Nil(t, decoded.Session.Events[0].SessionEvent) + assert.NotNil(t, decoded.Session.Events[0].Output) +} + +func TestSanitizeRunContextForSessionCheckpointPreservesLaneChain(t *testing.T) { + parentTimelineOnly := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "parent-session-only", + SessionEvent: &SessionEvent[*schema.Message]{Kind: SessionEventSessionStatusRunning}, + }, + } + parentOutput := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "parent-output", + Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("parent", nil)}}, + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "parent-output", + Kind: SessionEventMessage, + }, + }, + } + childTimelineOnly := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "child-session-only", + SessionEvent: &SessionEvent[*schema.Message]{Kind: SessionEventSessionStatusIdle}, + }, + } + childOutput := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + EventID: "child-output", + Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("child", nil)}}, + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "child-output", + Kind: SessionEventMessage, + }, + }, + } + parent := &laneEvents{Events: []*agentEventWrapper{parentTimelineOnly, parentOutput}} + child := &laneEvents{Events: []*agentEventWrapper{childTimelineOnly, childOutput}, Parent: parent} + rc := &runContext{Session: &runSession{LaneEvents: child, valuesMtx: &sync.Mutex{}}} + + sanitized := sanitizeRunContextForSessionCheckpoint[*schema.Message](rc) + + require.NotNil(t, sanitized.Session.LaneEvents) + require.NotNil(t, sanitized.Session.LaneEvents.Parent) + require.Len(t, sanitized.Session.LaneEvents.Events, 1) + require.Len(t, sanitized.Session.LaneEvents.Parent.Events, 1) + assert.Equal(t, "child-output", sanitized.Session.LaneEvents.Events[0].EventID) + assert.Nil(t, sanitized.Session.LaneEvents.Events[0].SessionEvent) + assert.Equal(t, "parent-output", sanitized.Session.LaneEvents.Parent.Events[0].EventID) + assert.Nil(t, sanitized.Session.LaneEvents.Parent.Events[0].SessionEvent) + assert.NotNil(t, childTimelineOnly.SessionEvent, "sanitizer must not mutate original child lane") + assert.NotNil(t, parentTimelineOnly.SessionEvent, "sanitizer must not mutate original parent lane") +} diff --git a/adk/runner.go b/adk/runner.go index 2adcd2587..a8f136f66 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -569,7 +569,12 @@ func saveRunnerCheckpoint[M MessageType]( //nolint:revive // argument-limit if isNilCheckPointStore(store) { return nil } - payload, err := encodeRunnerCheckPointImpl(enableStreaming, ctx, info, is) + payload, err := encodeRunnerCheckPointWithRunCtx( + enableStreaming, + sanitizeRunContextForSessionCheckpoint[M](getRunCtx(ctx)), + info, + is, + ) if err != nil { return err } @@ -638,7 +643,6 @@ func typedRunnerRunImpl[M MessageType](a TypedAgent[M], enableStreaming bool, st } concreteInput := any(input).(*AgentInput) ctx = ctxWithNewTypedRunCtx(ctx, input, o.sharedParentSession) - ctx = contextWithToolPermissionDecisionStore(ctx) AddSessionValues(ctx, o.sessionValues) iter := fa.Run(ctx, concreteInput, opts...) @@ -665,7 +669,6 @@ func typedRunnerRunImpl[M MessageType](a TypedAgent[M], enableStreaming bool, st } ctx = ctxWithNewTypedRunCtx(ctx, input, o.sharedParentSession) - ctx = contextWithToolPermissionDecisionStore(ctx) AddSessionValues(ctx, o.sessionValues) iter := fa.Run(ctx, input, opts...) @@ -733,7 +736,6 @@ func typedRunnerResumeInternalImpl[M MessageType](a TypedAgent[M], store CheckPo } ctx = setRunCtx(ctx, runCtx) - ctx = contextWithToolPermissionDecisionStore(ctx) AddSessionValues(ctx, o.sessionValues) if sessionState.enabled { diff --git a/adk/session.go b/adk/session.go index 47a97edd1..fe938102d 100644 --- a/adk/session.go +++ b/adk/session.go @@ -20,7 +20,6 @@ import ( "bytes" "context" "encoding/gob" - "encoding/json" "errors" "fmt" "strings" @@ -487,10 +486,13 @@ type AgentInterruptContext struct { // SessionExtensionEvent carries application-owned timeline event payloads. // The SessionEvent.Kind field is the application event type and must use the -// SessionEventExtensionPrefix namespace. Data is raw JSON and is not -// schema-decoded by ADK. +// SessionEventExtensionPrefix namespace. Data is application-owned typed payload +// data. Custom payload types that need durable round-trip behavior must be +// registered with schema.RegisterName before session events are encoded and +// decoded. Consumers can inspect SessionEvent.Kind and type-assert Data to the +// registered concrete payload type. type SessionExtensionEvent struct { - Data json.RawMessage `json:"data,omitempty"` + Data any `json:"data,omitempty"` } // MessageUpdatedEvent represents a single message replacement within the messages array. @@ -924,28 +926,6 @@ func NormalizeSessionEventKind[M MessageType](event *SessionEvent[M]) error { return fmt.Errorf("session event kind %q does not match payload %q", event.Kind, kind) } event.Kind = kind - if err := normalizeSessionExtensionEvent(event.Extension); err != nil { - return err - } - return nil -} - -func normalizeSessionExtensionEvent(event *SessionExtensionEvent) error { - if event == nil { - return nil - } - if len(event.Data) == 0 { - event.Data = nil - return nil - } - if !json.Valid(event.Data) { - return errors.New("session extension event data must be valid JSON") - } - var compact bytes.Buffer - if err := json.Compact(&compact, event.Data); err != nil { - return err - } - event.Data = append(event.Data[:0], compact.Bytes()...) return nil } diff --git a/adk/session/conformance.go b/adk/session/conformance.go index 6056ef9f8..3c9a4d188 100644 --- a/adk/session/conformance.go +++ b/adk/session/conformance.go @@ -29,6 +29,14 @@ import ( "github.com/cloudwego/eino/schema" ) +type conformanceExtensionPayload struct { + OK bool `json:"ok"` +} + +func init() { + schema.RegisterName[*conformanceExtensionPayload]("_eino_adk_session_conformance_extension_payload") +} + // RunConformanceTests validates the SessionEventStore contract shared by // provider-facing session persistence implementations. // @@ -502,7 +510,7 @@ func extensionEvent[M adk.MessageType](id, kind string) *adk.SessionEvent[M] { EventID: id, Kind: adk.SessionEventKind(kind), Extension: &adk.SessionExtensionEvent{ - Data: []byte(`{"ok":true}`), + Data: &conformanceExtensionPayload{OK: true}, }, } } diff --git a/adk/session_test.go b/adk/session_test.go index 7d9bcb69f..7540297cb 100644 --- a/adk/session_test.go +++ b/adk/session_test.go @@ -1375,6 +1375,52 @@ func (a *runnerInterruptAgent) Resume(ctx context.Context, info *ResumeInfo, _ . return iter } +type runnerCheckpointSanitizeAgent struct{} + +func (a *runnerCheckpointSanitizeAgent) Name(_ context.Context) string { + return "CheckpointSanitizeAgent" +} + +func (a *runnerCheckpointSanitizeAgent) Description(_ context.Context) string { + return "session checkpoint sanitizer test agent" +} + +func (a *runnerCheckpointSanitizeAgent) Run(ctx context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + gen.Send(&AgentEvent{ + EventID: "checkpoint-session-only", + AgentName: "CheckpointSanitizeAgent", + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "checkpoint-session-only", + Kind: SessionEventSessionStatusRunning, + Lifecycle: &LifecycleEvent{ + Scope: LifecycleScopeSession, + State: SessionRunStateRunning, + }, + }, + }) + gen.Send(&AgentEvent{ + EventID: "checkpoint-output", + AgentName: "CheckpointSanitizeAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage("mixed output", nil), + Role: schema.Assistant, + }, + }, + SessionEvent: &SessionEvent[*schema.Message]{ + EventID: "checkpoint-output", + Kind: SessionEventMessage, + Message: schema.AssistantMessage("mixed output", nil), + }, + }) + gen.Send(Interrupt(ctx, "confirm?")) + }() + return iter +} + func TestRunnerSessionModeResumeWithEmptyCheckpointID(t *testing.T) { ctx := context.Background() store := newSessionHelperStore() @@ -2795,6 +2841,80 @@ func TestRunnerSessionInterruptCheckpointTailIsFinalIdle(t *testing.T) { store.sessionHelperStore.mu.Unlock() assert.Equal(t, SessionEventSessionStatusIdle, tail.Kind) assert.Equal(t, tail.EventID, cp.SessionTailEventID) + + _, runCtx, _, err := runnerLoadCheckPointBytes(ctx, cp.Payload) + require.NoError(t, err) + require.NotNil(t, runCtx) + require.NotNil(t, runCtx.Session) + for _, event := range runCtx.Session.Events { + require.NotNil(t, event.AgentEvent) + assert.Nil(t, event.SessionEvent) + } +} + +func TestRunnerSessionCheckpointPayloadStripsSessionEvents(t *testing.T) { + ctx := context.Background() + store := newRecordingHelperStore() + sid := "checkpoint-strip-session-events" + + runner := NewRunner(ctx, RunnerConfig{ + Agent: &runnerCheckpointSanitizeAgent{}, + CheckPointStore: store, + SessionID: sid, + SessionService: store, + }) + iter := runner.Query(ctx, "hi", WithTimelineEvents()) + var liveSessionEventIDs []string + for { + event, ok := iter.Next() + if !ok { + break + } + require.NoError(t, event.Err) + if event.SessionEvent != nil { + liveSessionEventIDs = append(liveSessionEventIDs, event.SessionEvent.EventID) + } + } + assert.Contains(t, liveSessionEventIDs, "checkpoint-session-only") + assert.Contains(t, liveSessionEventIDs, "checkpoint-output") + + cpKey := sessionRunnerCheckpointID(sid) + raw, ok := store.checkpoints[cpKey] + require.True(t, ok, "expected interrupt checkpoint to be saved") + cp, err := decodeRunnerSessionCheckpoint(raw) + require.NoError(t, err) + assert.NotEmpty(t, cp.SessionTailEventID) + + _, runCtx, _, err := runnerLoadCheckPointBytes(ctx, cp.Payload) + require.NoError(t, err) + require.NotNil(t, runCtx) + require.NotNil(t, runCtx.Session) + + var checkpointEventIDs []string + var foundOutput bool + for _, event := range runCtx.Session.Events { + require.NotNil(t, event.AgentEvent) + assert.Nil(t, event.SessionEvent) + assert.True(t, event.Output != nil || event.Action != nil || event.Err != nil) + checkpointEventIDs = append(checkpointEventIDs, event.EventID) + if event.Output != nil && + event.Output.MessageOutput != nil && + event.Output.MessageOutput.Message != nil && + event.Output.MessageOutput.Message.Content == "mixed output" { + foundOutput = true + } + } + assert.NotContains(t, checkpointEventIDs, "checkpoint-session-only") + assert.True(t, foundOutput) + + var persistedKinds []SessionEventKind + store.sessionHelperStore.mu.Lock() + for _, event := range store.events { + persistedKinds = append(persistedKinds, event.Kind) + } + store.sessionHelperStore.mu.Unlock() + assert.Contains(t, persistedKinds, SessionEventSessionStatusRunning) + assert.Contains(t, persistedKinds, SessionEventMessage) } func TestRunnerSessionAgentInterruptBoundaryFailureNotExposed(t *testing.T) { diff --git a/adk/session_timeline_test.go b/adk/session_timeline_test.go index 27369616f..3b23be811 100644 --- a/adk/session_timeline_test.go +++ b/adk/session_timeline_test.go @@ -34,6 +34,15 @@ import ( "github.com/cloudwego/eino/schema" ) +type sessionTimelineExtensionPayload struct { + OutcomeName string `json:"outcome_name,omitempty"` + Attempt int `json:"attempt,omitempty"` +} + +func init() { + schema.RegisterName[*sessionTimelineExtensionPayload]("_eino_adk_session_timeline_extension_payload") +} + func requireStoredIdleStopReason(t *testing.T, raw []storedSessionEvent, want string) *SessionEvent[*schema.Message] { t.Helper() idleEvents := filterStoredSessionEvents(t, raw, func(se *SessionEvent[*schema.Message]) bool { @@ -114,7 +123,7 @@ func TestSessionTimeline_ClassifyAndSerializeVariants(t *testing.T) { name: "extension", se: &SessionEvent[*schema.Message]{ Kind: SessionEventKind("x.outcome.started"), - Extension: &SessionExtensionEvent{Data: []byte(`{"outcome_name":"code_review"}`)}, + Extension: &SessionExtensionEvent{Data: &sessionTimelineExtensionPayload{OutcomeName: "code_review"}}, }, kind: SessionEventKind("x.outcome.started"), }, @@ -133,7 +142,9 @@ func TestSessionTimeline_ClassifyAndSerializeVariants(t *testing.T) { assert.Equal(t, tc.kind, decoded.Kind) if tc.se.Extension != nil { require.NotNil(t, decoded.Extension) - assert.Equal(t, []byte(`{"outcome_name":"code_review"}`), []byte(decoded.Extension.Data)) + payload, ok := decoded.Extension.Data.(*sessionTimelineExtensionPayload) + require.True(t, ok) + assert.Equal(t, "code_review", payload.OutcomeName) } }) } @@ -185,41 +196,12 @@ func TestSessionTimeline_ExtensionValidation(t *testing.T) { assert.Contains(t, err.Error(), "exactly one active payload") }) - t.Run("invalid data rejected", func(t *testing.T) { - err := NormalizeSessionEventKind(&SessionEvent[*schema.Message]{ - Kind: SessionEventKind("x.outcome.started"), - Extension: &SessionExtensionEvent{Data: []byte(`{"broken"`)}, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "must be valid JSON") - }) - - t.Run("zero length data becomes marker", func(t *testing.T) { - se := &SessionEvent[*schema.Message]{ - Kind: SessionEventKind("x.outcome.started"), - Extension: &SessionExtensionEvent{Data: []byte{}}, - } - require.NoError(t, NormalizeSessionEventKind(se)) - assert.Nil(t, se.Extension.Data) - }) - - t.Run("pretty data is compacted", func(t *testing.T) { - se := &SessionEvent[*schema.Message]{ - Kind: SessionEventKind("x.outcome.started"), - Extension: &SessionExtensionEvent{ - Data: []byte("{\n \"outcome_name\": \"code_review\",\n \"attempt\": 1\n}"), - }, - } - require.NoError(t, NormalizeSessionEventKind(se)) - assert.Equal(t, []byte(`{"outcome_name":"code_review","attempt":1}`), []byte(se.Extension.Data)) - }) - - t.Run("human readable round trip", func(t *testing.T) { + t.Run("human readable typed round trip", func(t *testing.T) { se := &SessionEvent[*schema.Message]{ EventID: uuid.NewString(), Timestamp: time.Now().UTC(), Kind: SessionEventKind("x.outcome.grading"), - Extension: &SessionExtensionEvent{Data: []byte(`{"attempt":1}`)}, + Extension: &SessionExtensionEvent{Data: &sessionTimelineExtensionPayload{Attempt: 1}}, } require.NoError(t, NormalizeSessionEventKind(se)) data, err := encodeSessionEvent(se) @@ -228,7 +210,9 @@ func TestSessionTimeline_ExtensionValidation(t *testing.T) { require.NoError(t, err) require.NotNil(t, decoded.Extension) assert.Equal(t, se.Kind, decoded.Kind) - assert.Equal(t, []byte(`{"attempt":1}`), []byte(decoded.Extension.Data)) + payload, ok := decoded.Extension.Data.(*sessionTimelineExtensionPayload) + require.True(t, ok) + assert.Equal(t, 1, payload.Attempt) }) } @@ -243,7 +227,7 @@ func TestSessionTimeline_ReconstructionIgnoresNonContextVariants(t *testing.T) { {EventID: uuid.NewString(), Kind: SessionEventSessionStatusRunning, Lifecycle: &LifecycleEvent{Scope: LifecycleScopeSession, State: SessionRunStateRunning}}, {EventID: uuid.NewString(), Kind: SessionEventMessage, Message: msg}, {EventID: uuid.NewString(), Kind: SessionEventSpanModelRequestStart, Span: &SpanEvent{SpanID: uuid.NewString(), Kind: SpanKindModel, StartedAt: time.Now().UTC(), Model: &ModelSpanMeta{}}}, - {EventID: uuid.NewString(), Kind: SessionEventKind("x.outcome.started"), Extension: &SessionExtensionEvent{Data: []byte(`{"attempt":1}`)}}, + {EventID: uuid.NewString(), Kind: SessionEventKind("x.outcome.started"), Extension: &SessionExtensionEvent{Data: &sessionTimelineExtensionPayload{Attempt: 1}}}, {EventID: uuid.NewString(), Kind: SessionEventAgentInterrupt, AgentInterrupt: &AgentInterruptEvent{ Contexts: []*AgentInterruptContext{ { @@ -563,7 +547,10 @@ func TestRunner_ExtensionEventSentWithTypedSendEventIsLiveAndPersisted(t *testin SessionEvent: &SessionEvent[*schema.Message]{ Kind: extensionKind, Extension: &SessionExtensionEvent{ - Data: []byte("{\n \"outcome_name\": \"code_review\",\n \"attempt\": 1\n}"), + Data: &sessionTimelineExtensionPayload{ + OutcomeName: "code_review", + Attempt: 1, + }, }, }, }) @@ -598,7 +585,10 @@ func TestRunner_ExtensionEventSentWithTypedSendEventIsLiveAndPersisted(t *testin require.NotEmpty(t, liveExtension.EventID) require.NotEmpty(t, liveExtension.TurnID) require.NotNil(t, liveExtension.Extension) - assert.Equal(t, []byte(`{"outcome_name":"code_review","attempt":1}`), []byte(liveExtension.Extension.Data)) + livePayload, ok := liveExtension.Extension.Data.(*sessionTimelineExtensionPayload) + require.True(t, ok) + assert.Equal(t, "code_review", livePayload.OutcomeName) + assert.Equal(t, 1, livePayload.Attempt) stored := filterStoredSessionEvents(t, store.events, func(se *SessionEvent[*schema.Message]) bool { return se.Kind == extensionKind @@ -607,7 +597,10 @@ func TestRunner_ExtensionEventSentWithTypedSendEventIsLiveAndPersisted(t *testin assert.Equal(t, liveExtension.EventID, stored[0].EventID) assert.Equal(t, liveExtension.TurnID, stored[0].TurnID) require.NotNil(t, stored[0].Extension) - assert.Equal(t, []byte(`{"outcome_name":"code_review","attempt":1}`), []byte(stored[0].Extension.Data)) + storedPayload, ok := stored[0].Extension.Data.(*sessionTimelineExtensionPayload) + require.True(t, ok) + assert.Equal(t, "code_review", storedPayload.OutcomeName) + assert.Equal(t, 1, storedPayload.Attempt) var extensionIndex, idleIndex = -1, -1 for i, payload := range store.events { @@ -648,15 +641,14 @@ func TestRunner_ExtensionEventSentWithTypedSendEventIsLiveAndPersisted(t *testin }) } -func TestTypedSendEventOutsideExecutionReturnsError(t *testing.T) { +func TestTypedSendEventOutsideExecutionIsNoop(t *testing.T) { err := SendEvent(context.Background(), &AgentEvent{ SessionEvent: &SessionEvent[*schema.Message]{ Kind: SessionEventKind("x.outcome.started"), Extension: &SessionExtensionEvent{}, }, }) - require.Error(t, err) - assert.Contains(t, err.Error(), "must be called within") + require.NoError(t, err) } func TestSessionTimeline_SpanMetaMustBeOneOf(t *testing.T) { @@ -676,16 +668,6 @@ func TestSessionTimeline_SpanMetaMustBeOneOf(t *testing.T) { assert.Contains(t, err.Error(), "exactly one of Model or Tool") } -func TestToolPermissionDecisionScopedByToolUseID(t *testing.T) { - ctx := contextWithToolPermissionDecisionStore(context.Background()) - SetToolPermissionDecision(ctx, "call_1", "allowed") - SetToolPermissionDecision(ctx, "call_2", "denied") - - assert.Equal(t, "allowed", GetToolPermissionDecision(ctx, "call_1")) - assert.Equal(t, "denied", GetToolPermissionDecision(ctx, "call_2")) - assert.Empty(t, GetToolPermissionDecision(ctx, "missing")) -} - func TestRetryTimelineEmitsRescheduleSequence(t *testing.T) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() ctx := withTypedChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ diff --git a/adk/tool_permission.go b/adk/tool_permission.go deleted file mode 100644 index f43dc1e19..000000000 --- a/adk/tool_permission.go +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2026 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package adk - -import ( - "context" - "sync" -) - -type toolPermissionDecisionStore struct { - mu sync.RWMutex - decision map[string]string -} - -type toolPermissionDecisionKey struct{} - -func contextWithToolPermissionDecisionStore(ctx context.Context) context.Context { - if ctx.Value(toolPermissionDecisionKey{}) != nil { - return ctx - } - return context.WithValue(ctx, toolPermissionDecisionKey{}, &toolPermissionDecisionStore{decision: map[string]string{}}) -} - -// SetToolPermissionDecision records the final permission decision for one tool -// call. Decisions are keyed by ToolContext.CallID / tool-use ID. -func SetToolPermissionDecision(ctx context.Context, toolCallID, decision string) { - if toolCallID == "" || decision == "" { - return - } - store, _ := ctx.Value(toolPermissionDecisionKey{}).(*toolPermissionDecisionStore) - if store == nil { - return - } - store.mu.Lock() - store.decision[toolCallID] = decision - store.mu.Unlock() -} - -// GetToolPermissionDecision returns the decision recorded for a single tool -// call, or an empty string when no middleware participated. -func GetToolPermissionDecision(ctx context.Context, toolCallID string) string { - store, _ := ctx.Value(toolPermissionDecisionKey{}).(*toolPermissionDecisionStore) - if store == nil || toolCallID == "" { - return "" - } - store.mu.RLock() - defer store.mu.RUnlock() - return store.decision[toolCallID] -}