Skip to content

Commit 2a735f6

Browse files
committed
fix(sync): snapshot WAL into memory to prevent checkpoint race
Instead of buffering each page individually (which caused severe performance regression due to excessive allocations), snapshot the entire WAL file into memory once at the start of sync. Both PageMap and the page encoding step read from this immutable in-memory copy, eliminating the TOCTOU race where a concurrent checkpoint could rewrite the WAL between the initial read and the encoding step. This approach uses the same memory as the OS page cache would (one copy of the WAL), but the copy is immutable and immune to concurrent modification. Fixes #1164
1 parent 9b7db29 commit 2a735f6

File tree

3 files changed

+292
-22
lines changed

3 files changed

+292
-22
lines changed

db.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,32 +1480,36 @@ func (db *DB) sync(ctx context.Context, checkpointing bool, info syncInfo) (sync
14801480
mode := fi.Mode()
14811481
commit := uint32(fi.Size() / int64(db.pageSize))
14821482

1483-
walFile, err := os.Open(db.WALPath())
1483+
// Snapshot WAL file into memory to prevent TOCTOU races with concurrent
1484+
// checkpoints. The application's SQLite driver can checkpoint (rewriting
1485+
// the WAL) between our initial read and the page encoding step. Reading
1486+
// the entire WAL upfront gives us an immutable copy to work from.
1487+
walData, err := os.ReadFile(db.WALPath())
14841488
if err != nil {
14851489
return false, err
14861490
}
1487-
defer walFile.Close()
1491+
walReader := bytes.NewReader(walData)
14881492

14891493
var rd *WALReader
14901494
if info.offset == WALHeaderSize {
1491-
if rd, err = NewWALReader(walFile, db.Logger); err != nil {
1495+
if rd, err = NewWALReader(walReader, db.Logger); err != nil {
14921496
return false, fmt.Errorf("new wal reader: %w", err)
14931497
}
14941498
} else {
14951499
// If we cannot verify the previous frame
14961500
var pfmError *PrevFrameMismatchError
1497-
if rd, err = NewWALReaderWithOffset(ctx, walFile, info.offset, info.salt1, info.salt2, db.Logger); errors.As(err, &pfmError) {
1501+
if rd, err = NewWALReaderWithOffset(ctx, walReader, info.offset, info.salt1, info.salt2, db.Logger); errors.As(err, &pfmError) {
14981502
db.Logger.Log(ctx, internal.LevelTrace, "prev frame mismatch, snapshotting", "err", pfmError.Err)
14991503
info.offset = WALHeaderSize
1500-
if rd, err = NewWALReader(walFile, db.Logger); err != nil {
1504+
if rd, err = NewWALReader(walReader, db.Logger); err != nil {
15011505
return false, fmt.Errorf("new wal reader, after reset")
15021506
}
15031507
} else if err != nil {
15041508
return false, fmt.Errorf("new wal reader with offset: %w", err)
15051509
}
15061510
}
15071511

1508-
// Build a mapping of changed page numbers and their latest content.
1512+
// Build a mapping of changed page numbers and their offsets.
15091513
pageMap, maxOffset, walCommit, err := rd.PageMap(ctx)
15101514
if err != nil {
15111515
return false, fmt.Errorf("page map: %w", err)
@@ -1578,11 +1582,11 @@ func (db *DB) sync(ctx context.Context, checkpointing bool, info syncInfo) (sync
15781582
// If we need a full snapshot, then copy from the database & WAL.
15791583
// Otherwise, just copy incrementally from the WAL.
15801584
if info.snapshotting {
1581-
if err := db.writeLTXFromDB(ctx, enc, walFile, commit, pageMap); err != nil {
1585+
if err := db.writeLTXFromDB(ctx, enc, walReader, commit, pageMap); err != nil {
15821586
return false, fmt.Errorf("write ltx from db: %w", err)
15831587
}
15841588
} else {
1585-
if err := db.writeLTXFromWAL(ctx, enc, walFile, pageMap); err != nil {
1589+
if err := db.writeLTXFromWAL(ctx, enc, walReader, pageMap); err != nil {
15861590
return false, fmt.Errorf("write ltx from wal: %w", err)
15871591
}
15881592
}
@@ -1640,7 +1644,7 @@ func (db *DB) sync(ctx context.Context, checkpointing bool, info syncInfo) (sync
16401644
return true, nil
16411645
}
16421646

1643-
func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.File, commit uint32, pageMap map[uint32]int64) error {
1647+
func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walReader io.ReaderAt, commit uint32, pageMap map[uint32]int64) error {
16441648
lockPgno := ltx.LockPgno(uint32(db.pageSize))
16451649
data := make([]byte, db.pageSize)
16461650

@@ -1649,7 +1653,6 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16491653
continue
16501654
}
16511655

1652-
// Check if the caller has canceled during processing.
16531656
select {
16541657
case <-ctx.Done():
16551658
return context.Cause(ctx)
@@ -1660,7 +1663,7 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16601663
if offset, ok := pageMap[pgno]; ok {
16611664
db.Logger.Log(ctx, internal.LevelTrace, "encode page from wal", "txid", enc.Header().MinTXID, "offset", offset, "pgno", pgno, "type", "db+wal")
16621665

1663-
if n, err := walFile.ReadAt(data, offset+WALFrameHeaderSize); err != nil {
1666+
if n, err := walReader.ReadAt(data, offset+WALFrameHeaderSize); err != nil {
16641667
return fmt.Errorf("read page %d @ %d: %w", pgno, offset, err)
16651668
} else if n != len(data) {
16661669
return fmt.Errorf("short read page %d @ %d", pgno, offset)
@@ -1675,7 +1678,6 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16751678
offset := int64(pgno-1) * int64(db.pageSize)
16761679
db.Logger.Log(ctx, internal.LevelTrace, "encode page from database", "offset", offset, "pgno", pgno)
16771680

1678-
// Otherwise read directly from the database file.
16791681
if _, err := db.f.ReadAt(data, offset); err != nil {
16801682
return fmt.Errorf("read database page %d: %w", pgno, err)
16811683
}
@@ -1687,8 +1689,7 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16871689
return nil
16881690
}
16891691

1690-
func (db *DB) writeLTXFromWAL(ctx context.Context, enc *ltx.Encoder, walFile *os.File, pageMap map[uint32]int64) error {
1691-
// Create an ordered list of page numbers since the LTX encoder requires it.
1692+
func (db *DB) writeLTXFromWAL(ctx context.Context, enc *ltx.Encoder, walReader io.ReaderAt, pageMap map[uint32]int64) error {
16921693
pgnos := make([]uint32, 0, len(pageMap))
16931694
for pgno := range pageMap {
16941695
pgnos = append(pgnos, pgno)
@@ -1701,14 +1702,12 @@ func (db *DB) writeLTXFromWAL(ctx context.Context, enc *ltx.Encoder, walFile *os
17011702

17021703
db.Logger.Log(ctx, internal.LevelTrace, "encode page from wal", "txid", enc.Header().MinTXID, "offset", offset, "pgno", pgno, "type", "walonly")
17031704

1704-
// Read source page using page map.
1705-
if n, err := walFile.ReadAt(data, offset+WALFrameHeaderSize); err != nil {
1705+
if n, err := walReader.ReadAt(data, offset+WALFrameHeaderSize); err != nil {
17061706
return fmt.Errorf("read page %d @ %d: %w", pgno, offset, err)
17071707
} else if n != len(data) {
17081708
return fmt.Errorf("short read page %d @ %d", pgno, offset)
17091709
}
17101710

1711-
// Write page to LTX encoder.
17121711
if err := enc.EncodePage(ltx.PageHeader{Pgno: pgno}, data); err != nil {
17131712
return fmt.Errorf("encode ltx frame (pgno=%d): %w", pgno, err)
17141713
}
@@ -1867,20 +1866,21 @@ func (db *DB) SnapshotReader(ctx context.Context) (ltx.Pos, io.Reader, error) {
18671866
// Execute encoding in a separate goroutine so the caller can initialize before reading.
18681867
pr, pw := io.Pipe()
18691868
go func() {
1870-
walFile, err := os.Open(db.WALPath())
1869+
// Snapshot WAL into memory to prevent TOCTOU races with concurrent checkpoints.
1870+
walData, err := os.ReadFile(db.WALPath())
18711871
if err != nil {
18721872
pw.CloseWithError(err)
18731873
return
18741874
}
1875-
defer walFile.Close()
1875+
walReader := bytes.NewReader(walData)
18761876

1877-
rd, err := NewWALReader(walFile, db.Logger)
1877+
rd, err := NewWALReader(walReader, db.Logger)
18781878
if err != nil {
18791879
pw.CloseWithError(fmt.Errorf("new wal reader: %w", err))
18801880
return
18811881
}
18821882

1883-
// Build a mapping of changed page numbers and their latest content.
1883+
// Build a mapping of changed page numbers and their offsets.
18841884
pageMap, maxOffset, walCommit, err := rd.PageMap(ctx)
18851885
if err != nil {
18861886
pw.CloseWithError(fmt.Errorf("page map: %w", err))
@@ -1925,7 +1925,7 @@ func (db *DB) SnapshotReader(ctx context.Context) (ltx.Pos, io.Reader, error) {
19251925
return
19261926
}
19271927

1928-
if err := db.writeLTXFromDB(ctx, enc, walFile, commit, pageMap); err != nil {
1928+
if err := db.writeLTXFromDB(ctx, enc, walReader, commit, pageMap); err != nil {
19291929
pw.CloseWithError(fmt.Errorf("write snapshot ltx: %w", err))
19301930
return
19311931
}

0 commit comments

Comments
 (0)