Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cmd/litestream/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
timestampStr := fs.String("timestamp", "", "timestamp")
fs.BoolVar(&opt.Follow, "f", false, "follow mode")
fs.DurationVar(&opt.FollowInterval, "follow-interval", opt.FollowInterval, "polling interval for follow mode")
integrityCheck := fs.String("integrity-check", "none", "post-restore integrity check: none, quick, or full")
fs.Usage = c.Usage
if err := fs.Parse(args); err != nil {
return err
Expand Down Expand Up @@ -57,6 +58,17 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
defer cancel()
}

switch *integrityCheck {
case "none":
opt.IntegrityCheck = litestream.IntegrityCheckNone
case "quick":
opt.IntegrityCheck = litestream.IntegrityCheckQuick
case "full":
opt.IntegrityCheck = litestream.IntegrityCheckFull
default:
return fmt.Errorf("invalid -integrity-check value: %s", *integrityCheck)
}

// Parse timestamp, if specified.
if *timestampStr != "" {
if opt.Timestamp, err = time.Parse(time.RFC3339, *timestampStr); err != nil {
Expand Down Expand Up @@ -211,6 +223,11 @@ Arguments:
Determines the number of WAL files downloaded in parallel.
Defaults to `+strconv.Itoa(litestream.DefaultRestoreParallelism)+`.

-integrity-check MODE
Run a post-restore integrity check on the database.
MODE is one of: none, quick, full.
Defaults to none.


Examples:

Expand Down
14 changes: 14 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ func (db *DB) EnsureExists(ctx context.Context) error {

opt := NewRestoreOptions()
opt.OutputPath = db.Path()
opt.IntegrityCheck = IntegrityCheckQuick

if err := db.Replica.Restore(ctx, opt); err != nil {
if errors.Is(err, ErrTxNotAvailable) || errors.Is(err, ErrNoSnapshots) {
Expand Down Expand Up @@ -2367,6 +2368,15 @@ const DefaultRestoreParallelism = 8
// DefaultFollowInterval is the default polling interval for follow mode.
const DefaultFollowInterval = 1 * time.Second

// IntegrityCheckMode specifies the level of integrity checking after restore.
type IntegrityCheckMode int

const (
IntegrityCheckNone IntegrityCheckMode = iota
IntegrityCheckQuick
IntegrityCheckFull
)

// RestoreOptions represents options for DB.Restore().
type RestoreOptions struct {
// Target path to restore into.
Expand All @@ -2390,6 +2400,10 @@ type RestoreOptions struct {

// FollowInterval specifies how often to poll for new LTX files in follow mode.
FollowInterval time.Duration

// IntegrityCheck specifies the level of integrity checking after restore.
// Zero value (IntegrityCheckNone) skips the check for backward compatibility.
IntegrityCheck IntegrityCheckMode
}

// NewRestoreOptions returns a new instance of RestoreOptions with defaults.
Expand Down
65 changes: 65 additions & 0 deletions replica.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) {
return fmt.Errorf("cannot use follow mode with -txid")
} else if opt.Follow && !opt.Timestamp.IsZero() {
return fmt.Errorf("cannot use follow mode with -timestamp")
} else if opt.IntegrityCheck != IntegrityCheckNone && opt.IntegrityCheck != IntegrityCheckQuick && opt.IntegrityCheck != IntegrityCheckFull {
return fmt.Errorf("unsupported integrity check mode: %d", opt.IntegrityCheck)
}

// In follow mode, if the database already exists, attempt crash recovery
Expand Down Expand Up @@ -680,6 +682,18 @@ func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) {
return err
}

if opt.IntegrityCheck != IntegrityCheckNone {
if err := checkIntegrity(ctx, opt.OutputPath, opt.IntegrityCheck); err != nil {
if ctx.Err() == nil {
_ = os.Remove(opt.OutputPath)
_ = os.Remove(opt.OutputPath + "-shm")
_ = os.Remove(opt.OutputPath + "-wal")
}
return fmt.Errorf("post-restore integrity check: %w", err)
}
r.Logger().Info("post-restore integrity check passed")
}

// Enter follow mode if enabled, continuously applying new LTX files.
if opt.Follow {
for _, rd := range rdrs {
Expand Down Expand Up @@ -971,6 +985,8 @@ func (r *Replica) RestoreV3(ctx context.Context, opt RestoreOptions) error {
// Validate options.
if opt.OutputPath == "" {
return fmt.Errorf("output path required")
} else if opt.IntegrityCheck != IntegrityCheckNone && opt.IntegrityCheck != IntegrityCheckQuick && opt.IntegrityCheck != IntegrityCheckFull {
return fmt.Errorf("unsupported integrity check mode: %d", opt.IntegrityCheck)
}

// Ensure output path does not already exist.
Expand Down Expand Up @@ -1053,6 +1069,18 @@ func (r *Replica) RestoreV3(ctx context.Context, opt RestoreOptions) error {
return fmt.Errorf("rename to output path: %w", err)
}

if opt.IntegrityCheck != IntegrityCheckNone {
if err := checkIntegrity(ctx, opt.OutputPath, opt.IntegrityCheck); err != nil {
if ctx.Err() == nil {
_ = os.Remove(opt.OutputPath)
_ = os.Remove(opt.OutputPath + "-shm")
_ = os.Remove(opt.OutputPath + "-wal")
}
return fmt.Errorf("post-restore integrity check: %w", err)
}
r.Logger().Info("post-restore integrity check passed")
}

return nil
}

Expand Down Expand Up @@ -1177,6 +1205,43 @@ func checkpointV3(dbPath string) error {
return err
}

// checkIntegrity runs a SQLite integrity check on the database at dbPath.
func checkIntegrity(ctx context.Context, dbPath string, mode IntegrityCheckMode) error {
if mode == IntegrityCheckNone {
return nil
}

db, err := sql.Open("sqlite", dbPath)
if err != nil {
return fmt.Errorf("open database for integrity check: %w", err)
}
defer func() { _ = db.Close() }()

var pragma string
switch mode {
case IntegrityCheckQuick:
pragma = "quick_check"
case IntegrityCheckFull:
pragma = "integrity_check"
default:
return fmt.Errorf("unsupported integrity check mode: %d", mode)
}

var result string
if err := db.QueryRowContext(ctx, "PRAGMA "+pragma).Scan(&result); err != nil {
return fmt.Errorf("integrity check: %w", err)
}
if result != "ok" {
return fmt.Errorf("integrity check failed: %s", result)
}

// Clean up -shm and -wal files that SQLite may create during the PRAGMA.
_ = os.Remove(dbPath + "-shm")
_ = os.Remove(dbPath + "-wal")

return nil
}

// findBestV3SnapshotForTimestamp returns the best v0.3.x snapshot for the given timestamp.
// Returns nil if no suitable snapshot exists.
func (r *Replica) findBestV3SnapshotForTimestamp(ctx context.Context, client ReplicaClientV3, timestamp time.Time) (*SnapshotInfoV3, error) {
Expand Down
83 changes: 83 additions & 0 deletions replica_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package litestream
import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"os"
Expand All @@ -11,6 +12,8 @@ import (
"time"

"github.com/superfly/ltx"

_ "modernc.org/sqlite"
)

func TestReplica_ApplyNewLTXFiles_FillGapWithOverlappingCompactedFile(t *testing.T) {
Expand Down Expand Up @@ -277,3 +280,83 @@ func mustCreateWritableDBFile(tb testing.TB) *os.File {
func ltxFixtureKey(level int, minTXID, maxTXID ltx.TXID) string {
return fmt.Sprintf("%d:%s:%s", level, minTXID, maxTXID)
}

func mustCreateValidSQLiteDB(tb testing.TB) string {
tb.Helper()
dbPath := filepath.Join(tb.TempDir(), "test.db")
db, err := sql.Open("sqlite", dbPath)
if err != nil {
tb.Fatal(err)
}
defer func() { _ = db.Close() }()
if _, err := db.Exec("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)"); err != nil {
tb.Fatal(err)
}
if _, err := db.Exec("INSERT INTO t (name) VALUES ('a'), ('b'), ('c')"); err != nil {
tb.Fatal(err)
}
if _, err := db.Exec("CREATE INDEX idx_t_name ON t(name)"); err != nil {
tb.Fatal(err)
}
return dbPath
}

func TestCheckIntegrity_Quick_ValidDB(t *testing.T) {
dbPath := mustCreateValidSQLiteDB(t)
if err := checkIntegrity(context.Background(), dbPath, IntegrityCheckQuick); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}

func TestCheckIntegrity_Full_ValidDB(t *testing.T) {
dbPath := mustCreateValidSQLiteDB(t)
if err := checkIntegrity(context.Background(), dbPath, IntegrityCheckFull); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}

func TestCheckIntegrity_None_Skips(t *testing.T) {
if err := checkIntegrity(context.Background(), "/nonexistent/path.db", IntegrityCheckNone); err != nil {
t.Fatalf("expected nil for IntegrityCheckNone, got: %v", err)
}
}

func TestCheckIntegrity_CorruptDB(t *testing.T) {
dbPath := mustCreateValidSQLiteDB(t)

// Remove any WAL/SHM files so we have a clean single-file database.
_ = os.Remove(dbPath + "-wal")
_ = os.Remove(dbPath + "-shm")

// Read the page size from the database header (bytes 16-17, big-endian).
f, err := os.OpenFile(dbPath, os.O_RDWR, 0o600)
if err != nil {
t.Fatal(err)
}

// Corrupt page 2 onwards. Page 1 is the header/schema page. Corrupting
// pages that contain table/index data triggers integrity check failures.
// We overwrite from byte offset 4096 (start of page 2 for 4096-byte pages,
// which is the default) with garbage data.
info, err := f.Stat()
if err != nil {
_ = f.Close()
t.Fatal(err)
}

// Overwrite everything after the first page with garbage to ensure corruption.
pageSize := int64(4096)
if info.Size() > pageSize {
garbage := bytes.Repeat([]byte{0xDE}, int(info.Size()-pageSize))
if _, err := f.WriteAt(garbage, pageSize); err != nil {
_ = f.Close()
t.Fatal(err)
}
}
_ = f.Close()

err = checkIntegrity(context.Background(), dbPath, IntegrityCheckFull)
if err == nil {
t.Fatal("expected integrity check to fail on corrupt database")
}
}
Loading