Skip to content

Commit 41d9ab3

Browse files
authored
Fix Dead lock in shard manager (#23446)
Signed-off-by: xiaofan-luan <xiaofan.luan@zilliz.com>
1 parent 52e8460 commit 41d9ab3

File tree

6 files changed

+122
-126
lines changed

6 files changed

+122
-126
lines changed

internal/querynode/distribution.go

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package querynode
1919
import (
2020
"context"
2121
"sync"
22+
"time"
2223

2324
"github.com/milvus-io/milvus/internal/log"
2425
"github.com/milvus-io/milvus/internal/util/typeutil"
@@ -42,9 +43,6 @@ type distribution struct {
4243
// version indicator
4344
version int64
4445

45-
// offline is the quick healthy check indicator for offline segments
46-
offlines *atomic.Int32
47-
4846
snapshots *typeutil.ConcurrentMap[int64, *snapshot]
4947
// current is the snapshot for quick usage for search/query
5048
// generated for each change of distribution
@@ -77,7 +75,6 @@ func NewDistribution(replicaID int64) *distribution {
7775
replicaID: replicaID,
7876
sealedSegments: make(map[UniqueID]SegmentEntry),
7977
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
80-
offlines: atomic.NewInt32(0),
8178
current: atomic.NewPointer[snapshot](nil),
8279
}
8380

@@ -89,16 +86,29 @@ func (d *distribution) getLogger() *log.MLogger {
8986
return log.Ctx(context.Background()).With(zap.Int64("replicaID", d.replicaID))
9087
}
9188

92-
// Serviceable returns whether all segment recorded is in loaded state.
9389
func (d *distribution) Serviceable() bool {
94-
return d.offlines.Load() == 0
90+
d.mut.RLock()
91+
defer d.mut.RUnlock()
92+
return d.serviceableImpl()
93+
}
94+
95+
// Serviceable returns whether all segment recorded is in loaded state, hold d.mut before call it
96+
func (d *distribution) serviceableImpl() bool {
97+
for _, entry := range d.sealedSegments {
98+
if entry.State != segmentStateLoaded {
99+
return false
100+
}
101+
}
102+
return true
95103
}
96104

97105
// GetCurrent returns current snapshot.
98106
func (d *distribution) GetCurrent(partitions ...int64) (sealed []SnapshotItem, version int64) {
99107
d.mut.RLock()
100108
defer d.mut.RUnlock()
101-
109+
if !d.serviceableImpl() {
110+
return nil, -1
111+
}
102112
current := d.current.Load()
103113
sealed = current.Get(partitions...)
104114
version = current.version
@@ -142,14 +152,15 @@ func (d *distribution) UpdateDistribution(entries ...SegmentEntry) {
142152

143153
for _, entry := range entries {
144154
old, ok := d.sealedSegments[entry.SegmentID]
155+
d.getLogger().Info("Update distribution", zap.Int64("segmentID", entry.SegmentID),
156+
zap.Int64("node", entry.NodeID),
157+
zap.Bool("segment exist", ok))
145158
if !ok {
146159
d.sealedSegments[entry.SegmentID] = entry
147-
if entry.State == segmentStateOffline {
148-
d.offlines.Add(1)
149-
}
150160
continue
151161
}
152-
d.updateSegment(old, entry)
162+
old.Update(entry)
163+
d.sealedSegments[old.SegmentID] = old
153164
}
154165

155166
d.genSnapshot()
@@ -160,76 +171,43 @@ func (d *distribution) NodeDown(nodeID int64) {
160171
d.mut.Lock()
161172
defer d.mut.Unlock()
162173

163-
var delta int32
164-
174+
d.getLogger().Info("handle node down", zap.Int64("node", nodeID))
165175
for _, entry := range d.sealedSegments {
166176
if entry.NodeID == nodeID && entry.State != segmentStateOffline {
167177
entry.State = segmentStateOffline
168178
d.sealedSegments[entry.SegmentID] = entry
169-
delta++
179+
d.getLogger().Info("update the segment to offline since nodeDown", zap.Int64("nodeID", nodeID), zap.Int64("segmentID", entry.SegmentID))
170180
}
171181
}
172-
173-
if delta != 0 {
174-
d.offlines.Add(delta)
175-
d.getLogger().Info("distribution updated since nodeDown", zap.Int32("delta", delta), zap.Int32("offlines", d.offlines.Load()), zap.Int64("nodeID", nodeID))
176-
}
177-
}
178-
179-
// updateSegment update segment entry value and offline segment number based on old/new state.
180-
func (d *distribution) updateSegment(old, new SegmentEntry) {
181-
delta := int32(0)
182-
switch {
183-
case old.State != segmentStateLoaded && new.State == segmentStateLoaded:
184-
delta = -1
185-
case old.State == segmentStateLoaded && new.State != segmentStateLoaded:
186-
delta = 1
187-
}
188-
189-
old.Update(new)
190-
d.sealedSegments[old.SegmentID] = old
191-
if delta != 0 {
192-
d.offlines.Add(delta)
193-
d.getLogger().Info("distribution updated since segment update",
194-
zap.Int32("delta", delta),
195-
zap.Int32("offlines", d.offlines.Load()),
196-
zap.Int64("segmentID", new.SegmentID),
197-
zap.Int32("state", int32(new.State)),
198-
)
199-
}
200182
}
201183

202184
// RemoveDistributions remove segments distributions and returns the clear signal channel,
203185
// requires the read lock of shard cluster mut held
204186
func (d *distribution) RemoveDistributions(releaseFn func(), sealedSegments ...SegmentEntry) {
205187
d.mut.Lock()
206188
defer d.mut.Unlock()
207-
208-
var delta int32
209189
for _, sealed := range sealedSegments {
210190
entry, ok := d.sealedSegments[sealed.SegmentID]
191+
d.getLogger().Info("Remove distribution", zap.Int64("segmentID", sealed.SegmentID),
192+
zap.Int64("node", sealed.NodeID),
193+
zap.Bool("segment exist", ok))
211194
if !ok {
212195
continue
213196
}
214197
if entry.NodeID == sealed.NodeID || sealed.NodeID == wildcardNodeID {
215-
if entry.State == segmentStateOffline {
216-
delta--
217-
}
218198
delete(d.sealedSegments, sealed.SegmentID)
219199
}
220200
}
221-
222-
d.offlines.Add(delta)
223-
201+
ts := time.Now()
224202
<-d.genSnapshot()
225203
releaseFn()
204+
d.getLogger().Info("successfully remove distribution", zap.Any("segments", sealedSegments), zap.Duration("time", time.Since(ts)))
226205
}
227206

228207
// getSnapshot converts current distribution to snapshot format.
229208
// in which, user could juse found nodeID=>segmentID list.
230209
// mutex RLock is required before calling this method.
231210
func (d *distribution) genSnapshot() chan struct{} {
232-
233211
nodeSegments := make(map[int64][]SegmentEntry)
234212
for _, entry := range d.sealedSegments {
235213
nodeSegments[entry.NodeID] = append(nodeSegments[entry.NodeID], entry)
@@ -260,6 +238,7 @@ func (d *distribution) genSnapshot() chan struct{} {
260238
return ch
261239
}
262240

241+
d.getLogger().Info("gen snapshot for version", zap.Any("version", d.version), zap.Any("is serviceable", d.serviceableImpl()))
263242
last.Expire(d.getCleanup(last.version))
264243

265244
return last.cleared

0 commit comments

Comments
 (0)