Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions lib/config/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package config
const (
// LocationLabelName indicates the label name that decides the location of TiProxy and backends.
// We use `zone` because the follower read in TiDB also uses `zone` to decide location.
LocationLabelName = "zone"
KeyspaceLabelName = "keyspace"
CidrLabelName = "cidr"
LocationLabelName = "zone"
KeyspaceLabelName = "keyspace"
CidrLabelName = "cidr"
TiProxyPortLabelName = "tiproxy-port"
)

func (cfg *Config) GetLocation() string {
Expand Down
2 changes: 2 additions & 0 deletions pkg/balance/router/backend_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import "net"
type ClientInfo struct {
ClientAddr net.Addr
ProxyAddr net.Addr
// ListenerPort is the SQL listener port that accepted the connection.
ListenerPort string
// TODO: username, database, etc.
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/balance/router/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const (
MatchClientCIDR
// Match connections based on proxy CIDR. If proxy-protocol is disabled, route by the client CIDR.
MatchProxyCIDR
// Match connections based on the local SQL listener port.
MatchPort
)

var _ ConnEventReceiver = (*Group)(nil)
Expand Down Expand Up @@ -105,7 +107,7 @@ func (g *Group) Match(clientInfo ClientInfo) bool {

func (g *Group) EqualValues(values []string) bool {
switch g.matchType {
case MatchClientCIDR, MatchProxyCIDR:
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
if len(g.values) != len(values) {
return false
}
Expand All @@ -124,7 +126,7 @@ func (g *Group) EqualValues(values []string) bool {
// E.g. enable public endpoint (3 cidrs) -> enable private endpoint (6 cidrs) -> disable public endpoint (3 cidrs).
func (g *Group) Intersect(values []string) bool {
switch g.matchType {
case MatchClientCIDR, MatchProxyCIDR:
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
for _, v := range g.values {
if slices.Contains(values, v) {
return true
Expand Down
15 changes: 13 additions & 2 deletions pkg/balance/router/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,28 @@ func (mbo *mockBackendObserver) toggleBackendHealth(addr string) {
}

func (mbo *mockBackendObserver) addBackend(addr string, labels map[string]string) {
mbo.addBackendWithCluster(addr, "", labels)
}

func (mbo *mockBackendObserver) addBackendWithCluster(addr, clusterName string, labels map[string]string) {
mbo.healthLock.Lock()
defer mbo.healthLock.Unlock()
mbo.healths[addr] = &observer.BackendHealth{
Healthy: true,
BackendInfo: observer.BackendInfo{
Addr: addr,
Labels: labels,
Addr: addr,
ClusterName: clusterName,
Labels: labels,
},
}
}

func (mbo *mockBackendObserver) setLabels(addr string, labels map[string]string) {
mbo.healthLock.Lock()
defer mbo.healthLock.Unlock()
mbo.healths[addr].Labels = labels
}

func (mbo *mockBackendObserver) Start(ctx context.Context) {
}

Expand Down
49 changes: 49 additions & 0 deletions pkg/balance/router/port_conflict_detector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2026 PingCAP, Inc.
// SPDX-License-Identifier: Apache-2.0

package router

import "github.com/pingcap/tiproxy/lib/util/errors"

type portConflictDetector struct {
routes map[string]*Group
blocked map[string]error
owners map[string]string
}

func newPortConflictDetector() *portConflictDetector {
return &portConflictDetector{
routes: make(map[string]*Group),
blocked: make(map[string]error),
owners: make(map[string]string),
}
}

func (v *portConflictDetector) bind(port, clusterName string, group *Group) {
if port == "" {
return
}
if _, blocked := v.blocked[port]; blocked {
return
}
if owner, ok := v.owners[port]; !ok {
v.owners[port] = clusterName
v.routes[port] = group
return
} else if owner != clusterName {
v.blocked[port] = errors.Wrapf(ErrPortConflict, "listener port %s is claimed by multiple backend clusters", port)
delete(v.routes, port)
return
}
v.routes[port] = group
}

func (v *portConflictDetector) groupFor(port string) (*Group, error) {
if port == "" {
return nil, nil
}
if err, ok := v.blocked[port]; ok {
return nil, err
}
return v.routes[port], nil
}
12 changes: 11 additions & 1 deletion pkg/balance/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
)

var (
ErrNoBackend = errors.New("no available backend")
ErrNoBackend = errors.New("no available backend")
ErrPortConflict = errors.New("port routing conflict")
)

// ConnEventReceiver receives connection events.
Expand Down Expand Up @@ -188,6 +189,7 @@ func (b *backendWrapper) ClusterName() string {
defer b.mu.RUnlock()
return b.mu.BackendHealth.ClusterName
}

func (b *backendWrapper) Cidr() []string {
labels := b.getHealth().Labels
if len(labels) == 0 {
Expand All @@ -209,6 +211,14 @@ func (b *backendWrapper) Cidr() []string {
return cidrs
}

func (b *backendWrapper) TiProxyPort() string {
labels := b.getHealth().Labels
if len(labels) == 0 {
return ""
}
return strings.TrimSpace(labels[config.TiProxyPortLabelName])
}

func (b *backendWrapper) String() string {
b.mu.RLock()
str := b.mu.String()
Expand Down
88 changes: 77 additions & 11 deletions pkg/balance/router/router_score.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package router

import (
"context"
"fmt"
"slices"
"strings"
"sync"
Expand Down Expand Up @@ -41,6 +42,8 @@ type ScoreBasedRouter struct {
backends map[string]*backendWrapper
// TODO: sort the groups to leverage binary search.
groups []*Group
// portConflictDetector dispatches listener ports to cluster-scoped backend groups.
portConflictDetector *portConflictDetector
// The routing rule for categorizing backends to groups.
matchType MatchType
observeError error
Expand Down Expand Up @@ -74,6 +77,8 @@ func (r *ScoreBasedRouter) Init(ctx context.Context, ob observer.BackendObserver
r.matchType = MatchClientCIDR
case config.MatchProxyCIDRStr:
r.matchType = MatchProxyCIDR
case config.MatchPortStr:
r.matchType = MatchPort
case "":
default:
r.logger.Error("unsupported routing rule, use the default rule", zap.String("rule", cfg.Balance.RoutingRule))
Expand Down Expand Up @@ -110,7 +115,10 @@ func (router *ScoreBasedRouter) GetBackendSelector(clientInfo ClientInfo) Backen
return
}
// The group may change from round to round because the backends are updated.
group = router.routeToGroup(clientInfo)
group, err = router.routeToGroup(clientInfo)
if err != nil {
return
}
if group == nil {
err = ErrNoBackend
return
Expand Down Expand Up @@ -146,14 +154,20 @@ func (router *ScoreBasedRouter) HealthyBackendCount() int {
}

// called in the lock
func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) *Group {
func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) (*Group, error) {
if router.matchType == MatchPort {
if router.portConflictDetector == nil {
return nil, nil
}
return router.portConflictDetector.groupFor(clientInfo.ListenerPort)
}
// TODO: binary search
for _, group := range router.groups {
if group.Match(clientInfo) {
return group
return group, nil
}
}
return nil
return nil, nil
}

// RefreshBackend implements Router.GetBackendSelector interface.
Expand Down Expand Up @@ -233,6 +247,45 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt
}
}

func matchPortValue(clusterName, port string) string {
if clusterName == "" {
return port
}
return fmt.Sprintf("%s:%s", clusterName, port)
}

func (router *ScoreBasedRouter) backendGroupValues(backend *backendWrapper) []string {
switch router.matchType {
case MatchClientCIDR, MatchProxyCIDR:
return backend.Cidr()
case MatchPort:
port := backend.TiProxyPort()
if port != "" {
return []string{matchPortValue(backend.ClusterName(), port)}
}
}
return nil
}

func (router *ScoreBasedRouter) rebuildPortConflictDetector() {
if router.matchType != MatchPort {
router.portConflictDetector = nil
return
}
detector := newPortConflictDetector()
for _, group := range router.groups {
for _, value := range group.values {
clusterName, port, ok := strings.Cut(value, ":")
if !ok {
port = value
clusterName = ""
}
detector.bind(port, clusterName, group)
}
}
router.portConflictDetector = detector
}

// Update the groups after the backend list is updated.
// called in the lock.
func (router *ScoreBasedRouter) updateGroups() {
Expand All @@ -254,6 +307,17 @@ func (router *ScoreBasedRouter) updateGroups() {
}
// If the labels were correctly set, we won't update its group even if the labels change.
if backend.group != nil {
switch router.matchType {
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
values := router.backendGroupValues(backend)
if !backend.group.EqualValues(values) {
router.logger.Warn("backend routing values changed, keep the existing group until it is removed",
zap.String("backend_id", backend.id),
zap.String("addr", backend.Addr()),
zap.Strings("current_values", values),
zap.Strings("group_values", backend.group.values))
}
}
continue
}

Expand All @@ -267,33 +331,35 @@ func (router *ScoreBasedRouter) updateGroups() {
router.groups = append(router.groups, group)
}
group = router.groups[0]
case MatchClientCIDR, MatchProxyCIDR:
cidrs := backend.Cidr()
if len(cidrs) == 0 {
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
values := router.backendGroupValues(backend)
if len(values) == 0 {
break
}
for _, g := range router.groups {
if g.Intersect(cidrs) {
if g.Intersect(values) {
group = g
break
}
}
if group == nil {
g, err := NewGroup(cidrs, router.bpCreator, router.matchType, router.logger)
g, err := NewGroup(values, router.bpCreator, router.matchType, router.logger)
if err == nil {
group = g
router.groups = append(router.groups, group)
}
// maybe too many logs, ignore the error now
}
}
if group != nil {
group.AddBackend(backend.id, backend)
if group == nil {
continue
}
group.AddBackend(backend.id, backend)
}
for _, group := range router.groups {
group.RefreshCidr()
}
router.rebuildPortConflictDetector()
}

func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) {
Expand Down
Loading