Skip to content

Commit d44040c

Browse files
authored
feat: agent model selection (#1651)
1 parent 86d0e9d commit d44040c

33 files changed

+7399
-1278
lines changed

api/v1/api.gen.go

Lines changed: 1476 additions & 482 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/v1/api.yaml

Lines changed: 451 additions & 20 deletions
Large diffs are not rendered by default.

internal/agent/api.go

Lines changed: 180 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"cmp"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"log/slog"
910
"net/http"
@@ -14,6 +15,7 @@ import (
1415
api "github.com/dagu-org/dagu/api/v1"
1516
"github.com/dagu-org/dagu/internal/auth"
1617
"github.com/dagu-org/dagu/internal/core/exec"
18+
"github.com/dagu-org/dagu/internal/llm"
1719
"github.com/go-chi/chi/v5"
1820
"github.com/google/uuid"
1921
)
@@ -30,6 +32,9 @@ func respondErrorDirect(w http.ResponseWriter, status int, code api.ErrorCode, m
3032
}
3133
}
3234

35+
// maxRequestBodySize is the maximum allowed size for JSON request bodies (1 MB).
36+
const maxRequestBodySize = 1 << 20
37+
3338
// defaultUserID is used when no user is authenticated (e.g., auth disabled).
3439
// This value should match the system's expected default user identifier.
3540
const defaultUserID = "admin"
@@ -59,6 +64,8 @@ type API struct {
5964
conversations sync.Map // id -> *ConversationManager (active conversations)
6065
store ConversationStore
6166
configStore ConfigStore
67+
modelStore ModelStore
68+
providers *ProviderCache
6269
workingDir string
6370
logger *slog.Logger
6471
dagStore exec.DAGStore // For resolving DAG file paths
@@ -69,6 +76,7 @@ type API struct {
6976
// APIConfig contains configuration for the API.
7077
type APIConfig struct {
7178
ConfigStore ConfigStore
79+
ModelStore ModelStore
7280
WorkingDir string
7381
Logger *slog.Logger
7482
ConversationStore ConversationStore
@@ -82,6 +90,7 @@ type ConversationWithState struct {
8290
Conversation Conversation `json:"conversation"`
8391
Working bool `json:"working"`
8492
Model string `json:"model,omitempty"`
93+
TotalCost float64 `json:"total_cost"`
8594
}
8695

8796
// NewAPI creates a new API instance.
@@ -93,6 +102,8 @@ func NewAPI(cfg APIConfig) *API {
93102

94103
return &API{
95104
configStore: cfg.ConfigStore,
105+
modelStore: cfg.ModelStore,
106+
providers: NewProviderCache(),
96107
workingDir: cfg.WorkingDir,
97108
logger: logger,
98109
store: cfg.ConversationStore,
@@ -214,6 +225,46 @@ func selectModel(requestModel, conversationModel, configModel string) string {
214225
return cmp.Or(requestModel, conversationModel, configModel)
215226
}
216227

228+
// getDefaultModelID returns the default model ID from config.
229+
func (a *API) getDefaultModelID(ctx context.Context) string {
230+
cfg, err := a.configStore.Load(ctx)
231+
if err != nil {
232+
a.logger.Warn("Failed to load agent config for default model", "error", err)
233+
return ""
234+
}
235+
return cfg.DefaultModelID
236+
}
237+
238+
// resolveProvider resolves a model ID to an LLM provider and model config.
239+
// If modelID is empty, uses the default from config.
240+
// If the requested model is not found (e.g., deleted), falls back to the default.
241+
func (a *API) resolveProvider(ctx context.Context, modelID string) (llm.Provider, *ModelConfig, error) {
242+
if a.modelStore == nil {
243+
return nil, nil, errors.New("model store not configured")
244+
}
245+
246+
defaultID := a.getDefaultModelID(ctx)
247+
modelID = cmp.Or(modelID, defaultID)
248+
if modelID == "" {
249+
return nil, nil, errors.New("no model configured")
250+
}
251+
252+
model, err := a.modelStore.GetByID(ctx, modelID)
253+
if errors.Is(err, ErrModelNotFound) && defaultID != "" && defaultID != modelID {
254+
// Requested model was deleted; fall back to default
255+
model, err = a.modelStore.GetByID(ctx, defaultID)
256+
}
257+
if err != nil {
258+
return nil, nil, err
259+
}
260+
261+
provider, _, err := a.providers.GetOrCreate(model.ToLLMConfig())
262+
if err != nil {
263+
return nil, nil, err
264+
}
265+
return provider, model, nil
266+
}
267+
217268
// createMessageCallback returns a persistence callback for the given conversation ID.
218269
// Returns nil if no store is configured.
219270
func (a *API) createMessageCallback(id string) func(ctx context.Context, msg Message) error {
@@ -265,6 +316,7 @@ func (a *API) respondError(w http.ResponseWriter, status int, code api.ErrorCode
265316
// POST /api/v1/agent/conversations/new
266317
func (a *API) handleNewConversation(w http.ResponseWriter, r *http.Request) {
267318
var req ChatRequest
319+
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)
268320
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
269321
a.respondError(w, http.StatusBadRequest, api.ErrorCodeBadRequest, "Invalid JSON")
270322
return
@@ -275,36 +327,39 @@ func (a *API) handleNewConversation(w http.ResponseWriter, r *http.Request) {
275327
return
276328
}
277329

278-
provider, configModel, err := a.configStore.GetProvider(r.Context())
330+
userID, username, ipAddress := getUserContextFromRequest(r)
331+
model := selectModel(req.Model, "", a.getDefaultModelID(r.Context()))
332+
333+
provider, modelCfg, err := a.resolveProvider(r.Context(), model)
279334
if err != nil {
280335
a.logger.Error("Failed to get LLM provider", "error", err)
281336
a.respondError(w, http.StatusServiceUnavailable, api.ErrorCodeInternalError, "Agent is not configured properly")
282337
return
283338
}
284339

285-
userID, username, ipAddress := getUserContextFromRequest(r)
286-
model := selectModel(req.Model, "", configModel)
287340
id := uuid.New().String()
288341
now := time.Now()
289342

290343
mgr := NewConversationManager(ConversationManagerConfig{
291-
ID: id,
292-
UserID: userID,
293-
Logger: a.logger,
294-
WorkingDir: a.workingDir,
295-
OnMessage: a.createMessageCallback(id),
296-
Environment: a.environment,
297-
SafeMode: req.SafeMode,
298-
Hooks: a.hooks,
299-
Username: username,
300-
IPAddress: ipAddress,
344+
ID: id,
345+
UserID: userID,
346+
Logger: a.logger,
347+
WorkingDir: a.workingDir,
348+
OnMessage: a.createMessageCallback(id),
349+
Environment: a.environment,
350+
SafeMode: req.SafeMode,
351+
Hooks: a.hooks,
352+
Username: username,
353+
IPAddress: ipAddress,
354+
InputCostPer1M: modelCfg.InputCostPer1M,
355+
OutputCostPer1M: modelCfg.OutputCostPer1M,
301356
})
302357

303358
a.persistNewConversation(r.Context(), id, userID, now)
304359
a.conversations.Store(id, mgr)
305360

306361
messageWithContext := a.formatMessage(r.Context(), req.Message, req.DAGContexts)
307-
if err := mgr.AcceptUserMessage(r.Context(), provider, model, messageWithContext); err != nil {
362+
if err := mgr.AcceptUserMessage(r.Context(), provider, model, modelCfg.Model, messageWithContext); err != nil {
308363
a.logger.Error("Failed to accept user message", "error", err)
309364
a.respondError(w, http.StatusInternalServerError, api.ErrorCodeInternalError, "Failed to process message")
310365
return
@@ -325,6 +380,11 @@ func (a *API) handleListConversations(w http.ResponseWriter, r *http.Request) {
325380
conversations := a.collectActiveConversations(userID, activeIDs)
326381
conversations = a.appendPersistedConversations(r.Context(), userID, activeIDs, conversations)
327382

383+
// Ensure we return [] instead of null in JSON when no conversations exist
384+
if conversations == nil {
385+
conversations = []ConversationWithState{}
386+
}
387+
328388
a.respondJSON(w, http.StatusOK, conversations)
329389
}
330390

@@ -350,6 +410,7 @@ func (a *API) collectActiveConversations(userID string, activeIDs map[string]str
350410
Conversation: mgr.GetConversation(),
351411
Working: mgr.IsWorking(),
352412
Model: mgr.GetModel(),
413+
TotalCost: mgr.GetTotalCost(),
353414
})
354415
return true
355416
})
@@ -397,6 +458,7 @@ func (a *API) handleGetConversation(w http.ResponseWriter, r *http.Request) {
397458
ConversationID: id,
398459
Working: mgr.IsWorking(),
399460
Model: mgr.GetModel(),
461+
TotalCost: mgr.GetTotalCost(),
400462
},
401463
})
402464
return
@@ -425,7 +487,10 @@ func (a *API) getActiveConversation(id, userID string) (*ConversationManager, bo
425487
if !ok {
426488
return nil, false
427489
}
428-
mgr := mgrValue.(*ConversationManager)
490+
mgr, ok := mgrValue.(*ConversationManager)
491+
if !ok {
492+
return nil, false
493+
}
429494
if mgr.UserID() != userID {
430495
return nil, false
431496
}
@@ -465,6 +530,7 @@ func (a *API) handleChat(w http.ResponseWriter, r *http.Request) {
465530
}
466531

467532
var req ChatRequest
533+
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)
468534
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
469535
a.respondError(w, http.StatusBadRequest, api.ErrorCodeBadRequest, "Invalid JSON")
470536
return
@@ -475,20 +541,20 @@ func (a *API) handleChat(w http.ResponseWriter, r *http.Request) {
475541
return
476542
}
477543

478-
provider, configModel, err := a.configStore.GetProvider(r.Context())
544+
model := selectModel(req.Model, mgr.GetModel(), a.getDefaultModelID(r.Context()))
545+
546+
provider, modelCfg, err := a.resolveProvider(r.Context(), model)
479547
if err != nil {
480548
a.logger.Error("Failed to get LLM provider", "error", err)
481549
a.respondError(w, http.StatusServiceUnavailable, api.ErrorCodeInternalError, "Agent is not configured properly")
482550
return
483551
}
484-
485-
model := selectModel(req.Model, mgr.GetModel(), configModel)
486552
messageWithContext := a.formatMessage(r.Context(), req.Message, req.DAGContexts)
487553

488-
// Update safe mode setting per request (allows toggling mid-conversation)
489554
mgr.SetSafeMode(req.SafeMode)
555+
mgr.UpdatePricing(modelCfg.InputCostPer1M, modelCfg.OutputCostPer1M)
490556

491-
if err := mgr.AcceptUserMessage(r.Context(), provider, model, messageWithContext); err != nil {
557+
if err := mgr.AcceptUserMessage(r.Context(), provider, model, modelCfg.Model, messageWithContext); err != nil {
492558
a.logger.Error("Failed to accept user message", "error", err)
493559
a.respondError(w, http.StatusInternalServerError, api.ErrorCodeInternalError, "Failed to process message")
494560
return
@@ -539,6 +605,7 @@ func (a *API) reactivateConversation(ctx context.Context, id, userID, username,
539605
History: messages,
540606
SequenceID: seqID,
541607
Environment: a.environment,
608+
SafeMode: true, // Default to safe mode for reactivated conversations
542609
Hooks: a.hooks,
543610
Username: username,
544611
IPAddress: ipAddress,
@@ -567,12 +634,45 @@ func (a *API) handleStream(w http.ResponseWriter, r *http.Request) {
567634
snapshot, next := mgr.SubscribeWithSnapshot(r.Context())
568635
a.sendSSEMessage(w, snapshot)
569636

637+
type streamResult struct {
638+
resp StreamResponse
639+
cont bool
640+
}
641+
642+
heartbeat := time.NewTicker(15 * time.Second)
643+
defer heartbeat.Stop()
644+
645+
ch := make(chan streamResult, 1)
646+
go func() {
647+
for {
648+
resp, cont := next()
649+
ch <- streamResult{resp, cont}
650+
if !cont {
651+
return
652+
}
653+
}
654+
}()
655+
570656
for {
571-
resp, cont := next()
572-
if !cont {
573-
break
657+
select {
658+
case res := <-ch:
659+
if !res.cont {
660+
return
661+
}
662+
a.sendSSEMessage(w, res.resp)
663+
heartbeat.Reset(15 * time.Second)
664+
case <-heartbeat.C:
665+
// SSE comment as heartbeat to keep connection alive
666+
if _, err := fmt.Fprint(w, ": heartbeat\n\n"); err != nil {
667+
a.logger.Debug("SSE heartbeat write failed", "error", err)
668+
return
669+
}
670+
if f, ok := w.(http.Flusher); ok {
671+
f.Flush()
672+
}
673+
case <-r.Context().Done():
674+
return
574675
}
575-
a.sendSSEMessage(w, resp)
576676
}
577677
}
578678

@@ -594,7 +694,10 @@ func (a *API) sendSSEMessage(w http.ResponseWriter, data any) {
594694
slog.Error("failed to marshal SSE data", "error", err)
595695
return
596696
}
597-
fmt.Fprintf(w, "data: %s\n\n", jsonData)
697+
if _, err := fmt.Fprintf(w, "data: %s\n\n", jsonData); err != nil {
698+
a.logger.Debug("SSE write failed", "error", err)
699+
return
700+
}
598701
if f, ok := w.(http.Flusher); ok {
599702
f.Flush()
600703
}
@@ -634,6 +737,7 @@ func (a *API) handleUserResponse(w http.ResponseWriter, r *http.Request) {
634737
}
635738

636739
var req UserPromptResponse
740+
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)
637741
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
638742
a.respondError(w, http.StatusBadRequest, api.ErrorCodeBadRequest, "Invalid JSON")
639743
return
@@ -652,6 +756,57 @@ func (a *API) handleUserResponse(w http.ResponseWriter, r *http.Request) {
652756
a.respondJSON(w, http.StatusOK, map[string]string{"status": "accepted"})
653757
}
654758

759+
// idleConversationTimeout is the duration after which idle conversations are cleaned up.
760+
const idleConversationTimeout = 30 * time.Minute
761+
762+
// cleanupInterval is how often the cleanup goroutine runs.
763+
const cleanupInterval = 5 * time.Minute
764+
765+
// StartCleanup begins periodic cleanup of idle conversations.
766+
// It should be called once when the API is initialized and will
767+
// stop when the context is cancelled.
768+
func (a *API) StartCleanup(ctx context.Context) {
769+
go func() {
770+
ticker := time.NewTicker(cleanupInterval)
771+
defer ticker.Stop()
772+
773+
for {
774+
select {
775+
case <-ctx.Done():
776+
return
777+
case <-ticker.C:
778+
a.cleanupIdleConversations()
779+
}
780+
}
781+
}()
782+
}
783+
784+
// cleanupIdleConversations removes conversations that have been idle too long and are not working.
785+
func (a *API) cleanupIdleConversations() {
786+
cutoff := time.Now().Add(-idleConversationTimeout)
787+
var toDelete []string
788+
789+
a.conversations.Range(func(key, value any) bool {
790+
id, ok := key.(string)
791+
if !ok {
792+
return true
793+
}
794+
mgr, ok := value.(*ConversationManager)
795+
if !ok {
796+
return true
797+
}
798+
if !mgr.IsWorking() && mgr.LastActivity().Before(cutoff) {
799+
toDelete = append(toDelete, id)
800+
}
801+
return true
802+
})
803+
804+
for _, id := range toDelete {
805+
a.conversations.Delete(id)
806+
a.logger.Debug("Cleaned up idle conversation", "conversation_id", id)
807+
}
808+
}
809+
655810
// ptrTo returns a pointer to the given value.
656811
func ptrTo[T any](v T) *T {
657812
return &v

0 commit comments

Comments
 (0)