@@ -1480,32 +1480,36 @@ func (db *DB) sync(ctx context.Context, checkpointing bool, info syncInfo) (sync
14801480 mode := fi .Mode ()
14811481 commit := uint32 (fi .Size () / int64 (db .pageSize ))
14821482
1483- walFile , err := os .Open (db .WALPath ())
1483+ // Snapshot WAL file into memory to prevent TOCTOU races with concurrent
1484+ // checkpoints. The application's SQLite driver can checkpoint (rewriting
1485+ // the WAL) between our initial read and the page encoding step. Reading
1486+ // the entire WAL upfront gives us an immutable copy to work from.
1487+ walData , err := os .ReadFile (db .WALPath ())
14841488 if err != nil {
14851489 return false , err
14861490 }
1487- defer walFile . Close ( )
1491+ walReader := bytes . NewReader ( walData )
14881492
14891493 var rd * WALReader
14901494 if info .offset == WALHeaderSize {
1491- if rd , err = NewWALReader (walFile , db .Logger ); err != nil {
1495+ if rd , err = NewWALReader (walReader , db .Logger ); err != nil {
14921496 return false , fmt .Errorf ("new wal reader: %w" , err )
14931497 }
14941498 } else {
14951499 // If we cannot verify the previous frame
14961500 var pfmError * PrevFrameMismatchError
1497- if rd , err = NewWALReaderWithOffset (ctx , walFile , info .offset , info .salt1 , info .salt2 , db .Logger ); errors .As (err , & pfmError ) {
1501+ if rd , err = NewWALReaderWithOffset (ctx , walReader , info .offset , info .salt1 , info .salt2 , db .Logger ); errors .As (err , & pfmError ) {
14981502 db .Logger .Log (ctx , internal .LevelTrace , "prev frame mismatch, snapshotting" , "err" , pfmError .Err )
14991503 info .offset = WALHeaderSize
1500- if rd , err = NewWALReader (walFile , db .Logger ); err != nil {
1504+ if rd , err = NewWALReader (walReader , db .Logger ); err != nil {
15011505 return false , fmt .Errorf ("new wal reader, after reset" )
15021506 }
15031507 } else if err != nil {
15041508 return false , fmt .Errorf ("new wal reader with offset: %w" , err )
15051509 }
15061510 }
15071511
1508- // Build a mapping of changed page numbers and their latest content .
1512+ // Build a mapping of changed page numbers and their offsets .
15091513 pageMap , maxOffset , walCommit , err := rd .PageMap (ctx )
15101514 if err != nil {
15111515 return false , fmt .Errorf ("page map: %w" , err )
@@ -1578,11 +1582,11 @@ func (db *DB) sync(ctx context.Context, checkpointing bool, info syncInfo) (sync
15781582 // If we need a full snapshot, then copy from the database & WAL.
15791583 // Otherwise, just copy incrementally from the WAL.
15801584 if info .snapshotting {
1581- if err := db .writeLTXFromDB (ctx , enc , walFile , commit , pageMap ); err != nil {
1585+ if err := db .writeLTXFromDB (ctx , enc , walReader , commit , pageMap ); err != nil {
15821586 return false , fmt .Errorf ("write ltx from db: %w" , err )
15831587 }
15841588 } else {
1585- if err := db .writeLTXFromWAL (ctx , enc , walFile , pageMap ); err != nil {
1589+ if err := db .writeLTXFromWAL (ctx , enc , walReader , pageMap ); err != nil {
15861590 return false , fmt .Errorf ("write ltx from wal: %w" , err )
15871591 }
15881592 }
@@ -1640,7 +1644,7 @@ func (db *DB) sync(ctx context.Context, checkpointing bool, info syncInfo) (sync
16401644 return true , nil
16411645}
16421646
1643- func (db * DB ) writeLTXFromDB (ctx context.Context , enc * ltx.Encoder , walFile * os. File , commit uint32 , pageMap map [uint32 ]int64 ) error {
1647+ func (db * DB ) writeLTXFromDB (ctx context.Context , enc * ltx.Encoder , walReader io. ReaderAt , commit uint32 , pageMap map [uint32 ]int64 ) error {
16441648 lockPgno := ltx .LockPgno (uint32 (db .pageSize ))
16451649 data := make ([]byte , db .pageSize )
16461650
@@ -1649,7 +1653,6 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16491653 continue
16501654 }
16511655
1652- // Check if the caller has canceled during processing.
16531656 select {
16541657 case <- ctx .Done ():
16551658 return context .Cause (ctx )
@@ -1660,7 +1663,7 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16601663 if offset , ok := pageMap [pgno ]; ok {
16611664 db .Logger .Log (ctx , internal .LevelTrace , "encode page from wal" , "txid" , enc .Header ().MinTXID , "offset" , offset , "pgno" , pgno , "type" , "db+wal" )
16621665
1663- if n , err := walFile .ReadAt (data , offset + WALFrameHeaderSize ); err != nil {
1666+ if n , err := walReader .ReadAt (data , offset + WALFrameHeaderSize ); err != nil {
16641667 return fmt .Errorf ("read page %d @ %d: %w" , pgno , offset , err )
16651668 } else if n != len (data ) {
16661669 return fmt .Errorf ("short read page %d @ %d" , pgno , offset )
@@ -1675,7 +1678,6 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16751678 offset := int64 (pgno - 1 ) * int64 (db .pageSize )
16761679 db .Logger .Log (ctx , internal .LevelTrace , "encode page from database" , "offset" , offset , "pgno" , pgno )
16771680
1678- // Otherwise read directly from the database file.
16791681 if _ , err := db .f .ReadAt (data , offset ); err != nil {
16801682 return fmt .Errorf ("read database page %d: %w" , pgno , err )
16811683 }
@@ -1687,8 +1689,7 @@ func (db *DB) writeLTXFromDB(ctx context.Context, enc *ltx.Encoder, walFile *os.
16871689 return nil
16881690}
16891691
1690- func (db * DB ) writeLTXFromWAL (ctx context.Context , enc * ltx.Encoder , walFile * os.File , pageMap map [uint32 ]int64 ) error {
1691- // Create an ordered list of page numbers since the LTX encoder requires it.
1692+ func (db * DB ) writeLTXFromWAL (ctx context.Context , enc * ltx.Encoder , walReader io.ReaderAt , pageMap map [uint32 ]int64 ) error {
16921693 pgnos := make ([]uint32 , 0 , len (pageMap ))
16931694 for pgno := range pageMap {
16941695 pgnos = append (pgnos , pgno )
@@ -1701,14 +1702,12 @@ func (db *DB) writeLTXFromWAL(ctx context.Context, enc *ltx.Encoder, walFile *os
17011702
17021703 db .Logger .Log (ctx , internal .LevelTrace , "encode page from wal" , "txid" , enc .Header ().MinTXID , "offset" , offset , "pgno" , pgno , "type" , "walonly" )
17031704
1704- // Read source page using page map.
1705- if n , err := walFile .ReadAt (data , offset + WALFrameHeaderSize ); err != nil {
1705+ if n , err := walReader .ReadAt (data , offset + WALFrameHeaderSize ); err != nil {
17061706 return fmt .Errorf ("read page %d @ %d: %w" , pgno , offset , err )
17071707 } else if n != len (data ) {
17081708 return fmt .Errorf ("short read page %d @ %d" , pgno , offset )
17091709 }
17101710
1711- // Write page to LTX encoder.
17121711 if err := enc .EncodePage (ltx.PageHeader {Pgno : pgno }, data ); err != nil {
17131712 return fmt .Errorf ("encode ltx frame (pgno=%d): %w" , pgno , err )
17141713 }
@@ -1867,20 +1866,21 @@ func (db *DB) SnapshotReader(ctx context.Context) (ltx.Pos, io.Reader, error) {
18671866 // Execute encoding in a separate goroutine so the caller can initialize before reading.
18681867 pr , pw := io .Pipe ()
18691868 go func () {
1870- walFile , err := os .Open (db .WALPath ())
1869+ // Snapshot WAL into memory to prevent TOCTOU races with concurrent checkpoints.
1870+ walData , err := os .ReadFile (db .WALPath ())
18711871 if err != nil {
18721872 pw .CloseWithError (err )
18731873 return
18741874 }
1875- defer walFile . Close ( )
1875+ walReader := bytes . NewReader ( walData )
18761876
1877- rd , err := NewWALReader (walFile , db .Logger )
1877+ rd , err := NewWALReader (walReader , db .Logger )
18781878 if err != nil {
18791879 pw .CloseWithError (fmt .Errorf ("new wal reader: %w" , err ))
18801880 return
18811881 }
18821882
1883- // Build a mapping of changed page numbers and their latest content .
1883+ // Build a mapping of changed page numbers and their offsets .
18841884 pageMap , maxOffset , walCommit , err := rd .PageMap (ctx )
18851885 if err != nil {
18861886 pw .CloseWithError (fmt .Errorf ("page map: %w" , err ))
@@ -1925,7 +1925,7 @@ func (db *DB) SnapshotReader(ctx context.Context) (ltx.Pos, io.Reader, error) {
19251925 return
19261926 }
19271927
1928- if err := db .writeLTXFromDB (ctx , enc , walFile , commit , pageMap ); err != nil {
1928+ if err := db .writeLTXFromDB (ctx , enc , walReader , commit , pageMap ); err != nil {
19291929 pw .CloseWithError (fmt .Errorf ("write snapshot ltx: %w" , err ))
19301930 return
19311931 }
0 commit comments