Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions cmd/serve_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1051,22 +1051,12 @@ func (s *serveServer) handleSessionsSearch(w http.ResponseWriter, r *http.Reques
}
}

summaries, err := s.store.List(r.Context(), session.ListOptions{
Limit: 2000,
Archived: includeArchived,
Categories: categories,
SortByActivity: true,
matches, err := s.store.Search(r.Context(), session.SearchOptions{
Query: query,
Categories: categories,
Limit: limit,
Archived: includeArchived,
})
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "server_error", "failed to list sessions")
return
}
allowed := make(map[string]session.SessionSummary, len(summaries))
for _, summary := range summaries {
allowed[summary.ID] = summary
}

matches, err := s.store.Search(r.Context(), query, limit*4)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", "invalid search query")
return
Expand All @@ -1090,17 +1080,29 @@ func (s *serveServer) handleSessionsSearch(w http.ResponseWriter, r *http.Reques
MessageID int64 `json:"message_id,omitempty"`
}

result := make([]sessionSearchEntry, 0, min(limit, len(matches)))
seen := make(map[string]bool, len(matches))
result := make([]sessionSearchEntry, 0, len(matches))
for _, match := range matches {
if seen[match.SessionID] {
continue
summary := session.SessionSummary{
ID: match.SessionID,
Number: match.SessionNumber,
Name: match.SessionName,
Summary: match.Summary,
GeneratedShortTitle: match.GeneratedShortTitle,
GeneratedLongTitle: match.GeneratedLongTitle,
TitleSource: match.TitleSource,
Provider: match.Provider,
ProviderKey: match.ProviderKey,
Model: match.Model,
Mode: match.Mode,
Origin: match.Origin,
Archived: match.Archived,
Pinned: match.Pinned,
MessageCount: match.MessageCount,
Status: match.Status,
CreatedAt: match.SessionCreatedAt,
UpdatedAt: match.UpdatedAt,
LastMessageAt: match.LastMessageAt,
}
summary, ok := allowed[match.SessionID]
if !ok {
continue
}
seen[match.SessionID] = true
lastMessageAt := sessionSummaryLastMessageAt(summary)
result = append(result, sessionSearchEntry{
ID: summary.ID,
Expand All @@ -1119,9 +1121,6 @@ func (s *serveServer) handleSessionsSearch(w http.ResponseWriter, r *http.Reques
Snippet: match.Snippet,
MessageID: match.MessageID,
})
if len(result) >= limit {
break
}
}

writeJSONConditional(w, r, http.StatusOK, map[string]any{"sessions": result})
Expand Down
2 changes: 1 addition & 1 deletion cmd/serve_runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (s *serveRuntimeTestStore) List(ctx context.Context, opts session.ListOptio
return nil, nil
}

func (s *serveRuntimeTestStore) Search(ctx context.Context, query string, limit int) ([]session.SearchResult, error) {
func (s *serveRuntimeTestStore) Search(ctx context.Context, opts session.SearchOptions) ([]session.SearchResult, error) {
return nil, nil
}

Expand Down
73 changes: 73 additions & 0 deletions cmd/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4129,6 +4129,79 @@ func TestHandleSessionsSearch_UsesFTSAndReturnsSessionSummaries(t *testing.T) {
}
}

func TestHandleSessionsSearch_DoesNotDropOlderMatchesOutsideRecent2000ListWindow(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "sessions.db")
store, err := session.NewStore(session.Config{Enabled: true, Path: dbPath})
if err != nil {
t.Fatalf("NewStore: %v", err)
}
defer store.Close()

ctx := context.Background()
base := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Second)
matching := &session.Session{
ID: "older-match",
Provider: "mock",
Model: "mock-model",
Mode: session.ModeChat,
Summary: "older matching session",
CreatedAt: base,
UpdatedAt: base,
Status: session.StatusActive,
}
if err := store.Create(ctx, matching); err != nil {
t.Fatalf("Create matching session: %v", err)
}
if err := store.AddMessage(ctx, matching.ID, &session.Message{
Role: llm.RoleUser,
TextContent: "needle older searchable session",
Parts: []llm.Part{{Type: llm.PartText, Text: "needle older searchable session"}},
CreatedAt: base,
}); err != nil {
t.Fatalf("AddMessage matching session: %v", err)
}

for i := 0; i < 2000; i++ {
createdAt := base.Add(time.Duration(i+1) * time.Second)
sess := &session.Session{
ID: fmt.Sprintf("newer-%04d", i),
Provider: "mock",
Model: "mock-model",
Mode: session.ModeChat,
Summary: fmt.Sprintf("newer session %d", i),
CreatedAt: createdAt,
UpdatedAt: createdAt,
Status: session.StatusActive,
}
if err := store.Create(ctx, sess); err != nil {
t.Fatalf("Create newer session %d: %v", i, err)
}
}

srv := &serveServer{store: store}
req := httptest.NewRequest(http.MethodGet, "/v1/sessions/search?q=needle&limit=5", nil)
rr := httptest.NewRecorder()
srv.handleSessionsSearch(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want 200, body: %s", rr.Code, rr.Body.String())
}

var body struct {
Sessions []struct {
ID string `json:"id"`
} `json:"sessions"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("decode: %v", err)
}
if len(body.Sessions) != 1 {
t.Fatalf("session count = %d, want 1; body: %s", len(body.Sessions), rr.Body.String())
}
if body.Sessions[0].ID != matching.ID {
t.Fatalf("id = %q, want %q", body.Sessions[0].ID, matching.ID)
}
}

func TestHandleSessions_GzipCompressed(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "sessions.db")
store, err := session.NewStore(session.Config{Enabled: true, Path: dbPath})
Expand Down
2 changes: 1 addition & 1 deletion cmd/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func runSessionsSearch(cmd *cobra.Command, args []string) error {

query := strings.Join(args, " ")
ctx := context.Background()
results, err := store.Search(ctx, query, 20)
results, err := store.Search(ctx, session.SearchOptions{Query: query, Limit: 20})
if err != nil {
return fmt.Errorf("search failed: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/session/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (s *NoopStore) List(ctx context.Context, opts ListOptions) ([]SessionSummar
return nil, nil
}

func (s *NoopStore) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) {
func (s *NoopStore) Search(ctx context.Context, opts SearchOptions) ([]SearchResult, error) {
return nil, nil
}

Expand Down
135 changes: 121 additions & 14 deletions internal/session/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,12 +1505,12 @@ func (s *SQLiteStore) List(ctx context.Context, opts ListOptions) ([]SessionSumm
}

// Search finds sessions containing the query text using FTS5.
func (s *SQLiteStore) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) {
if limit == 0 {
limit = 20
func (s *SQLiteStore) Search(ctx context.Context, opts SearchOptions) ([]SearchResult, error) {
if opts.Limit == 0 {
opts.Limit = 20
}

ftsQuery := sqlitefts.LiteralQuery(query)
ftsQuery := sqlitefts.LiteralQuery(opts.Query)
if ftsQuery == "" {
return []SearchResult{}, nil
}
Expand All @@ -1528,15 +1528,99 @@ func (s *SQLiteStore) Search(ctx context.Context, query string, limit int) ([]Se
messageCountCol = "(SELECT COUNT(*) FROM messages WHERE session_id = s.id AND role IN ('user', 'assistant')" + countCompactionTailClause + ")"
}

originCol := "'tui'"
if s.hasOrigin {
originCol = "COALESCE(NULLIF(TRIM(s.origin), ''), 'tui')"
}
pinnedCol := "FALSE"
if s.hasPinned {
pinnedCol = "COALESCE(s.pinned, FALSE)"
}
generatedShortCol := "''"
generatedLongCol := "''"
titleSourceCol := "''"
if s.hasGeneratedTitles {
generatedShortCol = "s.generated_short_title"
generatedLongCol = "s.generated_long_title"
titleSourceCol = "s.title_source"
}
lastMessageAtCol := "NULL"
if s.hasLastMessageAt {
lastMessageAtCol = "s.last_message_at"
}

filterClause := ""
args := []any{ftsQuery}
if len(opts.Categories) > 0 {
clauses := make([]string, 0, len(opts.Categories))
sawSpecificCategory := false
for _, raw := range opts.Categories {
category := strings.ToLower(strings.TrimSpace(raw))
switch category {
case "", "all":
clauses = nil
case "chat":
sawSpecificCategory = true
if s.hasOrigin {
clauses = append(clauses, "(s.mode = 'chat' AND COALESCE(NULLIF(TRIM(s.origin), ''), 'tui') = 'tui')")
} else {
clauses = append(clauses, "(s.mode = 'chat')")
}
case "web":
sawSpecificCategory = true
if s.hasOrigin {
clauses = append(clauses, "(COALESCE(NULLIF(TRIM(s.origin), ''), 'tui') = 'web')")
}
case "ask", "plan", "exec":
sawSpecificCategory = true
clauses = append(clauses, "(s.mode = ?)")
args = append(args, category)
}
if clauses == nil {
break
}
}
if len(clauses) > 0 {
filterClause += " AND (" + strings.Join(clauses, " OR ") + ")"
} else if sawSpecificCategory {
filterClause += " AND 1 = 0"
}
}
if !opts.Archived {
filterClause += " AND s.archived = FALSE"
}
args = append(args, opts.Limit)

rows, err := s.db.QueryContext(ctx, `
SELECT m.session_id, s.number, m.id, s.name, s.summary, snippet(messages_fts, 0, '**', '**', '...', 32),
s.provider, s.model, s.mode, s.status, `+messageCountCol+` as message_count, s.created_at, s.updated_at, m.created_at
FROM messages_fts f
JOIN messages m ON m.id = f.rowid
JOIN sessions s ON s.id = m.session_id
WHERE messages_fts MATCH ?`+compactionTailClause+`
ORDER BY rank
LIMIT ?`, ftsQuery, limit)
WITH raw_matches AS (
SELECT rowid AS message_id, rank AS match_rank
FROM messages_fts
WHERE messages_fts MATCH ?
), ranked_matches AS (
SELECT m.id AS message_id, m.session_id, raw_matches.match_rank,
ROW_NUMBER() OVER (PARTITION BY m.session_id ORDER BY raw_matches.match_rank, m.id) AS session_row
FROM raw_matches
JOIN messages m ON m.id = raw_matches.message_id
JOIN sessions s ON s.id = m.session_id
WHERE 1=1`+compactionTailClause+filterClause+`
), session_matches AS (
SELECT message_id, session_id, match_rank
FROM ranked_matches
WHERE session_row = 1
ORDER BY match_rank, message_id
LIMIT ?
)
SELECT m.session_id, s.number, m.id, s.name, s.summary, `+generatedShortCol+` AS generated_short_title,
`+generatedLongCol+` AS generated_long_title, `+titleSourceCol+` AS title_source,
snippet(messages_fts, 0, '**', '**', '...', 32) AS snippet, s.provider, COALESCE(s.provider_key, '') AS provider_key,
s.model, s.mode, `+originCol+` AS origin, s.archived, `+pinnedCol+` AS pinned, s.status,
`+messageCountCol+` AS message_count, s.created_at, s.updated_at, `+lastMessageAtCol+` AS last_message_at,
m.created_at AS message_created_at
FROM session_matches sm
JOIN messages m ON m.id = sm.message_id
JOIN sessions s ON s.id = sm.session_id
JOIN messages_fts ON messages_fts.rowid = sm.message_id
ORDER BY sm.match_rank, sm.message_id`, args...)
if err != nil {
return nil, fmt.Errorf("search messages: %w", err)
}
Expand All @@ -1546,21 +1630,44 @@ func (s *SQLiteStore) Search(ctx context.Context, query string, limit int) ([]Se
for rows.Next() {
var r SearchResult
var number sql.NullInt64
var mode, status sql.NullString
var generatedShortTitle, generatedLongTitle, titleSource, providerKey, mode, origin, status sql.NullString
var lastMessageAt sql.NullTime
err := rows.Scan(&r.SessionID, &number, &r.MessageID, &r.SessionName, &r.Summary,
&r.Snippet, &r.Provider, &r.Model, &mode, &status, &r.MessageCount, &r.SessionCreatedAt, &r.UpdatedAt, &r.CreatedAt)
&generatedShortTitle, &generatedLongTitle, &titleSource, &r.Snippet, &r.Provider, &providerKey,
&r.Model, &mode, &origin, &r.Archived, &r.Pinned, &status, &r.MessageCount,
&r.SessionCreatedAt, &r.UpdatedAt, &lastMessageAt, &r.CreatedAt)
if err != nil {
return nil, fmt.Errorf("scan search result: %w", err)
}
if number.Valid {
r.SessionNumber = number.Int64
}
if generatedShortTitle.Valid {
r.GeneratedShortTitle = generatedShortTitle.String
}
if generatedLongTitle.Valid {
r.GeneratedLongTitle = generatedLongTitle.String
}
if titleSource.Valid {
r.TitleSource = SessionTitleSource(titleSource.String)
}
if providerKey.Valid {
r.ProviderKey = providerKey.String
}
if mode.Valid {
r.Mode = SessionMode(mode.String)
}
if origin.Valid {
r.Origin = SessionOrigin(origin.String)
} else {
r.Origin = OriginTUI
}
if status.Valid {
r.Status = SessionStatus(status.String)
}
if lastMessageAt.Valid {
r.LastMessageAt = lastMessageAt.Time
}
results = append(results, r)
}
return results, rows.Err()
Expand Down
Loading
Loading