Skip to content

Commit 5574ecc

Browse files
committed
enhance: support balancing multiple collections in single trigger
- Optimize balance_checker to support balancing multiple collections simultaneously - Add new parameters for segment and channel balancing batch sizes - Add enableBalanceOnMultipleCollections parameter - Update tests for balance checker This change improves resource utilization by allowing the system to balance multiple collections in a single trigger with configurable batch sizes. Signed-off-by: Wei Liu <[email protected]>
1 parent f094d02 commit 5574ecc

File tree

8 files changed

+265
-30
lines changed

8 files changed

+265
-30
lines changed

internal/querycoordv2/balance/balance.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int
9292
return cnt1+delta1 < cnt2+delta2
9393
})
9494

95-
balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
95+
balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
9696
ret := make([]SegmentAssignPlan, 0, len(segments))
9797
for i, s := range segments {
9898
plan := SegmentAssignPlan{

internal/querycoordv2/balance/rowcount_based_balancer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID
6565
return segments[i].GetNumOfRows() > segments[j].GetNumOfRows()
6666
})
6767

68-
balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
68+
balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
6969
plans := make([]SegmentAssignPlan, 0, len(segments))
7070
for _, s := range segments {
7171
// pick the node with the least row count and allocate to it.

internal/querycoordv2/balance/score_based_balancer.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (b *ScoreBasedBalancer) assignSegment(br *balanceReport, collectionID int64
6969
}
7070
return normalNode
7171
})
72-
balanceBatchSize = paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
72+
balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
7373
}
7474

7575
// calculate each node's score
@@ -163,7 +163,7 @@ func (b *ScoreBasedBalancer) assignChannel(br *balanceReport, collectionID int64
163163
}
164164
return normalNode
165165
})
166-
balanceBatchSize = paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.GetAsInt()
166+
balanceBatchSize = paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt()
167167
}
168168

169169
// calculate each node's score
@@ -653,7 +653,7 @@ func (b *ScoreBasedBalancer) genChannelPlan(ctx context.Context, br *balanceRepo
653653
channelDist[node] = b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
654654
}
655655

656-
balanceBatchSize := paramtable.Get().QueryCoordCfg.CollectionBalanceSegmentBatchSize.GetAsInt()
656+
balanceBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
657657
// find the segment from the node which has more score than the average
658658
channelsToMove := make([]*meta.DmChannel, 0)
659659
for node, channels := range channelDist {

internal/querycoordv2/balance/score_based_balancer_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,8 +1371,8 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceChannelOnDifferentQN() {
13711371
suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, 2)
13721372
utils.RecoverAllCollection(balancer.meta)
13731373

1374-
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key, "10")
1375-
defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.CollectionBalanceChannelBatchSize.Key)
1374+
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.Key, "10")
1375+
defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.Key)
13761376

13771377
// test balance channel on same query node
13781378
_, channelPlans = suite.getCollectionBalancePlans(balancer, collectionID)

internal/querycoordv2/checkers/balance_checker.go

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,9 @@ func (b *BalanceChecker) getReplicaForStoppingBalance(ctx context.Context) []int
108108
continue
109109
}
110110
if b.stoppingBalanceCollectionsCurrentRound.Contain(cid) {
111-
log.RatedDebug(10, "BalanceChecker is balancing this collection, skip balancing in this round",
112-
zap.Int64("collectionID", cid))
113111
continue
114112
}
113+
115114
replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid)
116115
stoppingReplicas := make([]int64, 0)
117116
for _, replica := range replicas {
@@ -208,42 +207,70 @@ func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64
208207
return segmentPlans, channelPlans
209208
}
210209

210+
// Notice: balance checker will generate tasks for multiple collections in one round,
211+
// so generated tasks will be submitted to scheduler directly, and return nil
211212
func (b *BalanceChecker) Check(ctx context.Context) []task.Task {
212-
var segmentPlans []balance.SegmentAssignPlan
213-
var channelPlans []balance.ChannelAssignPlan
213+
segmentBatchSize := paramtable.Get().QueryCoordCfg.BalanceSegmentBatchSize.GetAsInt()
214+
channelBatchSize := paramtable.Get().QueryCoordCfg.BalanceChannelBatchSize.GetAsInt()
215+
balanceOnMultipleCollections := paramtable.Get().QueryCoordCfg.EnableBalanceOnMultipleCollections.GetAsBool()
216+
217+
segmentTasks := make([]task.Task, 0)
218+
channelTasks := make([]task.Task, 0)
219+
220+
generateBalanceTaskForReplicas := func(replicas []int64) {
221+
segmentPlans, channelPlans := b.balanceReplicas(ctx, replicas)
222+
tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans)
223+
task.SetPriority(task.TaskPriorityLow, tasks...)
224+
task.SetReason("segment unbalanced", tasks...)
225+
segmentTasks = append(segmentTasks, tasks...)
226+
227+
tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans)
228+
task.SetReason("channel unbalanced", tasks...)
229+
channelTasks = append(channelTasks, tasks...)
230+
}
231+
214232
stoppingReplicas := b.getReplicaForStoppingBalance(ctx)
215233
if len(stoppingReplicas) > 0 {
216234
// check for stopping balance first
217-
segmentPlans, channelPlans = b.balanceReplicas(ctx, stoppingReplicas)
235+
generateBalanceTaskForReplicas(stoppingReplicas)
218236
// iterate all collection to find a collection to balance
219-
for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 {
220-
replicasToBalance := b.getReplicaForStoppingBalance(ctx)
221-
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
237+
for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.stoppingBalanceCollectionsCurrentRound.Len() > 0 {
238+
if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) {
239+
// if balance on multiple collections is disabled, and there are already some tasks, break
240+
break
241+
}
242+
if len(channelTasks) < channelBatchSize {
243+
replicasToBalance := b.getReplicaForStoppingBalance(ctx)
244+
generateBalanceTaskForReplicas(replicasToBalance)
245+
}
222246
}
223247
} else {
224248
// then check for auto balance
225249
if time.Since(b.autoBalanceTs) > paramtable.Get().QueryCoordCfg.AutoBalanceInterval.GetAsDuration(time.Millisecond) {
226250
b.autoBalanceTs = time.Now()
227251
replicasToBalance := b.getReplicaForNormalBalance(ctx)
228-
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
252+
generateBalanceTaskForReplicas(replicasToBalance)
229253
// iterate all collection to find a collection to balance
230-
for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 {
254+
for len(segmentTasks) < segmentBatchSize && len(channelTasks) < channelBatchSize && b.normalBalanceCollectionsCurrentRound.Len() > 0 {
255+
if !balanceOnMultipleCollections && (len(segmentTasks) > 0 || len(channelTasks) > 0) {
256+
// if balance on multiple collections is disabled, and there are already some tasks, break
257+
break
258+
}
231259
replicasToBalance := b.getReplicaForNormalBalance(ctx)
232-
segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance)
260+
generateBalanceTaskForReplicas(replicasToBalance)
233261
}
234262
}
235263
}
236264

237-
ret := make([]task.Task, 0)
238-
tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans)
239-
task.SetPriority(task.TaskPriorityLow, tasks...)
240-
task.SetReason("segment unbalanced", tasks...)
241-
ret = append(ret, tasks...)
265+
for _, task := range segmentTasks {
266+
b.scheduler.Add(task)
267+
}
268+
269+
for _, task := range channelTasks {
270+
b.scheduler.Add(task)
271+
}
242272

243-
tasks = balance.CreateChannelTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), channelPlans)
244-
task.SetReason("channel unbalanced", tasks...)
245-
ret = append(ret, tasks...)
246-
return ret
273+
return nil
247274
}
248275

249276
func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int64) []int64 {
@@ -252,10 +279,15 @@ func (b *BalanceChecker) sortCollections(ctx context.Context, collections []int6
252279
sortOrder = "byrowcount" // Default to ByRowCount
253280
}
254281

282+
collectionRowCountMap := make(map[int64]int64)
283+
for _, cid := range collections {
284+
collectionRowCountMap[cid] = b.targetMgr.GetCollectionRowCount(ctx, cid, meta.CurrentTargetFirst)
285+
}
286+
255287
// Define sorting functions
256288
sortByRowCount := func(i, j int) bool {
257-
rowCount1 := b.targetMgr.GetCollectionRowCount(ctx, collections[i], meta.CurrentTargetFirst)
258-
rowCount2 := b.targetMgr.GetCollectionRowCount(ctx, collections[j], meta.CurrentTargetFirst)
289+
rowCount1 := collectionRowCountMap[collections[i]]
290+
rowCount2 := collectionRowCountMap[collections[j]]
259291
return rowCount1 > rowCount2 || (rowCount1 == rowCount2 && collections[i] < collections[j])
260292
}
261293

internal/querycoordv2/checkers/balance_checker_test.go

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ func (suite *BalanceCheckerTestSuite) SetupTest() {
7777
suite.meta = meta.NewMeta(idAllocator, store, suite.nodeMgr)
7878
suite.broker = meta.NewMockBroker(suite.T())
7979
suite.scheduler = task.NewMockScheduler(suite.T())
80+
suite.scheduler.EXPECT().Add(mock.Anything).Return(nil).Maybe()
8081
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
8182

8283
suite.balancer = balance.NewMockBalancer(suite.T())
@@ -326,8 +327,16 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() {
326327
}
327328
segPlans = append(segPlans, mockPlan)
328329
suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans)
329-
tasks := suite.checker.Check(context.TODO())
330+
331+
tasks := make([]task.Task, 0)
332+
suite.scheduler.ExpectedCalls = nil
333+
suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(task task.Task) error {
334+
tasks = append(tasks, task)
335+
return nil
336+
})
337+
suite.checker.Check(context.TODO())
330338
suite.Len(tasks, 1)
339+
suite.Equal(tasks[0], mockPlan)
331340
}
332341

333342
func (suite *BalanceCheckerTestSuite) TestTargetNotReady() {
@@ -850,6 +859,162 @@ func (suite *BalanceCheckerTestSuite) TestHasUnbalancedCollectionFlag() {
850859
"stoppingBalanceCollectionsCurrentRound should contain the collection when it has RO nodes")
851860
}
852861

862+
func (suite *BalanceCheckerTestSuite) TestCheckBatchSizesAndMultiCollection() {
863+
ctx := context.Background()
864+
865+
// Set up nodes
866+
nodeID1, nodeID2 := int64(1), int64(2)
867+
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
868+
NodeID: nodeID1,
869+
Address: "localhost",
870+
Hostname: "localhost",
871+
}))
872+
suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
873+
NodeID: nodeID2,
874+
Address: "localhost",
875+
Hostname: "localhost",
876+
}))
877+
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1)
878+
suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2)
879+
880+
// Create 3 collections
881+
collections := make([]int64, 0)
882+
replicas := make([]int64, 0)
883+
884+
for i := 1; i <= 3; i++ {
885+
cid := int64(i)
886+
replicaID := int64(100 + i)
887+
888+
collection := utils.CreateTestCollection(cid, int32(replicaID))
889+
collection.Status = querypb.LoadStatus_Loaded
890+
replica := utils.CreateTestReplica(replicaID, cid, []int64{})
891+
mutableReplica := replica.CopyForWrite()
892+
mutableReplica.AddRWNode(nodeID1)
893+
mutableReplica.AddRONode(nodeID2)
894+
895+
suite.checker.meta.CollectionManager.PutCollection(ctx, collection)
896+
suite.checker.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica())
897+
898+
collections = append(collections, cid)
899+
replicas = append(replicas, replicaID)
900+
}
901+
902+
// Mock target manager
903+
mockTargetManager := meta.NewMockTargetManager(suite.T())
904+
suite.checker.targetMgr = mockTargetManager
905+
906+
// All collections have same row count for simplicity
907+
mockTargetManager.EXPECT().GetCollectionRowCount(mock.Anything, mock.Anything, mock.Anything).Return(int64(100)).Maybe()
908+
mockTargetManager.EXPECT().IsCurrentTargetReady(mock.Anything, mock.Anything).Return(true).Maybe()
909+
mockTargetManager.EXPECT().IsNextTargetExist(mock.Anything, mock.Anything).Return(true).Maybe()
910+
mockTargetManager.EXPECT().IsCurrentTargetExist(mock.Anything, mock.Anything, mock.Anything).Return(true).Maybe()
911+
912+
// For each collection, return different segment plans
913+
suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.AnythingOfType("*meta.Replica")).RunAndReturn(
914+
func(ctx context.Context, replica *meta.Replica) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) {
915+
// Create 2 segment plans and 1 channel plan per replica
916+
collID := replica.GetCollectionID()
917+
segPlans := make([]balance.SegmentAssignPlan, 0)
918+
chanPlans := make([]balance.ChannelAssignPlan, 0)
919+
920+
// Create 2 segment plans
921+
for j := 1; j <= 2; j++ {
922+
segID := collID*100 + int64(j)
923+
segPlan := balance.SegmentAssignPlan{
924+
Segment: utils.CreateTestSegment(segID, collID, 1, 1, 1, "test-channel"),
925+
Replica: replica,
926+
From: nodeID1,
927+
To: nodeID2,
928+
}
929+
segPlans = append(segPlans, segPlan)
930+
}
931+
932+
// Create 1 channel plan
933+
chanPlan := balance.ChannelAssignPlan{
934+
Channel: &meta.DmChannel{
935+
VchannelInfo: &datapb.VchannelInfo{
936+
CollectionID: collID,
937+
ChannelName: "test-channel",
938+
},
939+
},
940+
Replica: replica,
941+
From: nodeID1,
942+
To: nodeID2,
943+
}
944+
chanPlans = append(chanPlans, chanPlan)
945+
946+
return segPlans, chanPlans
947+
}).Maybe()
948+
949+
// Add tasks to check batch size limits
950+
var addedTasks []task.Task
951+
suite.scheduler.ExpectedCalls = nil
952+
suite.scheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error {
953+
addedTasks = append(addedTasks, t)
954+
return nil
955+
}).Maybe()
956+
957+
// Test 1: Balance with multiple collections disabled
958+
paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true")
959+
paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "false")
960+
// Set batch sizes to large values to test single-collection case
961+
paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "10")
962+
paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "10")
963+
964+
// Reset test state
965+
suite.checker.stoppingBalanceCollectionsCurrentRound.Clear()
966+
suite.checker.autoBalanceTs = time.Time{} // Reset to trigger auto balance
967+
addedTasks = nil
968+
969+
// Run the Check method
970+
suite.checker.Check(ctx)
971+
972+
// Should have tasks for a single collection (2 segment tasks + 1 channel task)
973+
suite.Equal(3, len(addedTasks), "Should have tasks for a single collection when multiple collections balance is disabled")
974+
975+
// Test 2: Balance with multiple collections enabled
976+
paramtable.Get().Save(Params.QueryCoordCfg.EnableBalanceOnMultipleCollections.Key, "true")
977+
978+
// Reset test state
979+
suite.checker.autoBalanceTs = time.Time{}
980+
suite.checker.stoppingBalanceCollectionsCurrentRound.Clear()
981+
addedTasks = nil
982+
983+
// Run the Check method
984+
suite.checker.Check(ctx)
985+
986+
// Should have tasks for all collections (3 collections * (2 segment tasks + 1 channel task) = 9 tasks)
987+
suite.Equal(9, len(addedTasks), "Should have tasks for all collections when multiple collections balance is enabled")
988+
989+
// Test 3: Batch size limits
990+
paramtable.Get().Save(Params.QueryCoordCfg.BalanceSegmentBatchSize.Key, "2")
991+
paramtable.Get().Save(Params.QueryCoordCfg.BalanceChannelBatchSize.Key, "1")
992+
993+
// Reset test state
994+
suite.checker.stoppingBalanceCollectionsCurrentRound.Clear()
995+
addedTasks = nil
996+
997+
// Run the Check method
998+
suite.checker.Check(ctx)
999+
1000+
// Should respect batch size limits: 2 segment tasks + 1 channel task = 3 tasks
1001+
suite.Equal(3, len(addedTasks), "Should respect batch size limits")
1002+
1003+
// Count segment tasks and channel tasks
1004+
segmentTaskCount := 0
1005+
channelTaskCount := 0
1006+
for _, t := range addedTasks {
1007+
if _, ok := t.(*task.SegmentTask); ok {
1008+
segmentTaskCount++
1009+
} else {
1010+
channelTaskCount++
1011+
}
1012+
}
1013+
1014+
suite.LessOrEqual(segmentTaskCount, 2, "Should have at most 2 segment tasks due to batch size limit")
1015+
suite.LessOrEqual(channelTaskCount, 1, "Should have at most 1 channel task due to batch size limit")
1016+
}
1017+
8531018
func TestBalanceCheckerSuite(t *testing.T) {
8541019
suite.Run(t, new(BalanceCheckerTestSuite))
8551020
}

0 commit comments

Comments
 (0)