44 "cmp"
55 "context"
66 "encoding/json"
7+ "errors"
78 "fmt"
89 "log/slog"
910 "net/http"
@@ -14,6 +15,7 @@ import (
1415 api "github.com/dagu-org/dagu/api/v1"
1516 "github.com/dagu-org/dagu/internal/auth"
1617 "github.com/dagu-org/dagu/internal/core/exec"
18+ "github.com/dagu-org/dagu/internal/llm"
1719 "github.com/go-chi/chi/v5"
1820 "github.com/google/uuid"
1921)
@@ -30,6 +32,9 @@ func respondErrorDirect(w http.ResponseWriter, status int, code api.ErrorCode, m
3032 }
3133}
3234
35+ // maxRequestBodySize is the maximum allowed size for JSON request bodies (1 MB).
36+ const maxRequestBodySize = 1 << 20
37+
3338// defaultUserID is used when no user is authenticated (e.g., auth disabled).
3439// This value should match the system's expected default user identifier.
3540const defaultUserID = "admin"
@@ -59,6 +64,8 @@ type API struct {
5964 conversations sync.Map // id -> *ConversationManager (active conversations)
6065 store ConversationStore
6166 configStore ConfigStore
67+ modelStore ModelStore
68+ providers * ProviderCache
6269 workingDir string
6370 logger * slog.Logger
6471 dagStore exec.DAGStore // For resolving DAG file paths
@@ -69,6 +76,7 @@ type API struct {
6976// APIConfig contains configuration for the API.
7077type APIConfig struct {
7178 ConfigStore ConfigStore
79+ ModelStore ModelStore
7280 WorkingDir string
7381 Logger * slog.Logger
7482 ConversationStore ConversationStore
@@ -82,6 +90,7 @@ type ConversationWithState struct {
8290 Conversation Conversation `json:"conversation"`
8391 Working bool `json:"working"`
8492 Model string `json:"model,omitempty"`
93+ TotalCost float64 `json:"total_cost"`
8594}
8695
8796// NewAPI creates a new API instance.
@@ -93,6 +102,8 @@ func NewAPI(cfg APIConfig) *API {
93102
94103 return & API {
95104 configStore : cfg .ConfigStore ,
105+ modelStore : cfg .ModelStore ,
106+ providers : NewProviderCache (),
96107 workingDir : cfg .WorkingDir ,
97108 logger : logger ,
98109 store : cfg .ConversationStore ,
@@ -214,6 +225,46 @@ func selectModel(requestModel, conversationModel, configModel string) string {
214225 return cmp .Or (requestModel , conversationModel , configModel )
215226}
216227
228+ // getDefaultModelID returns the default model ID from config.
229+ func (a * API ) getDefaultModelID (ctx context.Context ) string {
230+ cfg , err := a .configStore .Load (ctx )
231+ if err != nil {
232+ a .logger .Warn ("Failed to load agent config for default model" , "error" , err )
233+ return ""
234+ }
235+ return cfg .DefaultModelID
236+ }
237+
238+ // resolveProvider resolves a model ID to an LLM provider and model config.
239+ // If modelID is empty, uses the default from config.
240+ // If the requested model is not found (e.g., deleted), falls back to the default.
241+ func (a * API ) resolveProvider (ctx context.Context , modelID string ) (llm.Provider , * ModelConfig , error ) {
242+ if a .modelStore == nil {
243+ return nil , nil , errors .New ("model store not configured" )
244+ }
245+
246+ defaultID := a .getDefaultModelID (ctx )
247+ modelID = cmp .Or (modelID , defaultID )
248+ if modelID == "" {
249+ return nil , nil , errors .New ("no model configured" )
250+ }
251+
252+ model , err := a .modelStore .GetByID (ctx , modelID )
253+ if errors .Is (err , ErrModelNotFound ) && defaultID != "" && defaultID != modelID {
254+ // Requested model was deleted; fall back to default
255+ model , err = a .modelStore .GetByID (ctx , defaultID )
256+ }
257+ if err != nil {
258+ return nil , nil , err
259+ }
260+
261+ provider , _ , err := a .providers .GetOrCreate (model .ToLLMConfig ())
262+ if err != nil {
263+ return nil , nil , err
264+ }
265+ return provider , model , nil
266+ }
267+
217268// createMessageCallback returns a persistence callback for the given conversation ID.
218269// Returns nil if no store is configured.
219270func (a * API ) createMessageCallback (id string ) func (ctx context.Context , msg Message ) error {
@@ -265,6 +316,7 @@ func (a *API) respondError(w http.ResponseWriter, status int, code api.ErrorCode
265316// POST /api/v1/agent/conversations/new
266317func (a * API ) handleNewConversation (w http.ResponseWriter , r * http.Request ) {
267318 var req ChatRequest
319+ r .Body = http .MaxBytesReader (w , r .Body , maxRequestBodySize )
268320 if err := json .NewDecoder (r .Body ).Decode (& req ); err != nil {
269321 a .respondError (w , http .StatusBadRequest , api .ErrorCodeBadRequest , "Invalid JSON" )
270322 return
@@ -275,36 +327,39 @@ func (a *API) handleNewConversation(w http.ResponseWriter, r *http.Request) {
275327 return
276328 }
277329
278- provider , configModel , err := a .configStore .GetProvider (r .Context ())
330+ userID , username , ipAddress := getUserContextFromRequest (r )
331+ model := selectModel (req .Model , "" , a .getDefaultModelID (r .Context ()))
332+
333+ provider , modelCfg , err := a .resolveProvider (r .Context (), model )
279334 if err != nil {
280335 a .logger .Error ("Failed to get LLM provider" , "error" , err )
281336 a .respondError (w , http .StatusServiceUnavailable , api .ErrorCodeInternalError , "Agent is not configured properly" )
282337 return
283338 }
284339
285- userID , username , ipAddress := getUserContextFromRequest (r )
286- model := selectModel (req .Model , "" , configModel )
287340 id := uuid .New ().String ()
288341 now := time .Now ()
289342
290343 mgr := NewConversationManager (ConversationManagerConfig {
291- ID : id ,
292- UserID : userID ,
293- Logger : a .logger ,
294- WorkingDir : a .workingDir ,
295- OnMessage : a .createMessageCallback (id ),
296- Environment : a .environment ,
297- SafeMode : req .SafeMode ,
298- Hooks : a .hooks ,
299- Username : username ,
300- IPAddress : ipAddress ,
344+ ID : id ,
345+ UserID : userID ,
346+ Logger : a .logger ,
347+ WorkingDir : a .workingDir ,
348+ OnMessage : a .createMessageCallback (id ),
349+ Environment : a .environment ,
350+ SafeMode : req .SafeMode ,
351+ Hooks : a .hooks ,
352+ Username : username ,
353+ IPAddress : ipAddress ,
354+ InputCostPer1M : modelCfg .InputCostPer1M ,
355+ OutputCostPer1M : modelCfg .OutputCostPer1M ,
301356 })
302357
303358 a .persistNewConversation (r .Context (), id , userID , now )
304359 a .conversations .Store (id , mgr )
305360
306361 messageWithContext := a .formatMessage (r .Context (), req .Message , req .DAGContexts )
307- if err := mgr .AcceptUserMessage (r .Context (), provider , model , messageWithContext ); err != nil {
362+ if err := mgr .AcceptUserMessage (r .Context (), provider , model , modelCfg . Model , messageWithContext ); err != nil {
308363 a .logger .Error ("Failed to accept user message" , "error" , err )
309364 a .respondError (w , http .StatusInternalServerError , api .ErrorCodeInternalError , "Failed to process message" )
310365 return
@@ -325,6 +380,11 @@ func (a *API) handleListConversations(w http.ResponseWriter, r *http.Request) {
325380 conversations := a .collectActiveConversations (userID , activeIDs )
326381 conversations = a .appendPersistedConversations (r .Context (), userID , activeIDs , conversations )
327382
383+ // Ensure we return [] instead of null in JSON when no conversations exist
384+ if conversations == nil {
385+ conversations = []ConversationWithState {}
386+ }
387+
328388 a .respondJSON (w , http .StatusOK , conversations )
329389}
330390
@@ -350,6 +410,7 @@ func (a *API) collectActiveConversations(userID string, activeIDs map[string]str
350410 Conversation : mgr .GetConversation (),
351411 Working : mgr .IsWorking (),
352412 Model : mgr .GetModel (),
413+ TotalCost : mgr .GetTotalCost (),
353414 })
354415 return true
355416 })
@@ -397,6 +458,7 @@ func (a *API) handleGetConversation(w http.ResponseWriter, r *http.Request) {
397458 ConversationID : id ,
398459 Working : mgr .IsWorking (),
399460 Model : mgr .GetModel (),
461+ TotalCost : mgr .GetTotalCost (),
400462 },
401463 })
402464 return
@@ -425,7 +487,10 @@ func (a *API) getActiveConversation(id, userID string) (*ConversationManager, bo
425487 if ! ok {
426488 return nil , false
427489 }
428- mgr := mgrValue .(* ConversationManager )
490+ mgr , ok := mgrValue .(* ConversationManager )
491+ if ! ok {
492+ return nil , false
493+ }
429494 if mgr .UserID () != userID {
430495 return nil , false
431496 }
@@ -465,6 +530,7 @@ func (a *API) handleChat(w http.ResponseWriter, r *http.Request) {
465530 }
466531
467532 var req ChatRequest
533+ r .Body = http .MaxBytesReader (w , r .Body , maxRequestBodySize )
468534 if err := json .NewDecoder (r .Body ).Decode (& req ); err != nil {
469535 a .respondError (w , http .StatusBadRequest , api .ErrorCodeBadRequest , "Invalid JSON" )
470536 return
@@ -475,20 +541,20 @@ func (a *API) handleChat(w http.ResponseWriter, r *http.Request) {
475541 return
476542 }
477543
478- provider , configModel , err := a .configStore .GetProvider (r .Context ())
544+ model := selectModel (req .Model , mgr .GetModel (), a .getDefaultModelID (r .Context ()))
545+
546+ provider , modelCfg , err := a .resolveProvider (r .Context (), model )
479547 if err != nil {
480548 a .logger .Error ("Failed to get LLM provider" , "error" , err )
481549 a .respondError (w , http .StatusServiceUnavailable , api .ErrorCodeInternalError , "Agent is not configured properly" )
482550 return
483551 }
484-
485- model := selectModel (req .Model , mgr .GetModel (), configModel )
486552 messageWithContext := a .formatMessage (r .Context (), req .Message , req .DAGContexts )
487553
488- // Update safe mode setting per request (allows toggling mid-conversation)
489554 mgr .SetSafeMode (req .SafeMode )
555+ mgr .UpdatePricing (modelCfg .InputCostPer1M , modelCfg .OutputCostPer1M )
490556
491- if err := mgr .AcceptUserMessage (r .Context (), provider , model , messageWithContext ); err != nil {
557+ if err := mgr .AcceptUserMessage (r .Context (), provider , model , modelCfg . Model , messageWithContext ); err != nil {
492558 a .logger .Error ("Failed to accept user message" , "error" , err )
493559 a .respondError (w , http .StatusInternalServerError , api .ErrorCodeInternalError , "Failed to process message" )
494560 return
@@ -539,6 +605,7 @@ func (a *API) reactivateConversation(ctx context.Context, id, userID, username,
539605 History : messages ,
540606 SequenceID : seqID ,
541607 Environment : a .environment ,
608+ SafeMode : true , // Default to safe mode for reactivated conversations
542609 Hooks : a .hooks ,
543610 Username : username ,
544611 IPAddress : ipAddress ,
@@ -567,12 +634,45 @@ func (a *API) handleStream(w http.ResponseWriter, r *http.Request) {
567634 snapshot , next := mgr .SubscribeWithSnapshot (r .Context ())
568635 a .sendSSEMessage (w , snapshot )
569636
637+ type streamResult struct {
638+ resp StreamResponse
639+ cont bool
640+ }
641+
642+ heartbeat := time .NewTicker (15 * time .Second )
643+ defer heartbeat .Stop ()
644+
645+ ch := make (chan streamResult , 1 )
646+ go func () {
647+ for {
648+ resp , cont := next ()
649+ ch <- streamResult {resp , cont }
650+ if ! cont {
651+ return
652+ }
653+ }
654+ }()
655+
570656 for {
571- resp , cont := next ()
572- if ! cont {
573- break
657+ select {
658+ case res := <- ch :
659+ if ! res .cont {
660+ return
661+ }
662+ a .sendSSEMessage (w , res .resp )
663+ heartbeat .Reset (15 * time .Second )
664+ case <- heartbeat .C :
665+ // SSE comment as heartbeat to keep connection alive
666+ if _ , err := fmt .Fprint (w , ": heartbeat\n \n " ); err != nil {
667+ a .logger .Debug ("SSE heartbeat write failed" , "error" , err )
668+ return
669+ }
670+ if f , ok := w .(http.Flusher ); ok {
671+ f .Flush ()
672+ }
673+ case <- r .Context ().Done ():
674+ return
574675 }
575- a .sendSSEMessage (w , resp )
576676 }
577677}
578678
@@ -594,7 +694,10 @@ func (a *API) sendSSEMessage(w http.ResponseWriter, data any) {
594694 slog .Error ("failed to marshal SSE data" , "error" , err )
595695 return
596696 }
597- fmt .Fprintf (w , "data: %s\n \n " , jsonData )
697+ if _ , err := fmt .Fprintf (w , "data: %s\n \n " , jsonData ); err != nil {
698+ a .logger .Debug ("SSE write failed" , "error" , err )
699+ return
700+ }
598701 if f , ok := w .(http.Flusher ); ok {
599702 f .Flush ()
600703 }
@@ -634,6 +737,7 @@ func (a *API) handleUserResponse(w http.ResponseWriter, r *http.Request) {
634737 }
635738
636739 var req UserPromptResponse
740+ r .Body = http .MaxBytesReader (w , r .Body , maxRequestBodySize )
637741 if err := json .NewDecoder (r .Body ).Decode (& req ); err != nil {
638742 a .respondError (w , http .StatusBadRequest , api .ErrorCodeBadRequest , "Invalid JSON" )
639743 return
@@ -652,6 +756,57 @@ func (a *API) handleUserResponse(w http.ResponseWriter, r *http.Request) {
652756 a .respondJSON (w , http .StatusOK , map [string ]string {"status" : "accepted" })
653757}
654758
759+ // idleConversationTimeout is the duration after which idle conversations are cleaned up.
760+ const idleConversationTimeout = 30 * time .Minute
761+
762+ // cleanupInterval is how often the cleanup goroutine runs.
763+ const cleanupInterval = 5 * time .Minute
764+
765+ // StartCleanup begins periodic cleanup of idle conversations.
766+ // It should be called once when the API is initialized and will
767+ // stop when the context is cancelled.
768+ func (a * API ) StartCleanup (ctx context.Context ) {
769+ go func () {
770+ ticker := time .NewTicker (cleanupInterval )
771+ defer ticker .Stop ()
772+
773+ for {
774+ select {
775+ case <- ctx .Done ():
776+ return
777+ case <- ticker .C :
778+ a .cleanupIdleConversations ()
779+ }
780+ }
781+ }()
782+ }
783+
784+ // cleanupIdleConversations removes conversations that have been idle too long and are not working.
785+ func (a * API ) cleanupIdleConversations () {
786+ cutoff := time .Now ().Add (- idleConversationTimeout )
787+ var toDelete []string
788+
789+ a .conversations .Range (func (key , value any ) bool {
790+ id , ok := key .(string )
791+ if ! ok {
792+ return true
793+ }
794+ mgr , ok := value .(* ConversationManager )
795+ if ! ok {
796+ return true
797+ }
798+ if ! mgr .IsWorking () && mgr .LastActivity ().Before (cutoff ) {
799+ toDelete = append (toDelete , id )
800+ }
801+ return true
802+ })
803+
804+ for _ , id := range toDelete {
805+ a .conversations .Delete (id )
806+ a .logger .Debug ("Cleaned up idle conversation" , "conversation_id" , id )
807+ }
808+ }
809+
655810// ptrTo returns a pointer to the given value.
656811func ptrTo [T any ](v T ) * T {
657812 return & v
0 commit comments