forked from docker/docker-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsession_manager.go
More file actions
429 lines (358 loc) · 12.5 KB
/
session_manager.go
File metadata and controls
429 lines (358 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/docker/docker-agent/pkg/api"
"github.com/docker/docker-agent/pkg/concurrent"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/runtime"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/sessiontitle"
"github.com/docker/docker-agent/pkg/team"
"github.com/docker/docker-agent/pkg/teamloader"
"github.com/docker/docker-agent/pkg/tools"
)
type activeRuntimes struct {
runtime runtime.Runtime
cancel context.CancelFunc
session *session.Session // The actual session object used by the runtime
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests
}
// SessionManager manages sessions for HTTP and Connect-RPC servers.
type SessionManager struct {
runtimeSessions *concurrent.Map[string, *activeRuntimes]
sessionStore session.Store
Sources config.Sources
// TODO: We have to do something about this, it's weird, session creation should send everything that is needed.
// This is only used for the working directory...
runConfig *config.RuntimeConfig
refreshInterval time.Duration
mux sync.Mutex
}
// NewSessionManager creates a new session manager.
func NewSessionManager(ctx context.Context, sources config.Sources, sessionStore session.Store, refreshInterval time.Duration, runConfig *config.RuntimeConfig) *SessionManager {
loaders := make(config.Sources)
for name, source := range sources {
loaders[name] = newSourceLoader(ctx, source, refreshInterval)
}
sm := &SessionManager{
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
sessionStore: sessionStore,
Sources: loaders,
refreshInterval: refreshInterval,
runConfig: runConfig,
}
return sm
}
// GetSession retrieves a session by ID.
func (sm *SessionManager) GetSession(ctx context.Context, id string) (*session.Session, error) {
sess, err := sm.sessionStore.GetSession(ctx, id)
if err != nil {
return nil, err
}
return sess, nil
}
// CreateSession creates a new session from a template.
func (sm *SessionManager) CreateSession(ctx context.Context, sessionTemplate *session.Session) (*session.Session, error) {
var opts []session.Opt
opts = append(opts,
session.WithMaxIterations(sessionTemplate.MaxIterations),
session.WithMaxConsecutiveToolCalls(sessionTemplate.MaxConsecutiveToolCalls),
session.WithMaxOldToolCallTokens(sessionTemplate.MaxOldToolCallTokens),
session.WithToolsApproved(sessionTemplate.ToolsApproved),
)
if wd := strings.TrimSpace(sessionTemplate.WorkingDir); wd != "" {
absWd, err := filepath.Abs(wd)
if err != nil {
return nil, err
}
info, err := os.Stat(absWd)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, errors.New("working directory must be a directory")
}
opts = append(opts, session.WithWorkingDir(absWd))
}
if sessionTemplate.Permissions != nil {
opts = append(opts, session.WithPermissions(sessionTemplate.Permissions))
}
sess := session.New(opts...)
return sess, sm.sessionStore.AddSession(ctx, sess)
}
// GetSessions retrieves all sessions.
func (sm *SessionManager) GetSessions(ctx context.Context) ([]*session.Session, error) {
sessions, err := sm.sessionStore.GetSessions(ctx)
if err != nil {
return nil, err
}
return sessions, nil
}
// DeleteSession deletes a session by ID.
func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) error {
sm.mux.Lock()
defer sm.mux.Unlock()
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
return err
}
if err := sm.sessionStore.DeleteSession(ctx, sessionID); err != nil {
return err
}
if sessionRuntime, ok := sm.runtimeSessions.Load(sess.ID); ok {
sessionRuntime.cancel()
sm.runtimeSessions.Delete(sess.ID)
}
return nil
}
// ErrSessionBusy is returned when a session is already processing a request.
var ErrSessionBusy = errors.New("session is already processing a request")
// RunSession runs a session with the given messages.
func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) {
sm.mux.Lock()
defer sm.mux.Unlock()
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
return nil, err
}
rc := sm.runConfig.Clone()
rc.WorkingDir = sess.WorkingDir
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
streamCtx, cancel := context.WithCancel(ctx)
var titleGen *sessiontitle.Generator
if !exists {
var rt runtime.Runtime
rt, titleGen, err = sm.runtimeForSession(ctx, sess, agentFilename, currentAgent, rc)
if err != nil {
cancel()
return nil, err
}
runtimeSession = &activeRuntimes{
runtime: rt,
cancel: cancel,
session: sess,
titleGen: titleGen,
}
sm.runtimeSessions.Store(sessionID, runtimeSession)
} else {
titleGen = runtimeSession.titleGen
}
// Reject the request immediately if the session is already streaming.
// This prevents interleaving user messages while a tool call is in
// progress, which would produce a tool_use without a matching
// tool_result and cause provider errors.
if !runtimeSession.streaming.TryLock() {
cancel()
return nil, ErrSessionBusy
}
// Now that we hold the streaming lock, it is safe to mutate the session.
// Collect user messages for potential title generation
var userMessages []string
for _, msg := range messages {
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
if msg.Content != "" {
userMessages = append(userMessages, msg.Content)
}
}
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
runtimeSession.streaming.Unlock()
cancel()
return nil, err
}
// Update the session pointer so the runtime sees the latest messages.
runtimeSession.session = sess
streamChan := make(chan runtime.Event)
// Check if we need to generate a title
needsTitle := sess.Title == "" && len(userMessages) > 0 && titleGen != nil
go func() {
defer runtimeSession.streaming.Unlock()
// Start title generation in parallel if needed
if needsTitle {
go sm.generateTitle(ctx, sess, titleGen, userMessages, streamChan)
}
stream := runtimeSession.runtime.RunStream(streamCtx, sess)
defer cancel()
defer close(streamChan)
for event := range stream {
if streamCtx.Err() != nil {
return
}
streamChan <- event
}
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
return
}
}()
return streamChan, nil
}
// ResumeSession resumes a paused session with an optional rejection reason or tool name.
func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirmation, reason, toolName string) error {
sm.mux.Lock()
defer sm.mux.Unlock()
// Ensure the session runtime exists
rt, exists := sm.runtimeSessions.Load(sessionID)
if !exists {
return errors.New("session not found")
}
rt.runtime.Resume(ctx, runtime.ResumeRequest{
Type: runtime.ResumeType(confirmation),
Reason: reason,
ToolName: toolName,
})
return nil
}
// ResumeElicitation resumes an elicitation request.
func (sm *SessionManager) ResumeElicitation(ctx context.Context, sessionID, action string, content map[string]any) error {
sm.mux.Lock()
defer sm.mux.Unlock()
rt, exists := sm.runtimeSessions.Load(sessionID)
if !exists {
return errors.New("session not found")
}
return rt.runtime.ResumeElicitation(ctx, tools.ElicitationAction(action), content)
}
// ToggleToolApproval toggles the tool approval mode for a session.
func (sm *SessionManager) ToggleToolApproval(ctx context.Context, sessionID string) error {
sm.mux.Lock()
defer sm.mux.Unlock()
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
return err
}
sess.ToolsApproved = !sess.ToolsApproved
return sm.sessionStore.UpdateSession(ctx, sess)
}
// UpdateSessionPermissions updates the permissions for a session.
func (sm *SessionManager) UpdateSessionPermissions(ctx context.Context, sessionID string, perms *session.PermissionsConfig) error {
sm.mux.Lock()
defer sm.mux.Unlock()
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
return err
}
sess.Permissions = perms
return sm.sessionStore.UpdateSession(ctx, sess)
}
// UpdateSessionTitle updates the title for a session.
// If the session is actively running, it also updates the in-memory session
// object to prevent subsequent runtime saves from overwriting the title.
func (sm *SessionManager) UpdateSessionTitle(ctx context.Context, sessionID, title string) error {
sm.mux.Lock()
defer sm.mux.Unlock()
// If session is actively running, update the in-memory session object directly.
// This ensures the runtime's saveSession won't overwrite our manual edit.
if rt, ok := sm.runtimeSessions.Load(sessionID); ok && rt.session != nil {
rt.session.Title = title
slog.Debug("Updated title for active session", "session_id", sessionID, "title", title)
return sm.sessionStore.UpdateSession(ctx, rt.session)
}
// Session is not actively running, load from store and update
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
return err
}
sess.Title = title
return sm.sessionStore.UpdateSession(ctx, sess)
}
// generateTitle generates a title for a session using the sessiontitle package.
// The generated title is stored in the session and persisted to the store.
// A SessionTitleEvent is emitted to notify clients.
func (sm *SessionManager) generateTitle(ctx context.Context, sess *session.Session, gen *sessiontitle.Generator, userMessages []string, events chan<- runtime.Event) {
if gen == nil || len(userMessages) == 0 {
return
}
title, err := gen.Generate(ctx, sess.ID, userMessages)
if err != nil {
slog.Error("Failed to generate session title", "session_id", sess.ID, "error", err)
return
}
if title == "" {
return
}
// Update the in-memory session
sess.Title = title
// Persist the title
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
slog.Error("Failed to persist generated title", "session_id", sess.ID, "error", err)
return
}
// Emit the title event
select {
case events <- runtime.SessionTitle(sess.ID, title):
slog.Debug("Generated and emitted session title", "session_id", sess.ID, "title", title)
case <-ctx.Done():
slog.Debug("Context cancelled while emitting title event", "session_id", sess.ID)
}
}
func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.Session, agentFilename, currentAgent string, rc *config.RuntimeConfig) (runtime.Runtime, *sessiontitle.Generator, error) {
rt, exists := sm.runtimeSessions.Load(sess.ID)
if exists && rt.runtime != nil {
return rt.runtime, rt.titleGen, nil
}
t, err := sm.loadTeam(ctx, agentFilename, rc)
if err != nil {
return nil, nil, err
}
agent, err := t.Agent(currentAgent)
if err != nil {
return nil, nil, err
}
sess.MaxIterations = agent.MaxIterations()
sess.MaxConsecutiveToolCalls = agent.MaxConsecutiveToolCalls()
sess.MaxOldToolCallTokens = agent.MaxOldToolCallTokens()
opts := []runtime.Opt{
runtime.WithCurrentAgent(currentAgent),
runtime.WithManagedOAuth(false),
runtime.WithSessionStore(sm.sessionStore),
}
run, err := runtime.New(t, opts...)
if err != nil {
return nil, nil, err
}
titleGen := sessiontitle.New(agent.Model(), agent.FallbackModels()...)
sm.runtimeSessions.Store(sess.ID, &activeRuntimes{
runtime: run,
session: sess,
titleGen: titleGen,
})
slog.Debug("Runtime created for session", "session_id", sess.ID)
return run, titleGen, nil
}
func (sm *SessionManager) loadTeam(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*team.Team, error) {
agentSource, found := sm.Sources[agentFilename]
if !found {
return nil, fmt.Errorf("agent not found: %s", agentFilename)
}
return teamloader.Load(ctx, agentSource, runConfig)
}
// GetAgentToolCount loads the agent's team and returns the number of
// tools available to the given agent.
func (sm *SessionManager) GetAgentToolCount(ctx context.Context, agentFilename, agentName string) (int, error) {
t, err := sm.loadTeam(ctx, agentFilename, sm.runConfig)
if err != nil {
return 0, err
}
defer func() {
if stopErr := t.StopToolSets(ctx); stopErr != nil {
slog.Error("Failed to stop tool sets", "error", stopErr)
}
}()
a, err := t.Agent(agentName)
if err != nil {
return 0, err
}
agentTools, err := a.Tools(ctx)
if err != nil {
return 0, fmt.Errorf("failed to get tools: %w", err)
}
return len(agentTools), nil
}