|
| 1 | +package share |
| 2 | + |
| 3 | +import ( |
| 4 | + "compress/gzip" |
| 5 | + "context" |
| 6 | + "encoding/json" |
| 7 | + "fmt" |
| 8 | + "os" |
| 9 | + "path/filepath" |
| 10 | + "strconv" |
| 11 | + "strings" |
| 12 | + "testing" |
| 13 | + "time" |
| 14 | + |
| 15 | + "github.com/stretchr/testify/require" |
| 16 | + |
| 17 | + "github.com/openclaw/discrawl/internal/store" |
| 18 | +) |
| 19 | + |
| 20 | +func TestImportRealSnapshot(t *testing.T) { |
| 21 | + repo := strings.TrimSpace(os.Getenv("DISCRAWL_REAL_REPO")) |
| 22 | + if repo == "" { |
| 23 | + t.Skip("set DISCRAWL_REAL_REPO to run real snapshot import validation") |
| 24 | + } |
| 25 | + |
| 26 | + ctx := context.Background() |
| 27 | + dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db")) |
| 28 | + require.NoError(t, err) |
| 29 | + defer func() { _ = dst.Close() }() |
| 30 | + |
| 31 | + _, changed, err := ImportIfChanged(ctx, dst, Options{ |
| 32 | + RepoPath: repo, |
| 33 | + Branch: "main", |
| 34 | + Progress: func(p ImportProgress) { |
| 35 | + if p.Phase == "start" || p.Phase == "rebuild_fts" || p.Phase == "done" { |
| 36 | + t.Logf("import progress phase=%s total_rows=%d", p.Phase, p.TotalRows) |
| 37 | + } |
| 38 | + }, |
| 39 | + }) |
| 40 | + require.NoError(t, err) |
| 41 | + require.True(t, changed) |
| 42 | + |
| 43 | + var messageCount int |
| 44 | + require.NoError(t, dst.DB().QueryRowContext(ctx, `select count(*) from messages`).Scan(&messageCount)) |
| 45 | + require.Positive(t, messageCount) |
| 46 | + var ftsCount int |
| 47 | + require.NoError(t, dst.DB().QueryRowContext(ctx, `select count(*) from message_fts`).Scan(&ftsCount)) |
| 48 | + require.Equal(t, messageCount, ftsCount) |
| 49 | +} |
| 50 | + |
| 51 | +func TestImportMemoryBounded(t *testing.T) { |
| 52 | + if os.Getenv("DISCRAWL_OOM_REGRESSION") != "1" { |
| 53 | + t.Skip("set DISCRAWL_OOM_REGRESSION=1 to run the memory-bounded import regression") |
| 54 | + } |
| 55 | + |
| 56 | + ctx := context.Background() |
| 57 | + repo := t.TempDir() |
| 58 | + messageRows := envInt(t, "DISCRAWL_OOM_ROWS", 80000) |
| 59 | + textBytes := envInt(t, "DISCRAWL_OOM_TEXT_BYTES", 2048) |
| 60 | + t.Logf("building synthetic snapshot rows=%d text_bytes=%d", messageRows, textBytes) |
| 61 | + writeSyntheticMemorySnapshot(t, repo, messageRows, textBytes) |
| 62 | + t.Log("synthetic snapshot built; starting import") |
| 63 | + |
| 64 | + dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db")) |
| 65 | + require.NoError(t, err) |
| 66 | + defer func() { _ = dst.Close() }() |
| 67 | + |
| 68 | + var progress []ImportProgress |
| 69 | + _, changed, err := ImportIfChanged(ctx, dst, Options{ |
| 70 | + RepoPath: repo, |
| 71 | + Branch: "main", |
| 72 | + Progress: func(p ImportProgress) { progress = append(progress, p) }, |
| 73 | + }) |
| 74 | + require.NoError(t, err) |
| 75 | + require.True(t, changed) |
| 76 | + require.Contains(t, progressPhases(progress), "rebuild_fts") |
| 77 | + |
| 78 | + needle := "oomunique000042" |
| 79 | + results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: needle, Limit: 10}) |
| 80 | + require.NoError(t, err) |
| 81 | + require.Len(t, results, 1) |
| 82 | + require.Contains(t, results[0].Content, needle) |
| 83 | +} |
| 84 | + |
| 85 | +func writeSyntheticMemorySnapshot(t *testing.T, repo string, messageRows, textBytes int) { |
| 86 | + t.Helper() |
| 87 | + generatedAt := time.Now().UTC() |
| 88 | + updatedAt := generatedAt.Format(time.RFC3339Nano) |
| 89 | + guildFile := "tables/guilds/000000.jsonl.gz" |
| 90 | + channelFile := "tables/channels/000000.jsonl.gz" |
| 91 | + memberFile := "tables/members/000000.jsonl.gz" |
| 92 | + messageFile := "tables/messages/000000.jsonl.gz" |
| 93 | + |
| 94 | + writeJSONLGzip(t, repo, guildFile, func(enc *json.Encoder) int { |
| 95 | + require.NoError(t, enc.Encode(map[string]any{ |
| 96 | + "id": "g1", |
| 97 | + "name": "Guild", |
| 98 | + "icon": "", |
| 99 | + "raw_json": `{}`, |
| 100 | + "updated_at": updatedAt, |
| 101 | + })) |
| 102 | + return 1 |
| 103 | + }) |
| 104 | + writeJSONLGzip(t, repo, channelFile, func(enc *json.Encoder) int { |
| 105 | + require.NoError(t, enc.Encode(map[string]any{ |
| 106 | + "id": "c1", |
| 107 | + "guild_id": "g1", |
| 108 | + "parent_id": "", |
| 109 | + "kind": "text", |
| 110 | + "name": "general", |
| 111 | + "topic": "", |
| 112 | + "position": 0, |
| 113 | + "is_nsfw": false, |
| 114 | + "is_archived": false, |
| 115 | + "is_locked": false, |
| 116 | + "is_private_thread": false, |
| 117 | + "thread_parent_id": "", |
| 118 | + "archive_timestamp": "", |
| 119 | + "raw_json": `{}`, |
| 120 | + "updated_at": updatedAt, |
| 121 | + })) |
| 122 | + return 1 |
| 123 | + }) |
| 124 | + writeJSONLGzip(t, repo, memberFile, func(enc *json.Encoder) int { |
| 125 | + require.NoError(t, enc.Encode(map[string]any{ |
| 126 | + "guild_id": "g1", |
| 127 | + "user_id": "u1", |
| 128 | + "username": "peter", |
| 129 | + "global_name": "", |
| 130 | + "display_name": "Peter", |
| 131 | + "nick": "", |
| 132 | + "discriminator": "", |
| 133 | + "avatar": "", |
| 134 | + "bot": false, |
| 135 | + "joined_at": "", |
| 136 | + "role_ids_json": `[]`, |
| 137 | + "raw_json": `{"bio":"memory regression profile"}`, |
| 138 | + "updated_at": updatedAt, |
| 139 | + })) |
| 140 | + return 1 |
| 141 | + }) |
| 142 | + writeJSONLGzip(t, repo, messageFile, func(enc *json.Encoder) int { |
| 143 | + for i := range messageRows { |
| 144 | + messageID := strconv.FormatInt(1456744319972282449+int64(i), 10) |
| 145 | + unique := fmt.Sprintf("oomunique%06d", i) |
| 146 | + content := syntheticMemoryContent(unique, textBytes) |
| 147 | + require.NoError(t, enc.Encode(map[string]any{ |
| 148 | + "id": messageID, |
| 149 | + "guild_id": "g1", |
| 150 | + "channel_id": "c1", |
| 151 | + "author_id": "u1", |
| 152 | + "message_type": 0, |
| 153 | + "created_at": updatedAt, |
| 154 | + "edited_at": "", |
| 155 | + "deleted_at": "", |
| 156 | + "content": content, |
| 157 | + "normalized_content": content, |
| 158 | + "reply_to_message_id": "", |
| 159 | + "pinned": false, |
| 160 | + "has_attachments": false, |
| 161 | + "raw_json": `{"author":{"username":"Peter"}}`, |
| 162 | + "updated_at": updatedAt, |
| 163 | + })) |
| 164 | + } |
| 165 | + return messageRows |
| 166 | + }) |
| 167 | + |
| 168 | + manifest := Manifest{ |
| 169 | + Version: 1, |
| 170 | + GeneratedAt: generatedAt, |
| 171 | + Tables: []TableManifest{ |
| 172 | + {Name: "guilds", Files: []string{guildFile}, Columns: []string{"id", "name", "icon", "raw_json", "updated_at"}, Rows: 1}, |
| 173 | + {Name: "channels", Files: []string{channelFile}, Columns: []string{"id", "guild_id", "parent_id", "kind", "name", "topic", "position", "is_nsfw", "is_archived", "is_locked", "is_private_thread", "thread_parent_id", "archive_timestamp", "raw_json", "updated_at"}, Rows: 1}, |
| 174 | + {Name: "members", Files: []string{memberFile}, Columns: []string{"guild_id", "user_id", "username", "global_name", "display_name", "nick", "discriminator", "avatar", "bot", "joined_at", "role_ids_json", "raw_json", "updated_at"}, Rows: 1}, |
| 175 | + {Name: "messages", Files: []string{messageFile}, Columns: []string{"id", "guild_id", "channel_id", "author_id", "message_type", "created_at", "edited_at", "deleted_at", "content", "normalized_content", "reply_to_message_id", "pinned", "has_attachments", "raw_json", "updated_at"}, Rows: messageRows}, |
| 176 | + }, |
| 177 | + } |
| 178 | + body, err := json.MarshalIndent(manifest, "", " ") |
| 179 | + require.NoError(t, err) |
| 180 | + require.NoError(t, os.WriteFile(filepath.Join(repo, ManifestName), append(body, '\n'), 0o600)) |
| 181 | +} |
| 182 | + |
| 183 | +func writeJSONLGzip(t *testing.T, repo, rel string, writeRows func(*json.Encoder) int) { |
| 184 | + t.Helper() |
| 185 | + path := filepath.Join(repo, filepath.FromSlash(rel)) |
| 186 | + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755)) |
| 187 | + file, err := os.Create(path) |
| 188 | + require.NoError(t, err) |
| 189 | + gz := gzip.NewWriter(file) |
| 190 | + enc := json.NewEncoder(gz) |
| 191 | + rows := writeRows(enc) |
| 192 | + require.Positive(t, rows) |
| 193 | + require.NoError(t, gz.Close()) |
| 194 | + require.NoError(t, file.Close()) |
| 195 | +} |
| 196 | + |
| 197 | +func syntheticMemoryContent(unique string, size int) string { |
| 198 | + if size <= len(unique) { |
| 199 | + return unique |
| 200 | + } |
| 201 | + var b strings.Builder |
| 202 | + b.Grow(size) |
| 203 | + b.WriteString(unique) |
| 204 | + for i := 0; b.Len() < size; i++ { |
| 205 | + _, _ = fmt.Fprintf(&b, " token_%s_%04d", unique, i) |
| 206 | + } |
| 207 | + return b.String()[:size] |
| 208 | +} |
| 209 | + |
| 210 | +func envInt(t *testing.T, name string, fallback int) int { |
| 211 | + t.Helper() |
| 212 | + raw := strings.TrimSpace(os.Getenv(name)) |
| 213 | + if raw == "" { |
| 214 | + return fallback |
| 215 | + } |
| 216 | + value, err := strconv.Atoi(raw) |
| 217 | + require.NoError(t, err) |
| 218 | + require.Positive(t, value) |
| 219 | + return value |
| 220 | +} |
0 commit comments