forked from docker/docker-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathremote_runtime.go
More file actions
460 lines (380 loc) · 14.8 KB
/
remote_runtime.go
File metadata and controls
460 lines (380 loc) · 14.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
package runtime
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
"golang.org/x/oauth2"
"github.com/docker/cagent/pkg/api"
"github.com/docker/cagent/pkg/chat"
"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/sessiontitle"
"github.com/docker/cagent/pkg/team"
"github.com/docker/cagent/pkg/tools"
"github.com/docker/cagent/pkg/tools/mcp"
)
// RemoteRuntime implements the Runtime interface using a remote client.
// It works with any client that implements the RemoteClient interface,
// including both HTTP (Client) and Connect-RPC (ConnectRPCClient) clients.
type RemoteRuntime struct {
client RemoteClient
currentAgent string
agentFilename string
sessionID string
team *team.Team
pendingOAuthElicitation *ElicitationRequestEvent
}
// RemoteRuntimeOption is a function for configuring the RemoteRuntime
type RemoteRuntimeOption func(*RemoteRuntime)
// WithRemoteCurrentAgent sets the current agent name
func WithRemoteCurrentAgent(agentName string) RemoteRuntimeOption {
return func(r *RemoteRuntime) {
r.currentAgent = agentName
}
}
// WithRemoteAgentFilename sets the agent filename to use with the remote API
func WithRemoteAgentFilename(filename string) RemoteRuntimeOption {
return func(r *RemoteRuntime) {
r.agentFilename = filename
}
}
// NewRemoteRuntime creates a new remote runtime that implements the Runtime interface.
// It accepts any client that implements the RemoteClient interface.
func NewRemoteRuntime(client RemoteClient, opts ...RemoteRuntimeOption) (*RemoteRuntime, error) {
if client == nil {
return nil, fmt.Errorf("client cannot be nil")
}
r := &RemoteRuntime{
client: client,
currentAgent: "root",
agentFilename: "agent.yaml",
team: team.New(),
}
for _, opt := range opts {
opt(r)
}
return r, nil
}
// CurrentAgentName returns the name of the currently active agent
func (r *RemoteRuntime) CurrentAgentName() string {
return r.currentAgent
}
func (r *RemoteRuntime) CurrentAgentInfo(ctx context.Context) CurrentAgentInfo {
cfg := r.readCurrentAgentConfig(ctx)
return CurrentAgentInfo{
Name: r.currentAgent,
Description: cfg.Description,
Commands: cfg.Commands,
}
}
// SetCurrentAgent sets the currently active agent for subsequent user messages
func (r *RemoteRuntime) SetCurrentAgent(agentName string) error {
r.currentAgent = agentName
slog.Debug("Switched current agent (remote)", "agent", agentName)
return nil
}
// CurrentAgentTools returns the tools for the current agent.
// For remote runtime, this returns nil as tools are managed server-side.
func (r *RemoteRuntime) CurrentAgentTools(_ context.Context) ([]tools.Tool, error) {
return nil, nil
}
// EmitStartupInfo emits initial agent, team, and toolset information
func (r *RemoteRuntime) EmitStartupInfo(ctx context.Context, events chan Event) {
cfg := r.readCurrentAgentConfig(ctx)
events <- AgentInfo(r.currentAgent, cfg.Model, cfg.Description, cfg.WelcomeMessage)
events <- TeamInfo(r.agentDetailsFromConfig(ctx), r.currentAgent)
events <- ToolsetInfo(len(cfg.Toolsets), false, r.currentAgent)
}
func (r *RemoteRuntime) agentDetailsFromConfig(ctx context.Context) []AgentDetails {
cfg, err := r.client.GetAgent(ctx, r.agentFilename)
if err != nil {
return nil
}
var details []AgentDetails
for _, agent := range cfg.Agents {
info := AgentDetails{
Name: agent.Name,
Description: agent.Description,
Commands: agent.Commands,
}
if provider, model, found := strings.Cut(agent.Model, "/"); found {
info.Provider = provider
info.Model = model
} else {
info.Model = agent.Model
}
details = append(details, info)
}
return details
}
func (r *RemoteRuntime) readCurrentAgentConfig(ctx context.Context) latest.AgentConfig {
cfg, err := r.client.GetAgent(ctx, r.agentFilename)
if err != nil {
return latest.AgentConfig{}
}
for _, agent := range cfg.Agents {
if agent.Name == r.currentAgent {
return agent
}
}
return latest.AgentConfig{}
}
// RunStream starts the agent's interaction loop and returns a channel of events
func (r *RemoteRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event {
slog.Debug("Starting remote runtime stream", "agent", r.currentAgent, "session_id", r.sessionID)
events := make(chan Event, 128)
go func() {
defer close(events)
messages := r.convertSessionMessages(sess)
r.sessionID = sess.ID
var streamChan <-chan Event
var err error
if r.currentAgent != "" && r.currentAgent != "root" {
streamChan, err = r.client.RunAgentWithAgentName(ctx, r.sessionID, r.agentFilename, r.currentAgent, messages)
} else {
streamChan, err = r.client.RunAgent(ctx, r.sessionID, r.agentFilename, messages)
}
if err != nil {
events <- Error(fmt.Sprintf("failed to start remote agent: %v", err))
return
}
for streamEvent := range streamChan {
if elicitationRequest, ok := streamEvent.(*ElicitationRequestEvent); ok {
r.pendingOAuthElicitation = elicitationRequest
}
events <- streamEvent
}
}()
return events
}
// Run starts the agent's interaction loop and returns the final messages
func (r *RemoteRuntime) Run(ctx context.Context, sess *session.Session) ([]session.Message, error) {
eventsChan := r.RunStream(ctx, sess)
for event := range eventsChan {
if errEvent, ok := event.(*ErrorEvent); ok {
return nil, fmt.Errorf("%s", errEvent.Error)
}
}
return sess.GetAllMessages(), nil
}
// Resume allows resuming execution after user confirmation
func (r *RemoteRuntime) Resume(ctx context.Context, req ResumeRequest) {
slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "type", req.Type, "reason", req.Reason, "tool_name", req.ToolName, "session_id", r.sessionID)
if r.sessionID == "" {
slog.Error("Cannot resume: no session ID available")
return
}
if err := r.client.ResumeSession(ctx, r.sessionID, string(req.Type), req.Reason, req.ToolName); err != nil {
slog.Error("Failed to resume remote session", "error", err, "session_id", r.sessionID)
}
}
// Summarize generates a summary for the session
func (r *RemoteRuntime) Summarize(_ context.Context, sess *session.Session, _ string, events chan Event) {
slog.Debug("Summarize not yet implemented for remote runtime", "session_id", r.sessionID)
events <- SessionSummary(sess.ID, "Summary generation not yet implemented for remote runtime", r.currentAgent)
}
func (r *RemoteRuntime) convertSessionMessages(sess *session.Session) []api.Message {
sessionMessages := sess.GetAllMessages()
messages := make([]api.Message, 0, len(sessionMessages))
for i := range sessionMessages {
if sessionMessages[i].Message.Role == chat.MessageRoleUser || sessionMessages[i].Message.Role == chat.MessageRoleAssistant {
messages = append(messages, api.Message{
Role: sessionMessages[i].Message.Role,
Content: sessionMessages[i].Message.Content,
})
}
}
return messages
}
// ResumeElicitation sends an elicitation response back to a waiting elicitation request
func (r *RemoteRuntime) ResumeElicitation(ctx context.Context, action tools.ElicitationAction, content map[string]any) error {
slog.Debug("Resuming remote runtime with elicitation response", "agent", r.currentAgent, "action", action, "session_id", r.sessionID)
err := r.handleOAuthElicitation(ctx, r.pendingOAuthElicitation)
if err != nil {
return err
}
if err := r.client.ResumeElicitation(ctx, r.sessionID, action, content); err != nil {
return err
}
return nil
}
func (r *RemoteRuntime) handleOAuthElicitation(ctx context.Context, req *ElicitationRequestEvent) error {
if req == nil {
return nil
}
slog.Debug("Handling OAuth elicitation request", "server_url", req.Meta["cagent/server_url"])
serverURL, ok := req.Meta["cagent/server_url"].(string)
if !ok {
err := fmt.Errorf("server_url missing from elicitation metadata")
slog.Error("Failed to extract server_url", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return err
}
authServerMetadata, ok := req.Meta["auth_server_metadata"].(map[string]any)
if !ok {
err := fmt.Errorf("auth_server_metadata missing from elicitation metadata")
slog.Error("Failed to extract auth_server_metadata", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return err
}
var authMetadata mcp.AuthorizationServerMetadata
metadataBytes, err := json.Marshal(authServerMetadata)
if err != nil {
slog.Error("Failed to marshal auth_server_metadata", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to marshal auth_server_metadata: %w", err)
}
if err := json.Unmarshal(metadataBytes, &authMetadata); err != nil {
slog.Error("Failed to unmarshal auth_server_metadata", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to unmarshal auth_server_metadata: %w", err)
}
slog.Debug("Authorization server metadata extracted", "issuer", authMetadata.Issuer)
oauthCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
slog.Debug("Creating OAuth callback server")
callbackServer, err := mcp.NewCallbackServer()
if err != nil {
slog.Error("Failed to create callback server", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to create callback server: %w", err)
}
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := callbackServer.Shutdown(shutdownCtx); err != nil {
slog.Error("Failed to shutdown callback server", "error", err)
}
}()
if err := callbackServer.Start(); err != nil {
slog.Error("Failed to start callback server", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to start callback server: %w", err)
}
redirectURI := callbackServer.GetRedirectURI()
slog.Debug("Callback server started", "redirect_uri", redirectURI)
var clientID, clientSecret string
if authMetadata.RegistrationEndpoint != "" {
slog.Debug("Attempting dynamic client registration")
clientID, clientSecret, err = mcp.RegisterClient(oauthCtx, &authMetadata, redirectURI, nil)
if err != nil {
slog.Error("Dynamic client registration failed", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to register client: %w", err)
}
slog.Debug("Client registered successfully", "client_id", clientID)
} else {
err := fmt.Errorf("authorization server does not support dynamic client registration")
slog.Error("Client registration not supported", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return err
}
state, err := mcp.GenerateState()
if err != nil {
slog.Error("Failed to generate state", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to generate state: %w", err)
}
callbackServer.SetExpectedState(state)
verifier := mcp.GeneratePKCEVerifier()
authURL := mcp.BuildAuthorizationURL(
authMetadata.AuthorizationEndpoint,
clientID,
redirectURI,
state,
oauth2.S256ChallengeFromVerifier(verifier),
serverURL,
)
slog.Debug("Authorization URL built", "url", authURL)
slog.Debug("Requesting authorization code")
code, receivedState, err := mcp.RequestAuthorizationCode(oauthCtx, authURL, callbackServer, state)
if err != nil {
slog.Error("Failed to get authorization code", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to get authorization code: %w", err)
}
if receivedState != state {
err := fmt.Errorf("state mismatch: expected %s, got %s", state, receivedState)
slog.Error("State mismatch in authorization response", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return err
}
slog.Debug("Authorization code received, exchanging for token")
token, err := mcp.ExchangeCodeForToken(
oauthCtx,
authMetadata.TokenEndpoint,
code,
verifier,
clientID,
clientSecret,
redirectURI,
)
if err != nil {
slog.Error("Failed to exchange code for token", "error", err)
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
return fmt.Errorf("failed to exchange code for token: %w", err)
}
slog.Debug("Token obtained successfully", "token_type", token.TokenType)
tokenData := map[string]any{
"access_token": token.AccessToken,
"token_type": token.TokenType,
}
if token.ExpiresIn > 0 {
tokenData["expires_in"] = token.ExpiresIn
}
if token.RefreshToken != "" {
tokenData["refresh_token"] = token.RefreshToken
}
slog.Debug("Sending token to server")
if err := r.client.ResumeElicitation(ctx, r.sessionID, tools.ElicitationActionAccept, tokenData); err != nil {
slog.Error("Failed to send token to server", "error", err)
return fmt.Errorf("failed to send token to server: %w", err)
}
slog.Debug("OAuth flow completed successfully")
return nil
}
// SessionStore returns nil for remote runtime since session storage is handled server-side.
func (r *RemoteRuntime) SessionStore() session.Store {
return nil
}
// PermissionsInfo returns nil for remote runtime since permissions are handled server-side.
func (r *RemoteRuntime) PermissionsInfo() *PermissionsInfo {
return nil
}
// ResetStartupInfo is a no-op for remote runtime.
func (r *RemoteRuntime) ResetStartupInfo() {
}
// CurrentAgentSkillsEnabled returns whether skills are enabled for the current agent.
// It reads the agent config from the remote API to determine the skills setting.
func (r *RemoteRuntime) CurrentAgentSkillsEnabled() bool {
cfg := r.readCurrentAgentConfig(context.Background())
return cfg.Skills != nil && *cfg.Skills
}
// UpdateSessionTitle updates the title of the current session on the remote server.
func (r *RemoteRuntime) UpdateSessionTitle(ctx context.Context, sess *session.Session, title string) error {
sess.Title = title
if r.sessionID == "" {
return fmt.Errorf("cannot update session title: no session ID available")
}
return r.client.UpdateSessionTitle(ctx, r.sessionID, title)
}
// CurrentMCPPrompts is not supported on remote runtimes.
func (r *RemoteRuntime) CurrentMCPPrompts(context.Context) map[string]mcp.PromptInfo {
return make(map[string]mcp.PromptInfo)
}
// ExecuteMCPPrompt is not supported on remote runtimes.
func (r *RemoteRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) {
return "", fmt.Errorf("MCP prompts are not supported by remote runtimes")
}
// TitleGenerator is not supported on remote runtimes (titles are generated server-side).
func (r *RemoteRuntime) TitleGenerator() *sessiontitle.Generator {
return nil
}
// Close is a no-op for remote runtimes.
func (r *RemoteRuntime) Close() error {
return nil
}
var _ Runtime = (*RemoteRuntime)(nil)