Skip to content

Commit 674aedc

Browse files
authored
Merge pull request #2363 from trungutt/steer-mid-turn-messages
Add mid-turn message steering for running agent sessions
2 parents f781eca + 5dac5cd commit 674aedc

12 files changed

Lines changed: 309 additions & 0 deletions

File tree

pkg/api/types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ type ResumeElicitationRequest struct {
160160
Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept")
161161
}
162162

163+
// SteerSessionRequest represents a request to inject user messages into a
164+
// running agent session. The messages are picked up by the agent loop between
165+
// tool execution and the next LLM call.
166+
type SteerSessionRequest struct {
167+
Messages []Message `json:"messages"`
168+
}
169+
163170
// UpdateSessionTitleRequest represents a request to update a session's title
164171
type UpdateSessionTitleRequest struct {
165172
Title string `json:"title"`

pkg/app/app_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ func (m *mockRuntime) UpdateSessionTitle(_ context.Context, sess *session.Sessio
6767
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
6868
func (m *mockRuntime) Close() error { return nil }
6969
func (m *mockRuntime) Stop() {}
70+
func (m *mockRuntime) Steer(_ runtime.QueuedMessage) error { return nil }
71+
func (m *mockRuntime) FollowUp(_ runtime.QueuedMessage) error { return nil }
7072

7173
// Verify mockRuntime implements runtime.Runtime
7274
var _ runtime.Runtime = (*mockRuntime)(nil)

pkg/cli/runner_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]strin
6060
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
6161
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
6262
func (m *mockRuntime) Close() error { return nil }
63+
func (m *mockRuntime) Steer(runtime.QueuedMessage) error { return nil }
64+
func (m *mockRuntime) FollowUp(runtime.QueuedMessage) error { return nil }
6365
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}
6466

6567
func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {

pkg/runtime/client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,18 @@ func (c *Client) ResumeSession(ctx context.Context, id, confirmation, reason, to
266266
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+id+"/resume", req, nil)
267267
}
268268

269+
// SteerSession injects user messages into a running session mid-turn.
270+
func (c *Client) SteerSession(ctx context.Context, sessionID string, messages []api.Message) error {
271+
req := api.SteerSessionRequest{Messages: messages}
272+
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/steer", req, nil)
273+
}
274+
275+
// FollowUpSession queues messages for end-of-turn processing.
276+
func (c *Client) FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error {
277+
req := api.SteerSessionRequest{Messages: messages}
278+
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/followup", req, nil)
279+
}
280+
269281
// DeleteSession deletes a session by ID
270282
func (c *Client) DeleteSession(ctx context.Context, id string) error {
271283
return c.doRequest(ctx, "DELETE", "/api/sessions/"+id, nil, nil)

pkg/runtime/commands_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, stri
6969
}
7070
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
7171
func (m *mockRuntime) Close() error { return nil }
72+
func (m *mockRuntime) Steer(QueuedMessage) error { return nil }
73+
func (m *mockRuntime) FollowUp(QueuedMessage) error { return nil }
7274

7375
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan Event) {
7476
}

pkg/runtime/loop.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,43 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
404404
// Record per-toolset model override for the next LLM turn.
405405
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)
406406

407+
// --- STEERING: mid-turn injection ---
408+
// Drain ALL pending steer messages. These are urgent course-
409+
// corrections that the model should see on the very next
410+
// iteration, wrapped in <system-reminder> tags.
411+
if steered := r.steerQueue.Drain(ctx); len(steered) > 0 {
412+
for _, sm := range steered {
413+
wrapped := fmt.Sprintf(
414+
"<system-reminder>\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n</system-reminder>",
415+
sm.Content,
416+
)
417+
userMsg := session.UserMessage(wrapped, sm.MultiContent...)
418+
sess.AddMessage(userMsg)
419+
events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1)
420+
}
421+
422+
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
423+
continue
424+
}
425+
407426
if res.Stopped {
408427
slog.Debug("Conversation stopped", "agent", a.Name())
409428
r.executeStopHooks(ctx, sess, a, res.Content, events)
429+
430+
// --- FOLLOW-UP: end-of-turn injection ---
431+
// Pop exactly one follow-up message. Unlike steered
432+
// messages, follow-ups are plain user messages that start
433+
// a new turn — the model sees them as fresh input, not a
434+
// mid-stream interruption. Each follow-up gets a full
435+
// undivided agent turn.
436+
if followUp, ok := r.followUpQueue.Dequeue(ctx); ok {
437+
userMsg := session.UserMessage(followUp.Content, followUp.MultiContent...)
438+
sess.AddMessage(userMsg)
439+
events <- UserMessage(followUp.Content, sess.ID, followUp.MultiContent, len(sess.Messages)-1)
440+
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
441+
continue // re-enter the loop for a new turn
442+
}
443+
410444
break
411445
}
412446

pkg/runtime/message_queue.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
6+
"github.com/docker/docker-agent/pkg/chat"
7+
)
8+
9+
// QueuedMessage is a user message waiting to be injected into the agent loop,
10+
// either mid-turn (via the steer queue) or at end-of-turn (via the follow-up
11+
// queue).
12+
type QueuedMessage struct {
13+
Content string
14+
MultiContent []chat.MessagePart
15+
}
16+
17+
// MessageQueue is the interface for storing messages that are injected into
18+
// the agent loop. Implementations must be safe for concurrent use: Enqueue
19+
// is called from API handlers while Dequeue/Drain are called from the agent
20+
// loop goroutine.
21+
//
22+
// The default implementation is NewInMemoryMessageQueue. Callers that need
23+
// durable or distributed storage can provide their own implementation
24+
// via the WithSteerQueue or WithFollowUpQueue options.
25+
type MessageQueue interface {
26+
// Enqueue adds a message to the queue. Returns false if the queue is
27+
// full or the context is cancelled.
28+
Enqueue(ctx context.Context, msg QueuedMessage) bool
29+
// Dequeue removes and returns the next message from the queue.
30+
// Returns the message and true, or a zero value and false if the
31+
// queue is empty. Must not block.
32+
Dequeue(ctx context.Context) (QueuedMessage, bool)
33+
// Drain returns all pending messages and removes them from the queue.
34+
// Must not block — if the queue is empty it returns nil.
35+
Drain(ctx context.Context) []QueuedMessage
36+
}
37+
38+
// inMemoryMessageQueue is the default MessageQueue backed by a buffered channel.
39+
type inMemoryMessageQueue struct {
40+
ch chan QueuedMessage
41+
}
42+
43+
const (
44+
// defaultSteerQueueCapacity is the buffer size for the default in-memory steer queue.
45+
defaultSteerQueueCapacity = 5
46+
// defaultFollowUpQueueCapacity is the buffer size for the default in-memory follow-up queue.
47+
// Higher than steer because follow-ups accumulate while waiting for the turn to end.
48+
defaultFollowUpQueueCapacity = 20
49+
)
50+
51+
// NewInMemoryMessageQueue creates a MessageQueue backed by a buffered channel
52+
// with the given capacity.
53+
func NewInMemoryMessageQueue(capacity int) MessageQueue {
54+
return &inMemoryMessageQueue{ch: make(chan QueuedMessage, capacity)}
55+
}
56+
57+
func (q *inMemoryMessageQueue) Enqueue(_ context.Context, msg QueuedMessage) bool {
58+
select {
59+
case q.ch <- msg:
60+
return true
61+
default:
62+
return false
63+
}
64+
}
65+
66+
func (q *inMemoryMessageQueue) Dequeue(_ context.Context) (QueuedMessage, bool) {
67+
select {
68+
case m := <-q.ch:
69+
return m, true
70+
default:
71+
return QueuedMessage{}, false
72+
}
73+
}
74+
75+
func (q *inMemoryMessageQueue) Drain(_ context.Context) []QueuedMessage {
76+
var msgs []QueuedMessage
77+
for {
78+
select {
79+
case m := <-q.ch:
80+
msgs = append(msgs, m)
81+
default:
82+
return msgs
83+
}
84+
}
85+
}

pkg/runtime/remote_client.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ type RemoteClient interface {
3030
// RunAgentWithAgentName executes an agent with a specific agent name
3131
RunAgentWithAgentName(ctx context.Context, sessionID, agent, agentName string, messages []api.Message) (<-chan Event, error)
3232

33+
// SteerSession injects user messages into a running session mid-turn
34+
SteerSession(ctx context.Context, sessionID string, messages []api.Message) error
35+
36+
// FollowUpSession queues messages for end-of-turn processing
37+
FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error
38+
3339
// UpdateSessionTitle updates the title of a session
3440
UpdateSessionTitle(ctx context.Context, sessionID, title string) error
3541

pkg/runtime/remote_runtime.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,27 @@ func (r *RemoteRuntime) Run(ctx context.Context, sess *session.Session) ([]sessi
211211
return sess.GetAllMessages(), nil
212212
}
213213

214+
// Steer enqueues a user message for mid-turn injection into the running
215+
// agent loop on the remote server.
216+
func (r *RemoteRuntime) Steer(msg QueuedMessage) error {
217+
if r.sessionID == "" {
218+
return errors.New("no active session")
219+
}
220+
return r.client.SteerSession(context.Background(), r.sessionID, []api.Message{
221+
{Content: msg.Content, MultiContent: msg.MultiContent},
222+
})
223+
}
224+
225+
// FollowUp enqueues a message for end-of-turn processing on the remote server.
226+
func (r *RemoteRuntime) FollowUp(msg QueuedMessage) error {
227+
if r.sessionID == "" {
228+
return errors.New("no active session")
229+
}
230+
return r.client.FollowUpSession(context.Background(), r.sessionID, []api.Message{
231+
{Content: msg.Content, MultiContent: msg.MultiContent},
232+
})
233+
}
234+
214235
// Resume allows resuming execution after user confirmation
215236
func (r *RemoteRuntime) Resume(ctx context.Context, req ResumeRequest) {
216237
slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "type", req.Type, "reason", req.Reason, "tool_name", req.ToolName, "session_id", r.sessionID)

pkg/runtime/runtime.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ type Runtime interface {
139139
// if the runtime does not support local title generation (e.g. remote runtimes).
140140
TitleGenerator() *sessiontitle.Generator
141141

142+
// Steer enqueues a user message for urgent mid-turn injection into the
143+
// running agent loop. Returns an error if the queue is full or steering
144+
// is not available.
145+
Steer(msg QueuedMessage) error
146+
// FollowUp enqueues a message for end-of-turn processing. Each follow-up
147+
// gets a full undivided agent turn. Returns an error if the queue is full.
148+
FollowUp(msg QueuedMessage) error
149+
142150
// Close releases resources held by the runtime (e.g., session store connections).
143151
Close() error
144152
}
@@ -201,6 +209,14 @@ type LocalRuntime struct {
201209

202210
currentAgentMu sync.RWMutex
203211

212+
// steerQueue stores urgent mid-turn messages. The agent loop drains
213+
// ALL pending messages after tool execution, before the stop check.
214+
steerQueue MessageQueue
215+
216+
// followUpQueue stores end-of-turn messages. The agent loop pops
217+
// exactly ONE message after the model stops and stop-hooks have run.
218+
followUpQueue MessageQueue
219+
204220
// onToolsChanged is called when an MCP toolset reports a tool list change.
205221
onToolsChanged func(Event)
206222

@@ -228,6 +244,22 @@ func WithTracer(t trace.Tracer) Opt {
228244
}
229245
}
230246

247+
// WithSteerQueue sets a custom MessageQueue for mid-turn message injection.
248+
// If not provided, an in-memory buffered queue is used.
249+
func WithSteerQueue(q MessageQueue) Opt {
250+
return func(r *LocalRuntime) {
251+
r.steerQueue = q
252+
}
253+
}
254+
255+
// WithFollowUpQueue sets a custom MessageQueue for end-of-turn follow-up
256+
// messages. If not provided, an in-memory buffered queue is used.
257+
func WithFollowUpQueue(q MessageQueue) Opt {
258+
return func(r *LocalRuntime) {
259+
r.followUpQueue = q
260+
}
261+
}
262+
231263
func WithSessionCompaction(sessionCompaction bool) Opt {
232264
return func(r *LocalRuntime) {
233265
r.sessionCompaction = sessionCompaction
@@ -291,6 +323,8 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
291323
currentAgent: defaultAgent.Name(),
292324
resumeChan: make(chan ResumeRequest),
293325
elicitationRequestCh: make(chan ElicitationResult),
326+
steerQueue: NewInMemoryMessageQueue(defaultSteerQueueCapacity),
327+
followUpQueue: NewInMemoryMessageQueue(defaultFollowUpQueueCapacity),
294328
sessionCompaction: true,
295329
managedOAuth: true,
296330
sessionStore: session.NewInMemorySessionStore(),
@@ -1015,6 +1049,26 @@ func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.Elici
10151049
}
10161050
}
10171051

1052+
// Steer enqueues a user message for urgent mid-turn injection into the
1053+
// running agent loop. The message will be picked up after the current batch
1054+
// of tool calls finishes but before the loop checks whether to stop.
1055+
func (r *LocalRuntime) Steer(msg QueuedMessage) error {
1056+
if !r.steerQueue.Enqueue(context.Background(), msg) {
1057+
return errors.New("steer queue full")
1058+
}
1059+
return nil
1060+
}
1061+
1062+
// FollowUp enqueues a message to be processed after the current agent turn
1063+
// finishes. Unlike Steer, follow-ups are popped one at a time and each gets
1064+
// a full undivided agent turn.
1065+
func (r *LocalRuntime) FollowUp(msg QueuedMessage) error {
1066+
if !r.followUpQueue.Enqueue(context.Background(), msg) {
1067+
return errors.New("follow-up queue full")
1068+
}
1069+
return nil
1070+
}
1071+
10181072
// Run starts the agent's interaction loop
10191073

10201074
func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {

0 commit comments

Comments
 (0)