Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memdb: prevent iterator invalidation #1563

Merged
merged 15 commits into from
Feb 19, 2025
16 changes: 16 additions & 0 deletions internal/unionstore/arena/arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ import (
"encoding/binary"
"math"

"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/atomic"
"go.uber.org/zap"
)

const (
Expand Down Expand Up @@ -223,6 +225,20 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool {
return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock
}

// LessThan compares two checkpoints.
func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool {
if cp == nil || cp2 == nil {
logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2))
}
if cp.blocks < cp2.blocks {
return true
}
if cp.blocks == cp2.blocks && cp.offsetInBlock < cp2.offsetInBlock {
return true
}
return false
}

func (a *MemdbArena) Checkpoint() MemDBCheckpoint {
snap := MemDBCheckpoint{
blockSize: a.blockSize,
Expand Down
31 changes: 30 additions & 1 deletion internal/unionstore/art/art.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,23 @@ type ART struct {
len int
size int

// These variables serve the caching mechanism, meaning they can be concurrently updated by read operations, thus
// they are protected by atomic operations.
// The lastTraversedNode stores addr in uint64 of the last traversed node, includes search and recursiveInsert.
// Compare to atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient.
// Compared with atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient.
lastTraversedNode atomic.Uint64
hitCount atomic.Uint64
missCount atomic.Uint64

// The counter of every write operation, used to invalidate iterators that were created before the write operation.
// The purpose of the counter is to check interleaving of write and read operations (via iterator).
// It does not protect against data race. If it happens, there must be a bug in the caller code.
// invariant: no concurrent write/write or read/write access to it
WriteSeqNo int
// Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens.
// It's used to invalidate snapshot iterators.
// invariant: no concurrent write/write or read/write access to it
SnapshotSeqNo int
}

func New() *ART {
Expand Down Expand Up @@ -115,8 +127,12 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error {
}
}

t.WriteSeqNo++
if len(t.stages) == 0 {
t.dirty = true

// note: there is no such usage in TiDB
t.SnapshotSeqNo++
}
// 1. create or search the existing leaf in the tree.
addr, leaf := t.traverse(key, true)
Expand Down Expand Up @@ -479,6 +495,10 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) {
t.allocator.vlogAllocator.RevertToCheckpoint(t, cp)
t.allocator.vlogAllocator.Truncate(cp)
t.allocator.vlogAllocator.OnMemChange()
t.WriteSeqNo++
if len(t.stages) == 0 || t.stages[0].LessThan(cp) {
t.SnapshotSeqNo++
}
}

func (t *ART) Stages() []arena.MemDBCheckpoint {
Expand All @@ -498,7 +518,9 @@ func (t *ART) Release(h int) {
if h != len(t.stages) {
panic("cannot release staging buffer")
}
t.WriteSeqNo++
if h == 1 {
t.SnapshotSeqNo++
tail := t.checkpoint()
if !t.stages[0].IsSamePosition(&tail) {
t.dirty = true
Expand All @@ -519,6 +541,11 @@ func (t *ART) Cleanup(h int) {
panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages)))
}

t.WriteSeqNo++
if h == 1 {
t.SnapshotSeqNo++
}

cp := &t.stages[h-1]
if !t.vlogInvalid {
curr := t.checkpoint()
Expand All @@ -542,6 +569,8 @@ func (t *ART) Reset() {
t.allocator.nodeAllocator.Reset()
t.allocator.vlogAllocator.Reset()
t.lastTraversedNode.Store(arena.NullU64Addr)
t.SnapshotSeqNo++
t.WriteSeqNo++
}

// DiscardValues releases the memory used by all values.
Expand Down
39 changes: 36 additions & 3 deletions internal/unionstore/art/art_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
"sort"

"github.com/pkg/errors"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/zap"
)

func (t *ART) Iter(lowerBound, upperBound []byte) (*Iterator, error) {
Expand Down Expand Up @@ -56,6 +58,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (*
// this avoids the initial value of currAddr equals to endAddr.
currAddr: arena.BadAddr,
endAddr: arena.NullAddr,
seqNo: t.WriteSeqNo,
}
it.init(lowerBound, upperBound)
if !it.valid {
Expand All @@ -76,12 +79,41 @@ type Iterator struct {
currLeaf *artLeaf
currAddr arena.MemdbArenaAddr
endAddr arena.MemdbArenaAddr

// only when seqNo == art.WriteSeqNo, the iterator is valid.
seqNo int
// ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation.
ignoreSeqNo bool
}

func (it *Iterator) checkSeqNo() {
if it.seqNo != it.tree.WriteSeqNo && !it.ignoreSeqNo {
logutil.BgLogger().Panic(
"seqNo mismatch",
zap.Int("it seqNo", it.seqNo),
zap.Int("art seqNo", it.tree.WriteSeqNo),
zap.Stack("stack"),
)
}
}

func (it *Iterator) Valid() bool {
it.checkSeqNo()
return it.valid
}

func (it *Iterator) Key() []byte {
it.checkSeqNo()
return it.currLeaf.GetKey()
}

func (it *Iterator) Flags() kv.KeyFlags {
it.checkSeqNo()
return it.currLeaf.GetKeyFlags()
}

func (it *Iterator) Valid() bool { return it.valid }
func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() }
func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() }
func (it *Iterator) Value() []byte {
it.checkSeqNo()
if it.currLeaf.vLogAddr.IsNull() {
return nil
}
Expand All @@ -102,6 +134,7 @@ func (it *Iterator) Next() error {
// iterate is finished
return errors.New("Art: iterator is finished")
}
it.checkSeqNo()
if it.currAddr == it.endAddr {
it.valid = false
return nil
Expand Down
3 changes: 3 additions & 0 deletions internal/unionstore/art/art_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)

// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
// of stage[0]
func (t *ART) getSnapshot() arena.MemDBCheckpoint {
if len(t.stages) > 0 {
return t.stages[0]
Expand Down Expand Up @@ -49,6 +51,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
if err != nil {
panic(err)
}
inner.ignoreSeqNo = true
it := &SnapIter{
Iterator: inner,
cp: t.getSnapshot(),
Expand Down
Loading
Loading