diff --git a/api/v1/api.go b/api/v1/api.go index 7b2981d4b0..a0d66a9001 100644 --- a/api/v1/api.go +++ b/api/v1/api.go @@ -19,7 +19,6 @@ import ( "io" "net/http" "strconv" - "sync/atomic" "github.com/gin-gonic/gin" "github.com/pingcap/log" @@ -27,6 +26,7 @@ import ( v2 "github.com/pingcap/ticdc/api/v2" "github.com/pingcap/ticdc/pkg/config" "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/node" "github.com/pingcap/ticdc/pkg/server" "github.com/pingcap/ticdc/pkg/util" "go.uber.org/zap" @@ -192,25 +192,41 @@ func (o *OpenAPIV1) rebalanceTables(c *gin.Context) { // drainCapture drains all tables from a capture. // Usage: // curl -X PUT http://127.0.0.1:8300/api/v1/captures/drain -// TODO: Implement this API in the future, currently it is a no-op. func (o *OpenAPIV1) drainCapture(c *gin.Context) { var req drainCaptureRequest if err := c.ShouldBindJSON(&req); err != nil { _ = c.Error(errors.ErrAPIInvalidParam.Wrap(err)) return } - drainCaptureCounter.Add(1) - if drainCaptureCounter.Load()%10 == 0 { - log.Info("api v1 drainCapture", zap.Any("captureID", req.CaptureID), zap.Int64("currentTableCount", drainCaptureCounter.Load())) - c.JSON(http.StatusAccepted, &drainCaptureResp{ - CurrentTableCount: 10, - }) - } else { - log.Info("api v1 drainCapture done", zap.Any("captureID", req.CaptureID), zap.Int64("currentTableCount", drainCaptureCounter.Load())) - c.JSON(http.StatusAccepted, &drainCaptureResp{ - CurrentTableCount: 0, - }) + + target := node.ID(req.CaptureID) + self, err := o.server.SelfInfo() + if err != nil { + _ = c.Error(err) + return + } + // For compatibility with old arch TiCDC, draining the current owner is not allowed. + if target == self.ID { + _ = c.Error(errors.ErrSchedulerRequestFailed.GenWithStackByArgs("cannot drain the owner")) + return } + + co, err := o.server.GetCoordinator() + if err != nil { + _ = c.Error(err) + return + } + + remaining, err := co.DrainNode(c.Request.Context(), target) + if err != nil { + _ = c.Error(err) + return + } + + log.Info("api v1 drain capture", + zap.String("captureID", req.CaptureID), + zap.Int("remaining", remaining)) + c.JSON(http.StatusAccepted, &drainCaptureResp{CurrentTableCount: remaining}) } func getV2ChangefeedConfig(changefeedConfig changefeedConfig) *v2.ChangefeedConfig { @@ -258,8 +274,6 @@ type drainCaptureRequest struct { CaptureID string `json:"capture_id"` } -var drainCaptureCounter atomic.Int64 - // drainCaptureResp is response for manual `DrainCapture` type drainCaptureResp struct { CurrentTableCount int `json:"current_table_count"` diff --git a/coordinator/controller.go b/coordinator/controller.go index 945dab0715..4598890d0f 100644 --- a/coordinator/controller.go +++ b/coordinator/controller.go @@ -89,6 +89,17 @@ type Controller struct { apiLock sync.RWMutex drainController *drain.Controller + + // drainSession is the in-memory drain state machine for v1 drain API. + // Only one drain session is allowed at a time. + drainSessionMu sync.Mutex + drainSession *drainSession + // drainClearState keeps a clearing tombstone after the active drain session + // is closed. It lets coordinator resend the clear request until all nodes + // confirm they have dropped the stale drain target for that epoch. + drainClearState *drainClearState + + dispatcherDrainEpoch uint64 } type changefeedChange struct { @@ -120,9 +131,9 @@ func NewController( pdClient pd.Client, ) *Controller { changefeedDB := changefeed.NewChangefeedDB(version) - messageCenter := appcontext.GetService[messaging.MessageCenter](appcontext.MessageCenter) oc := operator.NewOperatorController(selfNode, changefeedDB, backend, batchSize) + messageCenter := appcontext.GetService[messaging.MessageCenter](appcontext.MessageCenter) drainController := drain.NewController(messageCenter) c := &Controller{ version: version, @@ -152,17 +163,18 @@ func NewController( drainController, ), }), - eventCh: eventCh, - operatorController: oc, - messageCenter: messageCenter, - changefeedDB: changefeedDB, - nodeManager: appcontext.GetService[*watcher.NodeManager](watcher.NodeManagerName), - taskScheduler: threadpool.NewThreadPoolDefault(), - backend: backend, - changefeedChangeCh: changefeedChangeCh, - pdClient: pdClient, - pdClock: appcontext.GetService[pdutil.Clock](appcontext.DefaultPDClock), - drainController: drainController, + eventCh: eventCh, + operatorController: oc, + messageCenter: messageCenter, + changefeedDB: changefeedDB, + nodeManager: appcontext.GetService[*watcher.NodeManager](watcher.NodeManagerName), + taskScheduler: threadpool.NewThreadPoolDefault(), + backend: backend, + changefeedChangeCh: changefeedChangeCh, + pdClient: pdClient, + pdClock: appcontext.GetService[pdutil.Clock](appcontext.DefaultPDClock), + drainController: drainController, + dispatcherDrainEpoch: newDispatcherDrainEpochSeed(), } c.nodeChanged.changed = false @@ -309,13 +321,13 @@ func (c *Controller) onPeriodTask() { _ = c.messageCenter.SendCommand(req) } - if !c.initialized.Load() { - return - } - + // Drain liveness transitions and drain-target broadcasts are retry-based + // control loops. Drive them from the periodic task so they keep progressing + // even when no fresh heartbeat or node-change event arrives. c.drainController.AdvanceLiveness(func(id node.ID) bool { return len(c.changefeedDB.GetByNodeID(id)) == 0 && !c.operatorController.HasOperatorInvolvingNode(id) }) + c.maybeBroadcastDispatcherDrainTarget(false) } func (c *Controller) onMessage(ctx context.Context, msg *messaging.TargetMessage) { @@ -330,6 +342,7 @@ func (c *Controller) onMessage(ctx context.Context, msg *messaging.TargetMessage case messaging.TypeNodeHeartbeatRequest: req := msg.Message[0].(*heartbeatpb.NodeHeartbeat) c.drainController.ObserveHeartbeat(msg.From, req) + c.observeDispatcherDrainTargetHeartbeat(msg.From, req) case messaging.TypeSetNodeLivenessResponse: req := msg.Message[0].(*heartbeatpb.SetNodeLivenessResponse) c.drainController.ObserveSetNodeLivenessResponse(msg.From, req) @@ -405,6 +418,7 @@ func (c *Controller) onNodeChanged(ctx context.Context) { zap.Any("targetNode", req.To), zap.Error(err)) } } + c.maybeBroadcastDispatcherDrainTarget(true) c.handleBootstrapResponses(ctx, responses) } @@ -897,6 +911,14 @@ func (c *Controller) getChangefeed(id common.ChangeFeedID) *changefeed.Changefee // RemoveNode is called when a node is removed func (c *Controller) RemoveNode(id node.ID) { c.operatorController.OnNodeRemoved(id) + // Membership removal is the only authoritative signal that this node will + // never acknowledge the current drain epoch again. Clear every drain-side + // in-memory reference immediately to avoid leaking a stuck drain session. + target, epoch, ok := c.getDispatcherDrainTarget() + if ok && target == id { + c.clearDispatcherDrainTarget(id, epoch) + } + c.observeDispatcherDrainTargetClearNodeRemoved(id) c.drainController.RemoveNode(id) } diff --git a/coordinator/controller_drain.go b/coordinator/controller_drain.go new file mode 100644 index 0000000000..0425adb0d4 --- /dev/null +++ b/coordinator/controller_drain.go @@ -0,0 +1,547 @@ +package coordinator + +import ( + "context" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/ticdc/coordinator/changefeed" + "github.com/pingcap/ticdc/coordinator/drain" + "github.com/pingcap/ticdc/heartbeatpb" + "github.com/pingcap/ticdc/pkg/common" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/messaging" + "github.com/pingcap/ticdc/pkg/node" + "go.uber.org/zap" +) + +const dispatcherDrainTargetResendIntvl = 5 * time.Second + +type drainSession struct { + target node.ID + epoch uint64 + + // trackedChangefeeds is the frozen set of running changefeeds that were + // relevant to this drain session when it started. Drain-aware scheduling must + // prevent new work from landing on the target, so repeated API polls can + // reuse this set instead of rescanning all replicating changefeeds. + trackedChangefeeds []common.ChangeFeedID + + // pendingStatus is the frozen baseline of running changefeeds that had not + // yet acknowledged this drain epoch when the session was created. + // The set only shrinks over time during one drain session. + pendingStatus map[common.ChangeFeedID]struct{} + + dirty bool + lastSent time.Time +} + +type drainClearState struct { + // target is the node whose drain target is being cleared at this epoch. + target node.ID + epoch uint64 + + // pendingNodes tracks manager-level clear acknowledgements from nodes that + // were alive when the clear was issued. + pendingNodes map[node.ID]struct{} + + // dirty/lastSent follow the same resend contract as an active drain target. + // Clear is not fire-and-forget because losing the empty-target broadcast + // would leave some nodes with a stale local drain target indefinitely. + dirty bool + lastSent time.Time +} + +// newDispatcherDrainEpochSeed creates a non-zero epoch seed for this process lifetime. +// It prevents immediate epoch reuse after coordinator restarts. +func newDispatcherDrainEpochSeed() uint64 { + epoch := uint64(time.Now().UnixNano()) + if epoch == 0 { + return 1 + } + return epoch +} + +// DrainNode starts or continues draining one target node for the v1 drain API. +// It ensures an active target epoch, broadcasts the drain target, requests liveness +// transition, then evaluates a one-shot drain observation. +// +// Drain completion requires the target to reach STOPPING after DRAINING and for +// all maintainer-side drain work to converge to zero. +// The returned remaining is guaranteed to be non-zero until completion is proven. +func (c *Controller) DrainNode(_ context.Context, target node.ID) (int, error) { + if c.nodeManager.GetNodeInfo(target) == nil { + return 0, errors.ErrCaptureNotExist.GenWithStackByArgs(target) + } + // Drain completion relies on in-memory changefeed state built by coordinator bootstrap. + // Before bootstrap is complete, always return non-zero remaining to avoid premature zero. + if c.initialized == nil || !c.initialized.Load() { + log.Info("drain waiting for coordinator bootstrap", + zap.Stringer("targetNodeID", target)) + return 1, nil + } + targetEpoch, err := c.ensureDispatcherDrainTarget(target) + if err != nil { + return 0, err + } + c.maybeBroadcastDispatcherDrainTarget(true) + + c.drainController.RequestDrain(target) + + observation := c.observeDrainNode(target, targetEpoch) + completionProven := isDrainCompletionProven( + observation.nodeState, + observation.drainingObserved, + observation.stoppingObserved, + observation.remaining, + ) + + // drain API must not return 0 until drain completion is proven. + if completionProven { + c.clearDispatcherDrainTarget(target, targetEpoch) + return 0, nil + } + + log.Info("drain completion not yet proven", + zap.Stringer("targetNodeID", target), + zap.Uint64("targetEpoch", targetEpoch), + zap.String("nodeState", drainStateString(observation.nodeState)), + zap.Bool("drainingObserved", observation.drainingObserved), + zap.Bool("stoppingObserved", observation.stoppingObserved), + zap.Int("maintainersOnTarget", observation.maintainersOnTarget), + zap.Int("inflightOpsInvolvingTarget", observation.inflightOpsInvolvingTarget), + zap.Int("dispatcherCountOnTarget", observation.dispatcherCountOnTarget), + zap.Int("targetInflightDrainMoveCount", observation.targetInflightDrainMoveCount), + zap.Int("pendingStatusCount", observation.pendingStatusCount), + zap.Int("remaining", observation.remaining)) + return ensureDrainRemainingNonZero(observation.remaining), nil +} + +type drainNodeObservation struct { + // drainNodeObservation captures all one-shot completion signals used by DrainNode. + // maintainersOnTarget is the number of maintainers still hosted on the target node. + maintainersOnTarget int + // inflightOpsInvolvingTarget is the number of operators that still involve the target node. + inflightOpsInvolvingTarget int + // dispatcherCountOnTarget is the sum of maintainer-reported dispatchers still on target. + dispatcherCountOnTarget int + // targetInflightDrainMoveCount is the sum of maintainer-reported dispatcher + // move operators still draining work away from the target node. + targetInflightDrainMoveCount int + // pendingStatusCount is the number of running changefeeds not converged to the active target epoch. + pendingStatusCount int + // remaining is the max of all workload dimensions used by drain completion gating. + remaining int + // nodeState is the drain controller state of the target node. + nodeState drain.State + // drainingObserved indicates DRAINING has been observed for this target. + drainingObserved bool + // stoppingObserved indicates STOPPING has been observed for this target. + stoppingObserved bool +} + +func (c *Controller) observeDrainNode(target node.ID, epoch uint64) drainNodeObservation { + observation := drainNodeObservation{ + maintainersOnTarget: len(c.changefeedDB.GetByNodeID(target)), + inflightOpsInvolvingTarget: c.operatorController.CountOperatorsInvolvingNode(target), + } + observation.dispatcherCountOnTarget, observation.targetInflightDrainMoveCount = c.aggregateDrainTargetProgress(target, epoch) + observation.pendingStatusCount = c.collectDrainPendingStatus(target, epoch) + observation.remaining = drainRemainingEstimate( + observation.maintainersOnTarget, + observation.inflightOpsInvolvingTarget, + observation.dispatcherCountOnTarget, + observation.targetInflightDrainMoveCount, + observation.pendingStatusCount, + ) + + _, observation.drainingObserved, observation.stoppingObserved = c.drainController.GetStatus(target) + observation.nodeState = c.drainController.GetState(target) + return observation +} + +// collectDrainPendingStatus advances the frozen pending baseline for the active +// drain session and returns how many changefeeds have not yet reported the +// active drain target epoch. +func (c *Controller) collectDrainPendingStatus(target node.ID, epoch uint64) int { + if target.IsEmpty() || epoch == 0 { + return 0 + } + + c.drainSessionMu.Lock() + defer c.drainSessionMu.Unlock() + + session := c.drainSession + if session == nil || session.target != target || session.epoch != epoch { + return 0 + } + + if len(session.pendingStatus) == 0 { + return 0 + } + + for id := range session.pendingStatus { + cf := c.changefeedDB.GetByID(id) + if !isDrainStatusConvergenceRelevant(cf) { + // Removed or non-running changefeeds should not block drain status convergence. + delete(session.pendingStatus, id) + continue + } + status := cf.GetStatus() + if status != nil { + progress := status.GetDrainProgress() + if progress != nil && progress.GetTargetNodeId() == target.String() && progress.GetTargetEpoch() == epoch { + delete(session.pendingStatus, id) + } + } + } + + return len(session.pendingStatus) +} + +// snapshotDrainTrackedChangefeeds captures the running changefeeds that are +// relevant to this drain session. The returned slice is frozen at session start +// so repeated API polls do not need to rescan the full replicating set. +func (c *Controller) snapshotDrainTrackedChangefeeds() []common.ChangeFeedID { + cfs := c.changefeedDB.GetReplicating() + snapshot := make([]common.ChangeFeedID, 0, len(cfs)) + for _, cf := range cfs { + if !isDrainStatusConvergenceRelevant(cf) { + continue + } + snapshot = append(snapshot, cf.ID) + } + return snapshot +} + +// isDrainStatusConvergenceRelevant returns whether a changefeed should +// participate in drain status convergence checks. +func isDrainStatusConvergenceRelevant(cf *changefeed.Changefeed) bool { + if cf == nil { + return false + } + info := cf.GetInfo() + return info != nil && shouldRunChangefeed(info.State) +} + +// ensureDispatcherDrainTarget creates or reuses the single active drain +// session. It rejects requests for a different target while one is active. +func (c *Controller) ensureDispatcherDrainTarget(target node.ID) (uint64, error) { + c.drainSessionMu.Lock() + defer c.drainSessionMu.Unlock() + + if c.drainSession != nil { + if c.drainSession.target == target { + return c.drainSession.epoch, nil + } + return 0, errors.ErrSchedulerRequestFailed.GenWithStackByArgs( + "drain already in progress on capture " + c.drainSession.target.String()) + } + + c.dispatcherDrainEpoch++ + if c.dispatcherDrainEpoch == 0 { + c.dispatcherDrainEpoch = 1 + } + + trackedChangefeeds := c.snapshotDrainTrackedChangefeeds() + pendingStatus := make(map[common.ChangeFeedID]struct{}, len(trackedChangefeeds)) + for _, id := range trackedChangefeeds { + pendingStatus[id] = struct{}{} + } + c.drainClearState = nil + c.drainSession = &drainSession{ + target: target, + epoch: c.dispatcherDrainEpoch, + trackedChangefeeds: trackedChangefeeds, + pendingStatus: pendingStatus, + dirty: true, + } + log.Info("dispatcher drain target activated", + zap.Stringer("targetNodeID", target), + zap.Uint64("targetEpoch", c.dispatcherDrainEpoch)) + return c.dispatcherDrainEpoch, nil +} + +// getDispatcherDrainTarget returns the current active drain target and epoch. +// The boolean return value indicates whether a session exists. +func (c *Controller) getDispatcherDrainTarget() (node.ID, uint64, bool) { + c.drainSessionMu.Lock() + defer c.drainSessionMu.Unlock() + if c.drainSession == nil { + return "", 0, false + } + return c.drainSession.target, c.drainSession.epoch, true +} + +// clearDispatcherDrainTarget closes the matching active drain session and +// broadcasts an empty target at the same epoch to clear stale local targets. +func (c *Controller) clearDispatcherDrainTarget(target node.ID, epoch uint64) { + c.drainSessionMu.Lock() + if c.drainSession == nil || c.drainSession.target != target || c.drainSession.epoch != epoch { + c.drainSessionMu.Unlock() + return + } + pendingNodes := make(map[node.ID]struct{}, len(c.nodeManager.GetAliveNodeIDs())) + for _, id := range c.nodeManager.GetAliveNodeIDs() { + pendingNodes[id] = struct{}{} + } + c.drainSession = nil + if len(pendingNodes) == 0 { + c.drainClearState = nil + } else { + // Freeze the nodes that must observe this clear. New nodes do not need + // to ack an old clear because they bootstrap from the current coordinator state. + c.drainClearState = &drainClearState{ + target: target, + epoch: epoch, + pendingNodes: pendingNodes, + dirty: true, + } + } + c.drainSessionMu.Unlock() + + log.Info("dispatcher drain target cleared", + zap.Stringer("targetNodeID", target), + zap.Uint64("targetEpoch", epoch)) + c.maybeBroadcastDispatcherDrainTarget(true) +} + +// maybeBroadcastDispatcherDrainTarget sends the active drain target or pending +// clear tombstone when forced, dirty, or periodic resend is due. +func (c *Controller) maybeBroadcastDispatcherDrainTarget(force bool) { + c.drainSessionMu.Lock() + var ( + target node.ID + epoch uint64 + needSend bool + sendingClear bool + ) + switch { + case c.drainSession != nil: + target = c.drainSession.target + epoch = c.drainSession.epoch + needSend = force || + c.drainSession.dirty || + time.Since(c.drainSession.lastSent) >= dispatcherDrainTargetResendIntvl + case c.drainClearState != nil: + epoch = c.drainClearState.epoch + sendingClear = true + needSend = force || + c.drainClearState.dirty || + time.Since(c.drainClearState.lastSent) >= dispatcherDrainTargetResendIntvl + default: + c.drainSessionMu.Unlock() + return + } + if !needSend { + c.drainSessionMu.Unlock() + return + } + c.drainSessionMu.Unlock() + + c.broadcastDispatcherDrainTarget(target, epoch) + + c.drainSessionMu.Lock() + if sendingClear { + if c.drainClearState != nil && c.drainClearState.epoch == epoch { + c.drainClearState.dirty = false + c.drainClearState.lastSent = time.Now() + } + c.drainSessionMu.Unlock() + return + } + if c.drainSession != nil && c.drainSession.target == target && c.drainSession.epoch == epoch { + c.drainSession.dirty = false + c.drainSession.lastSent = time.Now() + } + c.drainSessionMu.Unlock() +} + +// broadcastDispatcherDrainTarget sends SetDispatcherDrainTargetRequest to all +// currently alive nodes as a best-effort broadcast. +func (c *Controller) broadcastDispatcherDrainTarget(target node.ID, epoch uint64) { + if epoch == 0 || c.messageCenter == nil || c.nodeManager == nil { + return + } + + req := &heartbeatpb.SetDispatcherDrainTargetRequest{ + TargetNodeId: target.String(), + TargetEpoch: epoch, + } + for _, id := range c.nodeManager.GetAliveNodeIDs() { + msg := messaging.NewSingleTargetMessage(id, messaging.MaintainerManagerTopic, req) + if err := c.messageCenter.SendCommand(msg); err != nil { + log.Warn("send set dispatcher drain target command failed", + zap.Stringer("nodeID", id), + zap.Stringer("targetNodeID", target), + zap.Uint64("targetEpoch", epoch), + zap.Error(err)) + } + } +} + +func (c *Controller) observeDispatcherDrainTargetHeartbeat(from node.ID, hb *heartbeatpb.NodeHeartbeat) { + if hb == nil { + return + } + + c.drainSessionMu.Lock() + defer c.drainSessionMu.Unlock() + + clearState := c.drainClearState + if clearState == nil { + return + } + if _, ok := clearState.pendingNodes[from]; !ok { + return + } + + hbEpoch := hb.GetDispatcherDrainTargetEpoch() + hbTarget := node.ID(hb.GetDispatcherDrainTargetNodeId()) + if hbEpoch < clearState.epoch { + // Older heartbeat cannot prove this node has observed the clear. + return + } + if hbEpoch == clearState.epoch && !hbTarget.IsEmpty() { + // Same epoch still carrying a target means the old target is still cached locally. + return + } + + // Ack is accepted when the node reports either: + // 1) the same epoch with an empty target, or + // 2) any newer epoch, which necessarily supersedes the old clear. + delete(clearState.pendingNodes, from) + if len(clearState.pendingNodes) != 0 { + return + } + + log.Info("dispatcher drain clear acknowledged by all nodes", + zap.Stringer("targetNodeID", clearState.target), + zap.Uint64("targetEpoch", clearState.epoch)) + c.drainClearState = nil +} + +func (c *Controller) observeDispatcherDrainTargetClearNodeRemoved(id node.ID) { + c.drainSessionMu.Lock() + defer c.drainSessionMu.Unlock() + + clearState := c.drainClearState + if clearState == nil { + return + } + if _, ok := clearState.pendingNodes[id]; !ok { + return + } + + // A removed node can no longer ack. Dropping it from the pending set keeps + // the clear tombstone from leaking when membership changes mid-clear. + delete(clearState.pendingNodes, id) + if len(clearState.pendingNodes) != 0 { + return + } + + log.Info("dispatcher drain clear completed after pending nodes were removed", + zap.Stringer("targetNodeID", clearState.target), + zap.Uint64("targetEpoch", clearState.epoch)) + c.drainClearState = nil +} + +// aggregateDrainTargetProgress sums per-changefeed drain progress counters +// reported for the given drain target epoch. +func (c *Controller) aggregateDrainTargetProgress(target node.ID, epoch uint64) (dispatcherCount int, inflightMoveCount int) { + if target.IsEmpty() || epoch == 0 { + return 0, 0 + } + + targetID := target.String() + c.drainSessionMu.Lock() + session := c.drainSession + if session == nil || session.target != target || session.epoch != epoch { + c.drainSessionMu.Unlock() + return 0, 0 + } + tracked := session.trackedChangefeeds + c.drainSessionMu.Unlock() + + for _, id := range tracked { + cf := c.changefeedDB.GetByID(id) + if cf == nil { + continue + } + if !c.changefeedDB.IsReplicating(cf) { + continue + } + status := cf.GetStatus() + if status == nil { + continue + } + progress := status.GetDrainProgress() + if progress == nil || progress.GetTargetNodeId() != targetID || progress.GetTargetEpoch() != epoch { + continue + } + dispatcherCount += int(progress.GetTargetDispatcherCount()) + inflightMoveCount += int(progress.GetTargetInflightDrainMoveCount()) + } + return dispatcherCount, inflightMoveCount +} + +func drainStateString(state drain.State) string { + switch state { + case drain.StateAlive: + return "alive" + case drain.StateDraining: + return "draining" + case drain.StateStopping: + return "stopping" + case drain.StateUnknown: + return "unknown" + default: + return "unspecified" + } +} + +// isDrainCompletionProven checks whether drain completion can be safely concluded. +// Returning true is intentionally strict because v1 drain API must avoid premature zero remaining. +func isDrainCompletionProven( + nodeState drain.State, + drainingObserved bool, + stoppingObserved bool, + remaining int, +) bool { + if nodeState == drain.StateUnknown || !drainingObserved { + return false + } + return stoppingObserved && remaining == 0 +} + +// drainRemainingEstimate uses the larger workload dimension to avoid obvious double counting. +func drainRemainingEstimate( + maintainersOnTarget int, + inflightOpsInvolvingTarget int, + dispatcherCountOnTarget int, + targetInflightDrainMoveCount int, + pendingStatusCount int, +) int { + remaining := maintainersOnTarget + if inflightOpsInvolvingTarget > remaining { + remaining = inflightOpsInvolvingTarget + } + if dispatcherCountOnTarget > remaining { + remaining = dispatcherCountOnTarget + } + if targetInflightDrainMoveCount > remaining { + remaining = targetInflightDrainMoveCount + } + if pendingStatusCount > remaining { + remaining = pendingStatusCount + } + return remaining +} + +// ensureDrainRemainingNonZero keeps v1 compatibility before completion is proven. +func ensureDrainRemainingNonZero(remaining int) int { + if remaining == 0 { + return 1 + } + return remaining +} diff --git a/coordinator/controller_drain_test.go b/coordinator/controller_drain_test.go new file mode 100644 index 0000000000..868719cd09 --- /dev/null +++ b/coordinator/controller_drain_test.go @@ -0,0 +1,340 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. +package coordinator + +import ( + "context" + "testing" + "time" + + "github.com/pingcap/ticdc/coordinator/changefeed" + "github.com/pingcap/ticdc/coordinator/drain" + "github.com/pingcap/ticdc/coordinator/operator" + "github.com/pingcap/ticdc/heartbeatpb" + "github.com/pingcap/ticdc/pkg/common" + appcontext "github.com/pingcap/ticdc/pkg/common/context" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/messaging" + "github.com/pingcap/ticdc/pkg/node" + "github.com/pingcap/ticdc/server/watcher" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +func newDrainTestController(t *testing.T) (*Controller, *drain.Controller, node.ID) { + t.Helper() + + mc := messaging.NewMockMessageCenter() + appcontext.SetService(appcontext.MessageCenter, mc) + + nodeManager := watcher.NewNodeManager(nil, nil) + appcontext.SetService(watcher.NodeManagerName, nodeManager) + + drainController := drain.NewController(mc) + db := changefeed.NewChangefeedDB(1) + selfNode := &node.Info{ID: node.ID("coordinator")} + oc := operator.NewOperatorController(selfNode, db, nil, 10) + + target := node.ID("target") + nodeManager.GetAliveNodes()[target] = &node.Info{ID: target} + + c := &Controller{ + nodeManager: nodeManager, + changefeedDB: db, + operatorController: oc, + drainController: drainController, + messageCenter: mc, + initialized: atomic.NewBool(true), + } + return c, drainController, target +} + +func TestDrainNodeReturnsNonZeroBeforeCoordinatorBootstrap(t *testing.T) { + c, _, target := newDrainTestController(t) + c.initialized.Store(false) + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + drainTarget, epoch, ok := c.getDispatcherDrainTarget() + require.False(t, ok) + require.Equal(t, node.ID(""), drainTarget) + require.Equal(t, uint64(0), epoch) +} + +func TestDrainNodeReturnsNonZeroBeforeStoppingObserved(t *testing.T) { + c, _, target := newDrainTestController(t) + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) +} + +func TestDrainNodeCompletesAfterCompletionObserved(t *testing.T) { + c, drainController, target := newDrainTestController(t) + cf := addRunningChangefeed(c, "cf1", node.ID("other"), 100) + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + _, epoch, ok := c.getDispatcherDrainTarget() + require.True(t, ok) + setChangefeedDrainStatus(cf, target, epoch, 0, 0) + setTargetStoppingObserved(drainController, target) + + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 0, remaining) +} + +func TestDrainNodeDispatcherCountBlocksCompletion(t *testing.T) { + c, drainController, target := newDrainTestController(t) + cf := addRunningChangefeed(c, "cf1", node.ID("other"), 100) + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + _, epoch, ok := c.getDispatcherDrainTarget() + require.True(t, ok) + setChangefeedDrainStatus(cf, target, epoch, 2, 0) + setTargetStoppingObserved(drainController, target) + + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 2, remaining) + + setChangefeedDrainStatus(cf, target, epoch, 0, 0) + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 0, remaining) +} + +func TestDrainNodePendingStatusConvergenceBlocksCompletion(t *testing.T) { + c, drainController, target := newDrainTestController(t) + cf := addRunningChangefeed(c, "cf1", node.ID("other"), 100) + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + _, epoch, ok := c.getDispatcherDrainTarget() + require.True(t, ok) + setTargetStoppingObserved(drainController, target) + + // Status convergence must finish before drain can complete. + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + setChangefeedDrainStatus(cf, target, epoch, 0, 0) + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 0, remaining) +} + +func TestDrainNodeInflightDrainMovesBlockCompletion(t *testing.T) { + c, drainController, target := newDrainTestController(t) + cf := addRunningChangefeed(c, "cf1", node.ID("other"), 100) + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + _, epoch, ok := c.getDispatcherDrainTarget() + require.True(t, ok) + setChangefeedDrainStatus(cf, target, epoch, 0, 1) + setTargetStoppingObserved(drainController, target) + + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + setChangefeedDrainStatus(cf, target, epoch, 0, 0) + remaining, err = c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 0, remaining) +} + +func TestDrainNodeRejectConcurrentDifferentDrainTarget(t *testing.T) { + c, _, target := newDrainTestController(t) + other := node.ID("other") + c.nodeManager.GetAliveNodes()[other] = &node.Info{ID: other} + + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + _, err = c.DrainNode(context.Background(), other) + require.Error(t, err) + require.Contains(t, err.Error(), "drain already in progress") +} + +func TestRemoveNodeClearsActiveDrainTarget(t *testing.T) { + c, _, target := newDrainTestController(t) + remaining, err := c.DrainNode(context.Background(), target) + require.NoError(t, err) + require.Equal(t, 1, remaining) + + _, _, ok := c.getDispatcherDrainTarget() + require.True(t, ok) + + c.RemoveNode(target) + _, _, ok = c.getDispatcherDrainTarget() + require.False(t, ok) +} + +func TestClearDispatcherDrainTargetTracksNodeHeartbeatAck(t *testing.T) { + c, _, target := newDrainTestController(t) + other := node.ID("other") + c.nodeManager.GetAliveNodes()[other] = &node.Info{ID: other} + + epoch, err := c.ensureDispatcherDrainTarget(target) + require.NoError(t, err) + + mc := c.messageCenter.(interface { + GetMessageChannel() chan *messaging.TargetMessage + }) + drainMessageChannel(mc.GetMessageChannel()) + + c.clearDispatcherDrainTarget(target, epoch) + require.Nil(t, c.drainSession) + require.NotNil(t, c.drainClearState) + require.Len(t, c.drainClearState.pendingNodes, 2) + + c.observeDispatcherDrainTargetHeartbeat(target, &heartbeatpb.NodeHeartbeat{ + DispatcherDrainTargetEpoch: epoch, + }) + require.NotNil(t, c.drainClearState) + require.Len(t, c.drainClearState.pendingNodes, 1) + + c.observeDispatcherDrainTargetHeartbeat(other, &heartbeatpb.NodeHeartbeat{ + DispatcherDrainTargetEpoch: epoch, + }) + require.Nil(t, c.drainClearState) +} + +func TestClearDispatcherDrainTargetResendsUntilAck(t *testing.T) { + c, _, target := newDrainTestController(t) + epoch, err := c.ensureDispatcherDrainTarget(target) + require.NoError(t, err) + + mc := c.messageCenter.(interface { + GetMessageChannel() chan *messaging.TargetMessage + }) + drainMessageChannel(mc.GetMessageChannel()) + + c.clearDispatcherDrainTarget(target, epoch) + drainMessageChannel(mc.GetMessageChannel()) + + require.NotNil(t, c.drainClearState) + c.drainClearState.lastSent = time.Now().Add(-dispatcherDrainTargetResendIntvl - time.Second) + c.maybeBroadcastDispatcherDrainTarget(false) + + msg := <-mc.GetMessageChannel() + require.Equal(t, messaging.TypeSetDispatcherDrainTargetRequest, msg.Type) + req := msg.Message[0].(*heartbeatpb.SetDispatcherDrainTargetRequest) + require.Equal(t, "", req.TargetNodeId) + require.Equal(t, epoch, req.TargetEpoch) +} + +func TestHigherEpochHeartbeatAcknowledgesPendingClear(t *testing.T) { + c, _, target := newDrainTestController(t) + epoch, err := c.ensureDispatcherDrainTarget(target) + require.NoError(t, err) + + c.clearDispatcherDrainTarget(target, epoch) + require.NotNil(t, c.drainClearState) + + c.observeDispatcherDrainTargetHeartbeat(target, &heartbeatpb.NodeHeartbeat{ + DispatcherDrainTargetNodeId: "next-target", + DispatcherDrainTargetEpoch: epoch + 1, + }) + require.Nil(t, c.drainClearState) +} + +func TestRemoveNodeAcknowledgesPendingClear(t *testing.T) { + c, _, target := newDrainTestController(t) + other := node.ID("other") + c.nodeManager.GetAliveNodes()[other] = &node.Info{ID: other} + + epoch, err := c.ensureDispatcherDrainTarget(target) + require.NoError(t, err) + + c.clearDispatcherDrainTarget(target, epoch) + require.NotNil(t, c.drainClearState) + require.Len(t, c.drainClearState.pendingNodes, 2) + + c.RemoveNode(other) + require.NotNil(t, c.drainClearState) + require.Len(t, c.drainClearState.pendingNodes, 1) + + c.RemoveNode(target) + require.Nil(t, c.drainClearState) +} + +func setTargetStoppingObserved( + drainController *drain.Controller, + target node.ID, +) { + resp := &heartbeatpb.SetNodeLivenessResponse{ + Applied: heartbeatpb.NodeLiveness_STOPPING, + NodeEpoch: 1, + } + drainController.ObserveSetNodeLivenessResponse(target, resp) +} + +func drainMessageChannel(ch chan *messaging.TargetMessage) { + for { + select { + case <-ch: + default: + return + } + } +} + +func addRunningChangefeed(c *Controller, name string, nodeID node.ID, checkpointTs uint64) *changefeed.Changefeed { + cfID := common.NewChangeFeedIDWithName(name, common.DefaultKeyspaceName) + info := &config.ChangeFeedInfo{ + ChangefeedID: cfID, + SinkURI: "blackhole://", + Config: config.GetDefaultReplicaConfig(), + State: config.StateNormal, + } + cf := changefeed.NewChangefeed(cfID, info, checkpointTs, false) + c.changefeedDB.AddReplicatingMaintainer(cf, nodeID) + return cf +} + +func setChangefeedDrainStatus( + cf *changefeed.Changefeed, + target node.ID, + epoch uint64, + dispatcherCount uint32, + inflightDrainMoveCount uint32, +) { + status := cf.GetStatus() + _, _, _ = cf.ForceUpdateStatus(&heartbeatpb.MaintainerStatus{ + ChangefeedID: cf.ID.ToPB(), + CheckpointTs: status.CheckpointTs, + DrainProgress: &heartbeatpb.DrainProgress{ + TargetNodeId: target.String(), + TargetEpoch: epoch, + TargetDispatcherCount: dispatcherCount, + TargetInflightDrainMoveCount: inflightDrainMoveCount, + }, + }) +} diff --git a/coordinator/coordinator.go b/coordinator/coordinator.go index eb7fb67c9c..e22c57043a 100644 --- a/coordinator/coordinator.go +++ b/coordinator/coordinator.go @@ -410,6 +410,10 @@ func (c *coordinator) GetChangefeed(ctx context.Context, changefeedDisplayName c return c.controller.GetChangefeed(ctx, changefeedDisplayName) } +func (c *coordinator) DrainNode(ctx context.Context, target node.ID) (int, error) { + return c.controller.DrainNode(ctx, target) +} + func (c *coordinator) Initialized() bool { return c.controller.initialized.Load() } diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index 1c4dc0c173..c0e337c460 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -178,6 +178,9 @@ func (m *mockMaintainerManager) handleMessage(msg *messaging.TargetMessage) { m.sendMessages(response) } } + case messaging.TypeSetDispatcherDrainTargetRequest: + // mock maintainer manager does not emulate dispatcher scheduling. + // this command is accepted and ignored for coordinator tests. } } @@ -198,7 +201,8 @@ func (m *mockMaintainerManager) recvMessages(ctx context.Context, msg *messaging // receive message from coordinator case messaging.TypeAddMaintainerRequest, messaging.TypeRemoveMaintainerRequest: fallthrough - case messaging.TypeCoordinatorBootstrapRequest: + case messaging.TypeCoordinatorBootstrapRequest, + messaging.TypeSetDispatcherDrainTargetRequest: select { case <-ctx.Done(): return ctx.Err() diff --git a/coordinator/operator/operator_controller.go b/coordinator/operator/operator_controller.go index d6fd7a3095..a8b5051d32 100644 --- a/coordinator/operator/operator_controller.go +++ b/coordinator/operator/operator_controller.go @@ -232,19 +232,27 @@ func (oc *Controller) OperatorSize() int { return len(oc.operators) } -// HasOperatorInvolvingNode returns true if any in-flight operator affects n. -func (oc *Controller) HasOperatorInvolvingNode(n node.ID) bool { +// CountOperatorsInvolvingNode returns the number of in-flight operators whose +// affected nodes include n. +func (oc *Controller) CountOperatorsInvolvingNode(n node.ID) int { oc.mu.RLock() defer oc.mu.RUnlock() + count := 0 for _, op := range oc.operators { for _, affected := range op.OP.AffectedNodes() { if affected == n { - return true + count++ + break } } } - return false + return count +} + +// HasOperatorInvolvingNode returns true if any in-flight operator affects n. +func (oc *Controller) HasOperatorInvolvingNode(n node.ID) bool { + return oc.CountOperatorsInvolvingNode(n) > 0 } // pollQueueingOperator returns the operator need to be executed, diff --git a/coordinator/operator/operator_controller_test.go b/coordinator/operator/operator_controller_test.go index 5afc45a924..18337cda84 100644 --- a/coordinator/operator/operator_controller_test.go +++ b/coordinator/operator/operator_controller_test.go @@ -123,6 +123,8 @@ func TestController_HasOperatorInvolvingNode(t *testing.T) { require.True(t, oc.AddOperator(NewAddMaintainerOperator(changefeedDB, cf, target.ID))) + require.Equal(t, 1, oc.CountOperatorsInvolvingNode(target.ID)) + require.Equal(t, 0, oc.CountOperatorsInvolvingNode("n3")) require.True(t, oc.HasOperatorInvolvingNode(target.ID)) require.False(t, oc.HasOperatorInvolvingNode("n3")) } diff --git a/pkg/server/coordinator.go b/pkg/server/coordinator.go index 299dfe2938..50da4df06c 100644 --- a/pkg/server/coordinator.go +++ b/pkg/server/coordinator.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/ticdc/pkg/common" "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/node" ) // Coordinator is the master of the ticdc cluster, @@ -48,5 +49,8 @@ type Coordinator interface { // and coordinator will update the changefeed status after receiving the resolved ts from log coordinator. RequestResolvedTsFromLogCoordinator(ctx context.Context, changefeedDisplayName common.ChangeFeedDisplayName) + // DrainNode requests draining on the target node and returns remaining work for drain completion. + DrainNode(ctx context.Context, target node.ID) (int, error) + Initialized() bool }