Skip to content

Commit 2bb7953

Browse files
committed
Add mid-turn message steering for running agent sessions
Addresses #2223. Allow API clients to inject user messages into an active agent session without waiting for the current turn to finish. This is a common pattern in agentic coding tools where the user can steer or provide follow-up context while the agent is executing tool calls. New API endpoint: POST /sessions/:id/steer Runtime changes: - SteeredMessage type + buffered channel on LocalRuntime - Steer() enqueues, DrainSteeredMessages() batch-drains - Agent loop injects steered messages after tool execution and before the stop-condition check; emits user_message events so clients know when the LLM actually picks them up - Messages wrapped in <system-reminder> tags for clear LLM attribution Server changes: - POST /sessions/:id/steer endpoint (202 Accepted) - SteerSession() on SessionManager with GetLocalRuntime() helper for PersistentRuntime unwrapping - Concurrent stream guard on RunSession (rejects if already streaming) - Proper defer ordering: streaming flag cleared before channel close No behavioral change to the TUI — the existing client-side message queue continues to work as before. The TUI can adopt mid-turn steering in a future change by calling LocalRuntime.Steer() directly.
1 parent 3fac361 commit 2bb7953

File tree

6 files changed

+173
-8
lines changed

6 files changed

+173
-8
lines changed

pkg/api/types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ type ResumeElicitationRequest struct {
160160
Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept")
161161
}
162162

163+
// SteerSessionRequest represents a request to inject user messages into a
164+
// running agent session. The messages are picked up by the agent loop between
165+
// tool execution and the next LLM call.
166+
type SteerSessionRequest struct {
167+
Messages []Message `json:"messages"`
168+
}
169+
163170
// UpdateSessionTitleRequest represents a request to update a session's title
164171
type UpdateSessionTitleRequest struct {
165172
Title string `json:"title"`

pkg/runtime/loop.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,42 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
386386
// Record per-toolset model override for the next LLM turn.
387387
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)
388388

389+
// Only compact proactively when the model will continue (has
390+
// tool calls to process on the next turn). If the model stopped
391+
// and no steered messages override that, compaction is wasteful
392+
// because no further LLM call follows.
393+
if !res.Stopped {
394+
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
395+
}
396+
397+
// Drain any steered (mid-turn) user messages that arrived while
398+
// the current iteration was in progress. Injecting them here —
399+
// after tool execution, before the stop check — ensures the LLM
400+
// sees the new messages on the next iteration via GetMessages().
401+
if steered := r.DrainSteeredMessages(); len(steered) > 0 {
402+
for _, sm := range steered {
403+
wrapped := fmt.Sprintf(
404+
"<system-reminder>\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n</system-reminder>",
405+
sm.Content,
406+
)
407+
userMsg := session.UserMessage(wrapped, sm.MultiContent...)
408+
sess.AddMessage(userMsg)
409+
events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1)
410+
}
411+
412+
// Force the loop to continue — the model must respond to
413+
// the injected messages even if it was about to stop.
414+
res.Stopped = false
415+
416+
// Now that the loop will continue, compact if needed.
417+
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
418+
}
419+
389420
if res.Stopped {
390421
slog.Debug("Conversation stopped", "agent", a.Name())
391422
r.executeStopHooks(ctx, sess, a, res.Content, events)
392423
break
393424
}
394-
395-
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
396425
}
397426
}()
398427

pkg/runtime/persistent_runtime.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ type streamingState struct {
2525
messageID int64 // ID of the current streaming message (0 if none)
2626
}
2727

28+
// GetLocalRuntime extracts the underlying *LocalRuntime from a Runtime
29+
// implementation. It handles both *LocalRuntime and *PersistentRuntime
30+
// (which embeds *LocalRuntime). Returns nil if the runtime type is not
31+
// supported (e.g. RemoteRuntime).
32+
func GetLocalRuntime(rt Runtime) *LocalRuntime {
33+
switch r := rt.(type) {
34+
case *LocalRuntime:
35+
return r
36+
case *PersistentRuntime:
37+
return r.LocalRuntime
38+
default:
39+
return nil
40+
}
41+
}
42+
2843
// New creates a new runtime for an agent and its team.
2944
// The runtime automatically persists session changes to the configured store.
3045
// Returns a Runtime interface which wraps LocalRuntime with persistence handling.

pkg/runtime/runtime.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ func ResumeReject(reason string) ResumeRequest {
8080
return ResumeRequest{Type: ResumeTypeReject, Reason: reason}
8181
}
8282

83+
// SteeredMessage is a user message injected mid-turn while the agent loop is
84+
// running. It is enqueued via Steer() and drained inside the loop between
85+
// tool execution and the stop-condition check.
86+
type SteeredMessage struct {
87+
Content string
88+
MultiContent []chat.MessagePart
89+
}
90+
91+
// maxSteeredMessages is the maximum number of steered messages that can be
92+
// buffered before Steer() starts rejecting new messages.
93+
const maxSteeredMessages = 5
94+
8395
// ToolHandlerFunc is a function type for handling tool calls
8496
type ToolHandlerFunc func(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error)
8597

@@ -201,6 +213,12 @@ type LocalRuntime struct {
201213

202214
currentAgentMu sync.RWMutex
203215

216+
// steerCh receives user messages injected mid-turn via Steer().
217+
// The agent loop drains this channel after tool execution, before
218+
// checking the stop condition, so the LLM sees the new message on
219+
// its next iteration.
220+
steerCh chan SteeredMessage
221+
204222
// onToolsChanged is called when an MCP toolset reports a tool list change.
205223
onToolsChanged func(Event)
206224

@@ -291,6 +309,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
291309
currentAgent: defaultAgent.Name(),
292310
resumeChan: make(chan ResumeRequest),
293311
elicitationRequestCh: make(chan ElicitationResult),
312+
steerCh: make(chan SteeredMessage, maxSteeredMessages),
294313
sessionCompaction: true,
295314
managedOAuth: true,
296315
sessionStore: session.NewInMemorySessionStore(),
@@ -1015,6 +1034,34 @@ func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.Elici
10151034
}
10161035
}
10171036

1037+
// Steer enqueues a user message for mid-turn injection into the running
1038+
// agent loop. The message will be picked up after the current batch of tool
1039+
// calls finishes but before the loop checks whether to stop. Returns false
1040+
// if the steer buffer is full and the message was not enqueued.
1041+
func (r *LocalRuntime) Steer(msg SteeredMessage) bool {
1042+
select {
1043+
case r.steerCh <- msg:
1044+
return true
1045+
default:
1046+
return false
1047+
}
1048+
}
1049+
1050+
// DrainSteeredMessages returns all pending steered messages without blocking.
1051+
// It is called inside the agent loop to batch-inject any messages that arrived
1052+
// while the current iteration was in progress.
1053+
func (r *LocalRuntime) DrainSteeredMessages() []SteeredMessage {
1054+
var msgs []SteeredMessage
1055+
for {
1056+
select {
1057+
case m := <-r.steerCh:
1058+
msgs = append(msgs, m)
1059+
default:
1060+
return msgs
1061+
}
1062+
}
1063+
}
1064+
10181065
// Run starts the agent's interaction loop
10191066

10201067
func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {

pkg/server/server.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
6262
group.POST("/sessions/:id/agent/:agent", s.runAgent)
6363
group.POST("/sessions/:id/agent/:agent/:agent_name", s.runAgent)
6464
group.POST("/sessions/:id/elicitation", s.elicitation)
65+
// Steer: inject user messages into a running agent session mid-turn
66+
group.POST("/sessions/:id/steer", s.steerSession)
6567

6668
// Agent tool count
6769
group.GET("/agents/:id/:agent_name/tools/count", s.getAgentToolCount)
@@ -317,3 +319,21 @@ func (s *Server) elicitation(c echo.Context) error {
317319

318320
return c.JSON(http.StatusOK, nil)
319321
}
322+
323+
func (s *Server) steerSession(c echo.Context) error {
324+
sessionID := c.Param("id")
325+
var req api.SteerSessionRequest
326+
if err := c.Bind(&req); err != nil {
327+
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
328+
}
329+
330+
if len(req.Messages) == 0 {
331+
return echo.NewHTTPError(http.StatusBadRequest, "at least one message is required")
332+
}
333+
334+
if err := s.sm.SteerSession(c.Request().Context(), sessionID, req.Messages); err != nil {
335+
return echo.NewHTTPError(http.StatusConflict, fmt.Sprintf("failed to steer session: %v", err))
336+
}
337+
338+
return c.JSON(http.StatusAccepted, map[string]string{"status": "queued"})
339+
}

pkg/server/session_manager.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ import (
2323
)
2424

2525
type activeRuntimes struct {
26-
runtime runtime.Runtime
27-
cancel context.CancelFunc
28-
session *session.Session // The actual session object used by the runtime
29-
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
26+
runtime runtime.Runtime
27+
cancel context.CancelFunc
28+
session *session.Session // The actual session object used by the runtime
29+
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
30+
streaming bool // True while RunStream is active; prevents concurrent runs
3031
}
3132

3233
// SessionManager manages sessions for HTTP and Connect-RPC servers.
@@ -160,6 +161,14 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
160161
}
161162

162163
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
164+
165+
// Reject if a stream is already active for this session. The caller
166+
// should use POST /sessions/:id/steer to inject follow-up messages
167+
// into a running session instead of starting a second concurrent stream.
168+
if exists && runtimeSession.streaming {
169+
return nil, errors.New("session is already streaming; use /steer to send follow-up messages")
170+
}
171+
163172
streamCtx, cancel := context.WithCancel(ctx)
164173
var titleGen *sessiontitle.Generator
165174
if !exists {
@@ -182,6 +191,8 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
182191
titleGen = runtimeSession.titleGen
183192
}
184193

194+
runtimeSession.streaming = true
195+
185196
streamChan := make(chan runtime.Event)
186197

187198
// Check if we need to generate a title
@@ -194,8 +205,17 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
194205
}
195206

196207
stream := runtimeSession.runtime.RunStream(streamCtx, sess)
197-
defer cancel()
198-
defer close(streamChan)
208+
// Single defer to control ordering: clear the streaming flag
209+
// BEFORE closing streamChan. When the client sees the channel
210+
// close it may immediately call RunSession for the next queued
211+
// message; streaming must already be false by then.
212+
defer func() {
213+
sm.mux.Lock()
214+
runtimeSession.streaming = false
215+
sm.mux.Unlock()
216+
close(streamChan)
217+
cancel()
218+
}()
199219
for event := range stream {
200220
if streamCtx.Err() != nil {
201221
return
@@ -230,6 +250,33 @@ func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirma
230250
return nil
231251
}
232252

253+
// SteerSession enqueues user messages for mid-turn injection into a running
254+
// session. The messages are picked up by the agent loop after the current tool
255+
// calls finish but before the next LLM call. Returns an error if the session
256+
// is not actively running or if the steer buffer is full.
257+
func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, messages []api.Message) error {
258+
rt, exists := sm.runtimeSessions.Load(sessionID)
259+
if !exists {
260+
return errors.New("session not found or not running")
261+
}
262+
263+
localRT := runtime.GetLocalRuntime(rt.runtime)
264+
if localRT == nil {
265+
return errors.New("steering not supported for this runtime type")
266+
}
267+
268+
for _, msg := range messages {
269+
if !localRT.Steer(runtime.SteeredMessage{
270+
Content: msg.Content,
271+
MultiContent: msg.MultiContent,
272+
}) {
273+
return errors.New("steer queue full")
274+
}
275+
}
276+
277+
return nil
278+
}
279+
233280
// ResumeElicitation resumes an elicitation request.
234281
func (sm *SessionManager) ResumeElicitation(ctx context.Context, sessionID, action string, content map[string]any) error {
235282
sm.mux.Lock()

0 commit comments

Comments
 (0)