Skip to content

Commit 8d8501d

Browse files
authored
router: add port-based routing groups (#1106)
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
1 parent 17250b1 commit 8d8501d

File tree

9 files changed

+634
-20
lines changed

9 files changed

+634
-20
lines changed

lib/config/label.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ package config
66
const (
77
// LocationLabelName indicates the label name that decides the location of TiProxy and backends.
88
// We use `zone` because the follower read in TiDB also uses `zone` to decide location.
9-
LocationLabelName = "zone"
10-
KeyspaceLabelName = "keyspace"
11-
CidrLabelName = "cidr"
9+
LocationLabelName = "zone"
10+
KeyspaceLabelName = "keyspace"
11+
CidrLabelName = "cidr"
12+
TiProxyPortLabelName = "tiproxy-port"
1213
)
1314

1415
func (cfg *Config) GetLocation() string {

pkg/balance/router/backend_selector.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import "net"
88
type ClientInfo struct {
99
ClientAddr net.Addr
1010
ProxyAddr net.Addr
11+
// ListenerPort is the SQL listener port that accepted the connection.
12+
ListenerPort string
1113
// TODO: username, database, etc.
1214
}
1315

pkg/balance/router/group.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ const (
3131
MatchClientCIDR
3232
// Match connections based on proxy CIDR. If proxy-protocol is disabled, route by the client CIDR.
3333
MatchProxyCIDR
34+
// Match connections based on the local SQL listener port.
35+
MatchPort
3436
)
3537

3638
var _ ConnEventReceiver = (*Group)(nil)
@@ -105,7 +107,7 @@ func (g *Group) Match(clientInfo ClientInfo) bool {
105107

106108
func (g *Group) EqualValues(values []string) bool {
107109
switch g.matchType {
108-
case MatchClientCIDR, MatchProxyCIDR:
110+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
109111
if len(g.values) != len(values) {
110112
return false
111113
}
@@ -124,7 +126,7 @@ func (g *Group) EqualValues(values []string) bool {
124126
// E.g. enable public endpoint (3 cidrs) -> enable private endpoint (6 cidrs) -> disable public endpoint (3 cidrs).
125127
func (g *Group) Intersect(values []string) bool {
126128
switch g.matchType {
127-
case MatchClientCIDR, MatchProxyCIDR:
129+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
128130
for _, v := range g.values {
129131
if slices.Contains(values, v) {
130132
return true

pkg/balance/router/mock_test.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,28 @@ func (mbo *mockBackendObserver) toggleBackendHealth(addr string) {
133133
}
134134

135135
func (mbo *mockBackendObserver) addBackend(addr string, labels map[string]string) {
136+
mbo.addBackendWithCluster(addr, "", labels)
137+
}
138+
139+
func (mbo *mockBackendObserver) addBackendWithCluster(addr, clusterName string, labels map[string]string) {
136140
mbo.healthLock.Lock()
137141
defer mbo.healthLock.Unlock()
138142
mbo.healths[addr] = &observer.BackendHealth{
139143
Healthy: true,
140144
BackendInfo: observer.BackendInfo{
141-
Addr: addr,
142-
Labels: labels,
145+
Addr: addr,
146+
ClusterName: clusterName,
147+
Labels: labels,
143148
},
144149
}
145150
}
146151

152+
func (mbo *mockBackendObserver) setLabels(addr string, labels map[string]string) {
153+
mbo.healthLock.Lock()
154+
defer mbo.healthLock.Unlock()
155+
mbo.healths[addr].Labels = labels
156+
}
157+
147158
func (mbo *mockBackendObserver) Start(ctx context.Context) {
148159
}
149160

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2026 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package router
5+
6+
import "github.com/pingcap/tiproxy/lib/util/errors"
7+
8+
type portConflictDetector struct {
9+
routes map[string]*Group
10+
blocked map[string]error
11+
owners map[string]string
12+
}
13+
14+
func newPortConflictDetector() *portConflictDetector {
15+
return &portConflictDetector{
16+
routes: make(map[string]*Group),
17+
blocked: make(map[string]error),
18+
owners: make(map[string]string),
19+
}
20+
}
21+
22+
func (v *portConflictDetector) bind(port, clusterName string, group *Group) {
23+
if port == "" {
24+
return
25+
}
26+
if _, blocked := v.blocked[port]; blocked {
27+
return
28+
}
29+
if owner, ok := v.owners[port]; !ok {
30+
v.owners[port] = clusterName
31+
v.routes[port] = group
32+
return
33+
} else if owner != clusterName {
34+
v.blocked[port] = errors.Wrapf(ErrPortConflict, "listener port %s is claimed by multiple backend clusters", port)
35+
delete(v.routes, port)
36+
return
37+
}
38+
v.routes[port] = group
39+
}
40+
41+
func (v *portConflictDetector) groupFor(port string) (*Group, error) {
42+
if port == "" {
43+
return nil, nil
44+
}
45+
if err, ok := v.blocked[port]; ok {
46+
return nil, err
47+
}
48+
return v.routes[port], nil
49+
}

pkg/balance/router/router.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import (
1616
)
1717

1818
var (
19-
ErrNoBackend = errors.New("no available backend")
19+
ErrNoBackend = errors.New("no available backend")
20+
ErrPortConflict = errors.New("port routing conflict")
2021
)
2122

2223
// ConnEventReceiver receives connection events.
@@ -188,6 +189,7 @@ func (b *backendWrapper) ClusterName() string {
188189
defer b.mu.RUnlock()
189190
return b.mu.BackendHealth.ClusterName
190191
}
192+
191193
func (b *backendWrapper) Cidr() []string {
192194
labels := b.getHealth().Labels
193195
if len(labels) == 0 {
@@ -209,6 +211,14 @@ func (b *backendWrapper) Cidr() []string {
209211
return cidrs
210212
}
211213

214+
func (b *backendWrapper) TiProxyPort() string {
215+
labels := b.getHealth().Labels
216+
if len(labels) == 0 {
217+
return ""
218+
}
219+
return strings.TrimSpace(labels[config.TiProxyPortLabelName])
220+
}
221+
212222
func (b *backendWrapper) String() string {
213223
b.mu.RLock()
214224
str := b.mu.String()

pkg/balance/router/router_score.go

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package router
55

66
import (
77
"context"
8+
"fmt"
89
"slices"
910
"strings"
1011
"sync"
@@ -41,6 +42,8 @@ type ScoreBasedRouter struct {
4142
backends map[string]*backendWrapper
4243
// TODO: sort the groups to leverage binary search.
4344
groups []*Group
45+
// portConflictDetector dispatches listener ports to cluster-scoped backend groups.
46+
portConflictDetector *portConflictDetector
4447
// The routing rule for categorizing backends to groups.
4548
matchType MatchType
4649
observeError error
@@ -74,6 +77,8 @@ func (r *ScoreBasedRouter) Init(ctx context.Context, ob observer.BackendObserver
7477
r.matchType = MatchClientCIDR
7578
case config.MatchProxyCIDRStr:
7679
r.matchType = MatchProxyCIDR
80+
case config.MatchPortStr:
81+
r.matchType = MatchPort
7782
case "":
7883
default:
7984
r.logger.Error("unsupported routing rule, use the default rule", zap.String("rule", cfg.Balance.RoutingRule))
@@ -110,7 +115,10 @@ func (router *ScoreBasedRouter) GetBackendSelector(clientInfo ClientInfo) Backen
110115
return
111116
}
112117
// The group may change from round to round because the backends are updated.
113-
group = router.routeToGroup(clientInfo)
118+
group, err = router.routeToGroup(clientInfo)
119+
if err != nil {
120+
return
121+
}
114122
if group == nil {
115123
err = ErrNoBackend
116124
return
@@ -146,14 +154,20 @@ func (router *ScoreBasedRouter) HealthyBackendCount() int {
146154
}
147155

148156
// called in the lock
149-
func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) *Group {
157+
func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) (*Group, error) {
158+
if router.matchType == MatchPort {
159+
if router.portConflictDetector == nil {
160+
return nil, nil
161+
}
162+
return router.portConflictDetector.groupFor(clientInfo.ListenerPort)
163+
}
150164
// TODO: binary search
151165
for _, group := range router.groups {
152166
if group.Match(clientInfo) {
153-
return group
167+
return group, nil
154168
}
155169
}
156-
return nil
170+
return nil, nil
157171
}
158172

159173
// RefreshBackend implements Router.GetBackendSelector interface.
@@ -233,6 +247,45 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt
233247
}
234248
}
235249

250+
func matchPortValue(clusterName, port string) string {
251+
if clusterName == "" {
252+
return port
253+
}
254+
return fmt.Sprintf("%s:%s", clusterName, port)
255+
}
256+
257+
func (router *ScoreBasedRouter) backendGroupValues(backend *backendWrapper) []string {
258+
switch router.matchType {
259+
case MatchClientCIDR, MatchProxyCIDR:
260+
return backend.Cidr()
261+
case MatchPort:
262+
port := backend.TiProxyPort()
263+
if port != "" {
264+
return []string{matchPortValue(backend.ClusterName(), port)}
265+
}
266+
}
267+
return nil
268+
}
269+
270+
func (router *ScoreBasedRouter) rebuildPortConflictDetector() {
271+
if router.matchType != MatchPort {
272+
router.portConflictDetector = nil
273+
return
274+
}
275+
detector := newPortConflictDetector()
276+
for _, group := range router.groups {
277+
for _, value := range group.values {
278+
clusterName, port, ok := strings.Cut(value, ":")
279+
if !ok {
280+
port = value
281+
clusterName = ""
282+
}
283+
detector.bind(port, clusterName, group)
284+
}
285+
}
286+
router.portConflictDetector = detector
287+
}
288+
236289
// Update the groups after the backend list is updated.
237290
// called in the lock.
238291
func (router *ScoreBasedRouter) updateGroups() {
@@ -254,6 +307,17 @@ func (router *ScoreBasedRouter) updateGroups() {
254307
}
255308
// If the labels were correctly set, we won't update its group even if the labels change.
256309
if backend.group != nil {
310+
switch router.matchType {
311+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
312+
values := router.backendGroupValues(backend)
313+
if !backend.group.EqualValues(values) {
314+
router.logger.Warn("backend routing values changed, keep the existing group until it is removed",
315+
zap.String("backend_id", backend.id),
316+
zap.String("addr", backend.Addr()),
317+
zap.Strings("current_values", values),
318+
zap.Strings("group_values", backend.group.values))
319+
}
320+
}
257321
continue
258322
}
259323

@@ -267,33 +331,35 @@ func (router *ScoreBasedRouter) updateGroups() {
267331
router.groups = append(router.groups, group)
268332
}
269333
group = router.groups[0]
270-
case MatchClientCIDR, MatchProxyCIDR:
271-
cidrs := backend.Cidr()
272-
if len(cidrs) == 0 {
334+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
335+
values := router.backendGroupValues(backend)
336+
if len(values) == 0 {
273337
break
274338
}
275339
for _, g := range router.groups {
276-
if g.Intersect(cidrs) {
340+
if g.Intersect(values) {
277341
group = g
278342
break
279343
}
280344
}
281345
if group == nil {
282-
g, err := NewGroup(cidrs, router.bpCreator, router.matchType, router.logger)
346+
g, err := NewGroup(values, router.bpCreator, router.matchType, router.logger)
283347
if err == nil {
284348
group = g
285349
router.groups = append(router.groups, group)
286350
}
287351
// maybe too many logs, ignore the error now
288352
}
289353
}
290-
if group != nil {
291-
group.AddBackend(backend.id, backend)
354+
if group == nil {
355+
continue
292356
}
357+
group.AddBackend(backend.id, backend)
293358
}
294359
for _, group := range router.groups {
295360
group.RefreshCidr()
296361
}
362+
router.rebuildPortConflictDetector()
297363
}
298364

299365
func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) {

0 commit comments

Comments
 (0)