66 "fmt"
77 "sort"
88 "strings"
9+ "sync"
910 "time"
1011
1112 "github.com/milvus-io/milvus-sdk-go/v2/client"
@@ -22,6 +23,9 @@ const (
2223 DefaultRetryBaseDelay = 100
2324)
2425
26+ // DefaultMaxConcurrentPrunes is the default limit for concurrent pruneIfOverCap goroutines.
27+ const DefaultMaxConcurrentPrunes = 10
28+
2529// MilvusStore provides memory retrieval from Milvus with similarity threshold filtering
2630type MilvusStore struct {
2731 client client.Client
@@ -31,6 +35,8 @@ type MilvusStore struct {
3135 maxRetries int
3236 retryBaseDelay time.Duration
3337 embeddingConfig EmbeddingConfig // Unified embedding configuration
38+ pruneSem chan struct {} // bounds concurrent prune goroutines
39+ pruneInFlight sync.Map // tracks userIDs with an active prune goroutine (dedup)
3440}
3541
3642// MilvusStoreOptions contains configuration for creating a MilvusStore
@@ -79,6 +85,11 @@ func NewMilvusStore(options MilvusStoreOptions) (*MilvusStore, error) {
7985 embeddingCfg = EmbeddingConfig {Model : EmbeddingModelBERT }
8086 }
8187
88+ maxPrunes := cfg .QualityScoring .MaxConcurrentPrunes
89+ if maxPrunes <= 0 {
90+ maxPrunes = DefaultMaxConcurrentPrunes
91+ }
92+
8293 store := & MilvusStore {
8394 client : options .Client ,
8495 collectionName : options .CollectionName ,
@@ -87,6 +98,7 @@ func NewMilvusStore(options MilvusStoreOptions) (*MilvusStore, error) {
8798 maxRetries : DefaultMaxRetries ,
8899 retryBaseDelay : DefaultRetryBaseDelay * time .Millisecond ,
89100 embeddingConfig : embeddingCfg ,
101+ pruneSem : make (chan struct {}, maxPrunes ),
90102 }
91103
92104 // Auto-create collection if it doesn't exist
@@ -638,9 +650,93 @@ func (m *MilvusStore) Store(ctx context.Context, memory *Memory) error {
638650 }
639651
640652 logging .Debugf ("MilvusStore.Store: successfully stored memory id=%s" , memory .ID )
653+
654+ // Path 1: event-driven cap enforcement — async prune if user exceeds max_memories_per_user.
655+ // Uses context.Background() intentionally: the goroutine must outlive the request ctx.
656+ // Two layers of protection against Milvus pressure:
657+ // 1. pruneInFlight (sync.Map): dedup — at most one goroutine per user at any time
658+ // 2. pruneSem (channel): semaphore — at most maxConcurrentPrunes goroutines globally
659+ if m .config .QualityScoring .MaxMemoriesPerUser > 0 {
660+ if _ , alreadyRunning := m .pruneInFlight .LoadOrStore (memory .UserID , struct {}{}); ! alreadyRunning {
661+ select {
662+ case m .pruneSem <- struct {}{}:
663+ go func (userID string ) {
664+ defer func () {
665+ <- m .pruneSem
666+ m .pruneInFlight .Delete (userID )
667+ }()
668+ m .pruneIfOverCap (context .Background (), userID )
669+ }(memory .UserID )
670+ default :
671+ m .pruneInFlight .Delete (memory .UserID )
672+ logging .Debugf ("MilvusStore.Store: prune semaphore full, skipping cap check for user_id=%s" , memory .UserID )
673+ }
674+ }
675+ }
676+
641677 return nil
642678}
643679
680+ // pruneIfOverCap counts the user's memories and calls PruneUser if over MaxMemoriesPerUser.
681+ // Designed to run in a goroutine triggered by Store().
682+ func (m * MilvusStore ) pruneIfOverCap (ctx context.Context , userID string ) {
683+ cap := m .config .QualityScoring .MaxMemoriesPerUser
684+ if cap <= 0 {
685+ return
686+ }
687+
688+ count , err := m .countUserMemories (ctx , userID )
689+ if err != nil {
690+ logging .Warnf ("MilvusStore.pruneIfOverCap: count failed for user_id=%s: %v" , userID , err )
691+ return
692+ }
693+
694+ if count <= cap {
695+ return
696+ }
697+
698+ PruneCapTriggeredTotal .Inc ()
699+ logging .Infof ("MilvusStore.pruneIfOverCap: user_id=%s has %d memories (cap=%d), pruning" , userID , count , cap )
700+
701+ deleted , err := m .PruneUser (ctx , userID )
702+ if err != nil {
703+ logging .Warnf ("MilvusStore.pruneIfOverCap: PruneUser failed for user_id=%s: %v" , userID , err )
704+ return
705+ }
706+ if deleted > 0 {
707+ PruneDeletedTotal .WithLabelValues ("cap" ).Add (float64 (deleted ))
708+ logging .Infof ("MilvusStore.pruneIfOverCap: user_id=%s pruned %d memories" , userID , deleted )
709+ }
710+ }
711+
712+ // countUserMemories returns the number of memories stored for a given user.
713+ func (m * MilvusStore ) countUserMemories (ctx context.Context , userID string ) (int , error ) {
714+ filterExpr := fmt .Sprintf ("user_id == \" %s\" " , userID )
715+
716+ var queryResult []entity.Column
717+ err := m .retryWithBackoff (ctx , func () error {
718+ var retryErr error
719+ queryResult , retryErr = m .client .Query (
720+ ctx ,
721+ m .collectionName ,
722+ []string {},
723+ filterExpr ,
724+ []string {"id" },
725+ )
726+ return retryErr
727+ })
728+ if err != nil {
729+ return 0 , fmt .Errorf ("milvus query failed: %w" , err )
730+ }
731+
732+ for _ , col := range queryResult {
733+ if col .Name () == "id" {
734+ return col .Len (), nil
735+ }
736+ }
737+ return 0 , nil
738+ }
739+
644740// upsert atomically replaces a row in Milvus by primary key.
645741// The memory must be fully populated (including Embedding, timestamps, etc.).
646742// Used by Update to avoid the delete+insert data-loss window.
@@ -1388,6 +1484,57 @@ func (m *MilvusStore) PruneUser(ctx context.Context, userID string) (deleted int
13881484 return deleted , nil
13891485}
13901486
1487+ // ListStaleUserIDs queries Milvus for memories with created_at older than cutoffUnix
1488+ // and returns the deduplicated set of user_id values. This targets users whose oldest
1489+ // memories may have decayed below the prune threshold, without iterating all users.
1490+ func (m * MilvusStore ) ListStaleUserIDs (ctx context.Context , cutoffUnix int64 ) ([]string , error ) {
1491+ if ! m .enabled {
1492+ return nil , fmt .Errorf ("milvus store is not enabled" )
1493+ }
1494+
1495+ filterExpr := fmt .Sprintf ("created_at < %d" , cutoffUnix )
1496+ outputFields := []string {"user_id" }
1497+
1498+ var queryResult []entity.Column
1499+ err := m .retryWithBackoff (ctx , func () error {
1500+ var retryErr error
1501+ queryResult , retryErr = m .client .Query (
1502+ ctx ,
1503+ m .collectionName ,
1504+ []string {},
1505+ filterExpr ,
1506+ outputFields ,
1507+ )
1508+ return retryErr
1509+ })
1510+ if err != nil {
1511+ return nil , fmt .Errorf ("milvus query for stale users failed: %w" , err )
1512+ }
1513+
1514+ seen := make (map [string ]struct {})
1515+ for _ , col := range queryResult {
1516+ if col .Name () == "user_id" {
1517+ vc , ok := col .(* entity.ColumnVarChar )
1518+ if ! ok {
1519+ continue
1520+ }
1521+ for i := 0 ; i < vc .Len (); i ++ {
1522+ uid , _ := vc .ValueByIdx (i )
1523+ if uid != "" {
1524+ seen [uid ] = struct {}{}
1525+ }
1526+ }
1527+ }
1528+ }
1529+
1530+ userIDs := make ([]string , 0 , len (seen ))
1531+ for uid := range seen {
1532+ userIDs = append (userIDs , uid )
1533+ }
1534+ logging .Debugf ("MilvusStore.ListStaleUserIDs: found %d users with memories older than %d" , len (userIDs ), cutoffUnix )
1535+ return userIDs , nil
1536+ }
1537+
13911538// isTransientError checks if an error is transient and should be retried
13921539func isTransientError (err error ) bool {
13931540 if err == nil {
0 commit comments