Skip to content

Commit 08f7a06

Browse files
feat(adk): record permission resume decisions (#1070)
1 parent 6ec3268 commit 08f7a06

16 files changed

Lines changed: 954 additions & 205 deletions

adk/cancel_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,12 @@ func TestWithCancel_AgenticResumeStreamableToolTimeout_DoesNotPersistTypedNil(t
348348
}()
349349
select {
350350
case err = <-cancelDone:
351-
assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout), "unexpected cancel wait error: %v", err)
351+
assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout) || errors.Is(err, ErrExecutionEnded),
352+
"unexpected cancel wait error: %v", err)
352353
case <-time.After(5 * time.Second):
353354
t.Fatal("resume cancel handle did not complete")
354355
}
356+
executionCompletedBeforeCancel := errors.Is(err, ErrExecutionEnded)
355357

356358
var hasCancelError bool
357359
for {
@@ -371,7 +373,8 @@ func TestWithCancel_AgenticResumeStreamableToolTimeout_DoesNotPersistTypedNil(t
371373
assert.NotContains(t, errText, "cannot encode nil pointer")
372374
assert.NotContains(t, errText, "*adk.agenticReactInput(nil=true")
373375
}
374-
assert.True(t, hasCancelError, "expected CancelError in resume event stream")
376+
assert.True(t, hasCancelError || executionCompletedBeforeCancel,
377+
"expected CancelError in resume event stream unless execution completed before cancel")
375378
}
376379

377380
func TestCancelContext(t *testing.T) {
@@ -2561,7 +2564,7 @@ func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) {
25612564

25622565
t.Run("unit_SendEvent_no_execCtx", func(t *testing.T) {
25632566
err := SendEvent(context.Background(), &AgentEvent{AgentName: "test"})
2564-
assert.Error(t, err, "SendEvent without execCtx should return error")
2567+
assert.NoError(t, err, "SendEvent without execCtx should be a no-op")
25652568
})
25662569

25672570
t.Run("integration_cancel_escalation_orphans_tool", func(t *testing.T) {

adk/chatmodel.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,7 +1674,6 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput
16741674
co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...)))
16751675
}
16761676
}
1677-
ctx = contextWithToolPermissionDecisionStore(ctx)
16781677

16791678
go func() {
16801679
defer func() {
@@ -1803,7 +1802,6 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o
18031802
return nil
18041803
}))
18051804
}
1806-
ctx = contextWithToolPermissionDecisionStore(ctx)
18071805

18081806
go func() {
18091807
defer func() {

adk/coverage_contract_test.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -156,27 +156,6 @@ func TestCommonOptionsAndFilteringContracts(t *testing.T) {
156156
assert.Len(t, filterOptions("parent", []AgentRunOption{nonCallback.DesignateAgent("parent"), otherCallback, {}}), 2)
157157
}
158158

159-
func TestToolPermissionDecisionStoreContracts(t *testing.T) {
160-
ctx := context.Background()
161-
assert.Empty(t, GetToolPermissionDecision(ctx, "call-1"))
162-
SetToolPermissionDecision(ctx, "call-1", "allowed")
163-
assert.Empty(t, GetToolPermissionDecision(ctx, "call-1"))
164-
165-
ctx = contextWithToolPermissionDecisionStore(ctx)
166-
same := contextWithToolPermissionDecisionStore(ctx)
167-
assert.Same(t, ctx, same)
168-
169-
SetToolPermissionDecision(ctx, "", "allowed")
170-
SetToolPermissionDecision(ctx, "call-1", "")
171-
assert.Empty(t, GetToolPermissionDecision(ctx, "call-1"))
172-
173-
SetToolPermissionDecision(ctx, "call-1", "allowed")
174-
SetToolPermissionDecision(ctx, "call-2", "denied")
175-
assert.Equal(t, "allowed", GetToolPermissionDecision(ctx, "call-1"))
176-
assert.Equal(t, "denied", GetToolPermissionDecision(ctx, "call-2"))
177-
assert.Empty(t, GetToolPermissionDecision(ctx, ""))
178-
}
179-
180159
func TestLocalSessionServiceHandleContracts(t *testing.T) {
181160
ctx := context.Background()
182161
assert.Nil(t, NewLocalSessionService[*schema.Message](nil))

adk/handler.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,12 @@ func DeleteRunLocalValue(ctx context.Context, key string) error {
419419
// via internal wrapper layers. If your middleware constructs its own messages, call
420420
// EnsureMessageID before sending to assign an ID.
421421
//
422-
// This function can only be called from within a TypedChatModelAgentMiddleware during agent execution.
423-
// Returns an error if called outside of an agent execution context.
422+
// When called outside of an agent execution context, or from a path without an
423+
// event generator, this function is a no-op.
424424
func TypedSendEvent[M MessageType](ctx context.Context, event *TypedAgentEvent[M]) error {
425425
execCtx := getTypedChatModelAgentExecCtx[M](ctx)
426426
if execCtx == nil || execCtx.generator == nil {
427-
return fmt.Errorf("TypedSendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context")
427+
return nil
428428
}
429429

430430
execCtx.send(ctx, event)
@@ -437,8 +437,8 @@ func TypedSendEvent[M MessageType](ctx context.Context, event *TypedAgentEvent[M
437437
// For custom session timeline events during a Runner run, set AgentEvent.SessionEvent
438438
// to an extension SessionEvent with an x.* Kind and send it through this function.
439439
//
440-
// This function can only be called from within a ChatModelAgentMiddleware during agent execution.
441-
// Returns an error if called outside of an agent execution context.
440+
// When called outside of an agent execution context, or from a path without an
441+
// event generator, this function is a no-op.
442442
func SendEvent(ctx context.Context, event *AgentEvent) error {
443443
return TypedSendEvent(ctx, event)
444444
}

adk/interrupt.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,15 @@ func encodeRunnerCheckPointImpl(
310310
info *InterruptInfo,
311311
is *core.InterruptSignal,
312312
) ([]byte, error) {
313-
runCtx := getRunCtx(ctx)
313+
return encodeRunnerCheckPointWithRunCtx(enableStreaming, getRunCtx(ctx), info, is)
314+
}
314315

316+
func encodeRunnerCheckPointWithRunCtx(
317+
enableStreaming bool,
318+
runCtx *runContext,
319+
info *InterruptInfo,
320+
is *core.InterruptSignal,
321+
) ([]byte, error) {
315322
id2Addr, id2State := core.SignalToPersistenceMaps(is)
316323

317324
buf := &bytes.Buffer{}

adk/middlewares/permission/permission.go

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
func init() {
3333
schema.RegisterName[*AskInfo]("_eino_adk_permission_ask_info")
3434
schema.RegisterName[*AskState]("_eino_adk_permission_ask_state")
35+
schema.RegisterName[*DecisionEvent]("_eino_adk_permission_decision_event")
3536
}
3637

3738
// GateDecision is the result of a pre-execution permission check.
@@ -47,6 +48,12 @@ const (
4748
GateAsk GateDecision = "ask"
4849
)
4950

51+
const (
52+
// SessionEventPermissionDecision records a valid user resume decision for a
53+
// previously interrupted permission ask.
54+
SessionEventPermissionDecision adk.SessionEventKind = adk.SessionEventKind(adk.SessionEventExtensionPrefix + "permission.decision")
55+
)
56+
5057
// GateCheckResult determines how a tool call should proceed before execution.
5158
type GateCheckResult struct {
5259
Decision GateDecision
@@ -114,6 +121,18 @@ type ResumeResponse struct {
114121
Message string
115122
}
116123

124+
// DecisionEvent is the typed payload for SessionEventPermissionDecision.
125+
// It intentionally omits the original saved tool arguments; only user-provided
126+
// UpdatedInput is carried when it is part of an approval decision.
127+
type DecisionEvent struct {
128+
Action ResumeAction `json:"action"`
129+
ToolName string `json:"tool_name"`
130+
ToolUseID string `json:"tool_use_id,omitempty"`
131+
DecisionText string `json:"decision_text,omitempty"`
132+
UpdatedInput string `json:"updated_input,omitempty"`
133+
HasUpdatedInput bool `json:"has_updated_input,omitempty"`
134+
}
135+
117136
// Middleware gates tool calls with a permission Checker.
118137
type Middleware[M adk.MessageType] struct {
119138
*adk.TypedBaseChatModelAgentMiddleware[M]
@@ -139,6 +158,13 @@ type gateResult struct {
139158
argument *schema.ToolArgument
140159
}
141160

161+
type normalizedResumeDecision struct {
162+
Action ResumeAction
163+
UpdatedInput string
164+
HasUpdatedInput bool
165+
DecisionText string
166+
}
167+
142168
func (m *Middleware[M]) permissionGate(
143169
ctx context.Context,
144170
tCtx *adk.ToolContext,
@@ -162,6 +188,9 @@ func (m *Middleware[M]) permissionGate(
162188
if !hasState || savedState == nil {
163189
return nil, fmt.Errorf("permission: missing AskState for targeted resume of tool %q (call_id=%s)", tCtx.Name, tCtx.CallID)
164190
}
191+
if err := emitDecisionEvent[M](ctx, tCtx, savedState, response); err != nil {
192+
return nil, err
193+
}
165194
return handleResumeResponse(ctx, tCtx, &schema.ToolArgument{Text: savedState.Arguments}, response)
166195
}
167196

@@ -191,16 +220,13 @@ func (m *Middleware[M]) permissionGate(
191220

192221
switch decision.Decision {
193222
case GateAllow:
194-
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateAllow))
195223
return &gateResult{
196224
allowed: true,
197225
argument: withUpdatedInput(argument, decision.UpdatedInput, decision.HasUpdatedInput || decision.UpdatedInput != ""),
198226
}, nil
199227
case GateDeny:
200-
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateDeny))
201228
return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.Message)}, nil
202229
case GateAsk:
203-
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(GateAsk))
204230
info := &AskInfo{
205231
ToolName: tCtx.Name,
206232
Summary: publicSummary(decision.Message, tCtx.CallID, argument.Text),
@@ -250,40 +276,94 @@ func handleResumeResponse(
250276
argument *schema.ToolArgument,
251277
response *ResumeResponse,
252278
) (*gateResult, error) {
253-
if response == nil {
254-
return nil, fmt.Errorf("permission: nil ResumeResponse for tool %q (call_id=%s)", tCtx.Name, tCtx.CallID)
279+
decision, err := normalizeResumeDecision(tCtx, response)
280+
if err != nil {
281+
return nil, err
255282
}
256283

257-
switch response.Action {
284+
switch decision.Action {
258285
case ResumeActionApprove:
259-
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionApprove))
260286
return &gateResult{
261287
allowed: true,
262-
argument: withUpdatedInput(argument, response.UpdatedInput, response.HasUpdatedInput || response.UpdatedInput != ""),
288+
argument: withUpdatedInput(argument, decision.UpdatedInput, decision.HasUpdatedInput),
263289
}, nil
264290
case ResumeActionReject:
265-
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionReject))
266-
message := response.Message
267-
if message == "" {
268-
message = "rejected by user"
291+
return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.DecisionText)}, nil
292+
case ResumeActionRespond:
293+
return &gateResult{denyResult: formatRespondResult(tCtx.Name, decision.DecisionText)}, nil
294+
default:
295+
return nil, fmt.Errorf("permission: unknown resume action %q for tool %q (call_id=%s); expected approve, reject, or respond",
296+
decision.Action, tCtx.Name, tCtx.CallID)
297+
}
298+
}
299+
300+
func normalizeResumeDecision(tCtx *adk.ToolContext, response *ResumeResponse) (*normalizedResumeDecision, error) {
301+
toolName, callID := "", ""
302+
if tCtx != nil {
303+
toolName = tCtx.Name
304+
callID = tCtx.CallID
305+
}
306+
if response == nil {
307+
return nil, fmt.Errorf("permission: nil ResumeResponse for tool %q (call_id=%s)", toolName, callID)
308+
}
309+
310+
decision := &normalizedResumeDecision{Action: response.Action}
311+
switch response.Action {
312+
case ResumeActionApprove:
313+
decision.HasUpdatedInput = response.HasUpdatedInput || response.UpdatedInput != ""
314+
if decision.HasUpdatedInput {
315+
decision.UpdatedInput = response.UpdatedInput
316+
}
317+
return decision, nil
318+
case ResumeActionReject:
319+
decision.DecisionText = response.Message
320+
if decision.DecisionText == "" {
321+
decision.DecisionText = "rejected by user"
269322
}
270-
return &gateResult{denyResult: formatDenyResult(tCtx.Name, message)}, nil
323+
return decision, nil
271324
case ResumeActionRespond:
272-
adk.SetToolPermissionDecision(ctx, tCtx.CallID, string(ResumeActionRespond))
273325
if response.Message == "" {
274326
return nil, fmt.Errorf("permission: empty response message for respond action on tool %q (call_id=%s)",
275-
tCtx.Name, tCtx.CallID)
327+
toolName, callID)
276328
}
277-
return &gateResult{denyResult: formatRespondResult(tCtx.Name, response.Message)}, nil
329+
decision.DecisionText = response.Message
330+
return decision, nil
278331
case "":
279332
return nil, fmt.Errorf("permission: empty resume action for tool %q (call_id=%s); expected approve, reject, or respond",
280-
tCtx.Name, tCtx.CallID)
333+
toolName, callID)
281334
default:
282335
return nil, fmt.Errorf("permission: unknown resume action %q for tool %q (call_id=%s); expected approve, reject, or respond",
283-
response.Action, tCtx.Name, tCtx.CallID)
336+
response.Action, toolName, callID)
284337
}
285338
}
286339

340+
func emitDecisionEvent[M adk.MessageType](ctx context.Context, tCtx *adk.ToolContext, state *AskState, response *ResumeResponse) error {
341+
if tCtx == nil {
342+
return fmt.Errorf("permission: nil ToolContext for resume decision event")
343+
}
344+
if state == nil {
345+
return fmt.Errorf("permission: nil AskState for resume decision event on tool %q (call_id=%s)", tCtx.Name, tCtx.CallID)
346+
}
347+
decision, err := normalizeResumeDecision(tCtx, response)
348+
if err != nil {
349+
return err
350+
}
351+
payload := &DecisionEvent{
352+
Action: decision.Action,
353+
ToolName: state.ToolName,
354+
ToolUseID: state.CallID,
355+
DecisionText: decision.DecisionText,
356+
UpdatedInput: decision.UpdatedInput,
357+
HasUpdatedInput: decision.HasUpdatedInput,
358+
}
359+
return adk.TypedSendEvent[M](ctx, &adk.TypedAgentEvent[M]{
360+
SessionEvent: &adk.SessionEvent[M]{
361+
Kind: SessionEventPermissionDecision,
362+
Extension: &adk.SessionExtensionEvent{Data: payload},
363+
},
364+
})
365+
}
366+
287367
func (m *Middleware[M]) WrapInvokableToolCall(
288368
_ context.Context,
289369
endpoint adk.InvokableToolCallEndpoint,

0 commit comments

Comments
 (0)