Skip to content

Commit faec828

Browse files
hxy91819steipete
authored andcommitted
fix: bound SQLite import memory
1 parent f55cba8 commit faec828

3 files changed

Lines changed: 233 additions & 4 deletions

File tree

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package share
2+
3+
import (
4+
"compress/gzip"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"os"
9+
"path/filepath"
10+
"strconv"
11+
"strings"
12+
"testing"
13+
"time"
14+
15+
"github.com/stretchr/testify/require"
16+
17+
"github.com/openclaw/discrawl/internal/store"
18+
)
19+
20+
func TestImportRealSnapshot(t *testing.T) {
21+
repo := strings.TrimSpace(os.Getenv("DISCRAWL_REAL_REPO"))
22+
if repo == "" {
23+
t.Skip("set DISCRAWL_REAL_REPO to run real snapshot import validation")
24+
}
25+
26+
ctx := context.Background()
27+
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
28+
require.NoError(t, err)
29+
defer func() { _ = dst.Close() }()
30+
31+
_, changed, err := ImportIfChanged(ctx, dst, Options{
32+
RepoPath: repo,
33+
Branch: "main",
34+
Progress: func(p ImportProgress) {
35+
if p.Phase == "start" || p.Phase == "rebuild_fts" || p.Phase == "done" {
36+
t.Logf("import progress phase=%s total_rows=%d", p.Phase, p.TotalRows)
37+
}
38+
},
39+
})
40+
require.NoError(t, err)
41+
require.True(t, changed)
42+
43+
var messageCount int
44+
require.NoError(t, dst.DB().QueryRowContext(ctx, `select count(*) from messages`).Scan(&messageCount))
45+
require.Positive(t, messageCount)
46+
var ftsCount int
47+
require.NoError(t, dst.DB().QueryRowContext(ctx, `select count(*) from message_fts`).Scan(&ftsCount))
48+
require.Equal(t, messageCount, ftsCount)
49+
}
50+
51+
func TestImportMemoryBounded(t *testing.T) {
52+
if os.Getenv("DISCRAWL_OOM_REGRESSION") != "1" {
53+
t.Skip("set DISCRAWL_OOM_REGRESSION=1 to run the memory-bounded import regression")
54+
}
55+
56+
ctx := context.Background()
57+
repo := t.TempDir()
58+
messageRows := envInt(t, "DISCRAWL_OOM_ROWS", 80000)
59+
textBytes := envInt(t, "DISCRAWL_OOM_TEXT_BYTES", 2048)
60+
t.Logf("building synthetic snapshot rows=%d text_bytes=%d", messageRows, textBytes)
61+
writeSyntheticMemorySnapshot(t, repo, messageRows, textBytes)
62+
t.Log("synthetic snapshot built; starting import")
63+
64+
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
65+
require.NoError(t, err)
66+
defer func() { _ = dst.Close() }()
67+
68+
var progress []ImportProgress
69+
_, changed, err := ImportIfChanged(ctx, dst, Options{
70+
RepoPath: repo,
71+
Branch: "main",
72+
Progress: func(p ImportProgress) { progress = append(progress, p) },
73+
})
74+
require.NoError(t, err)
75+
require.True(t, changed)
76+
require.Contains(t, progressPhases(progress), "rebuild_fts")
77+
78+
needle := "oomunique000042"
79+
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: needle, Limit: 10})
80+
require.NoError(t, err)
81+
require.Len(t, results, 1)
82+
require.Contains(t, results[0].Content, needle)
83+
}
84+
85+
func writeSyntheticMemorySnapshot(t *testing.T, repo string, messageRows, textBytes int) {
86+
t.Helper()
87+
generatedAt := time.Now().UTC()
88+
updatedAt := generatedAt.Format(time.RFC3339Nano)
89+
guildFile := "tables/guilds/000000.jsonl.gz"
90+
channelFile := "tables/channels/000000.jsonl.gz"
91+
memberFile := "tables/members/000000.jsonl.gz"
92+
messageFile := "tables/messages/000000.jsonl.gz"
93+
94+
writeJSONLGzip(t, repo, guildFile, func(enc *json.Encoder) int {
95+
require.NoError(t, enc.Encode(map[string]any{
96+
"id": "g1",
97+
"name": "Guild",
98+
"icon": "",
99+
"raw_json": `{}`,
100+
"updated_at": updatedAt,
101+
}))
102+
return 1
103+
})
104+
writeJSONLGzip(t, repo, channelFile, func(enc *json.Encoder) int {
105+
require.NoError(t, enc.Encode(map[string]any{
106+
"id": "c1",
107+
"guild_id": "g1",
108+
"parent_id": "",
109+
"kind": "text",
110+
"name": "general",
111+
"topic": "",
112+
"position": 0,
113+
"is_nsfw": false,
114+
"is_archived": false,
115+
"is_locked": false,
116+
"is_private_thread": false,
117+
"thread_parent_id": "",
118+
"archive_timestamp": "",
119+
"raw_json": `{}`,
120+
"updated_at": updatedAt,
121+
}))
122+
return 1
123+
})
124+
writeJSONLGzip(t, repo, memberFile, func(enc *json.Encoder) int {
125+
require.NoError(t, enc.Encode(map[string]any{
126+
"guild_id": "g1",
127+
"user_id": "u1",
128+
"username": "peter",
129+
"global_name": "",
130+
"display_name": "Peter",
131+
"nick": "",
132+
"discriminator": "",
133+
"avatar": "",
134+
"bot": false,
135+
"joined_at": "",
136+
"role_ids_json": `[]`,
137+
"raw_json": `{"bio":"memory regression profile"}`,
138+
"updated_at": updatedAt,
139+
}))
140+
return 1
141+
})
142+
writeJSONLGzip(t, repo, messageFile, func(enc *json.Encoder) int {
143+
for i := range messageRows {
144+
messageID := strconv.FormatInt(1456744319972282449+int64(i), 10)
145+
unique := fmt.Sprintf("oomunique%06d", i)
146+
content := syntheticMemoryContent(unique, textBytes)
147+
require.NoError(t, enc.Encode(map[string]any{
148+
"id": messageID,
149+
"guild_id": "g1",
150+
"channel_id": "c1",
151+
"author_id": "u1",
152+
"message_type": 0,
153+
"created_at": updatedAt,
154+
"edited_at": "",
155+
"deleted_at": "",
156+
"content": content,
157+
"normalized_content": content,
158+
"reply_to_message_id": "",
159+
"pinned": false,
160+
"has_attachments": false,
161+
"raw_json": `{"author":{"username":"Peter"}}`,
162+
"updated_at": updatedAt,
163+
}))
164+
}
165+
return messageRows
166+
})
167+
168+
manifest := Manifest{
169+
Version: 1,
170+
GeneratedAt: generatedAt,
171+
Tables: []TableManifest{
172+
{Name: "guilds", Files: []string{guildFile}, Columns: []string{"id", "name", "icon", "raw_json", "updated_at"}, Rows: 1},
173+
{Name: "channels", Files: []string{channelFile}, Columns: []string{"id", "guild_id", "parent_id", "kind", "name", "topic", "position", "is_nsfw", "is_archived", "is_locked", "is_private_thread", "thread_parent_id", "archive_timestamp", "raw_json", "updated_at"}, Rows: 1},
174+
{Name: "members", Files: []string{memberFile}, Columns: []string{"guild_id", "user_id", "username", "global_name", "display_name", "nick", "discriminator", "avatar", "bot", "joined_at", "role_ids_json", "raw_json", "updated_at"}, Rows: 1},
175+
{Name: "messages", Files: []string{messageFile}, Columns: []string{"id", "guild_id", "channel_id", "author_id", "message_type", "created_at", "edited_at", "deleted_at", "content", "normalized_content", "reply_to_message_id", "pinned", "has_attachments", "raw_json", "updated_at"}, Rows: messageRows},
176+
},
177+
}
178+
body, err := json.MarshalIndent(manifest, "", " ")
179+
require.NoError(t, err)
180+
require.NoError(t, os.WriteFile(filepath.Join(repo, ManifestName), append(body, '\n'), 0o600))
181+
}
182+
183+
func writeJSONLGzip(t *testing.T, repo, rel string, writeRows func(*json.Encoder) int) {
184+
t.Helper()
185+
path := filepath.Join(repo, filepath.FromSlash(rel))
186+
require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755))
187+
file, err := os.Create(path)
188+
require.NoError(t, err)
189+
gz := gzip.NewWriter(file)
190+
enc := json.NewEncoder(gz)
191+
rows := writeRows(enc)
192+
require.Positive(t, rows)
193+
require.NoError(t, gz.Close())
194+
require.NoError(t, file.Close())
195+
}
196+
197+
func syntheticMemoryContent(unique string, size int) string {
198+
if size <= len(unique) {
199+
return unique
200+
}
201+
var b strings.Builder
202+
b.Grow(size)
203+
b.WriteString(unique)
204+
for i := 0; b.Len() < size; i++ {
205+
_, _ = fmt.Fprintf(&b, " token_%s_%04d", unique, i)
206+
}
207+
return b.String()[:size]
208+
}
209+
210+
func envInt(t *testing.T, name string, fallback int) int {
211+
t.Helper()
212+
raw := strings.TrimSpace(os.Getenv(name))
213+
if raw == "" {
214+
return fallback
215+
}
216+
value, err := strconv.Atoi(raw)
217+
require.NoError(t, err)
218+
require.Positive(t, value)
219+
return value
220+
}

internal/share/share.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,11 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
252252
func applyImportPragmas(ctx context.Context, db *sql.DB) (func(context.Context) error, error) {
253253
// Snapshot imports touch most of the archive. Keep SQLite's crash recovery
254254
// enabled; journal_mode=off can leave the live DB malformed if the process
255-
// or host dies mid-import.
255+
// or host dies mid-import. Keep temporary storage file-backed and bound the
256+
// page cache so large imports and FTS rebuilds do not exhaust small hosts.
256257
for _, stmt := range []string{
257-
`pragma temp_store = memory`,
258-
`pragma cache_size = -262144`,
258+
`pragma temp_store = file`,
259+
`pragma cache_size = -32768`,
259260
`pragma journal_mode = wal`,
260261
`pragma synchronous = normal`,
261262
} {

internal/share/share_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func TestImportIfChangedInfersLegacyManifestFilesFromGit(t *testing.T) {
311311
require.Len(t, results, 1)
312312
}
313313

314-
func TestApplyImportPragmasKeepCrashRecoveryEnabled(t *testing.T) {
314+
func TestApplyImportPragmasBoundImportMemory(t *testing.T) {
315315
ctx := context.Background()
316316
s := seedStore(t, filepath.Join(t.TempDir(), "dst.db"))
317317
defer func() { _ = s.Close() }()
@@ -320,6 +320,14 @@ func TestApplyImportPragmasKeepCrashRecoveryEnabled(t *testing.T) {
320320
require.NoError(t, err)
321321
defer func() { require.NoError(t, restore(ctx)) }()
322322

323+
var tempStore int
324+
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma temp_store`).Scan(&tempStore))
325+
require.Equal(t, 1, tempStore, "snapshot imports should use file-backed temporary storage")
326+
327+
var cacheSize int
328+
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma cache_size`).Scan(&cacheSize))
329+
require.GreaterOrEqual(t, cacheSize, -65536)
330+
323331
var journalMode string
324332
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma journal_mode`).Scan(&journalMode))
325333
require.NotEqual(t, "off", strings.ToLower(journalMode))

0 commit comments

Comments
 (0)