Skip to content

Commit

Permalink
memdb: prevent iterator invalidation (#1563)
Browse files Browse the repository at this point in the history
ref pingcap/tidb#59153

Signed-off-by: ekexium <[email protected]>
  • Loading branch information
ekexium authored Feb 19, 2025
1 parent 075b19f commit ddec823
Show file tree
Hide file tree
Showing 11 changed files with 681 additions and 9 deletions.
16 changes: 16 additions & 0 deletions internal/unionstore/arena/arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ import (
"encoding/binary"
"math"

"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/atomic"
"go.uber.org/zap"
)

const (
Expand Down Expand Up @@ -223,6 +225,20 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool {
return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock
}

// LessThan compares two checkpoints.
func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool {
if cp == nil || cp2 == nil {
logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2))
}
if cp.blocks < cp2.blocks {
return true
}
if cp.blocks == cp2.blocks && cp.offsetInBlock < cp2.offsetInBlock {
return true
}
return false
}

func (a *MemdbArena) Checkpoint() MemDBCheckpoint {
snap := MemDBCheckpoint{
blockSize: a.blockSize,
Expand Down
31 changes: 30 additions & 1 deletion internal/unionstore/art/art.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,23 @@ type ART struct {
len int
size int

// These variables serve the caching mechanism, meaning they can be concurrently updated by read operations, thus
// they are protected by atomic operations.
// The lastTraversedNode stores addr in uint64 of the last traversed node, includes search and recursiveInsert.
// Compare to atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient.
// Compared with atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient.
lastTraversedNode atomic.Uint64
hitCount atomic.Uint64
missCount atomic.Uint64

// The counter of every write operation, used to invalidate iterators that were created before the write operation.
// The purpose of the counter is to check interleaving of write and read operations (via iterator).
// It does not protect against data race. If it happens, there must be a bug in the caller code.
// invariant: no concurrent write/write or read/write access to it
WriteSeqNo int
// Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens.
// It's used to invalidate snapshot iterators.
// invariant: no concurrent write/write or read/write access to it
SnapshotSeqNo int
}

func New() *ART {
Expand Down Expand Up @@ -115,8 +127,12 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error {
}
}

t.WriteSeqNo++
if len(t.stages) == 0 {
t.dirty = true

// note: there is no such usage in TiDB
t.SnapshotSeqNo++
}
// 1. create or search the existing leaf in the tree.
addr, leaf := t.traverse(key, true)
Expand Down Expand Up @@ -479,6 +495,10 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) {
t.allocator.vlogAllocator.RevertToCheckpoint(t, cp)
t.allocator.vlogAllocator.Truncate(cp)
t.allocator.vlogAllocator.OnMemChange()
t.WriteSeqNo++
if len(t.stages) == 0 || t.stages[0].LessThan(cp) {
t.SnapshotSeqNo++
}
}

func (t *ART) Stages() []arena.MemDBCheckpoint {
Expand All @@ -498,7 +518,9 @@ func (t *ART) Release(h int) {
if h != len(t.stages) {
panic("cannot release staging buffer")
}
t.WriteSeqNo++
if h == 1 {
t.SnapshotSeqNo++
tail := t.checkpoint()
if !t.stages[0].IsSamePosition(&tail) {
t.dirty = true
Expand All @@ -519,6 +541,11 @@ func (t *ART) Cleanup(h int) {
panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages)))
}

t.WriteSeqNo++
if h == 1 {
t.SnapshotSeqNo++
}

cp := &t.stages[h-1]
if !t.vlogInvalid {
curr := t.checkpoint()
Expand All @@ -542,6 +569,8 @@ func (t *ART) Reset() {
t.allocator.nodeAllocator.Reset()
t.allocator.vlogAllocator.Reset()
t.lastTraversedNode.Store(arena.NullU64Addr)
t.SnapshotSeqNo++
t.WriteSeqNo++
}

// DiscardValues releases the memory used by all values.
Expand Down
39 changes: 36 additions & 3 deletions internal/unionstore/art/art_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
"sort"

"github.com/pkg/errors"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/zap"
)

func (t *ART) Iter(lowerBound, upperBound []byte) (*Iterator, error) {
Expand Down Expand Up @@ -56,6 +58,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (*
// this avoids the initial value of currAddr equals to endAddr.
currAddr: arena.BadAddr,
endAddr: arena.NullAddr,
seqNo: t.WriteSeqNo,
}
it.init(lowerBound, upperBound)
if !it.valid {
Expand All @@ -76,12 +79,41 @@ type Iterator struct {
currLeaf *artLeaf
currAddr arena.MemdbArenaAddr
endAddr arena.MemdbArenaAddr

// only when seqNo == art.WriteSeqNo, the iterator is valid.
seqNo int
// ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation.
ignoreSeqNo bool
}

func (it *Iterator) checkSeqNo() {
if it.seqNo != it.tree.WriteSeqNo && !it.ignoreSeqNo {
logutil.BgLogger().Panic(
"seqNo mismatch",
zap.Int("it seqNo", it.seqNo),
zap.Int("art seqNo", it.tree.WriteSeqNo),
zap.Stack("stack"),
)
}
}

func (it *Iterator) Valid() bool {
it.checkSeqNo()
return it.valid
}

func (it *Iterator) Key() []byte {
it.checkSeqNo()
return it.currLeaf.GetKey()
}

func (it *Iterator) Flags() kv.KeyFlags {
it.checkSeqNo()
return it.currLeaf.GetKeyFlags()
}

func (it *Iterator) Valid() bool { return it.valid }
func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() }
func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() }
func (it *Iterator) Value() []byte {
it.checkSeqNo()
if it.currLeaf.vLogAddr.IsNull() {
return nil
}
Expand All @@ -102,6 +134,7 @@ func (it *Iterator) Next() error {
// iterate is finished
return errors.New("Art: iterator is finished")
}
it.checkSeqNo()
if it.currAddr == it.endAddr {
it.valid = false
return nil
Expand Down
3 changes: 3 additions & 0 deletions internal/unionstore/art/art_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)

// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
// of stage[0]
func (t *ART) getSnapshot() arena.MemDBCheckpoint {
if len(t.stages) > 0 {
return t.stages[0]
Expand Down Expand Up @@ -49,6 +51,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
if err != nil {
panic(err)
}
inner.ignoreSeqNo = true
it := &SnapIter{
Iterator: inner,
cp: t.getSnapshot(),
Expand Down
Loading

0 comments on commit ddec823

Please sign in to comment.