Skip to content

Commit 64cec42

Browse files
committed
router: add port-based routing groups
1 parent c81e4dc commit 64cec42

File tree

9 files changed

+635
-26
lines changed

9 files changed

+635
-26
lines changed

lib/config/label.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ package config
66
const (
77
// LocationLabelName indicates the label name that decides the location of TiProxy and backends.
88
// We use `zone` because the follower read in TiDB also uses `zone` to decide location.
9-
LocationLabelName = "zone"
10-
KeyspaceLabelName = "keyspace"
11-
CidrLabelName = "cidr"
9+
LocationLabelName = "zone"
10+
KeyspaceLabelName = "keyspace"
11+
CidrLabelName = "cidr"
12+
TiProxyPortLabelName = "tiproxy-port"
1213
)
1314

1415
func (cfg *Config) GetLocation() string {

pkg/balance/router/backend_selector.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import "net"
88
type ClientInfo struct {
99
ClientAddr net.Addr
1010
ProxyAddr net.Addr
11+
// ListenerAddr is the SQL listener address that accepted the connection.
12+
ListenerAddr string
1113
// TODO: username, database, etc.
1214
}
1315

pkg/balance/router/group.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ const (
3030
MatchClientCIDR
3131
// Match connections based on proxy CIDR. If proxy-protocol is disabled, route by the client CIDR.
3232
MatchProxyCIDR
33+
// Match connections based on the local SQL listener port.
34+
MatchPort
3335
)
3436

3537
var _ ConnEventReceiver = (*Group)(nil)
@@ -104,7 +106,7 @@ func (g *Group) Match(clientInfo ClientInfo) bool {
104106

105107
func (g *Group) EqualValues(values []string) bool {
106108
switch g.matchType {
107-
case MatchClientCIDR, MatchProxyCIDR:
109+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
108110
if len(g.values) != len(values) {
109111
return false
110112
}
@@ -123,7 +125,7 @@ func (g *Group) EqualValues(values []string) bool {
123125
// E.g. enable public endpoint (3 cidrs) -> enable private endpoint (6 cidrs) -> disable public endpoint (3 cidrs).
124126
func (g *Group) Intersect(values []string) bool {
125127
switch g.matchType {
126-
case MatchClientCIDR, MatchProxyCIDR:
128+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
127129
for _, v := range g.values {
128130
if slices.Contains(values, v) {
129131
return true

pkg/balance/router/mock_test.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,28 @@ func (mbo *mockBackendObserver) toggleBackendHealth(addr string) {
133133
}
134134

135135
func (mbo *mockBackendObserver) addBackend(addr string, labels map[string]string) {
136+
mbo.addBackendWithCluster(addr, "", labels)
137+
}
138+
139+
func (mbo *mockBackendObserver) addBackendWithCluster(addr, clusterName string, labels map[string]string) {
136140
mbo.healthLock.Lock()
137141
defer mbo.healthLock.Unlock()
138142
mbo.healths[addr] = &observer.BackendHealth{
139143
Healthy: true,
140144
BackendInfo: observer.BackendInfo{
141-
Addr: addr,
142-
Labels: labels,
145+
Addr: addr,
146+
ClusterName: clusterName,
147+
Labels: labels,
143148
},
144149
}
145150
}
146151

152+
func (mbo *mockBackendObserver) setLabels(addr string, labels map[string]string) {
153+
mbo.healthLock.Lock()
154+
defer mbo.healthLock.Unlock()
155+
mbo.healths[addr].Labels = labels
156+
}
157+
147158
func (mbo *mockBackendObserver) Start(ctx context.Context) {
148159
}
149160

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2026 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package router
5+
6+
import "github.com/pingcap/tiproxy/lib/util/errors"
7+
8+
type portConflictDetector struct {
9+
routes map[string]*Group
10+
blocked map[string]error
11+
owners map[string]string
12+
}
13+
14+
func newPortConflictDetector() *portConflictDetector {
15+
return &portConflictDetector{
16+
routes: make(map[string]*Group),
17+
blocked: make(map[string]error),
18+
owners: make(map[string]string),
19+
}
20+
}
21+
22+
func (v *portConflictDetector) bind(port, clusterName string, group *Group) {
23+
if v == nil || port == "" || group == nil {
24+
return
25+
}
26+
if _, blocked := v.blocked[port]; blocked {
27+
return
28+
}
29+
if owner, ok := v.owners[port]; !ok {
30+
v.owners[port] = clusterName
31+
v.routes[port] = group
32+
return
33+
} else if owner != clusterName {
34+
v.blocked[port] = errors.Wrapf(ErrPortConflict, "listener port %s is claimed by multiple backend clusters", port)
35+
delete(v.routes, port)
36+
return
37+
}
38+
v.routes[port] = group
39+
}
40+
41+
func (v *portConflictDetector) groupFor(port string) (*Group, error) {
42+
if v == nil || port == "" {
43+
return nil, nil
44+
}
45+
if err, ok := v.blocked[port]; ok {
46+
return nil, err
47+
}
48+
return v.routes[port], nil
49+
}

pkg/balance/router/router.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import (
1616
)
1717

1818
var (
19-
ErrNoBackend = errors.New("no available backend")
19+
ErrNoBackend = errors.New("no available backend")
20+
ErrPortConflict = errors.New("port routing conflict")
2021
)
2122

2223
// ConnEventReceiver receives connection events.
@@ -188,6 +189,7 @@ func (b *backendWrapper) ClusterName() string {
188189
defer b.mu.RUnlock()
189190
return b.mu.BackendHealth.ClusterName
190191
}
192+
191193
func (b *backendWrapper) Cidr() []string {
192194
labels := b.getHealth().Labels
193195
if len(labels) == 0 {
@@ -209,6 +211,14 @@ func (b *backendWrapper) Cidr() []string {
209211
return cidrs
210212
}
211213

214+
func (b *backendWrapper) TiProxyPort() string {
215+
labels := b.getHealth().Labels
216+
if len(labels) == 0 {
217+
return ""
218+
}
219+
return strings.TrimSpace(labels[config.TiProxyPortLabelName])
220+
}
221+
212222
func (b *backendWrapper) String() string {
213223
b.mu.RLock()
214224
str := b.mu.String()

pkg/balance/router/router_score.go

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package router
55

66
import (
77
"context"
8+
"fmt"
9+
"net"
810
"slices"
911
"strings"
1012
"sync"
@@ -41,6 +43,8 @@ type ScoreBasedRouter struct {
4143
backends map[string]*backendWrapper
4244
// TODO: sort the groups to leverage binary search.
4345
groups []*Group
46+
// portConflictDetector dispatches listener ports to cluster-scoped backend groups.
47+
portConflictDetector *portConflictDetector
4448
// The routing rule for categorizing backends to groups.
4549
matchType MatchType
4650
observeError error
@@ -74,6 +78,8 @@ func (r *ScoreBasedRouter) Init(ctx context.Context, ob observer.BackendObserver
7478
r.matchType = MatchClientCIDR
7579
case config.MatchProxyCIDRStr:
7680
r.matchType = MatchProxyCIDR
81+
case config.MatchPortStr:
82+
r.matchType = MatchPort
7783
case "":
7884
default:
7985
r.logger.Error("unsupported routing rule, use the default rule", zap.String("rule", cfg.Balance.RoutingRule))
@@ -110,7 +116,10 @@ func (router *ScoreBasedRouter) GetBackendSelector(clientInfo ClientInfo) Backen
110116
return
111117
}
112118
// The group may change from round to round because the backends are updated.
113-
group = router.routeToGroup(clientInfo)
119+
group, err = router.routeToGroup(clientInfo)
120+
if err != nil {
121+
return
122+
}
114123
if group == nil {
115124
err = ErrNoBackend
116125
return
@@ -146,14 +155,22 @@ func (router *ScoreBasedRouter) HealthyBackendCount() int {
146155
}
147156

148157
// called in the lock
149-
func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) *Group {
158+
func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) (*Group, error) {
159+
if router.matchType == MatchPort {
160+
_, port, err := net.SplitHostPort(clientInfo.ListenerAddr)
161+
if err != nil {
162+
router.logger.Error("checking port failed", zap.String("listener_addr", clientInfo.ListenerAddr), zap.Error(err))
163+
return nil, nil
164+
}
165+
return router.portConflictDetector.groupFor(port)
166+
}
150167
// TODO: binary search
151168
for _, group := range router.groups {
152169
if group.Match(clientInfo) {
153-
return group
170+
return group, nil
154171
}
155172
}
156-
return nil
173+
return nil, nil
157174
}
158175

159176
// RefreshBackend implements Router.GetBackendSelector interface.
@@ -233,7 +250,32 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt
233250
}
234251
}
235252

236-
// Update the groups after the backend list is updated.
253+
func matchPortValue(clusterName, port string) string {
254+
if clusterName == "" {
255+
return port
256+
}
257+
return fmt.Sprintf("%s:%s", clusterName, port)
258+
}
259+
260+
func (router *ScoreBasedRouter) rebuildPortConflictDetector() {
261+
if router.matchType != MatchPort {
262+
router.portConflictDetector = nil
263+
return
264+
}
265+
detector := newPortConflictDetector()
266+
for _, group := range router.groups {
267+
for _, value := range group.values {
268+
clusterName, port, ok := strings.Cut(value, ":")
269+
if !ok {
270+
port = value
271+
clusterName = ""
272+
}
273+
detector.bind(port, clusterName, group)
274+
}
275+
}
276+
router.portConflictDetector = detector
277+
}
278+
237279
// called in the lock.
238280
func (router *ScoreBasedRouter) updateGroups() {
239281
for _, backend := range router.backends {
@@ -243,22 +285,39 @@ func (router *ScoreBasedRouter) updateGroups() {
243285
delete(router.backends, backend.id)
244286
if backend.group != nil {
245287
backend.group.RemoveBackend(backend.id)
246-
// remove empty groups
247288
if backend.group.Empty() {
248289
router.groups = slices.DeleteFunc(router.groups, func(g *Group) bool {
249290
return g == backend.group
250291
})
251292
}
293+
backend.group = nil
252294
}
253295
continue
254296
}
255-
// If the labels were correctly set, we won't update its group even if the labels change.
256297
if backend.group != nil {
298+
switch router.matchType {
299+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
300+
var values []string
301+
switch router.matchType {
302+
case MatchClientCIDR, MatchProxyCIDR:
303+
values = backend.Cidr()
304+
case MatchPort:
305+
port := backend.TiProxyPort()
306+
if port != "" {
307+
values = []string{matchPortValue(backend.ClusterName(), port)}
308+
}
309+
}
310+
if !backend.group.EqualValues(values) {
311+
router.logger.Warn("backend routing values changed, keep the existing group until it is removed",
312+
zap.String("backend_id", backend.id),
313+
zap.String("addr", backend.Addr()),
314+
zap.Strings("current_values", values),
315+
zap.Strings("group_values", backend.group.values))
316+
}
317+
}
257318
continue
258319
}
259320

260-
// If the backend is not in any group, add it to a new group if its label is set.
261-
// In operator deployment, the labels are set dynamically.
262321
var group *Group
263322
switch router.matchType {
264323
case MatchAll:
@@ -267,33 +326,43 @@ func (router *ScoreBasedRouter) updateGroups() {
267326
router.groups = append(router.groups, group)
268327
}
269328
group = router.groups[0]
270-
case MatchClientCIDR, MatchProxyCIDR:
271-
cidrs := backend.Cidr()
272-
if len(cidrs) == 0 {
329+
case MatchClientCIDR, MatchProxyCIDR, MatchPort:
330+
var values []string
331+
switch router.matchType {
332+
case MatchClientCIDR, MatchProxyCIDR:
333+
values = backend.Cidr()
334+
case MatchPort:
335+
port := backend.TiProxyPort()
336+
if port != "" {
337+
values = []string{matchPortValue(backend.ClusterName(), port)}
338+
}
339+
}
340+
if len(values) == 0 {
273341
break
274342
}
275343
for _, g := range router.groups {
276-
if g.Intersect(cidrs) {
344+
if g.Intersect(values) {
277345
group = g
278346
break
279347
}
280348
}
281349
if group == nil {
282-
g, err := NewGroup(cidrs, router.bpCreator, router.matchType, router.logger)
350+
g, err := NewGroup(values, router.bpCreator, router.matchType, router.logger)
283351
if err == nil {
284352
group = g
285353
router.groups = append(router.groups, group)
286354
}
287-
// maybe too many logs, ignore the error now
288355
}
289356
}
290-
if group != nil {
291-
group.AddBackend(backend.id, backend)
357+
if group == nil {
358+
continue
292359
}
360+
group.AddBackend(backend.id, backend)
293361
}
294362
for _, group := range router.groups {
295363
group.RefreshCidr()
296364
}
365+
router.rebuildPortConflictDetector()
297366
}
298367

299368
func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) {

0 commit comments

Comments
 (0)