Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memdb: prevent iterator invalidation #1563

Merged
merged 15 commits into from
Feb 19, 2025
16 changes: 16 additions & 0 deletions internal/unionstore/arena/arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ import (
"encoding/binary"
"math"

"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/atomic"
"go.uber.org/zap"
)

const (
Expand Down Expand Up @@ -223,6 +225,20 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool {
return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock
}

// LessThan compares two checkpoints.
func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool {
if cp == nil || cp2 == nil {
logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2))
}
if cp.blocks < cp2.blocks {
return true
}
if cp.blocks == cp2.blocks && cp.offsetInBlock < cp2.offsetInBlock {
return true
}
return false
}

func (a *MemdbArena) Checkpoint() MemDBCheckpoint {
snap := MemDBCheckpoint{
blockSize: a.blockSize,
Expand Down
27 changes: 27 additions & 0 deletions internal/unionstore/art/art.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ type ART struct {
lastTraversedNode atomic.Uint64
hitCount atomic.Uint64
missCount atomic.Uint64

// The counter of every write operation, used to invalidate iterators that were created before the write operation.
// The purpose of the counter is to check interleaving of write and read operations (via iterator).
// It does not protect against data race. If it happens, there must be a bug in the caller code.
// invariant: no concurrent access to it
WriteSeqNo int
// Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens.
// It's used to invalidate snapshot iterators.
// invariant: no concurrent access to it
SnapshotSeqNo int
}

func New() *ART {
Expand Down Expand Up @@ -115,8 +125,12 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error {
}
}

t.WriteSeqNo++
if len(t.stages) == 0 {
t.dirty = true

// note: there is no such usage in TiDB
t.SnapshotSeqNo++
}
// 1. create or search the existing leaf in the tree.
addr, leaf := t.traverse(key, true)
Expand Down Expand Up @@ -479,6 +493,10 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) {
t.allocator.vlogAllocator.RevertToCheckpoint(t, cp)
t.allocator.vlogAllocator.Truncate(cp)
t.allocator.vlogAllocator.OnMemChange()
t.WriteSeqNo++
if len(t.stages) == 0 || t.stages[0].LessThan(cp) {
t.SnapshotSeqNo++
}
}

func (t *ART) Stages() []arena.MemDBCheckpoint {
Expand All @@ -498,7 +516,9 @@ func (t *ART) Release(h int) {
if h != len(t.stages) {
panic("cannot release staging buffer")
}
t.WriteSeqNo++
if h == 1 {
t.SnapshotSeqNo++
tail := t.checkpoint()
if !t.stages[0].IsSamePosition(&tail) {
t.dirty = true
Expand All @@ -519,6 +539,11 @@ func (t *ART) Cleanup(h int) {
panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages)))
}

t.WriteSeqNo++
if h == 1 {
t.SnapshotSeqNo++
}

cp := &t.stages[h-1]
if !t.vlogInvalid {
curr := t.checkpoint()
Expand All @@ -542,6 +567,8 @@ func (t *ART) Reset() {
t.allocator.nodeAllocator.Reset()
t.allocator.vlogAllocator.Reset()
t.lastTraversedNode.Store(arena.NullU64Addr)
t.SnapshotSeqNo++
t.WriteSeqNo++
}

// DiscardValues releases the memory used by all values.
Expand Down
39 changes: 36 additions & 3 deletions internal/unionstore/art/art_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
"sort"

"github.com/pkg/errors"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/zap"
)

func (t *ART) Iter(lowerBound, upperBound []byte) (*Iterator, error) {
Expand Down Expand Up @@ -56,6 +58,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (*
// this avoids the initial value of currAddr equals to endAddr.
currAddr: arena.BadAddr,
endAddr: arena.NullAddr,
seqNo: t.WriteSeqNo,
}
it.init(lowerBound, upperBound)
if !it.valid {
Expand All @@ -76,12 +79,41 @@ type Iterator struct {
currLeaf *artLeaf
currAddr arena.MemdbArenaAddr
endAddr arena.MemdbArenaAddr

// only when seqNo == art.WriteSeqNo, the iterator is valid.
seqNo int
// ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation.
ignoreSeqNo bool
}

func (it *Iterator) checkSeqNo() {
if it.seqNo != it.tree.WriteSeqNo && !it.ignoreSeqNo {
logutil.BgLogger().Panic(
"seqNo mismatch",
zap.Int("it seqNo", it.seqNo),
zap.Int("art seqNo", it.tree.WriteSeqNo),
zap.Stack("stack"),
)
}
}

func (it *Iterator) Valid() bool {
it.checkSeqNo()
return it.valid
}

func (it *Iterator) Key() []byte {
it.checkSeqNo()
return it.currLeaf.GetKey()
}

func (it *Iterator) Flags() kv.KeyFlags {
it.checkSeqNo()
return it.currLeaf.GetKeyFlags()
}

func (it *Iterator) Valid() bool { return it.valid }
func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() }
func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() }
func (it *Iterator) Value() []byte {
it.checkSeqNo()
if it.currLeaf.vLogAddr.IsNull() {
return nil
}
Expand All @@ -102,6 +134,7 @@ func (it *Iterator) Next() error {
// iterate is finished
return errors.New("Art: iterator is finished")
}
it.checkSeqNo()
if it.currAddr == it.endAddr {
it.valid = false
return nil
Expand Down
3 changes: 3 additions & 0 deletions internal/unionstore/art/art_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)

// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
// of stage[0]
func (t *ART) getSnapshot() arena.MemDBCheckpoint {
if len(t.stages) > 0 {
return t.stages[0]
Expand Down Expand Up @@ -49,6 +51,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
if err != nil {
panic(err)
}
inner.ignoreSeqNo = true
it := &SnapIter{
Iterator: inner,
cp: t.getSnapshot(),
Expand Down
Loading
Loading