From 1a6d10078d75b0581c96c5bde5dbd4668dcaf438 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 14:58:12 +0800 Subject: [PATCH 01/14] memdb: prevent iterator invalidation Signed-off-by: ekexium --- examples/gcworker/go.mod | 4 +- examples/rawkv/go.mod | 4 +- examples/txnkv/1pc_txn/go.mod | 4 +- examples/txnkv/async_commit/go.mod | 4 +- examples/txnkv/delete_range/go.mod | 4 +- examples/txnkv/go.mod | 4 +- examples/txnkv/pessimistic_txn/go.mod | 4 +- examples/txnkv/unsafedestoryrange/go.mod | 4 +- internal/unionstore/arena/arena.go | 16 ++ internal/unionstore/art/art.go | 21 +++ internal/unionstore/art/art_iterator.go | 9 +- internal/unionstore/memdb_art.go | 159 ++++++++++++++++ internal/unionstore/memdb_bench_test.go | 80 +++++++- internal/unionstore/memdb_rbt.go | 35 ++++ internal/unionstore/memdb_test.go | 229 ++++++++++++++++++++++- internal/unionstore/pipelined_memdb.go | 8 + internal/unionstore/union_store.go | 31 ++- tikv/unionstore_export.go | 2 + 18 files changed, 600 insertions(+), 22 deletions(-) diff --git a/examples/gcworker/go.mod b/examples/gcworker/go.mod index 7af1d6b76e..81562a4bdd 100644 --- a/examples/gcworker/go.mod +++ b/examples/gcworker/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/rawkv/go.mod b/examples/rawkv/go.mod index b779427b3c..041c3cb224 100644 --- a/examples/rawkv/go.mod +++ b/examples/rawkv/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/1pc_txn/go.mod b/examples/txnkv/1pc_txn/go.mod index 4926479a1b..2fe90783db 100644 --- a/examples/txnkv/1pc_txn/go.mod +++ b/examples/txnkv/1pc_txn/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/async_commit/go.mod b/examples/txnkv/async_commit/go.mod index 066120e1ec..236e443831 100644 --- a/examples/txnkv/async_commit/go.mod +++ b/examples/txnkv/async_commit/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/delete_range/go.mod b/examples/txnkv/delete_range/go.mod index 2f9d244b9b..f599c50606 100644 --- a/examples/txnkv/delete_range/go.mod +++ b/examples/txnkv/delete_range/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/go.mod b/examples/txnkv/go.mod index 5a23a1978e..06bf6c7f53 100644 --- a/examples/txnkv/go.mod +++ b/examples/txnkv/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/pessimistic_txn/go.mod b/examples/txnkv/pessimistic_txn/go.mod index 016964ad93..da6c997c67 100644 --- a/examples/txnkv/pessimistic_txn/go.mod +++ b/examples/txnkv/pessimistic_txn/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/unsafedestoryrange/go.mod b/examples/txnkv/unsafedestoryrange/go.mod index 7f5c8d11ea..91a04ec8cc 100644 --- a/examples/txnkv/unsafedestoryrange/go.mod +++ b/examples/txnkv/unsafedestoryrange/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/internal/unionstore/arena/arena.go b/internal/unionstore/arena/arena.go index ef5c081921..9671ccb492 100644 --- a/internal/unionstore/arena/arena.go +++ b/internal/unionstore/arena/arena.go @@ -38,6 +38,9 @@ import ( "encoding/binary" "math" + "github.com/tikv/client-go/v2/internal/logutil" + "go.uber.org/zap" + "github.com/tikv/client-go/v2/kv" "go.uber.org/atomic" ) @@ -223,6 +226,19 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool { return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock } +func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool { + if cp == nil || cp2 == nil { + logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2)) + } + if cp.blocks < cp2.blocks { + return true + } + if cp.blocks == cp2.blocks && cp.offsetInBlock < cp2.offsetInBlock { + return true + } + return false +} + func (a *MemdbArena) Checkpoint() MemDBCheckpoint { snap := MemDBCheckpoint{ blockSize: a.blockSize, diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 36acb907a5..ee98bf1e64 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -52,6 +52,13 @@ type ART struct { lastTraversedNode atomic.Uint64 hitCount atomic.Uint64 missCount atomic.Uint64 + + // The counter of every write operation, used to invalidate iterators that were created before the write operation. + SeqNo int + // increased by 1 when an operation that may affect the content returned by "snapshot iter" (i.e. stage[0]) happens. + // It's used to invalidate snapshot iterators. + // invariant: no concurrent access to it + SnapshotSeqNo int } func New() *ART { @@ -115,6 +122,7 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { } } + t.SeqNo++ if len(t.stages) == 0 { t.dirty = true } @@ -479,6 +487,10 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) { t.allocator.vlogAllocator.RevertToCheckpoint(t, cp) t.allocator.vlogAllocator.Truncate(cp) t.allocator.vlogAllocator.OnMemChange() + t.SeqNo++ + if len(t.stages) == 0 || t.stages[0].LessThan(cp) { + t.SnapshotSeqNo++ + } } func (t *ART) Stages() []arena.MemDBCheckpoint { @@ -498,7 +510,9 @@ func (t *ART) Release(h int) { if h != len(t.stages) { panic("cannot release staging buffer") } + t.SeqNo++ if h == 1 { + t.SnapshotSeqNo++ tail := t.checkpoint() if !t.stages[0].IsSamePosition(&tail) { t.dirty = true @@ -519,6 +533,11 @@ func (t *ART) Cleanup(h int) { panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages))) } + t.SeqNo++ + if h == 1 { + t.SnapshotSeqNo++ + } + cp := &t.stages[h-1] if !t.vlogInvalid { curr := t.checkpoint() @@ -542,6 +561,8 @@ func (t *ART) Reset() { t.allocator.nodeAllocator.Reset() t.allocator.vlogAllocator.Reset() t.lastTraversedNode.Store(arena.NullU64Addr) + t.SnapshotSeqNo++ + t.SeqNo++ } // DiscardValues releases the memory used by all values. diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index 2bf4fdba64..ba7ff0b855 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -56,6 +56,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (* // this avoids the initial value of currAddr equals to endAddr. currAddr: arena.BadAddr, endAddr: arena.NullAddr, + seqNo: t.SeqNo, } it.init(lowerBound, upperBound) if !it.valid { @@ -76,9 +77,12 @@ type Iterator struct { currLeaf *artLeaf currAddr arena.MemdbArenaAddr endAddr arena.MemdbArenaAddr + + // only when seqNo == art.seqNo, the iterator is valid. + seqNo int } -func (it *Iterator) Valid() bool { return it.valid } +func (it *Iterator) Valid() bool { return it.valid && it.seqNo == it.tree.SeqNo } func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } func (it *Iterator) Value() []byte { @@ -102,6 +106,9 @@ func (it *Iterator) Next() error { // iterate is finished return errors.New("Art: iterator is finished") } + if it.seqNo != it.tree.SeqNo { + return errors.New(fmt.Sprintf("seqNo mismatch: iter=%d, art=%d", it.seqNo, it.tree.SeqNo)) + } if it.currAddr == it.endAddr { it.valid = false return nil diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index c7c1b21d98..ebd63edce8 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -16,8 +16,11 @@ package unionstore import ( "context" + "fmt" "sync" + "github.com/pingcap/errors" + tikverr "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/internal/unionstore/art" @@ -151,6 +154,32 @@ func (db *artDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) { return db.ART.IterReverse(upper, lower) } +func (db *artDBWithContext) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error { + db.RLock() + defer db.RUnlock() + var iter Iterator + if reverse { + iter = db.SnapshotIterReverse(upper, lower) + } else { + iter = db.SnapshotIter(lower, upper) + } + defer iter.Close() + for iter.Valid() { + stop, err := f(iter.Key(), iter.Value()) + if err != nil { + return err + } + err = iter.Next() + if err != nil { + return err + } + if stop { + break + } + } + return nil +} + // SnapshotIter returns an Iterator for a snapshot of MemBuffer. func (db *artDBWithContext) SnapshotIter(lower, upper []byte) Iterator { return db.ART.SnapshotIter(lower, upper) @@ -165,3 +194,133 @@ func (db *artDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator { func (db *artDBWithContext) SnapshotGetter() Getter { return db.ART.SnapshotGetter() } + +type snapshotBatchedIter struct { + db *artDBWithContext + snapshotTruncateSeqNo int + lower []byte + upper []byte + reverse bool + + // current batch + kvs []KvPair + pos int + batchSize int + nextKey []byte +} + +func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + iter := &snapshotBatchedIter{ + db: db, + snapshotTruncateSeqNo: db.SnapshotSeqNo, + lower: lower, + upper: upper, + reverse: reverse, + batchSize: 4, + } + + // Position at first key immediately + iter.fillBatch() + return iter +} + +func (it *snapshotBatchedIter) fillBatch() error { + if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { + return errors.New(fmt.Sprintf("invalid iter: truncation happened, iter's=%d, db's=%d", + it.snapshotTruncateSeqNo, it.db.SnapshotSeqNo)) + } + + it.db.RLock() + defer it.db.RUnlock() + + if it.kvs == nil { + it.kvs = make([]KvPair, 0, it.batchSize) + } else { + it.kvs = it.kvs[:0] + } + + var snapshotIter Iterator + if it.reverse { + searchUpper := it.upper + if it.nextKey != nil { + searchUpper = it.nextKey + } + snapshotIter = it.db.SnapshotIterReverse(searchUpper, it.lower) + } else { + searchLower := it.lower + if it.nextKey != nil { + searchLower = it.nextKey + } + snapshotIter = it.db.SnapshotIter(searchLower, it.upper) + } + defer snapshotIter.Close() + + // fill current batch + for i := 0; i < it.batchSize && snapshotIter.Valid(); i++ { + it.kvs = append(it.kvs, KvPair{ + Key: snapshotIter.Key(), + Value: snapshotIter.Value(), + }) + if err := snapshotIter.Next(); err != nil { + return err + } + } + + // update state + it.pos = 0 + if len(it.kvs) > 0 { + lastKV := it.kvs[len(it.kvs)-1] + if it.reverse { + it.nextKey = append([]byte(nil), lastKV.Key...) + } else { + it.nextKey = append(append([]byte(nil), lastKV.Key...), 0) + } + } else { + it.nextKey = nil + } + + it.batchSize = min(it.batchSize*2, 4096) + return nil +} + +func (it *snapshotBatchedIter) Valid() bool { + return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo && + it.pos < len(it.kvs) +} + +func (it *snapshotBatchedIter) Next() error { + if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { + return errors.New( + fmt.Sprintf( + "invalid snapshotBatchedIter: truncation happened, iter's=%d, db's=%d", + it.snapshotTruncateSeqNo, + it.db.SnapshotSeqNo, + ), + ) + } + + it.pos++ + if it.pos >= len(it.kvs) { + return it.fillBatch() + } + return nil +} + +func (it *snapshotBatchedIter) Key() []byte { + if !it.Valid() { + return nil + } + return it.kvs[it.pos].Key +} + +func (it *snapshotBatchedIter) Value() []byte { + if !it.Valid() { + return nil + } + return it.kvs[it.pos].Value +} + +func (it *snapshotBatchedIter) Close() { + it.kvs = nil + it.nextKey = nil +} diff --git a/internal/unionstore/memdb_bench_test.go b/internal/unionstore/memdb_bench_test.go index 8a2c3e5d4b..07fa6d3ce9 100644 --- a/internal/unionstore/memdb_bench_test.go +++ b/internal/unionstore/memdb_bench_test.go @@ -172,13 +172,36 @@ func BenchmarkMemDbBufferRandom(b *testing.B) { } func BenchmarkMemDbIter(b *testing.B) { - fn := func(b *testing.B, buffer MemBuffer) { + fnIter := func(b *testing.B, buffer MemBuffer) { benchIterator(b, buffer) b.ReportAllocs() } - b.Run("RBT", func(b *testing.B) { fn(b, newRbtDBWithContext()) }) - b.Run("ART", func(b *testing.B) { fn(b, newArtDBWithContext()) }) + b.Run("RBT", func(b *testing.B) { fnIter(b, newRbtDBWithContext()) }) + b.Run("ART", func(b *testing.B) { fnIter(b, newArtDBWithContext()) }) +} + +func BenchmarkSnapshotIter(b *testing.B) { + f := func(b *testing.B, buffer MemBuffer) { + benchSnapshotIter(b, buffer) + b.ReportAllocs() + } + + fBatched := func(b *testing.B, buffer MemBuffer) { + benchBatchedSnapshotIter(b, buffer) + b.ReportAllocs() + } + + fForEach := func(b *testing.B, buffer MemBuffer) { + benchForEachInSnapshot(b, buffer) + b.ReportAllocs() + } + + b.Run("RBT-SnapshotIter", func(b *testing.B) { f(b, newRbtDBWithContext()) }) + // unimplemented for RBT + b.Run("ART-SnapshotIter", func(b *testing.B) { f(b, newArtDBWithContext()) }) + b.Run("ART-BatchedSnapshotIter", func(b *testing.B) { fBatched(b, newArtDBWithContext()) }) + b.Run("ART-ForEachInSnapshot", func(b *testing.B) { fForEach(b, newArtDBWithContext()) }) } func BenchmarkMemDbCreation(b *testing.B) { @@ -224,6 +247,40 @@ func benchIterator(b *testing.B, buffer MemBuffer) { if err != nil { b.Error(err) } + for iter.Valid() { + _ = iter.Key() + _ = iter.Value() + iter.Next() + } + iter.Close() + } +} + +func benchSnapshotIter(b *testing.B, buffer MemBuffer) { + for k := 0; k < opCnt; k++ { + buffer.Set(encodeInt(k), encodeInt(k)) + } + buffer.Staging() + b.ResetTimer() + for i := 0; i < b.N; i++ { + iter := buffer.SnapshotIter(nil, nil) + for iter.Valid() { + _ = iter.Value() + _ = iter.Key() + iter.Next() + } + iter.Close() + } +} + +func benchBatchedSnapshotIter(b *testing.B, buffer MemBuffer) { + for k := 0; k < opCnt; k++ { + buffer.Set(encodeInt(k), encodeInt(k)) + } + buffer.Staging() + b.ResetTimer() + for i := 0; i < b.N; i++ { + iter := buffer.BatchedSnapshotIter(nil, nil, false) for iter.Valid() { iter.Next() } @@ -231,6 +288,23 @@ func benchIterator(b *testing.B, buffer MemBuffer) { } } +func benchForEachInSnapshot(b *testing.B, buffer MemBuffer) { + for k := 0; k < opCnt; k++ { + buffer.Set(encodeInt(k), encodeInt(k)) + } + buffer.Staging() + b.ResetTimer() + f := func(key, value []byte) (bool, error) { + return false, nil + } + for i := 0; i < b.N; i++ { + err := buffer.ForEachInSnapshotRange(nil, nil, f, false) + if err != nil { + b.Error(err) + } + } +} + func BenchmarkMemBufferCache(b *testing.B) { fn := func(b *testing.B, buffer MemBuffer) { buf := make([][keySize]byte, b.N) diff --git a/internal/unionstore/memdb_rbt.go b/internal/unionstore/memdb_rbt.go index c805f49935..f45f941c46 100644 --- a/internal/unionstore/memdb_rbt.go +++ b/internal/unionstore/memdb_rbt.go @@ -161,6 +161,32 @@ func (db *rbtDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) { return db.RBT.IterReverse(upper, lower) } +func (db *rbtDBWithContext) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error { + db.RLock() + defer db.RUnlock() + var iter Iterator + if reverse { + iter = db.SnapshotIterReverse(upper, lower) + } else { + iter = db.SnapshotIter(lower, upper) + } + defer iter.Close() + for iter.Valid() { + stop, err := f(iter.Key(), iter.Value()) + if err != nil { + return err + } + err = iter.Next() + if err != nil { + return err + } + if stop { + break + } + } + return nil +} + // SnapshotIter returns an Iterator for a snapshot of MemBuffer. func (db *rbtDBWithContext) SnapshotIter(lower, upper []byte) Iterator { return db.RBT.SnapshotIter(lower, upper) @@ -175,3 +201,12 @@ func (db *rbtDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator { func (db *rbtDBWithContext) SnapshotGetter() Getter { return db.RBT.SnapshotGetter() } + +func (db *rbtDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + // TODO: implement this + if reverse { + return db.SnapshotIterReverse(upper, lower) + } else { + return db.SnapshotIter(lower, upper) + } +} diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 4721f837e1..2481fb3373 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1336,7 +1336,7 @@ func TestSnapshotReaderWithWrite(t *testing.T) { h := db.Staging() defer db.Release(h) - iter := db.SnapshotIter([]byte{0, 0}, []byte{0, 255}) + iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false) assert.Equal(t, iter.Key(), []byte{0, 0}) db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) // ART: node4/node16/node48 is freed and wait to be reused. @@ -1364,3 +1364,230 @@ func TestSnapshotReaderWithWrite(t *testing.T) { check(newRbtDBWithContext(), 48) check(newArtDBWithContext(), 48) } + +func TestBatchedSnapshotIter(t *testing.T) { + check := func(db *artDBWithContext, num int) { + // Insert test data + for i := 0; i < num; i++ { + db.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } + h := db.Staging() + defer db.Release(h) + + // Create iterator - should be positioned at first key + iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false) + defer iter.Close() + + // Should be able to read first key immediately + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, 0}, iter.Key()) + + // Write additional data + db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) + for i := 0; i < num; i++ { + db.Set([]byte{1, byte(i)}, []byte{1, byte(i)}) + } + + // Verify iteration + i := 0 + for ; i < num; i++ { + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, byte(i)}, iter.Key()) + require.Equal(t, []byte{0, byte(i)}, iter.Value()) + require.NoError(t, iter.Next()) + } + require.False(t, iter.Valid()) + } + + checkReverse := func(db *artDBWithContext, num int) { + for i := 0; i < num; i++ { + db.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } + h := db.Staging() + defer db.Release(h) + + iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, true) + defer iter.Close() + + // Should be positioned at last key + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, byte(num - 1)}, iter.Key()) + + db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) + for i := 0; i < num; i++ { + db.Set([]byte{1, byte(i)}, []byte{1, byte(i)}) + } + + i := num - 1 + for ; i >= 0; i-- { + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, byte(i)}, iter.Key()) + require.Equal(t, []byte{0, byte(i)}, iter.Value()) + require.NoError(t, iter.Next()) + } + require.False(t, iter.Valid()) + } + + // Run size test cases + check(newArtDBWithContext(), 3) + check(newArtDBWithContext(), 17) + check(newArtDBWithContext(), 64) + + checkReverse(newArtDBWithContext(), 3) + checkReverse(newArtDBWithContext(), 17) + checkReverse(newArtDBWithContext(), 64) +} + +func TestBatchedSnapshotIterEdgeCase(t *testing.T) { + t.Run("EdgeCases", func(t *testing.T) { + db := newArtDBWithContext() + + // invalid range - should be invalid immediately + iter := db.BatchedSnapshotIter([]byte{1}, []byte{1}, false) + require.False(t, iter.Valid()) + iter.Close() + + // empty range - should be invalid immediately + iter = db.BatchedSnapshotIter([]byte{0}, []byte{1}, false) + require.False(t, iter.Valid()) + iter.Close() + + // Single element range + db.Set([]byte{1}, []byte{1}) + iter = db.BatchedSnapshotIter([]byte{1}, []byte{2}, false) + require.True(t, iter.Valid()) + require.Equal(t, []byte{1}, iter.Key()) + require.NoError(t, iter.Next()) + require.False(t, iter.Valid()) + iter.Close() + + // Multiple elements + db.Set([]byte{2}, []byte{2}) + db.Set([]byte{3}, []byte{3}) + db.Set([]byte{4}, []byte{4}) + + // Forward iteration [2,4) + iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false) + vals := []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[0]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{2, 3}, vals) + iter.Close() + + // Reverse iteration [2,4) + iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, true) + vals = []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[0]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{3, 2}, vals) + iter.Close() + }) + + t.Run("BoundaryTests", func(t *testing.T) { + db := newArtDBWithContext() + keys := [][]byte{ + {1, 0}, {1, 2}, {1, 4}, {1, 6}, {1, 8}, + } + for _, k := range keys { + db.Set(k, k) + } + + // lower bound included + iter := db.BatchedSnapshotIter([]byte{1, 2}, []byte{1, 9}, false) + vals := []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[1]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{2, 4, 6, 8}, vals) + iter.Close() + + // upper bound excluded + iter = db.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, false) + vals = []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[1]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{0, 2, 4}, vals) + iter.Close() + + // reverse + iter = db.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, true) + vals = []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[1]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{4, 2, 0}, vals) + iter.Close() + }) + + t.Run("AlphabeticalOrder", func(t *testing.T) { + db := newArtDBWithContext() + keys := [][]byte{ + {2}, + {2, 1}, + {2, 1, 1}, + {2, 1, 1, 1}, + } + for _, k := range keys { + db.Set(k, k) + } + + // forward + iter := db.BatchedSnapshotIter([]byte{2}, []byte{3}, false) + count := 0 + for iter.Valid() { + require.Equal(t, keys[count], iter.Key()) + require.NoError(t, iter.Next()) + count++ + } + require.Equal(t, len(keys), count) + iter.Close() + + // reverse + iter = db.BatchedSnapshotIter([]byte{2}, []byte{3}, true) + count = len(keys) - 1 + for iter.Valid() { + require.Equal(t, keys[count], iter.Key()) + require.NoError(t, iter.Next()) + count-- + } + require.Equal(t, -1, count) + iter.Close() + }) + + t.Run("BatchSizeGrowth", func(t *testing.T) { + db := newArtDBWithContext() + for i := 0; i < 100; i++ { + db.Set([]byte{3, byte(i)}, []byte{3, byte(i)}) + } + + // forward + iter := db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, false) + count := 0 + for iter.Valid() { + require.Equal(t, []byte{3, byte(count)}, iter.Key()) + require.NoError(t, iter.Next()) + count++ + } + require.Equal(t, 100, count) + iter.Close() + + // reverse + iter = db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, true) + count = 99 + for iter.Valid() { + require.Equal(t, []byte{3, byte(count)}, iter.Key()) + require.NoError(t, iter.Next()) + count-- + } + require.Equal(t, -1, count) + iter.Close() + }) +} diff --git a/internal/unionstore/pipelined_memdb.go b/internal/unionstore/pipelined_memdb.go index 163a289f4c..888f2ecffb 100644 --- a/internal/unionstore/pipelined_memdb.go +++ b/internal/unionstore/pipelined_memdb.go @@ -412,6 +412,10 @@ func (p *PipelinedMemDB) IterReverse([]byte, []byte) (Iterator, error) { return nil, errors.New("pipelined memdb does not support IterReverse") } +func (db *PipelinedMemDB) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (bool, error), reverse bool) error { + return errors.New("pipelined memdb does not support ForEachInSnapshotRange") +} + // SetEntrySizeLimit sets the size limit for each entry and total buffer. func (p *PipelinedMemDB) SetEntrySizeLimit(entryLimit, _ uint64) { p.entryLimit = entryLimit @@ -550,3 +554,7 @@ func (p *PipelinedMemDB) GetMetrics() Metrics { func (p *PipelinedMemDB) MemHookSet() bool { return p.memChangeHook != nil } + +func (p *PipelinedMemDB) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + panic("BatchedSnapshotIter is not supported for PipelinedMemDB") +} diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index 1a5f1a36b9..b6beaad010 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -162,6 +162,11 @@ func (us *KVUnionStore) SetEntrySizeLimit(entryLimit, bufferLimit uint64) { us.memBuffer.SetEntrySizeLimit(entryLimit, bufferLimit) } +type KvPair struct { + Key []byte + Value []byte +} + // MemBuffer is an interface that stores mutations that written during transaction execution. // It now unifies MemDB and PipelinedMemDB. // The implementations should follow the transaction guarantees: @@ -193,15 +198,39 @@ type MemBuffer interface { Delete([]byte) error // DeleteWithFlags deletes the key k in the MemBuffer with flags. DeleteWithFlags([]byte, ...kv.FlagsOp) error + // Iter implements the Retriever interface. Iter([]byte, []byte) (Iterator, error) // IterReverse implements the Retriever interface. IterReverse([]byte, []byte) (Iterator, error) // SnapshotIter returns an Iterator for a snapshot of MemBuffer. + // Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead. SnapshotIter([]byte, []byte) Iterator // SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer. + // Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead. SnapshotIterReverse([]byte, []byte) Iterator - // SnapshotGetter returns a Getter for a snapshot of MemBuffer. + + // ForEachInSnapshotRange scans the key-value pairs in the state[0] snapshot if it exists, + // otherwise it uses the current checkpoint as snapshot. + // + // NOTE: returned kv-pairs are only valid during the iteration. If you want to use them after the iteration, + // you need to make a copy. + // + // The method is protected by a RWLock to prevent potential iterator invalidation, i.e. + // You cannot modify the MemBuffer during the iteration. + // + // Use it when you need to scan the whole range, otherwise consider using BatchedSnapshotIter. + ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error + + // BatchedSnapshotIter iterates in batches to prevent iterator invalidation: + // It does not save any iterator state, instead it copies the keys and values to a buffer. + // It behaves like SnapshotIter, but it is safe to use the returned keys and values after the iteration. + // Use it when you need on-demand "next", otherwise consider using ForEachInSnapshotRange. + // + // The iterator becomes invalid after a membuffer vlog truncation operation. + BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator + + //SnapshotGetter returns a Getter for a snapshot of MemBuffer. SnapshotGetter() Getter // InspectStage iterates all buffered keys and values in MemBuffer. InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) diff --git a/tikv/unionstore_export.go b/tikv/unionstore_export.go index 80ee88f42a..efbaf93f2c 100644 --- a/tikv/unionstore_export.go +++ b/tikv/unionstore_export.go @@ -60,3 +60,5 @@ type MemDBCheckpoint = unionstore.MemDBCheckpoint // Metrics is the metrics of unionstore. type Metrics = unionstore.Metrics + +type KvPair = unionstore.KvPair From c46d4903553c667c147888c30bcf809366f56f99 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 16:21:15 +0800 Subject: [PATCH 02/14] fix for snapshot iter Signed-off-by: ekexium --- internal/unionstore/art/art_iterator.go | 6 ++++-- internal/unionstore/art/art_snapshot.go | 1 + internal/unionstore/memdb_test.go | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index ba7ff0b855..05a23e62a4 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -80,9 +80,11 @@ type Iterator struct { // only when seqNo == art.seqNo, the iterator is valid. seqNo int + // ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation. + ignoreSeqNo bool } -func (it *Iterator) Valid() bool { return it.valid && it.seqNo == it.tree.SeqNo } +func (it *Iterator) Valid() bool { return it.valid && (it.seqNo == it.tree.SeqNo || it.ignoreSeqNo) } func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } func (it *Iterator) Value() []byte { @@ -106,7 +108,7 @@ func (it *Iterator) Next() error { // iterate is finished return errors.New("Art: iterator is finished") } - if it.seqNo != it.tree.SeqNo { + if !it.ignoreSeqNo && it.seqNo != it.tree.SeqNo { return errors.New(fmt.Sprintf("seqNo mismatch: iter=%d, art=%d", it.seqNo, it.tree.SeqNo)) } if it.currAddr == it.endAddr { diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go index 454634b234..6b240367ac 100644 --- a/internal/unionstore/art/art_snapshot.go +++ b/internal/unionstore/art/art_snapshot.go @@ -49,6 +49,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter { if err != nil { panic(err) } + inner.ignoreSeqNo = true it := &SnapIter{ Iterator: inner, cp: t.getSnapshot(), diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 2481fb3373..3261d12e1b 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1336,7 +1336,7 @@ func TestSnapshotReaderWithWrite(t *testing.T) { h := db.Staging() defer db.Release(h) - iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false) + iter := db.SnapshotIter([]byte{0, 0}, []byte{0, 255}) assert.Equal(t, iter.Key(), []byte{0, 0}) db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) // ART: node4/node16/node48 is freed and wait to be reused. From 4c16a14e44ad1904520d24f5bd3bdc7d91e10179 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 19:30:05 +0800 Subject: [PATCH 03/14] fix initKeysAndMutations Signed-off-by: ekexium --- internal/unionstore/art/art_iterator.go | 23 +++++++++++++++++++---- txnkv/transaction/2pc.go | 7 ++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index 05a23e62a4..ee825f9dd6 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -19,6 +19,9 @@ import ( "fmt" "sort" + "github.com/tikv/client-go/v2/internal/logutil" + "go.uber.org/zap" + "github.com/pkg/errors" "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/kv" @@ -84,7 +87,21 @@ type Iterator struct { ignoreSeqNo bool } -func (it *Iterator) Valid() bool { return it.valid && (it.seqNo == it.tree.SeqNo || it.ignoreSeqNo) } +func (it *Iterator) checkSeqNo() { + if it.seqNo != it.tree.SeqNo && !it.ignoreSeqNo { + logutil.BgLogger().Panic( + "seqNo mismatch", + zap.Int("it seqNo", it.seqNo), + zap.Int("art seqNo", it.tree.SeqNo), + zap.Stack("stack"), + ) + } +} + +func (it *Iterator) Valid() bool { + it.checkSeqNo() + return it.valid +} func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } func (it *Iterator) Value() []byte { @@ -108,9 +125,7 @@ func (it *Iterator) Next() error { // iterate is finished return errors.New("Art: iterator is finished") } - if !it.ignoreSeqNo && it.seqNo != it.tree.SeqNo { - return errors.New(fmt.Sprintf("seqNo mismatch: iter=%d, art=%d", it.seqNo, it.tree.SeqNo)) - } + it.checkSeqNo() if it.currAddr == it.endAddr { it.valid = false return nil diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index 6a739b3616..14307bf38e 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -559,6 +559,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations(ctx context.Context) error { var err error var assertionError error + toUpdatePrewriteOnly := make([][]byte, 0) for it := memBuf.IterWithFlags(nil, nil); it.Valid(); err = it.Next() { _ = err key := it.Key() @@ -607,7 +608,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations(ctx context.Context) error { // due to `Op_CheckNotExists` doesn't prewrite lock, so mark those keys should not be used in commit-phase. op = kvrpcpb.Op_CheckNotExists checkCnt++ - memBuf.UpdateFlags(key, kv.SetPrewriteOnly) + toUpdatePrewriteOnly = append(toUpdatePrewriteOnly, key) } else { if flags.HasNewlyInserted() { // The delete-your-write keys in pessimistic transactions, only lock needed keys and skip @@ -682,6 +683,10 @@ func (c *twoPhaseCommitter) initKeysAndMutations(ctx context.Context) error { } } + for _, key := range toUpdatePrewriteOnly { + memBuf.UpdateFlags(key, kv.SetPrewriteOnly) + } + if c.mutations.Len() == 0 { return nil } From e1a3b5aa286619b52831ac365025518b2836a2db Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 20:21:30 +0800 Subject: [PATCH 04/14] more checks Signed-off-by: ekexium --- internal/unionstore/art/art_iterator.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index ee825f9dd6..381d13753c 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -102,9 +102,19 @@ func (it *Iterator) Valid() bool { it.checkSeqNo() return it.valid } -func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } -func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } + +func (it *Iterator) Key() []byte { + it.checkSeqNo() + return it.currLeaf.GetKey() +} + +func (it *Iterator) Flags() kv.KeyFlags { + it.checkSeqNo() + return it.currLeaf.GetKeyFlags() +} + func (it *Iterator) Value() []byte { + it.checkSeqNo() if it.currLeaf.vLogAddr.IsNull() { return nil } From 29fc98ea5f26a83902c15bdadacbd293536dd161 Mon Sep 17 00:00:00 2001 From: ekexium Date: Fri, 24 Jan 2025 12:50:35 +0800 Subject: [PATCH 05/14] optimize batched iter Signed-off-by: ekexium --- internal/unionstore/memdb_art.go | 55 +++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index ebd63edce8..4598a99564 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -203,7 +203,8 @@ type snapshotBatchedIter struct { reverse bool // current batch - kvs []KvPair + keys [][]byte + values [][]byte pos int batchSize int nextKey []byte @@ -216,10 +217,9 @@ func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo lower: lower, upper: upper, reverse: reverse, - batchSize: 4, + batchSize: 32, } - // Position at first key immediately iter.fillBatch() return iter } @@ -233,10 +233,12 @@ func (it *snapshotBatchedIter) fillBatch() error { it.db.RLock() defer it.db.RUnlock() - if it.kvs == nil { - it.kvs = make([]KvPair, 0, it.batchSize) + if it.keys == nil || it.values == nil || cap(it.keys) < it.batchSize || cap(it.values) < it.batchSize { + it.keys = make([][]byte, 0, it.batchSize) + it.values = make([][]byte, 0, it.batchSize) } else { - it.kvs = it.kvs[:0] + it.keys = it.keys[:0] + it.values = it.values[:0] } var snapshotIter Iterator @@ -256,11 +258,12 @@ func (it *snapshotBatchedIter) fillBatch() error { defer snapshotIter.Close() // fill current batch + // Further optimization: let the underlying memdb support batch iter. for i := 0; i < it.batchSize && snapshotIter.Valid(); i++ { - it.kvs = append(it.kvs, KvPair{ - Key: snapshotIter.Key(), - Value: snapshotIter.Value(), - }) + it.keys = it.keys[:i+1] + it.values = it.values[:i+1] + it.keys[i] = snapshotIter.Key() + it.values[i] = snapshotIter.Value() if err := snapshotIter.Next(); err != nil { return err } @@ -268,12 +271,25 @@ func (it *snapshotBatchedIter) fillBatch() error { // update state it.pos = 0 - if len(it.kvs) > 0 { - lastKV := it.kvs[len(it.kvs)-1] + if len(it.keys) > 0 { + lastKey := it.keys[len(it.keys)-1] + keyLen := len(lastKey) + if it.reverse { - it.nextKey = append([]byte(nil), lastKV.Key...) + if cap(it.nextKey) >= keyLen { + it.nextKey = it.nextKey[:keyLen] + } else { + it.nextKey = make([]byte, keyLen) + } + copy(it.nextKey, lastKey) } else { - it.nextKey = append(append([]byte(nil), lastKV.Key...), 0) + if cap(it.nextKey) >= keyLen+1 { + it.nextKey = it.nextKey[:keyLen+1] + } else { + it.nextKey = make([]byte, keyLen+1) + } + copy(it.nextKey, lastKey) + it.nextKey[keyLen] = 0 } } else { it.nextKey = nil @@ -285,7 +301,7 @@ func (it *snapshotBatchedIter) fillBatch() error { func (it *snapshotBatchedIter) Valid() bool { return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo && - it.pos < len(it.kvs) + it.pos < len(it.keys) } func (it *snapshotBatchedIter) Next() error { @@ -300,7 +316,7 @@ func (it *snapshotBatchedIter) Next() error { } it.pos++ - if it.pos >= len(it.kvs) { + if it.pos >= len(it.keys) { return it.fillBatch() } return nil @@ -310,17 +326,18 @@ func (it *snapshotBatchedIter) Key() []byte { if !it.Valid() { return nil } - return it.kvs[it.pos].Key + return it.keys[it.pos] } func (it *snapshotBatchedIter) Value() []byte { if !it.Valid() { return nil } - return it.kvs[it.pos].Value + return it.values[it.pos] } func (it *snapshotBatchedIter) Close() { - it.kvs = nil + it.keys = nil + it.values = nil it.nextKey = nil } From e906cd537b58476bc4544c941a9c164124347875 Mon Sep 17 00:00:00 2001 From: ekexium Date: Fri, 24 Jan 2025 14:14:21 +0800 Subject: [PATCH 06/14] refine comment Signed-off-by: ekexium --- internal/unionstore/memdb_art.go | 5 +++++ internal/unionstore/union_store.go | 14 ++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index 4598a99564..306a4d6d34 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -19,6 +19,8 @@ import ( "fmt" "sync" + "github.com/tikv/client-go/v2/internal/logutil" + "github.com/pingcap/errors" tikverr "github.com/tikv/client-go/v2/error" @@ -211,6 +213,9 @@ type snapshotBatchedIter struct { } func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + if len(db.Stages()) == 0 { + logutil.BgLogger().Error("should not use BatchedSnapshotIter for a memdb without any staging buffer") + } iter := &snapshotBatchedIter{ db: db, snapshotTruncateSeqNo: db.SnapshotSeqNo, diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index b6beaad010..6127101863 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -200,8 +200,12 @@ type MemBuffer interface { DeleteWithFlags([]byte, ...kv.FlagsOp) error // Iter implements the Retriever interface. + // Any write operation to the memdb invalidates this iterator immediately after its creation. + // Attempting to use such an invalidated iterator will result in a panic. Iter([]byte, []byte) (Iterator, error) // IterReverse implements the Retriever interface. + // Any write operation to the memdb invalidates this iterator immediately after its creation. + // Attempting to use such an invalidated iterator will result in a panic. IterReverse([]byte, []byte) (Iterator, error) // SnapshotIter returns an Iterator for a snapshot of MemBuffer. // Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead. @@ -222,12 +226,14 @@ type MemBuffer interface { // Use it when you need to scan the whole range, otherwise consider using BatchedSnapshotIter. ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error - // BatchedSnapshotIter iterates in batches to prevent iterator invalidation: - // It does not save any iterator state, instead it copies the keys and values to a buffer. - // It behaves like SnapshotIter, but it is safe to use the returned keys and values after the iteration. + // BatchedSnapshotIter returns an iterator of the "snapshot", namely stage[0]. + // It iterates in batches and prevents iterator invalidation. + // // Use it when you need on-demand "next", otherwise consider using ForEachInSnapshotRange. + // NOTE: you should never use it when there are no stages. // - // The iterator becomes invalid after a membuffer vlog truncation operation. + // The iterator becomes invalid when any operation that may modify the "snapshot", + // e.g. RevertToCheckpoint or releasing stage[0]. BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator //SnapshotGetter returns a Getter for a snapshot of MemBuffer. From e672ef4789169ea3cf7410ce277f15ca3ec33df4 Mon Sep 17 00:00:00 2001 From: ekexium Date: Mon, 10 Feb 2025 14:36:58 +0800 Subject: [PATCH 07/14] minor improvements from comments Signed-off-by: ekexium --- internal/unionstore/arena/arena.go | 1 + internal/unionstore/art/art.go | 14 +++++++------- internal/unionstore/art/art_iterator.go | 6 +++--- internal/unionstore/art/art_snapshot.go | 8 +++++--- internal/unionstore/memdb_art.go | 25 ++++++++++++++++++++++--- 5 files changed, 38 insertions(+), 16 deletions(-) diff --git a/internal/unionstore/arena/arena.go b/internal/unionstore/arena/arena.go index 9671ccb492..f5e7b83855 100644 --- a/internal/unionstore/arena/arena.go +++ b/internal/unionstore/arena/arena.go @@ -226,6 +226,7 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool { return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock } +// LessThan compares two checkpoints. func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool { if cp == nil || cp2 == nil { logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2)) diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index ee98bf1e64..36513663e3 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -54,8 +54,8 @@ type ART struct { missCount atomic.Uint64 // The counter of every write operation, used to invalidate iterators that were created before the write operation. - SeqNo int - // increased by 1 when an operation that may affect the content returned by "snapshot iter" (i.e. stage[0]) happens. + WriteSeqNo int + // Increased by 1 when an operation that may affect the content returned by "snapshot iter" (i.e. stage[0]) happens. // It's used to invalidate snapshot iterators. // invariant: no concurrent access to it SnapshotSeqNo int @@ -122,7 +122,7 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { } } - t.SeqNo++ + t.WriteSeqNo++ if len(t.stages) == 0 { t.dirty = true } @@ -487,7 +487,7 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) { t.allocator.vlogAllocator.RevertToCheckpoint(t, cp) t.allocator.vlogAllocator.Truncate(cp) t.allocator.vlogAllocator.OnMemChange() - t.SeqNo++ + t.WriteSeqNo++ if len(t.stages) == 0 || t.stages[0].LessThan(cp) { t.SnapshotSeqNo++ } @@ -510,7 +510,7 @@ func (t *ART) Release(h int) { if h != len(t.stages) { panic("cannot release staging buffer") } - t.SeqNo++ + t.WriteSeqNo++ if h == 1 { t.SnapshotSeqNo++ tail := t.checkpoint() @@ -533,7 +533,7 @@ func (t *ART) Cleanup(h int) { panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages))) } - t.SeqNo++ + t.WriteSeqNo++ if h == 1 { t.SnapshotSeqNo++ } @@ -562,7 +562,7 @@ func (t *ART) Reset() { t.allocator.vlogAllocator.Reset() t.lastTraversedNode.Store(arena.NullU64Addr) t.SnapshotSeqNo++ - t.SeqNo++ + t.WriteSeqNo++ } // DiscardValues releases the memory used by all values. diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index 381d13753c..d0ea980bc3 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -59,7 +59,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (* // this avoids the initial value of currAddr equals to endAddr. currAddr: arena.BadAddr, endAddr: arena.NullAddr, - seqNo: t.SeqNo, + seqNo: t.WriteSeqNo, } it.init(lowerBound, upperBound) if !it.valid { @@ -88,11 +88,11 @@ type Iterator struct { } func (it *Iterator) checkSeqNo() { - if it.seqNo != it.tree.SeqNo && !it.ignoreSeqNo { + if it.seqNo != it.tree.WriteSeqNo && !it.ignoreSeqNo { logutil.BgLogger().Panic( "seqNo mismatch", zap.Int("it seqNo", it.seqNo), - zap.Int("art seqNo", it.tree.SeqNo), + zap.Int("art seqNo", it.tree.WriteSeqNo), zap.Stack("stack"), ) } diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go index 6b240367ac..89a9fa1993 100644 --- a/internal/unionstore/art/art_snapshot.go +++ b/internal/unionstore/art/art_snapshot.go @@ -21,7 +21,9 @@ import ( "github.com/tikv/client-go/v2/internal/unionstore/arena" ) -func (t *ART) getSnapshot() arena.MemDBCheckpoint { +// GetSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot +// of stage[0] +func (t *ART) GetSnapshot() arena.MemDBCheckpoint { if len(t.stages) > 0 { return t.stages[0] } @@ -32,7 +34,7 @@ func (t *ART) getSnapshot() arena.MemDBCheckpoint { func (t *ART) SnapshotGetter() *SnapGetter { return &SnapGetter{ tree: t, - cp: t.getSnapshot(), + cp: t.GetSnapshot(), } } @@ -52,7 +54,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter { inner.ignoreSeqNo = true it := &SnapIter{ Iterator: inner, - cp: t.getSnapshot(), + cp: t.GetSnapshot(), } it.tree.allocator.snapshotInc() for !it.setValue() && it.Valid() { diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index 306a4d6d34..144c75b8f4 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -17,6 +17,7 @@ package unionstore import ( "context" "fmt" + "go.uber.org/zap" "sync" "github.com/tikv/client-go/v2/internal/logutil" @@ -210,6 +211,9 @@ type snapshotBatchedIter struct { pos int batchSize int nextKey []byte + + // only used to check if the snapshot ever changes between batches. It is not supposed to change. + snapshot MemDBCheckpoint } func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { @@ -225,14 +229,29 @@ func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo batchSize: 32, } - iter.fillBatch() + iter.snapshot = db.GetSnapshot() + err := iter.fillBatch() + if err != nil { + logutil.BgLogger().Error("failed to fill batch for snapshotBatchedIter", zap.Error(err)) + } return iter } func (it *snapshotBatchedIter) fillBatch() error { if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { - return errors.New(fmt.Sprintf("invalid iter: truncation happened, iter's=%d, db's=%d", - it.snapshotTruncateSeqNo, it.db.SnapshotSeqNo)) + return errors.Errorf( + "invalid iter: truncation happened, iter's=%d, db's=%d", + it.snapshotTruncateSeqNo, + it.db.SnapshotSeqNo, + ) + } + + if it.db.GetSnapshot() != it.snapshot { + return errors.Errorf( + "snapshot changed between batches, expected=%v, actual=%v", + it.snapshot, + it.db.GetSnapshot(), + ) } it.db.RLock() From 941d94b94335b9b0f67feda42590e9d3444bfd87 Mon Sep 17 00:00:00 2001 From: ekexium Date: Mon, 10 Feb 2025 14:54:53 +0800 Subject: [PATCH 08/14] fix: return error if the first fillBatch fails Signed-off-by: ekexium --- internal/unionstore/arena/arena.go | 3 +-- internal/unionstore/memdb_art.go | 13 +++++++------ internal/unionstore/union_store.go | 5 ----- tikv/unionstore_export.go | 2 -- 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/internal/unionstore/arena/arena.go b/internal/unionstore/arena/arena.go index f5e7b83855..6ae3e890d9 100644 --- a/internal/unionstore/arena/arena.go +++ b/internal/unionstore/arena/arena.go @@ -39,10 +39,9 @@ import ( "math" "github.com/tikv/client-go/v2/internal/logutil" - "go.uber.org/zap" - "github.com/tikv/client-go/v2/kv" "go.uber.org/atomic" + "go.uber.org/zap" ) const ( diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index 144c75b8f4..0f0f17338b 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -17,7 +17,6 @@ package unionstore import ( "context" "fmt" - "go.uber.org/zap" "sync" "github.com/tikv/client-go/v2/internal/logutil" @@ -204,6 +203,7 @@ type snapshotBatchedIter struct { lower []byte upper []byte reverse bool + err error // current batch keys [][]byte @@ -230,10 +230,7 @@ func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo } iter.snapshot = db.GetSnapshot() - err := iter.fillBatch() - if err != nil { - logutil.BgLogger().Error("failed to fill batch for snapshotBatchedIter", zap.Error(err)) - } + iter.err = iter.fillBatch() return iter } @@ -325,10 +322,14 @@ func (it *snapshotBatchedIter) fillBatch() error { func (it *snapshotBatchedIter) Valid() bool { return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo && - it.pos < len(it.keys) + it.pos < len(it.keys) && + it.err == nil } func (it *snapshotBatchedIter) Next() error { + if it.err != nil { + return it.err + } if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { return errors.New( fmt.Sprintf( diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index 6127101863..c6d091f592 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -162,11 +162,6 @@ func (us *KVUnionStore) SetEntrySizeLimit(entryLimit, bufferLimit uint64) { us.memBuffer.SetEntrySizeLimit(entryLimit, bufferLimit) } -type KvPair struct { - Key []byte - Value []byte -} - // MemBuffer is an interface that stores mutations that written during transaction execution. // It now unifies MemDB and PipelinedMemDB. // The implementations should follow the transaction guarantees: diff --git a/tikv/unionstore_export.go b/tikv/unionstore_export.go index efbaf93f2c..80ee88f42a 100644 --- a/tikv/unionstore_export.go +++ b/tikv/unionstore_export.go @@ -60,5 +60,3 @@ type MemDBCheckpoint = unionstore.MemDBCheckpoint // Metrics is the metrics of unionstore. type Metrics = unionstore.Metrics - -type KvPair = unionstore.KvPair From de43ee43d1a01f55aced9930b0c971810a8a6617 Mon Sep 17 00:00:00 2001 From: ekexium Date: Tue, 11 Feb 2025 15:02:03 +0800 Subject: [PATCH 09/14] revert snapshot check and minor improvements Signed-off-by: ekexium --- internal/unionstore/art/art.go | 2 +- internal/unionstore/art/art_snapshot.go | 8 ++-- internal/unionstore/memdb_art.go | 50 ++++++++++--------------- internal/unionstore/memdb_rbt.go | 2 +- internal/unionstore/memdb_test.go | 35 +++++++++++++---- 5 files changed, 53 insertions(+), 44 deletions(-) diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 36513663e3..dccfc50064 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -55,7 +55,7 @@ type ART struct { // The counter of every write operation, used to invalidate iterators that were created before the write operation. WriteSeqNo int - // Increased by 1 when an operation that may affect the content returned by "snapshot iter" (i.e. stage[0]) happens. + // Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens. // It's used to invalidate snapshot iterators. // invariant: no concurrent access to it SnapshotSeqNo int diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go index 89a9fa1993..b899d7c49c 100644 --- a/internal/unionstore/art/art_snapshot.go +++ b/internal/unionstore/art/art_snapshot.go @@ -21,9 +21,9 @@ import ( "github.com/tikv/client-go/v2/internal/unionstore/arena" ) -// GetSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot +// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot // of stage[0] -func (t *ART) GetSnapshot() arena.MemDBCheckpoint { +func (t *ART) getSnapshot() arena.MemDBCheckpoint { if len(t.stages) > 0 { return t.stages[0] } @@ -34,7 +34,7 @@ func (t *ART) GetSnapshot() arena.MemDBCheckpoint { func (t *ART) SnapshotGetter() *SnapGetter { return &SnapGetter{ tree: t, - cp: t.GetSnapshot(), + cp: t.getSnapshot(), } } @@ -54,7 +54,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter { inner.ignoreSeqNo = true it := &SnapIter{ Iterator: inner, - cp: t.GetSnapshot(), + cp: t.getSnapshot(), } it.tree.allocator.snapshotInc() for !it.setValue() && it.Valid() { diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index 0f0f17338b..49d22a40c7 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -198,12 +198,12 @@ func (db *artDBWithContext) SnapshotGetter() Getter { } type snapshotBatchedIter struct { - db *artDBWithContext - snapshotTruncateSeqNo int - lower []byte - upper []byte - reverse bool - err error + db *artDBWithContext + snapshotSeqNo int + lower []byte + upper []byte + reverse bool + err error // current batch keys [][]byte @@ -211,9 +211,6 @@ type snapshotBatchedIter struct { pos int batchSize int nextKey []byte - - // only used to check if the snapshot ever changes between batches. It is not supposed to change. - snapshot MemDBCheckpoint } func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { @@ -221,36 +218,27 @@ func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo logutil.BgLogger().Error("should not use BatchedSnapshotIter for a memdb without any staging buffer") } iter := &snapshotBatchedIter{ - db: db, - snapshotTruncateSeqNo: db.SnapshotSeqNo, - lower: lower, - upper: upper, - reverse: reverse, - batchSize: 32, + db: db, + snapshotSeqNo: db.SnapshotSeqNo, + lower: lower, + upper: upper, + reverse: reverse, + batchSize: 32, } - iter.snapshot = db.GetSnapshot() iter.err = iter.fillBatch() return iter } func (it *snapshotBatchedIter) fillBatch() error { - if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { + if it.snapshotSeqNo != it.db.SnapshotSeqNo { return errors.Errorf( - "invalid iter: truncation happened, iter's=%d, db's=%d", - it.snapshotTruncateSeqNo, + "invalid iter: snapshotSeqNo changed, iter's=%d, db's=%d", + it.snapshotSeqNo, it.db.SnapshotSeqNo, ) } - if it.db.GetSnapshot() != it.snapshot { - return errors.Errorf( - "snapshot changed between batches, expected=%v, actual=%v", - it.snapshot, - it.db.GetSnapshot(), - ) - } - it.db.RLock() defer it.db.RUnlock() @@ -321,7 +309,7 @@ func (it *snapshotBatchedIter) fillBatch() error { } func (it *snapshotBatchedIter) Valid() bool { - return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo && + return it.snapshotSeqNo == it.db.SnapshotSeqNo && it.pos < len(it.keys) && it.err == nil } @@ -330,11 +318,11 @@ func (it *snapshotBatchedIter) Next() error { if it.err != nil { return it.err } - if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { + if it.snapshotSeqNo != it.db.SnapshotSeqNo { return errors.New( fmt.Sprintf( - "invalid snapshotBatchedIter: truncation happened, iter's=%d, db's=%d", - it.snapshotTruncateSeqNo, + "invalid snapshotBatchedIter: snapshotSeqNo changed, iter's=%d, db's=%d", + it.snapshotSeqNo, it.db.SnapshotSeqNo, ), ) diff --git a/internal/unionstore/memdb_rbt.go b/internal/unionstore/memdb_rbt.go index f45f941c46..3524150474 100644 --- a/internal/unionstore/memdb_rbt.go +++ b/internal/unionstore/memdb_rbt.go @@ -203,7 +203,7 @@ func (db *rbtDBWithContext) SnapshotGetter() Getter { } func (db *rbtDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { - // TODO: implement this + // TODO: implement *batched* iter if reverse { return db.SnapshotIterReverse(upper, lower) } else { diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 3261d12e1b..3321564b6b 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1441,7 +1441,8 @@ func TestBatchedSnapshotIter(t *testing.T) { func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("EdgeCases", func(t *testing.T) { db := newArtDBWithContext() - + h := db.Staging() + defer db.Release(h) // invalid range - should be invalid immediately iter := db.BatchedSnapshotIter([]byte{1}, []byte{1}, false) require.False(t, iter.Valid()) @@ -1453,7 +1454,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { iter.Close() // Single element range - db.Set([]byte{1}, []byte{1}) + _ = db.Set([]byte{1}, []byte{1}) iter = db.BatchedSnapshotIter([]byte{1}, []byte{2}, false) require.True(t, iter.Valid()) require.Equal(t, []byte{1}, iter.Key()) @@ -1462,9 +1463,9 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { iter.Close() // Multiple elements - db.Set([]byte{2}, []byte{2}) - db.Set([]byte{3}, []byte{3}) - db.Set([]byte{4}, []byte{4}) + _ = db.Set([]byte{2}, []byte{2}) + _ = db.Set([]byte{3}, []byte{3}) + _ = db.Set([]byte{4}, []byte{4}) // Forward iteration [2,4) iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false) @@ -1489,11 +1490,13 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("BoundaryTests", func(t *testing.T) { db := newArtDBWithContext() + h := db.Staging() + defer db.Release(h) keys := [][]byte{ {1, 0}, {1, 2}, {1, 4}, {1, 6}, {1, 8}, } for _, k := range keys { - db.Set(k, k) + _ = db.Set(k, k) } // lower bound included @@ -1529,6 +1532,8 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("AlphabeticalOrder", func(t *testing.T) { db := newArtDBWithContext() + h := db.Staging() + defer db.Release(h) keys := [][]byte{ {2}, {2, 1}, @@ -1564,8 +1569,10 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("BatchSizeGrowth", func(t *testing.T) { db := newArtDBWithContext() + h := db.Staging() + defer db.Release(h) for i := 0; i < 100; i++ { - db.Set([]byte{3, byte(i)}, []byte{3, byte(i)}) + _ = db.Set([]byte{3, byte(i)}, []byte{3, byte(i)}) } // forward @@ -1590,4 +1597,18 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { require.Equal(t, -1, count) iter.Close() }) + + t.Run("SnapshotChange", func(t *testing.T) { + db := newArtDBWithContext() + _ = db.Set([]byte{0}, []byte{0}) + h := db.Staging() + _ = db.Set([]byte{byte(1)}, []byte{byte(1)}) + iter := db.BatchedSnapshotIter([]byte{0}, []byte{255}, false) + require.True(t, iter.Valid()) + require.NoError(t, iter.Next()) + db.Release(h) + db.Staging() + require.False(t, iter.Valid()) + require.Error(t, iter.Next()) + }) } From b15c57a5d46e2ec902808d581cf705b0b192fd74 Mon Sep 17 00:00:00 2001 From: ekexium Date: Wed, 12 Feb 2025 16:14:23 +0800 Subject: [PATCH 10/14] comment: explain the sequence number and data race Signed-off-by: ekexium --- internal/unionstore/art/art.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index dccfc50064..734bcc1a1c 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -54,6 +54,9 @@ type ART struct { missCount atomic.Uint64 // The counter of every write operation, used to invalidate iterators that were created before the write operation. + // The purpose of the counter is to check interleaving of write and read operations (via iterator). + // It does not protect against data race. If it happens, there must be a bug in the caller code. + // invariant: no concurrent access to it WriteSeqNo int // Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens. // It's used to invalidate snapshot iterators. From 9b452306e11077540aad1e0024f091fd2eeccf20 Mon Sep 17 00:00:00 2001 From: ekexium Date: Wed, 12 Feb 2025 16:30:20 +0800 Subject: [PATCH 11/14] fix test Signed-off-by: ekexium --- internal/unionstore/memdb_test.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 3321564b6b..824f9b0c57 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1442,7 +1442,6 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("EdgeCases", func(t *testing.T) { db := newArtDBWithContext() h := db.Staging() - defer db.Release(h) // invalid range - should be invalid immediately iter := db.BatchedSnapshotIter([]byte{1}, []byte{1}, false) require.False(t, iter.Valid()) @@ -1455,6 +1454,8 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { // Single element range _ = db.Set([]byte{1}, []byte{1}) + db.Release(h) + h = db.Staging() iter = db.BatchedSnapshotIter([]byte{1}, []byte{2}, false) require.True(t, iter.Valid()) require.Equal(t, []byte{1}, iter.Key()) @@ -1466,6 +1467,8 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { _ = db.Set([]byte{2}, []byte{2}) _ = db.Set([]byte{3}, []byte{3}) _ = db.Set([]byte{4}, []byte{4}) + db.Release(h) + h = db.Staging() // Forward iteration [2,4) iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false) @@ -1490,8 +1493,6 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("BoundaryTests", func(t *testing.T) { db := newArtDBWithContext() - h := db.Staging() - defer db.Release(h) keys := [][]byte{ {1, 0}, {1, 2}, {1, 4}, {1, 6}, {1, 8}, } @@ -1500,6 +1501,8 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { } // lower bound included + h := db.Staging() + defer db.Release(h) iter := db.BatchedSnapshotIter([]byte{1, 2}, []byte{1, 9}, false) vals := []byte{} for iter.Valid() { @@ -1532,8 +1535,6 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("AlphabeticalOrder", func(t *testing.T) { db := newArtDBWithContext() - h := db.Staging() - defer db.Release(h) keys := [][]byte{ {2}, {2, 1}, @@ -1541,9 +1542,10 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { {2, 1, 1, 1}, } for _, k := range keys { - db.Set(k, k) + _ = db.Set(k, k) } - + h := db.Staging() + defer db.Release(h) // forward iter := db.BatchedSnapshotIter([]byte{2}, []byte{3}, false) count := 0 @@ -1569,12 +1571,12 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { t.Run("BatchSizeGrowth", func(t *testing.T) { db := newArtDBWithContext() - h := db.Staging() - defer db.Release(h) + for i := 0; i < 100; i++ { _ = db.Set([]byte{3, byte(i)}, []byte{3, byte(i)}) } - + h := db.Staging() + defer db.Release(h) // forward iter := db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, false) count := 0 From 8e6e654a5d4497736d811d88b8a0ecb509faad90 Mon Sep 17 00:00:00 2001 From: ekexium Date: Tue, 18 Feb 2025 20:34:52 +0800 Subject: [PATCH 12/14] address comments: minor improvements Signed-off-by: ekexium --- internal/unionstore/art/art.go | 3 +++ internal/unionstore/art/art_iterator.go | 7 +++---- internal/unionstore/memdb_art.go | 4 +--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 734bcc1a1c..26f4bbfaf4 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -128,6 +128,9 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { t.WriteSeqNo++ if len(t.stages) == 0 { t.dirty = true + + // note: there is no such usage in TiDB + t.SnapshotSeqNo++ } // 1. create or search the existing leaf in the tree. addr, leaf := t.traverse(key, true) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index d0ea980bc3..fd5f32839d 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -19,12 +19,11 @@ import ( "fmt" "sort" - "github.com/tikv/client-go/v2/internal/logutil" - "go.uber.org/zap" - "github.com/pkg/errors" + "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/kv" + "go.uber.org/zap" ) func (t *ART) Iter(lowerBound, upperBound []byte) (*Iterator, error) { @@ -81,7 +80,7 @@ type Iterator struct { currAddr arena.MemdbArenaAddr endAddr arena.MemdbArenaAddr - // only when seqNo == art.seqNo, the iterator is valid. + // only when seqNo == art.WriteSeqNo, the iterator is valid. seqNo int // ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation. ignoreSeqNo bool diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index 49d22a40c7..f6b3bcafe9 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -19,11 +19,9 @@ import ( "fmt" "sync" - "github.com/tikv/client-go/v2/internal/logutil" - "github.com/pingcap/errors" - tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/internal/unionstore/art" "github.com/tikv/client-go/v2/kv" From 830ba071aa1b0870fd5cf9944a2244d803e463f3 Mon Sep 17 00:00:00 2001 From: ekexium Date: Wed, 19 Feb 2025 14:29:11 +0800 Subject: [PATCH 13/14] fix lint Signed-off-by: ekexium --- internal/unionstore/memdb_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 824f9b0c57..ccd0699e0b 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1468,7 +1468,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) { _ = db.Set([]byte{3}, []byte{3}) _ = db.Set([]byte{4}, []byte{4}) db.Release(h) - h = db.Staging() + _ = db.Staging() // Forward iteration [2,4) iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false) From 30a8a00df96e72ae1662ce28c7c24d5bc12f457d Mon Sep 17 00:00:00 2001 From: ekexium Date: Wed, 19 Feb 2025 14:39:58 +0800 Subject: [PATCH 14/14] comment: better explain concurrency concern about sequence numbers Signed-off-by: ekexium --- internal/unionstore/art/art.go | 8 +++++--- internal/unionstore/memdb_art.go | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 26f4bbfaf4..4df6a86a77 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -47,8 +47,10 @@ type ART struct { len int size int + // These variables serve the caching mechanism, meaning they can be concurrently updated by read operations, thus + // they are protected by atomic operations. // The lastTraversedNode stores addr in uint64 of the last traversed node, includes search and recursiveInsert. - // Compare to atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient. + // Compared with atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient. lastTraversedNode atomic.Uint64 hitCount atomic.Uint64 missCount atomic.Uint64 @@ -56,11 +58,11 @@ type ART struct { // The counter of every write operation, used to invalidate iterators that were created before the write operation. // The purpose of the counter is to check interleaving of write and read operations (via iterator). // It does not protect against data race. If it happens, there must be a bug in the caller code. - // invariant: no concurrent access to it + // invariant: no concurrent write/write or read/write access to it WriteSeqNo int // Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens. // It's used to invalidate snapshot iterators. - // invariant: no concurrent access to it + // invariant: no concurrent write/write or read/write access to it SnapshotSeqNo int } diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index f6b3bcafe9..16438d5e58 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -229,6 +229,8 @@ func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo } func (it *snapshotBatchedIter) fillBatch() error { + // The check of sequence numbers don't have to be protected by the rwlock, as the invariant is that + // there cannot be concurrent writes to the seqNo variables. if it.snapshotSeqNo != it.db.SnapshotSeqNo { return errors.Errorf( "invalid iter: snapshotSeqNo changed, iter's=%d, db's=%d",