Skip to content

Commit ddec823

Browse files
authored
memdb: prevent iterator invalidation (#1563)
ref pingcap/tidb#59153 Signed-off-by: ekexium <[email protected]>
1 parent 075b19f commit ddec823

File tree

11 files changed

+681
-9
lines changed

11 files changed

+681
-9
lines changed

internal/unionstore/arena/arena.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ import (
3838
"encoding/binary"
3939
"math"
4040

41+
"github.com/tikv/client-go/v2/internal/logutil"
4142
"github.com/tikv/client-go/v2/kv"
4243
"go.uber.org/atomic"
44+
"go.uber.org/zap"
4345
)
4446

4547
const (
@@ -223,6 +225,20 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool {
223225
return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock
224226
}
225227

228+
// LessThan compares two checkpoints.
229+
func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool {
230+
if cp == nil || cp2 == nil {
231+
logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2))
232+
}
233+
if cp.blocks < cp2.blocks {
234+
return true
235+
}
236+
if cp.blocks == cp2.blocks && cp.offsetInBlock < cp2.offsetInBlock {
237+
return true
238+
}
239+
return false
240+
}
241+
226242
func (a *MemdbArena) Checkpoint() MemDBCheckpoint {
227243
snap := MemDBCheckpoint{
228244
blockSize: a.blockSize,

internal/unionstore/art/art.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,23 @@ type ART struct {
4747
len int
4848
size int
4949

50+
// These variables serve the caching mechanism, meaning they can be concurrently updated by read operations, thus
51+
// they are protected by atomic operations.
5052
// The lastTraversedNode stores addr in uint64 of the last traversed node, includes search and recursiveInsert.
51-
// Compare to atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient.
53+
// Compared with atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient.
5254
lastTraversedNode atomic.Uint64
5355
hitCount atomic.Uint64
5456
missCount atomic.Uint64
57+
58+
// The counter of every write operation, used to invalidate iterators that were created before the write operation.
59+
// The purpose of the counter is to check interleaving of write and read operations (via iterator).
60+
// It does not protect against data race. If it happens, there must be a bug in the caller code.
61+
// invariant: no concurrent write/write or read/write access to it
62+
WriteSeqNo int
63+
// Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens.
64+
// It's used to invalidate snapshot iterators.
65+
// invariant: no concurrent write/write or read/write access to it
66+
SnapshotSeqNo int
5567
}
5668

5769
func New() *ART {
@@ -115,8 +127,12 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error {
115127
}
116128
}
117129

130+
t.WriteSeqNo++
118131
if len(t.stages) == 0 {
119132
t.dirty = true
133+
134+
// note: there is no such usage in TiDB
135+
t.SnapshotSeqNo++
120136
}
121137
// 1. create or search the existing leaf in the tree.
122138
addr, leaf := t.traverse(key, true)
@@ -479,6 +495,10 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) {
479495
t.allocator.vlogAllocator.RevertToCheckpoint(t, cp)
480496
t.allocator.vlogAllocator.Truncate(cp)
481497
t.allocator.vlogAllocator.OnMemChange()
498+
t.WriteSeqNo++
499+
if len(t.stages) == 0 || t.stages[0].LessThan(cp) {
500+
t.SnapshotSeqNo++
501+
}
482502
}
483503

484504
func (t *ART) Stages() []arena.MemDBCheckpoint {
@@ -498,7 +518,9 @@ func (t *ART) Release(h int) {
498518
if h != len(t.stages) {
499519
panic("cannot release staging buffer")
500520
}
521+
t.WriteSeqNo++
501522
if h == 1 {
523+
t.SnapshotSeqNo++
502524
tail := t.checkpoint()
503525
if !t.stages[0].IsSamePosition(&tail) {
504526
t.dirty = true
@@ -519,6 +541,11 @@ func (t *ART) Cleanup(h int) {
519541
panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages)))
520542
}
521543

544+
t.WriteSeqNo++
545+
if h == 1 {
546+
t.SnapshotSeqNo++
547+
}
548+
522549
cp := &t.stages[h-1]
523550
if !t.vlogInvalid {
524551
curr := t.checkpoint()
@@ -542,6 +569,8 @@ func (t *ART) Reset() {
542569
t.allocator.nodeAllocator.Reset()
543570
t.allocator.vlogAllocator.Reset()
544571
t.lastTraversedNode.Store(arena.NullU64Addr)
572+
t.SnapshotSeqNo++
573+
t.WriteSeqNo++
545574
}
546575

547576
// DiscardValues releases the memory used by all values.

internal/unionstore/art/art_iterator.go

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import (
2020
"sort"
2121

2222
"github.com/pkg/errors"
23+
"github.com/tikv/client-go/v2/internal/logutil"
2324
"github.com/tikv/client-go/v2/internal/unionstore/arena"
2425
"github.com/tikv/client-go/v2/kv"
26+
"go.uber.org/zap"
2527
)
2628

2729
func (t *ART) Iter(lowerBound, upperBound []byte) (*Iterator, error) {
@@ -56,6 +58,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (*
5658
// this avoids the initial value of currAddr equals to endAddr.
5759
currAddr: arena.BadAddr,
5860
endAddr: arena.NullAddr,
61+
seqNo: t.WriteSeqNo,
5962
}
6063
it.init(lowerBound, upperBound)
6164
if !it.valid {
@@ -76,12 +79,41 @@ type Iterator struct {
7679
currLeaf *artLeaf
7780
currAddr arena.MemdbArenaAddr
7881
endAddr arena.MemdbArenaAddr
82+
83+
// only when seqNo == art.WriteSeqNo, the iterator is valid.
84+
seqNo int
85+
// ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation.
86+
ignoreSeqNo bool
87+
}
88+
89+
func (it *Iterator) checkSeqNo() {
90+
if it.seqNo != it.tree.WriteSeqNo && !it.ignoreSeqNo {
91+
logutil.BgLogger().Panic(
92+
"seqNo mismatch",
93+
zap.Int("it seqNo", it.seqNo),
94+
zap.Int("art seqNo", it.tree.WriteSeqNo),
95+
zap.Stack("stack"),
96+
)
97+
}
98+
}
99+
100+
func (it *Iterator) Valid() bool {
101+
it.checkSeqNo()
102+
return it.valid
103+
}
104+
105+
func (it *Iterator) Key() []byte {
106+
it.checkSeqNo()
107+
return it.currLeaf.GetKey()
108+
}
109+
110+
func (it *Iterator) Flags() kv.KeyFlags {
111+
it.checkSeqNo()
112+
return it.currLeaf.GetKeyFlags()
79113
}
80114

81-
func (it *Iterator) Valid() bool { return it.valid }
82-
func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() }
83-
func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() }
84115
func (it *Iterator) Value() []byte {
116+
it.checkSeqNo()
85117
if it.currLeaf.vLogAddr.IsNull() {
86118
return nil
87119
}
@@ -102,6 +134,7 @@ func (it *Iterator) Next() error {
102134
// iterate is finished
103135
return errors.New("Art: iterator is finished")
104136
}
137+
it.checkSeqNo()
105138
if it.currAddr == it.endAddr {
106139
it.valid = false
107140
return nil

internal/unionstore/art/art_snapshot.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"github.com/tikv/client-go/v2/internal/unionstore/arena"
2222
)
2323

24+
// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
25+
// of stage[0]
2426
func (t *ART) getSnapshot() arena.MemDBCheckpoint {
2527
if len(t.stages) > 0 {
2628
return t.stages[0]
@@ -49,6 +51,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
4951
if err != nil {
5052
panic(err)
5153
}
54+
inner.ignoreSeqNo = true
5255
it := &SnapIter{
5356
Iterator: inner,
5457
cp: t.getSnapshot(),

0 commit comments

Comments
 (0)