66 "fmt"
77 "sort"
88 "strings"
9+ "sync"
910 "time"
1011
1112 "github.com/milvus-io/milvus-sdk-go/v2/client"
@@ -22,6 +23,9 @@ const (
2223 DefaultRetryBaseDelay = 100
2324)
2425
26+ // DefaultMaxConcurrentPrunes is the default limit for concurrent pruneIfOverCap goroutines.
27+ const DefaultMaxConcurrentPrunes = 10
28+
2529// MilvusStore provides memory retrieval from Milvus with similarity threshold filtering
2630type MilvusStore struct {
2731 client client.Client
@@ -31,6 +35,8 @@ type MilvusStore struct {
3135 maxRetries int
3236 retryBaseDelay time.Duration
3337 embeddingConfig EmbeddingConfig // Unified embedding configuration
38+ pruneSem chan struct {} // bounds concurrent prune goroutines
39+ pruneInFlight sync.Map // tracks userIDs with an active prune goroutine (dedup)
3440}
3541
3642// MilvusStoreOptions contains configuration for creating a MilvusStore
@@ -79,6 +85,11 @@ func NewMilvusStore(options MilvusStoreOptions) (*MilvusStore, error) {
7985 embeddingCfg = EmbeddingConfig {Model : EmbeddingModelBERT }
8086 }
8187
88+ maxPrunes := cfg .QualityScoring .MaxConcurrentPrunes
89+ if maxPrunes <= 0 {
90+ maxPrunes = DefaultMaxConcurrentPrunes
91+ }
92+
8293 store := & MilvusStore {
8394 client : options .Client ,
8495 collectionName : options .CollectionName ,
@@ -87,6 +98,7 @@ func NewMilvusStore(options MilvusStoreOptions) (*MilvusStore, error) {
8798 maxRetries : DefaultMaxRetries ,
8899 retryBaseDelay : DefaultRetryBaseDelay * time .Millisecond ,
89100 embeddingConfig : embeddingCfg ,
101+ pruneSem : make (chan struct {}, maxPrunes ),
90102 }
91103
92104 // Auto-create collection if it doesn't exist
@@ -684,9 +696,93 @@ func (m *MilvusStore) Store(ctx context.Context, memory *Memory) error {
684696 }
685697
686698 logging .Debugf ("MilvusStore.Store: successfully stored memory id=%s" , memory .ID )
699+
700+ // Path 1: event-driven cap enforcement — async prune if user exceeds max_memories_per_user.
701+ // Uses context.Background() intentionally: the goroutine must outlive the request ctx.
702+ // Two layers of protection against Milvus pressure:
703+ // 1. pruneInFlight (sync.Map): dedup — at most one goroutine per user at any time
704+ // 2. pruneSem (channel): semaphore — at most maxConcurrentPrunes goroutines globally
705+ if m .config .QualityScoring .MaxMemoriesPerUser > 0 {
706+ if _ , alreadyRunning := m .pruneInFlight .LoadOrStore (memory .UserID , struct {}{}); ! alreadyRunning {
707+ select {
708+ case m .pruneSem <- struct {}{}:
709+ go func (userID string ) {
710+ defer func () {
711+ <- m .pruneSem
712+ m .pruneInFlight .Delete (userID )
713+ }()
714+ m .pruneIfOverCap (context .Background (), userID )
715+ }(memory .UserID )
716+ default :
717+ m .pruneInFlight .Delete (memory .UserID )
718+ logging .Debugf ("MilvusStore.Store: prune semaphore full, skipping cap check for user_id=%s" , memory .UserID )
719+ }
720+ }
721+ }
722+
687723 return nil
688724}
689725
726+ // pruneIfOverCap counts the user's memories and calls PruneUser if over MaxMemoriesPerUser.
727+ // Designed to run in a goroutine triggered by Store().
728+ func (m * MilvusStore ) pruneIfOverCap (ctx context.Context , userID string ) {
729+ cap := m .config .QualityScoring .MaxMemoriesPerUser
730+ if cap <= 0 {
731+ return
732+ }
733+
734+ count , err := m .countUserMemories (ctx , userID )
735+ if err != nil {
736+ logging .Warnf ("MilvusStore.pruneIfOverCap: count failed for user_id=%s: %v" , userID , err )
737+ return
738+ }
739+
740+ if count <= cap {
741+ return
742+ }
743+
744+ PruneCapTriggeredTotal .Inc ()
745+ logging .Infof ("MilvusStore.pruneIfOverCap: user_id=%s has %d memories (cap=%d), pruning" , userID , count , cap )
746+
747+ deleted , err := m .PruneUser (ctx , userID )
748+ if err != nil {
749+ logging .Warnf ("MilvusStore.pruneIfOverCap: PruneUser failed for user_id=%s: %v" , userID , err )
750+ return
751+ }
752+ if deleted > 0 {
753+ PruneDeletedTotal .WithLabelValues ("cap" ).Add (float64 (deleted ))
754+ logging .Infof ("MilvusStore.pruneIfOverCap: user_id=%s pruned %d memories" , userID , deleted )
755+ }
756+ }
757+
758+ // countUserMemories returns the number of memories stored for a given user.
759+ func (m * MilvusStore ) countUserMemories (ctx context.Context , userID string ) (int , error ) {
760+ filterExpr := fmt .Sprintf ("user_id == \" %s\" " , userID )
761+
762+ var queryResult []entity.Column
763+ err := m .retryWithBackoff (ctx , func () error {
764+ var retryErr error
765+ queryResult , retryErr = m .client .Query (
766+ ctx ,
767+ m .collectionName ,
768+ []string {},
769+ filterExpr ,
770+ []string {"id" },
771+ )
772+ return retryErr
773+ })
774+ if err != nil {
775+ return 0 , fmt .Errorf ("milvus query failed: %w" , err )
776+ }
777+
778+ for _ , col := range queryResult {
779+ if col .Name () == "id" {
780+ return col .Len (), nil
781+ }
782+ }
783+ return 0 , nil
784+ }
785+
690786// upsert atomically replaces a row in Milvus by primary key.
691787// The memory must be fully populated (including Embedding, timestamps, etc.).
692788// Used by Update to avoid the delete+insert data-loss window.
@@ -1485,6 +1581,57 @@ func (m *MilvusStore) PruneUser(ctx context.Context, userID string) (deleted int
14851581 return deleted , nil
14861582}
14871583
1584+ // ListStaleUserIDs queries Milvus for memories with created_at older than cutoffUnix
1585+ // and returns the deduplicated set of user_id values. This targets users whose oldest
1586+ // memories may have decayed below the prune threshold, without iterating all users.
1587+ func (m * MilvusStore ) ListStaleUserIDs (ctx context.Context , cutoffUnix int64 ) ([]string , error ) {
1588+ if ! m .enabled {
1589+ return nil , fmt .Errorf ("milvus store is not enabled" )
1590+ }
1591+
1592+ filterExpr := fmt .Sprintf ("created_at < %d" , cutoffUnix )
1593+ outputFields := []string {"user_id" }
1594+
1595+ var queryResult []entity.Column
1596+ err := m .retryWithBackoff (ctx , func () error {
1597+ var retryErr error
1598+ queryResult , retryErr = m .client .Query (
1599+ ctx ,
1600+ m .collectionName ,
1601+ []string {},
1602+ filterExpr ,
1603+ outputFields ,
1604+ )
1605+ return retryErr
1606+ })
1607+ if err != nil {
1608+ return nil , fmt .Errorf ("milvus query for stale users failed: %w" , err )
1609+ }
1610+
1611+ seen := make (map [string ]struct {})
1612+ for _ , col := range queryResult {
1613+ if col .Name () == "user_id" {
1614+ vc , ok := col .(* entity.ColumnVarChar )
1615+ if ! ok {
1616+ continue
1617+ }
1618+ for i := 0 ; i < vc .Len (); i ++ {
1619+ uid , _ := vc .ValueByIdx (i )
1620+ if uid != "" {
1621+ seen [uid ] = struct {}{}
1622+ }
1623+ }
1624+ }
1625+ }
1626+
1627+ userIDs := make ([]string , 0 , len (seen ))
1628+ for uid := range seen {
1629+ userIDs = append (userIDs , uid )
1630+ }
1631+ logging .Debugf ("MilvusStore.ListStaleUserIDs: found %d users with memories older than %d" , len (userIDs ), cutoffUnix )
1632+ return userIDs , nil
1633+ }
1634+
14881635// isTransientError checks if an error is transient and should be retried
14891636func isTransientError (err error ) bool {
14901637 if err == nil {
0 commit comments