Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions adk/cancel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,12 @@ func TestWithCancel_AgenticResumeStreamableToolTimeout_DoesNotPersistTypedNil(t
}()
select {
case err = <-cancelDone:
assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout), "unexpected cancel wait error: %v", err)
assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout) || errors.Is(err, ErrExecutionEnded),
"unexpected cancel wait error: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("resume cancel handle did not complete")
}
executionCompletedBeforeCancel := errors.Is(err, ErrExecutionEnded)

var hasCancelError bool
for {
Expand All @@ -371,7 +373,8 @@ func TestWithCancel_AgenticResumeStreamableToolTimeout_DoesNotPersistTypedNil(t
assert.NotContains(t, errText, "cannot encode nil pointer")
assert.NotContains(t, errText, "*adk.agenticReactInput(nil=true")
}
assert.True(t, hasCancelError, "expected CancelError in resume event stream")
assert.True(t, hasCancelError || executionCompletedBeforeCancel,
"expected CancelError in resume event stream unless execution completed before cancel")
}

func TestCancelContext(t *testing.T) {
Expand Down Expand Up @@ -2561,7 +2564,7 @@ func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) {

t.Run("unit_SendEvent_no_execCtx", func(t *testing.T) {
err := SendEvent(context.Background(), &AgentEvent{AgentName: "test"})
assert.Error(t, err, "SendEvent without execCtx should return error")
assert.NoError(t, err, "SendEvent without execCtx should be a no-op")
})
Comment thread
shentongmartin marked this conversation as resolved.

t.Run("integration_cancel_escalation_orphans_tool", func(t *testing.T) {
Expand Down
2 changes: 0 additions & 2 deletions adk/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,6 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput
co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...)))
}
}
ctx = contextWithToolPermissionDecisionStore(ctx)

go func() {
defer func() {
Expand Down Expand Up @@ -1803,7 +1802,6 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o
return nil
}))
}
ctx = contextWithToolPermissionDecisionStore(ctx)

go func() {
defer func() {
Expand Down
21 changes: 0 additions & 21 deletions adk/coverage_contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,6 @@ func TestCommonOptionsAndFilteringContracts(t *testing.T) {
assert.Len(t, filterOptions("parent", []AgentRunOption{nonCallback.DesignateAgent("parent"), otherCallback, {}}), 2)
}

func TestToolPermissionDecisionStoreContracts(t *testing.T) {
ctx := context.Background()
assert.Empty(t, GetToolPermissionDecision(ctx, "call-1"))
SetToolPermissionDecision(ctx, "call-1", "allowed")
assert.Empty(t, GetToolPermissionDecision(ctx, "call-1"))

ctx = contextWithToolPermissionDecisionStore(ctx)
same := contextWithToolPermissionDecisionStore(ctx)
assert.Same(t, ctx, same)

SetToolPermissionDecision(ctx, "", "allowed")
SetToolPermissionDecision(ctx, "call-1", "")
assert.Empty(t, GetToolPermissionDecision(ctx, "call-1"))

SetToolPermissionDecision(ctx, "call-1", "allowed")
SetToolPermissionDecision(ctx, "call-2", "denied")
assert.Equal(t, "allowed", GetToolPermissionDecision(ctx, "call-1"))
assert.Equal(t, "denied", GetToolPermissionDecision(ctx, "call-2"))
assert.Empty(t, GetToolPermissionDecision(ctx, ""))
}

func TestLocalSessionServiceHandleContracts(t *testing.T) {
ctx := context.Background()
assert.Nil(t, NewLocalSessionService[*schema.Message](nil))
Expand Down
10 changes: 5 additions & 5 deletions adk/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,12 @@ func DeleteRunLocalValue(ctx context.Context, key string) error {
// via internal wrapper layers. If your middleware constructs its own messages, call
// EnsureMessageID before sending to assign an ID.
//
// This function can only be called from within a TypedChatModelAgentMiddleware during agent execution.
// Returns an error if called outside of an agent execution context.
// When called outside of an agent execution context, or from a path without an
// event generator, this function is a no-op.
func TypedSendEvent[M MessageType](ctx context.Context, event *TypedAgentEvent[M]) error {
execCtx := getTypedChatModelAgentExecCtx[M](ctx)
if execCtx == nil || execCtx.generator == nil {
return fmt.Errorf("TypedSendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context")
return nil
}

execCtx.send(ctx, event)
Expand All @@ -437,8 +437,8 @@ func TypedSendEvent[M MessageType](ctx context.Context, event *TypedAgentEvent[M
// For custom session timeline events during a Runner run, set AgentEvent.SessionEvent
// to an extension SessionEvent with an x.* Kind and send it through this function.
//
// This function can only be called from within a ChatModelAgentMiddleware during agent execution.
// Returns an error if called outside of an agent execution context.
// When called outside of an agent execution context, or from a path without an
// event generator, this function is a no-op.
func SendEvent(ctx context.Context, event *AgentEvent) error {
return TypedSendEvent(ctx, event)
}
Expand Down
9 changes: 8 additions & 1 deletion adk/interrupt.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,15 @@ func encodeRunnerCheckPointImpl(
info *InterruptInfo,
is *core.InterruptSignal,
) ([]byte, error) {
runCtx := getRunCtx(ctx)
return encodeRunnerCheckPointWithRunCtx(enableStreaming, getRunCtx(ctx), info, is)
}

func encodeRunnerCheckPointWithRunCtx(
enableStreaming bool,
runCtx *runContext,
info *InterruptInfo,
is *core.InterruptSignal,
) ([]byte, error) {
id2Addr, id2State := core.SignalToPersistenceMaps(is)

buf := &bytes.Buffer{}
Expand Down
116 changes: 98 additions & 18 deletions adk/middlewares/permission/permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
func init() {
schema.RegisterName[*AskInfo]("_eino_adk_permission_ask_info")
schema.RegisterName[*AskState]("_eino_adk_permission_ask_state")
schema.RegisterName[*DecisionEvent]("_eino_adk_permission_decision_event")
}

// GateDecision is the result of a pre-execution permission check.
Expand All @@ -47,6 +48,12 @@ const (
GateAsk GateDecision = "ask"
)

const (
// SessionEventPermissionDecision records a valid user resume decision for a
// previously interrupted permission ask.
SessionEventPermissionDecision adk.SessionEventKind = adk.SessionEventKind(adk.SessionEventExtensionPrefix + "permission.decision")
)

// GateCheckResult determines how a tool call should proceed before execution.
type GateCheckResult struct {
Decision GateDecision
Expand Down Expand Up @@ -114,6 +121,18 @@ type ResumeResponse struct {
Message string
}

// DecisionEvent is the typed payload for SessionEventPermissionDecision.
// It intentionally omits the original saved tool arguments; only user-provided
// UpdatedInput is carried when it is part of an approval decision.
type DecisionEvent struct {
Action ResumeAction `json:"action"`
ToolName string `json:"tool_name"`
ToolUseID string `json:"tool_use_id,omitempty"`
DecisionText string `json:"decision_text,omitempty"`
UpdatedInput string `json:"updated_input,omitempty"`
HasUpdatedInput bool `json:"has_updated_input,omitempty"`
}

// Middleware gates tool calls with a permission Checker.
type Middleware[M adk.MessageType] struct {
*adk.TypedBaseChatModelAgentMiddleware[M]
Expand All @@ -139,6 +158,13 @@ type gateResult struct {
argument *schema.ToolArgument
}

type normalizedResumeDecision struct {
Action ResumeAction
UpdatedInput string
HasUpdatedInput bool
DecisionText string
}

func (m *Middleware[M]) permissionGate(
ctx context.Context,
tCtx *adk.ToolContext,
Expand All @@ -162,6 +188,9 @@ func (m *Middleware[M]) permissionGate(
if !hasState || savedState == nil {
return nil, fmt.Errorf("permission: missing AskState for targeted resume of tool %q (call_id=%s)", tCtx.Name, tCtx.CallID)
}
if err := emitDecisionEvent[M](ctx, tCtx, savedState, response); err != nil {
return nil, err
}
return handleResumeResponse(ctx, tCtx, &schema.ToolArgument{Text: savedState.Arguments}, response)
}

Expand Down Expand Up @@ -191,16 +220,13 @@ func (m *Middleware[M]) permissionGate(

switch decision.Decision {
case GateAllow:
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateAllow))
return &gateResult{
allowed: true,
argument: withUpdatedInput(argument, decision.UpdatedInput, decision.HasUpdatedInput || decision.UpdatedInput != ""),
}, nil
case GateDeny:
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateDeny))
return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.Message)}, nil
case GateAsk:
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateAsk))
info := &AskInfo{
ToolName: tCtx.Name,
Summary: publicSummary(decision.Message, tCtx.CallID, argument.Text),
Expand Down Expand Up @@ -250,40 +276,94 @@ func handleResumeResponse(
argument *schema.ToolArgument,
response *ResumeResponse,
) (*gateResult, error) {
if response == nil {
return nil, fmt.Errorf("permission: nil ResumeResponse for tool %q (call_id=%s)", tCtx.Name, tCtx.CallID)
decision, err := normalizeResumeDecision(tCtx, response)
if err != nil {
return nil, err
}

switch response.Action {
switch decision.Action {
case ResumeActionApprove:
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionApprove))
return &gateResult{
allowed: true,
argument: withUpdatedInput(argument, response.UpdatedInput, response.HasUpdatedInput || response.UpdatedInput != ""),
argument: withUpdatedInput(argument, decision.UpdatedInput, decision.HasUpdatedInput),
}, nil
case ResumeActionReject:
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionReject))
message := response.Message
if message == "" {
message = "rejected by user"
return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.DecisionText)}, nil
case ResumeActionRespond:
return &gateResult{denyResult: formatRespondResult(tCtx.Name, decision.DecisionText)}, nil
default:
return nil, fmt.Errorf("permission: unknown resume action %q for tool %q (call_id=%s); expected approve, reject, or respond",
decision.Action, tCtx.Name, tCtx.CallID)
}
}

func normalizeResumeDecision(tCtx *adk.ToolContext, response *ResumeResponse) (*normalizedResumeDecision, error) {
toolName, callID := "", ""
if tCtx != nil {
toolName = tCtx.Name
callID = tCtx.CallID
}
if response == nil {
return nil, fmt.Errorf("permission: nil ResumeResponse for tool %q (call_id=%s)", toolName, callID)
}

decision := &normalizedResumeDecision{Action: response.Action}
switch response.Action {
case ResumeActionApprove:
decision.HasUpdatedInput = response.HasUpdatedInput || response.UpdatedInput != ""
if decision.HasUpdatedInput {
decision.UpdatedInput = response.UpdatedInput
}
return decision, nil
case ResumeActionReject:
decision.DecisionText = response.Message
if decision.DecisionText == "" {
decision.DecisionText = "rejected by user"
}
return &gateResult{denyResult: formatDenyResult(tCtx.Name, message)}, nil
return decision, nil
case ResumeActionRespond:
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionRespond))
if response.Message == "" {
return nil, fmt.Errorf("permission: empty response message for respond action on tool %q (call_id=%s)",
tCtx.Name, tCtx.CallID)
toolName, callID)
}
return &gateResult{denyResult: formatRespondResult(tCtx.Name, response.Message)}, nil
decision.DecisionText = response.Message
return decision, nil
case "":
return nil, fmt.Errorf("permission: empty resume action for tool %q (call_id=%s); expected approve, reject, or respond",
tCtx.Name, tCtx.CallID)
toolName, callID)
default:
return nil, fmt.Errorf("permission: unknown resume action %q for tool %q (call_id=%s); expected approve, reject, or respond",
response.Action, tCtx.Name, tCtx.CallID)
response.Action, toolName, callID)
}
}

func emitDecisionEvent[M adk.MessageType](ctx context.Context, tCtx *adk.ToolContext, state *AskState, response *ResumeResponse) error {
if tCtx == nil {
return fmt.Errorf("permission: nil ToolContext for resume decision event")
}
if state == nil {
return fmt.Errorf("permission: nil AskState for resume decision event on tool %q (call_id=%s)", tCtx.Name, tCtx.CallID)
}
decision, err := normalizeResumeDecision(tCtx, response)
if err != nil {
return err
}
payload := &DecisionEvent{
Action: decision.Action,
ToolName: state.ToolName,
ToolUseID: state.CallID,
DecisionText: decision.DecisionText,
UpdatedInput: decision.UpdatedInput,
HasUpdatedInput: decision.HasUpdatedInput,
}
return adk.TypedSendEvent[M](ctx, &adk.TypedAgentEvent[M]{
SessionEvent: &adk.SessionEvent[M]{
Kind: SessionEventPermissionDecision,
Extension: &adk.SessionExtensionEvent{Data: payload},
},
})
}

func (m *Middleware[M]) WrapInvokableToolCall(
_ context.Context,
endpoint adk.InvokableToolCallEndpoint,
Expand Down
Loading
Loading