From e63c35206551f1ae7cd0043f4cc79aa13d7977d0 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Tue, 24 Feb 2026 17:10:43 -0600 Subject: [PATCH 1/4] fix(restore): add post-restore integrity validation Add IntegrityCheckMode type with None/Quick/Full modes and a checkIntegrity() function that runs PRAGMA quick_check or integrity_check on restored databases. EnsureExists() now defaults to IntegrityCheckQuick so K8s init containers catch corrupt restores before the application starts. Adds -integrity-check CLI flag to the restore command for manual use. Closes #1164 --- cmd/litestream/restore.go | 17 ++++++++ db.go | 14 +++++++ replica.go | 68 ++++++++++++++++++++++++++++++++ replica_internal_test.go | 83 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 182 insertions(+) diff --git a/cmd/litestream/restore.go b/cmd/litestream/restore.go index d6ac7e2af..71d02f105 100644 --- a/cmd/litestream/restore.go +++ b/cmd/litestream/restore.go @@ -30,6 +30,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { timestampStr := fs.String("timestamp", "", "timestamp") fs.BoolVar(&opt.Follow, "f", false, "follow mode") fs.DurationVar(&opt.FollowInterval, "follow-interval", opt.FollowInterval, "polling interval for follow mode") + integrityCheck := fs.String("integrity-check", "none", "post-restore integrity check: none, quick, or full") fs.Usage = c.Usage if err := fs.Parse(args); err != nil { return err @@ -57,6 +58,17 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { defer cancel() } + switch *integrityCheck { + case "none": + opt.IntegrityCheck = litestream.IntegrityCheckNone + case "quick": + opt.IntegrityCheck = litestream.IntegrityCheckQuick + case "full": + opt.IntegrityCheck = litestream.IntegrityCheckFull + default: + return fmt.Errorf("invalid -integrity-check value: %s", *integrityCheck) + } + // Parse timestamp, if specified. if *timestampStr != "" { if opt.Timestamp, err = time.Parse(time.RFC3339, *timestampStr); err != nil { @@ -211,6 +223,11 @@ Arguments: Determines the number of WAL files downloaded in parallel. Defaults to `+strconv.Itoa(litestream.DefaultRestoreParallelism)+`. + -integrity-check MODE + Run a post-restore integrity check on the database. + MODE is one of: none, quick, full. + Defaults to none. + Examples: diff --git a/db.go b/db.go index 968ffe63f..2c3ed78d4 100644 --- a/db.go +++ b/db.go @@ -500,6 +500,7 @@ func (db *DB) EnsureExists(ctx context.Context) error { opt := NewRestoreOptions() opt.OutputPath = db.Path() + opt.IntegrityCheck = IntegrityCheckQuick if err := db.Replica.Restore(ctx, opt); err != nil { if errors.Is(err, ErrTxNotAvailable) || errors.Is(err, ErrNoSnapshots) { @@ -2367,6 +2368,15 @@ const DefaultRestoreParallelism = 8 // DefaultFollowInterval is the default polling interval for follow mode. const DefaultFollowInterval = 1 * time.Second +// IntegrityCheckMode specifies the level of integrity checking after restore. +type IntegrityCheckMode int + +const ( + IntegrityCheckNone IntegrityCheckMode = iota + IntegrityCheckQuick + IntegrityCheckFull +) + // RestoreOptions represents options for DB.Restore(). type RestoreOptions struct { // Target path to restore into. @@ -2390,6 +2400,10 @@ type RestoreOptions struct { // FollowInterval specifies how often to poll for new LTX files in follow mode. FollowInterval time.Duration + + // IntegrityCheck specifies the level of integrity checking after restore. + // Zero value (IntegrityCheckNone) skips the check for backward compatibility. + IntegrityCheck IntegrityCheckMode } // NewRestoreOptions returns a new instance of RestoreOptions with defaults. diff --git a/replica.go b/replica.go index 1c028d302..73d0c00ae 100644 --- a/replica.go +++ b/replica.go @@ -680,6 +680,16 @@ func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) { return err } + if opt.IntegrityCheck != IntegrityCheckNone { + if err := checkIntegrity(opt.OutputPath, opt.IntegrityCheck); err != nil { + _ = os.Remove(opt.OutputPath) + _ = os.Remove(opt.OutputPath + "-shm") + _ = os.Remove(opt.OutputPath + "-wal") + return fmt.Errorf("post-restore integrity check: %w", err) + } + r.Logger().Info("post-restore integrity check passed") + } + // Enter follow mode if enabled, continuously applying new LTX files. if opt.Follow { for _, rd := range rdrs { @@ -1053,6 +1063,16 @@ func (r *Replica) RestoreV3(ctx context.Context, opt RestoreOptions) error { return fmt.Errorf("rename to output path: %w", err) } + if opt.IntegrityCheck != IntegrityCheckNone { + if err := checkIntegrity(opt.OutputPath, opt.IntegrityCheck); err != nil { + _ = os.Remove(opt.OutputPath) + _ = os.Remove(opt.OutputPath + "-shm") + _ = os.Remove(opt.OutputPath + "-wal") + return fmt.Errorf("post-restore integrity check: %w", err) + } + r.Logger().Info("post-restore integrity check passed") + } + return nil } @@ -1177,6 +1197,54 @@ func checkpointV3(dbPath string) error { return err } +// checkIntegrity runs a SQLite integrity check on the database at dbPath. +func checkIntegrity(dbPath string, mode IntegrityCheckMode) error { + if mode == IntegrityCheckNone { + return nil + } + + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return fmt.Errorf("open database for integrity check: %w", err) + } + defer func() { _ = db.Close() }() + + pragma := "quick_check" + if mode == IntegrityCheckFull { + pragma = "integrity_check" + } + + rows, err := db.Query("PRAGMA " + pragma) + if err != nil { + return fmt.Errorf("integrity check: %w", err) + } + defer func() { _ = rows.Close() }() + + var results []string + for rows.Next() { + var result string + if err := rows.Scan(&result); err != nil { + return fmt.Errorf("scan integrity check result: %w", err) + } + if result != "ok" { + results = append(results, result) + } + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate integrity check results: %w", err) + } + + if len(results) > 0 { + return fmt.Errorf("integrity check failed: %s", results[0]) + } + + // Clean up -shm and -wal files that SQLite may create during the PRAGMA. + _ = os.Remove(dbPath + "-shm") + _ = os.Remove(dbPath + "-wal") + + return nil +} + // findBestV3SnapshotForTimestamp returns the best v0.3.x snapshot for the given timestamp. // Returns nil if no suitable snapshot exists. func (r *Replica) findBestV3SnapshotForTimestamp(ctx context.Context, client ReplicaClientV3, timestamp time.Time) (*SnapshotInfoV3, error) { diff --git a/replica_internal_test.go b/replica_internal_test.go index 9ef6f5e20..f2db3696c 100644 --- a/replica_internal_test.go +++ b/replica_internal_test.go @@ -3,6 +3,7 @@ package litestream import ( "bytes" "context" + "database/sql" "fmt" "io" "os" @@ -11,6 +12,8 @@ import ( "time" "github.com/superfly/ltx" + + _ "modernc.org/sqlite" ) func TestReplica_ApplyNewLTXFiles_FillGapWithOverlappingCompactedFile(t *testing.T) { @@ -277,3 +280,83 @@ func mustCreateWritableDBFile(tb testing.TB) *os.File { func ltxFixtureKey(level int, minTXID, maxTXID ltx.TXID) string { return fmt.Sprintf("%d:%s:%s", level, minTXID, maxTXID) } + +func mustCreateValidSQLiteDB(tb testing.TB) string { + tb.Helper() + dbPath := filepath.Join(tb.TempDir(), "test.db") + db, err := sql.Open("sqlite", dbPath) + if err != nil { + tb.Fatal(err) + } + defer func() { _ = db.Close() }() + if _, err := db.Exec("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)"); err != nil { + tb.Fatal(err) + } + if _, err := db.Exec("INSERT INTO t (name) VALUES ('a'), ('b'), ('c')"); err != nil { + tb.Fatal(err) + } + if _, err := db.Exec("CREATE INDEX idx_t_name ON t(name)"); err != nil { + tb.Fatal(err) + } + return dbPath +} + +func TestCheckIntegrity_Quick_ValidDB(t *testing.T) { + dbPath := mustCreateValidSQLiteDB(t) + if err := checkIntegrity(dbPath, IntegrityCheckQuick); err != nil { + t.Fatalf("expected no error, got: %v", err) + } +} + +func TestCheckIntegrity_Full_ValidDB(t *testing.T) { + dbPath := mustCreateValidSQLiteDB(t) + if err := checkIntegrity(dbPath, IntegrityCheckFull); err != nil { + t.Fatalf("expected no error, got: %v", err) + } +} + +func TestCheckIntegrity_None_Skips(t *testing.T) { + if err := checkIntegrity("/nonexistent/path.db", IntegrityCheckNone); err != nil { + t.Fatalf("expected nil for IntegrityCheckNone, got: %v", err) + } +} + +func TestCheckIntegrity_CorruptDB(t *testing.T) { + dbPath := mustCreateValidSQLiteDB(t) + + // Remove any WAL/SHM files so we have a clean single-file database. + _ = os.Remove(dbPath + "-wal") + _ = os.Remove(dbPath + "-shm") + + // Read the page size from the database header (bytes 16-17, big-endian). + f, err := os.OpenFile(dbPath, os.O_RDWR, 0o600) + if err != nil { + t.Fatal(err) + } + + // Corrupt page 2 onwards. Page 1 is the header/schema page. Corrupting + // pages that contain table/index data triggers integrity check failures. + // We overwrite from byte offset 4096 (start of page 2 for 4096-byte pages, + // which is the default) with garbage data. + info, err := f.Stat() + if err != nil { + _ = f.Close() + t.Fatal(err) + } + + // Overwrite everything after the first page with garbage to ensure corruption. + pageSize := int64(4096) + if info.Size() > pageSize { + garbage := bytes.Repeat([]byte{0xDE}, int(info.Size()-pageSize)) + if _, err := f.WriteAt(garbage, pageSize); err != nil { + _ = f.Close() + t.Fatal(err) + } + } + _ = f.Close() + + err = checkIntegrity(dbPath, IntegrityCheckFull) + if err == nil { + t.Fatal("expected integrity check to fail on corrupt database") + } +} From 9e40321e6cbaaafe4008e7425a7135d1452d35f9 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Thu, 26 Feb 2026 16:59:07 -0600 Subject: [PATCH 2/4] fix(restore): accept context in checkIntegrity and reject invalid modes - Add context.Context parameter to checkIntegrity() so cancellation/timeout is respected during long-running PRAGMAs - Use explicit switch on IntegrityCheckMode to reject unsupported values instead of silently falling through to quick_check --- replica.go | 17 +++++++++++------ replica_internal_test.go | 8 ++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/replica.go b/replica.go index 73d0c00ae..f178cbe20 100644 --- a/replica.go +++ b/replica.go @@ -681,7 +681,7 @@ func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) { } if opt.IntegrityCheck != IntegrityCheckNone { - if err := checkIntegrity(opt.OutputPath, opt.IntegrityCheck); err != nil { + if err := checkIntegrity(ctx, opt.OutputPath, opt.IntegrityCheck); err != nil { _ = os.Remove(opt.OutputPath) _ = os.Remove(opt.OutputPath + "-shm") _ = os.Remove(opt.OutputPath + "-wal") @@ -1064,7 +1064,7 @@ func (r *Replica) RestoreV3(ctx context.Context, opt RestoreOptions) error { } if opt.IntegrityCheck != IntegrityCheckNone { - if err := checkIntegrity(opt.OutputPath, opt.IntegrityCheck); err != nil { + if err := checkIntegrity(ctx, opt.OutputPath, opt.IntegrityCheck); err != nil { _ = os.Remove(opt.OutputPath) _ = os.Remove(opt.OutputPath + "-shm") _ = os.Remove(opt.OutputPath + "-wal") @@ -1198,7 +1198,7 @@ func checkpointV3(dbPath string) error { } // checkIntegrity runs a SQLite integrity check on the database at dbPath. -func checkIntegrity(dbPath string, mode IntegrityCheckMode) error { +func checkIntegrity(ctx context.Context, dbPath string, mode IntegrityCheckMode) error { if mode == IntegrityCheckNone { return nil } @@ -1209,12 +1209,17 @@ func checkIntegrity(dbPath string, mode IntegrityCheckMode) error { } defer func() { _ = db.Close() }() - pragma := "quick_check" - if mode == IntegrityCheckFull { + var pragma string + switch mode { + case IntegrityCheckQuick: + pragma = "quick_check" + case IntegrityCheckFull: pragma = "integrity_check" + default: + return fmt.Errorf("unsupported integrity check mode: %d", mode) } - rows, err := db.Query("PRAGMA " + pragma) + rows, err := db.QueryContext(ctx, "PRAGMA "+pragma) if err != nil { return fmt.Errorf("integrity check: %w", err) } diff --git a/replica_internal_test.go b/replica_internal_test.go index f2db3696c..ba9a0f208 100644 --- a/replica_internal_test.go +++ b/replica_internal_test.go @@ -303,20 +303,20 @@ func mustCreateValidSQLiteDB(tb testing.TB) string { func TestCheckIntegrity_Quick_ValidDB(t *testing.T) { dbPath := mustCreateValidSQLiteDB(t) - if err := checkIntegrity(dbPath, IntegrityCheckQuick); err != nil { + if err := checkIntegrity(context.Background(), dbPath, IntegrityCheckQuick); err != nil { t.Fatalf("expected no error, got: %v", err) } } func TestCheckIntegrity_Full_ValidDB(t *testing.T) { dbPath := mustCreateValidSQLiteDB(t) - if err := checkIntegrity(dbPath, IntegrityCheckFull); err != nil { + if err := checkIntegrity(context.Background(), dbPath, IntegrityCheckFull); err != nil { t.Fatalf("expected no error, got: %v", err) } } func TestCheckIntegrity_None_Skips(t *testing.T) { - if err := checkIntegrity("/nonexistent/path.db", IntegrityCheckNone); err != nil { + if err := checkIntegrity(context.Background(), "/nonexistent/path.db", IntegrityCheckNone); err != nil { t.Fatalf("expected nil for IntegrityCheckNone, got: %v", err) } } @@ -355,7 +355,7 @@ func TestCheckIntegrity_CorruptDB(t *testing.T) { } _ = f.Close() - err = checkIntegrity(dbPath, IntegrityCheckFull) + err = checkIntegrity(context.Background(), dbPath, IntegrityCheckFull) if err == nil { t.Fatal("expected integrity check to fail on corrupt database") } From c0720ad96a04ebe8adb505dfc8a46d04d9b5ba78 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Sat, 7 Mar 2026 16:46:26 -0600 Subject: [PATCH 3/4] refactor(restore): simplify integrity check to use QueryRowContext Use QueryRowContext() instead of QueryContext() with row iteration since the integrity check returns a single result row. --- replica.go | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/replica.go b/replica.go index f178cbe20..bfd3d3da8 100644 --- a/replica.go +++ b/replica.go @@ -1219,28 +1219,12 @@ func checkIntegrity(ctx context.Context, dbPath string, mode IntegrityCheckMode) return fmt.Errorf("unsupported integrity check mode: %d", mode) } - rows, err := db.QueryContext(ctx, "PRAGMA "+pragma) - if err != nil { + var result string + if err := db.QueryRowContext(ctx, "PRAGMA "+pragma).Scan(&result); err != nil { return fmt.Errorf("integrity check: %w", err) } - defer func() { _ = rows.Close() }() - - var results []string - for rows.Next() { - var result string - if err := rows.Scan(&result); err != nil { - return fmt.Errorf("scan integrity check result: %w", err) - } - if result != "ok" { - results = append(results, result) - } - } - if err := rows.Err(); err != nil { - return fmt.Errorf("iterate integrity check results: %w", err) - } - - if len(results) > 0 { - return fmt.Errorf("integrity check failed: %s", results[0]) + if result != "ok" { + return fmt.Errorf("integrity check failed: %s", result) } // Clean up -shm and -wal files that SQLite may create during the PRAGMA. From d9ed7f2c526a21589a6a05001c9976689141ffc5 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Sat, 7 Mar 2026 19:02:34 -0600 Subject: [PATCH 4/4] fix(restore): validate integrity mode upfront and preserve DB on cancellation - Validate IntegrityCheckMode in Restore() and RestoreV3() before doing any restore work, preventing invalid modes from causing unnecessary restore followed by deletion - Only delete restored DB on actual integrity failures, not on context cancellation or timeout, preserving valid restores interrupted by Ctrl+C or deadline --- replica.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/replica.go b/replica.go index bfd3d3da8..e00589a76 100644 --- a/replica.go +++ b/replica.go @@ -529,6 +529,8 @@ func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) { return fmt.Errorf("cannot use follow mode with -txid") } else if opt.Follow && !opt.Timestamp.IsZero() { return fmt.Errorf("cannot use follow mode with -timestamp") + } else if opt.IntegrityCheck != IntegrityCheckNone && opt.IntegrityCheck != IntegrityCheckQuick && opt.IntegrityCheck != IntegrityCheckFull { + return fmt.Errorf("unsupported integrity check mode: %d", opt.IntegrityCheck) } // In follow mode, if the database already exists, attempt crash recovery @@ -682,9 +684,11 @@ func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) { if opt.IntegrityCheck != IntegrityCheckNone { if err := checkIntegrity(ctx, opt.OutputPath, opt.IntegrityCheck); err != nil { - _ = os.Remove(opt.OutputPath) - _ = os.Remove(opt.OutputPath + "-shm") - _ = os.Remove(opt.OutputPath + "-wal") + if ctx.Err() == nil { + _ = os.Remove(opt.OutputPath) + _ = os.Remove(opt.OutputPath + "-shm") + _ = os.Remove(opt.OutputPath + "-wal") + } return fmt.Errorf("post-restore integrity check: %w", err) } r.Logger().Info("post-restore integrity check passed") @@ -981,6 +985,8 @@ func (r *Replica) RestoreV3(ctx context.Context, opt RestoreOptions) error { // Validate options. if opt.OutputPath == "" { return fmt.Errorf("output path required") + } else if opt.IntegrityCheck != IntegrityCheckNone && opt.IntegrityCheck != IntegrityCheckQuick && opt.IntegrityCheck != IntegrityCheckFull { + return fmt.Errorf("unsupported integrity check mode: %d", opt.IntegrityCheck) } // Ensure output path does not already exist. @@ -1065,9 +1071,11 @@ func (r *Replica) RestoreV3(ctx context.Context, opt RestoreOptions) error { if opt.IntegrityCheck != IntegrityCheckNone { if err := checkIntegrity(ctx, opt.OutputPath, opt.IntegrityCheck); err != nil { - _ = os.Remove(opt.OutputPath) - _ = os.Remove(opt.OutputPath + "-shm") - _ = os.Remove(opt.OutputPath + "-wal") + if ctx.Err() == nil { + _ = os.Remove(opt.OutputPath) + _ = os.Remove(opt.OutputPath + "-shm") + _ = os.Remove(opt.OutputPath + "-wal") + } return fmt.Errorf("post-restore integrity check: %w", err) } r.Logger().Info("post-restore integrity check passed")