Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -1852,21 +1852,23 @@ func (db *DB) SnapshotReader(ctx context.Context) (ltx.Pos, io.Reader, error) {

db.Logger.Debug("snapshot", "txid", pos.TXID.String())

// Prevent internal checkpoints during sync.
db.chkMu.RLock()
defer db.chkMu.RUnlock()

// TODO(ltx): Read database size from database header.

fi, err := db.f.Stat()
if err != nil {
return pos, nil, err
}
commit := uint32(fi.Size() / int64(db.pageSize))

// Execute encoding in a separate goroutine so the caller can initialize before reading.
pr, pw := io.Pipe()
go func() {
// Prevent internal checkpoints for the entire duration of page reading.
// This lock must be held inside the goroutine (not in the outer function)
// because the outer function returns before the goroutine finishes,
// which would release a deferred RUnlock while pages are still being read.
db.chkMu.RLock()
defer db.chkMu.RUnlock()

fi, err := db.f.Stat()
if err != nil {
pw.CloseWithError(err)
return
}
commit := uint32(fi.Size() / int64(db.pageSize))

walFile, err := os.Open(db.WALPath())
if err != nil {
pw.CloseWithError(err)
Expand Down
270 changes: 270 additions & 0 deletions db_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2066,3 +2066,273 @@ func TestDB_Sync_InitErrorMetrics(t *testing.T) {
t.Fatalf("litestream_sync_error_count=%v, want > %v (init error should be counted)", syncErrorValue, baselineErrors)
}
}

// TestSyncRestoreIntegrity exercises the full sync→checkpoint→restore→integrity
// flow with concurrent writes to verify no corruption is introduced. This
// reproduces the scenario from issue #1164 where users observed "database disk
// image is malformed" errors after restoring from replicas.
func TestSyncRestoreIntegrity(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}

dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
replicaDir := t.TempDir()

db := NewDB(dbPath)
db.MonitorInterval = 0
db.ShutdownSyncTimeout = 0
db.MinCheckpointPageN = 50
db.CheckpointInterval = 100 * time.Millisecond
db.Replica = NewReplica(db)
db.Replica.Client = &testReplicaClient{dir: replicaDir}
db.Replica.MonitorEnabled = false
db.Logger = slog.New(slog.NewTextHandler(io.Discard, nil))
if err := db.Open(); err != nil {
t.Fatal(err)
}
defer func() { _ = db.Close(context.Background()) }()

sqldb, err := sql.Open("sqlite", dbPath)
if err != nil {
t.Fatal(err)
}
defer sqldb.Close()
if _, err := sqldb.Exec(`PRAGMA journal_mode = wal`); err != nil {
t.Fatal(err)
}
if _, err := sqldb.Exec(`PRAGMA busy_timeout = 5000`); err != nil {
t.Fatal(err)
}

schema := `
CREATE TABLE IF NOT EXISTS data (
ROWID INTEGER PRIMARY KEY AUTOINCREMENT,
_uid TEXT NOT NULL,
_resource_version INTEGER NOT NULL,
_updated_at DATETIME NOT NULL,
name TEXT,
data_json BLOB,
is_active INTEGER,
UNIQUE (_uid, _resource_version)
);
CREATE INDEX IF NOT EXISTS data_uid_idx ON data (_uid);
CREATE INDEX IF NOT EXISTS data_name_idx ON data (name);
`
if _, err := sqldb.Exec(schema); err != nil {
t.Fatal(err)
}

ctx := context.Background()

// Run concurrent writes and syncs for several iterations.
// This mirrors the pattern from the reproduction repo.
const iterations = 20
for i := 0; i < iterations; i++ {
// Insert a batch of rows (simulating application writes).
for j := 0; j < 10; j++ {
uid := fmt.Sprintf("uid-%d-%d", i, j)
_, err := sqldb.ExecContext(ctx,
`INSERT INTO data (_uid, _resource_version, _updated_at, name, data_json, is_active)
VALUES (?, 1, datetime('now'), ?, ?, ?)`,
uid, fmt.Sprintf("item-%d-%d", i, j),
[]byte(fmt.Sprintf(`{"key":"k%d","value":%d}`, j, j)),
j%2,
)
if err != nil {
t.Fatal(err)
}
}

// Run sync (this is what litestream does periodically).
if err := db.Sync(ctx); err != nil {
t.Fatalf("sync iteration %d: %v", i, err)
}
}

// Ensure final sync captures everything.
if err := db.Sync(ctx); err != nil {
t.Fatalf("final sync: %v", err)
}

// Close the application DB connection.
sqldb.Close()

// Stop litestream.
if err := db.Close(ctx); err != nil {
t.Fatalf("close db: %v", err)
}

// Restore from replica.
restorePath := filepath.Join(t.TempDir(), "restored.db")
restoreDB := NewDB(restorePath)
restoreDB.Replica = NewReplica(restoreDB)
restoreDB.Replica.Client = &testReplicaClient{dir: replicaDir}
restoreDB.Logger = slog.New(slog.NewTextHandler(io.Discard, nil))
if err := restoreDB.Replica.Restore(ctx, RestoreOptions{
OutputPath: restorePath,
}); err != nil {
t.Fatalf("restore: %v", err)
}

// Run integrity check on restored database.
restoredDB, err := sql.Open("sqlite", restorePath)
if err != nil {
t.Fatalf("open restored db: %v", err)
}
defer restoredDB.Close()

rows, err := restoredDB.QueryContext(ctx, `PRAGMA integrity_check;`)
if err != nil {
t.Fatalf("integrity check: %v", err)
}
defer rows.Close()

var results []string
for rows.Next() {
var result string
if err := rows.Scan(&result); err != nil {
t.Fatal(err)
}
results = append(results, result)
}
if err := rows.Err(); err != nil {
t.Fatal(err)
}

if len(results) == 0 {
t.Fatal("integrity check returned no results")
}
if results[0] != "ok" {
t.Fatalf("integrity check failed on restored database:\n%s", fmt.Sprintf("%v", results))
}

// Verify data was restored correctly.
var count int
if err := restoredDB.QueryRowContext(ctx, `SELECT COUNT(*) FROM data`).Scan(&count); err != nil {
t.Fatalf("count data: %v", err)
}
expectedRows := iterations * 10
if count != expectedRows {
t.Fatalf("restored row count=%d, want %d", count, expectedRows)
}
}

// TestSyncRestoreIntegrity_WithCheckpoints is a stress variant that forces
// PASSIVE and TRUNCATE checkpoints between syncs to maximize the chance of
// hitting any TOCTOU race conditions. The read transaction (db.rtx) held by
// litestream during sync should prevent TRUNCATE checkpoints from the
// application's connection from completing.
func TestSyncRestoreIntegrity_WithCheckpoints(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}

dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
replicaDir := t.TempDir()

db := NewDB(dbPath)
db.MonitorInterval = 0
db.ShutdownSyncTimeout = 0
db.MinCheckpointPageN = 20
db.TruncatePageN = 200
db.CheckpointInterval = 50 * time.Millisecond
db.Replica = NewReplica(db)
db.Replica.Client = &testReplicaClient{dir: replicaDir}
db.Replica.MonitorEnabled = false
db.Logger = slog.New(slog.NewTextHandler(io.Discard, nil))
if err := db.Open(); err != nil {
t.Fatal(err)
}
defer func() { _ = db.Close(context.Background()) }()

sqldb, err := sql.Open("sqlite", dbPath)
if err != nil {
t.Fatal(err)
}
defer sqldb.Close()
if _, err := sqldb.Exec(`PRAGMA journal_mode = wal`); err != nil {
t.Fatal(err)
}
if _, err := sqldb.Exec(`PRAGMA busy_timeout = 5000`); err != nil {
t.Fatal(err)
}
if _, err := sqldb.Exec(`CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT, extra BLOB)`); err != nil {
t.Fatal(err)
}

ctx := context.Background()

// Write data in bursts with syncs and checkpoints interleaved.
for i := 0; i < 30; i++ {
// Burst of writes.
for j := 0; j < 20; j++ {
_, err := sqldb.ExecContext(ctx,
`INSERT INTO t (val, extra) VALUES (?, ?)`,
fmt.Sprintf("val-%d-%d", i, j),
bytes.Repeat([]byte{byte(i)}, 512),
)
if err != nil {
t.Fatal(err)
}
}

// Sync after each burst.
if err := db.Sync(ctx); err != nil {
t.Fatalf("sync %d: %v", i, err)
}

// Periodically issue PASSIVE checkpoints from the app's connection
// to simulate the pattern from issue #1164.
if i%5 == 4 {
if _, err := sqldb.ExecContext(ctx, `PRAGMA wal_checkpoint(PASSIVE)`); err != nil {
t.Logf("passive checkpoint %d: %v", i, err)
}
}
}

if err := db.Sync(ctx); err != nil {
t.Fatalf("final sync: %v", err)
}

sqldb.Close()
if err := db.Close(ctx); err != nil {
t.Fatalf("close: %v", err)
}

// Restore and verify.
restorePath := filepath.Join(t.TempDir(), "restored.db")
restoreDB := NewDB(restorePath)
restoreDB.Replica = NewReplica(restoreDB)
restoreDB.Replica.Client = &testReplicaClient{dir: replicaDir}
restoreDB.Logger = slog.New(slog.NewTextHandler(io.Discard, nil))
if err := restoreDB.Replica.Restore(ctx, RestoreOptions{
OutputPath: restorePath,
}); err != nil {
t.Fatalf("restore: %v", err)
}

restoredDB, err := sql.Open("sqlite", restorePath)
if err != nil {
t.Fatal(err)
}
defer restoredDB.Close()

var result string
if err := restoredDB.QueryRowContext(ctx, `PRAGMA integrity_check`).Scan(&result); err != nil {
t.Fatal(err)
}
if result != "ok" {
t.Fatalf("integrity check failed: %s", result)
}

var count int
if err := restoredDB.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&count); err != nil {
t.Fatal(err)
}
if count != 30*20 {
t.Fatalf("row count=%d, want %d", count, 30*20)
}
}
3 changes: 3 additions & 0 deletions replica.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,9 @@ func (r *Replica) applyLTXFile(ctx context.Context, f *os.File, info *ltx.FileIn
}

if hdr.Commit > 0 {
if err := f.Sync(); err != nil {
return fmt.Errorf("sync before truncate: %w", err)
}
newSize := int64(hdr.Commit) * int64(pageSize)
if err := f.Truncate(newSize); err != nil {
return fmt.Errorf("truncate: %w", err)
Expand Down
Loading
Loading