diff --git a/lib/config/label.go b/lib/config/label.go index 2a7a858f..b34615c2 100644 --- a/lib/config/label.go +++ b/lib/config/label.go @@ -6,9 +6,10 @@ package config const ( // LocationLabelName indicates the label name that decides the location of TiProxy and backends. // We use `zone` because the follower read in TiDB also uses `zone` to decide location. - LocationLabelName = "zone" - KeyspaceLabelName = "keyspace" - CidrLabelName = "cidr" + LocationLabelName = "zone" + KeyspaceLabelName = "keyspace" + CidrLabelName = "cidr" + TiProxyPortLabelName = "tiproxy-port" ) func (cfg *Config) GetLocation() string { diff --git a/pkg/balance/router/backend_selector.go b/pkg/balance/router/backend_selector.go index 25225cfb..357b19ec 100644 --- a/pkg/balance/router/backend_selector.go +++ b/pkg/balance/router/backend_selector.go @@ -8,6 +8,8 @@ import "net" type ClientInfo struct { ClientAddr net.Addr ProxyAddr net.Addr + // ListenerPort is the SQL listener port that accepted the connection. + ListenerPort string // TODO: username, database, etc. } diff --git a/pkg/balance/router/group.go b/pkg/balance/router/group.go index fbda1a9f..924fe598 100644 --- a/pkg/balance/router/group.go +++ b/pkg/balance/router/group.go @@ -31,6 +31,8 @@ const ( MatchClientCIDR // Match connections based on proxy CIDR. If proxy-protocol is disabled, route by the client CIDR. MatchProxyCIDR + // Match connections based on the local SQL listener port. + MatchPort ) var _ ConnEventReceiver = (*Group)(nil) @@ -105,7 +107,7 @@ func (g *Group) Match(clientInfo ClientInfo) bool { func (g *Group) EqualValues(values []string) bool { switch g.matchType { - case MatchClientCIDR, MatchProxyCIDR: + case MatchClientCIDR, MatchProxyCIDR, MatchPort: if len(g.values) != len(values) { return false } @@ -124,7 +126,7 @@ func (g *Group) EqualValues(values []string) bool { // E.g. enable public endpoint (3 cidrs) -> enable private endpoint (6 cidrs) -> disable public endpoint (3 cidrs). func (g *Group) Intersect(values []string) bool { switch g.matchType { - case MatchClientCIDR, MatchProxyCIDR: + case MatchClientCIDR, MatchProxyCIDR, MatchPort: for _, v := range g.values { if slices.Contains(values, v) { return true diff --git a/pkg/balance/router/mock_test.go b/pkg/balance/router/mock_test.go index e2918a59..b701657e 100644 --- a/pkg/balance/router/mock_test.go +++ b/pkg/balance/router/mock_test.go @@ -133,17 +133,28 @@ func (mbo *mockBackendObserver) toggleBackendHealth(addr string) { } func (mbo *mockBackendObserver) addBackend(addr string, labels map[string]string) { + mbo.addBackendWithCluster(addr, "", labels) +} + +func (mbo *mockBackendObserver) addBackendWithCluster(addr, clusterName string, labels map[string]string) { mbo.healthLock.Lock() defer mbo.healthLock.Unlock() mbo.healths[addr] = &observer.BackendHealth{ Healthy: true, BackendInfo: observer.BackendInfo{ - Addr: addr, - Labels: labels, + Addr: addr, + ClusterName: clusterName, + Labels: labels, }, } } +func (mbo *mockBackendObserver) setLabels(addr string, labels map[string]string) { + mbo.healthLock.Lock() + defer mbo.healthLock.Unlock() + mbo.healths[addr].Labels = labels +} + func (mbo *mockBackendObserver) Start(ctx context.Context) { } diff --git a/pkg/balance/router/port_conflict_detector.go b/pkg/balance/router/port_conflict_detector.go new file mode 100644 index 00000000..2922a848 --- /dev/null +++ b/pkg/balance/router/port_conflict_detector.go @@ -0,0 +1,49 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package router + +import "github.com/pingcap/tiproxy/lib/util/errors" + +type portConflictDetector struct { + routes map[string]*Group + blocked map[string]error + owners map[string]string +} + +func newPortConflictDetector() *portConflictDetector { + return &portConflictDetector{ + routes: make(map[string]*Group), + blocked: make(map[string]error), + owners: make(map[string]string), + } +} + +func (v *portConflictDetector) bind(port, clusterName string, group *Group) { + if port == "" { + return + } + if _, blocked := v.blocked[port]; blocked { + return + } + if owner, ok := v.owners[port]; !ok { + v.owners[port] = clusterName + v.routes[port] = group + return + } else if owner != clusterName { + v.blocked[port] = errors.Wrapf(ErrPortConflict, "listener port %s is claimed by multiple backend clusters", port) + delete(v.routes, port) + return + } + v.routes[port] = group +} + +func (v *portConflictDetector) groupFor(port string) (*Group, error) { + if port == "" { + return nil, nil + } + if err, ok := v.blocked[port]; ok { + return nil, err + } + return v.routes[port], nil +} diff --git a/pkg/balance/router/router.go b/pkg/balance/router/router.go index bca70fbb..e8347511 100644 --- a/pkg/balance/router/router.go +++ b/pkg/balance/router/router.go @@ -16,7 +16,8 @@ import ( ) var ( - ErrNoBackend = errors.New("no available backend") + ErrNoBackend = errors.New("no available backend") + ErrPortConflict = errors.New("port routing conflict") ) // ConnEventReceiver receives connection events. @@ -188,6 +189,7 @@ func (b *backendWrapper) ClusterName() string { defer b.mu.RUnlock() return b.mu.BackendHealth.ClusterName } + func (b *backendWrapper) Cidr() []string { labels := b.getHealth().Labels if len(labels) == 0 { @@ -209,6 +211,14 @@ func (b *backendWrapper) Cidr() []string { return cidrs } +func (b *backendWrapper) TiProxyPort() string { + labels := b.getHealth().Labels + if len(labels) == 0 { + return "" + } + return strings.TrimSpace(labels[config.TiProxyPortLabelName]) +} + func (b *backendWrapper) String() string { b.mu.RLock() str := b.mu.String() diff --git a/pkg/balance/router/router_score.go b/pkg/balance/router/router_score.go index fc52d303..97b22da6 100644 --- a/pkg/balance/router/router_score.go +++ b/pkg/balance/router/router_score.go @@ -5,6 +5,7 @@ package router import ( "context" + "fmt" "slices" "strings" "sync" @@ -41,6 +42,8 @@ type ScoreBasedRouter struct { backends map[string]*backendWrapper // TODO: sort the groups to leverage binary search. groups []*Group + // portConflictDetector dispatches listener ports to cluster-scoped backend groups. + portConflictDetector *portConflictDetector // The routing rule for categorizing backends to groups. matchType MatchType observeError error @@ -74,6 +77,8 @@ func (r *ScoreBasedRouter) Init(ctx context.Context, ob observer.BackendObserver r.matchType = MatchClientCIDR case config.MatchProxyCIDRStr: r.matchType = MatchProxyCIDR + case config.MatchPortStr: + r.matchType = MatchPort case "": default: r.logger.Error("unsupported routing rule, use the default rule", zap.String("rule", cfg.Balance.RoutingRule)) @@ -110,7 +115,10 @@ func (router *ScoreBasedRouter) GetBackendSelector(clientInfo ClientInfo) Backen return } // The group may change from round to round because the backends are updated. - group = router.routeToGroup(clientInfo) + group, err = router.routeToGroup(clientInfo) + if err != nil { + return + } if group == nil { err = ErrNoBackend return @@ -146,14 +154,20 @@ func (router *ScoreBasedRouter) HealthyBackendCount() int { } // called in the lock -func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) *Group { +func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) (*Group, error) { + if router.matchType == MatchPort { + if router.portConflictDetector == nil { + return nil, nil + } + return router.portConflictDetector.groupFor(clientInfo.ListenerPort) + } // TODO: binary search for _, group := range router.groups { if group.Match(clientInfo) { - return group + return group, nil } } - return nil + return nil, nil } // RefreshBackend implements Router.GetBackendSelector interface. @@ -233,6 +247,45 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt } } +func matchPortValue(clusterName, port string) string { + if clusterName == "" { + return port + } + return fmt.Sprintf("%s:%s", clusterName, port) +} + +func (router *ScoreBasedRouter) backendGroupValues(backend *backendWrapper) []string { + switch router.matchType { + case MatchClientCIDR, MatchProxyCIDR: + return backend.Cidr() + case MatchPort: + port := backend.TiProxyPort() + if port != "" { + return []string{matchPortValue(backend.ClusterName(), port)} + } + } + return nil +} + +func (router *ScoreBasedRouter) rebuildPortConflictDetector() { + if router.matchType != MatchPort { + router.portConflictDetector = nil + return + } + detector := newPortConflictDetector() + for _, group := range router.groups { + for _, value := range group.values { + clusterName, port, ok := strings.Cut(value, ":") + if !ok { + port = value + clusterName = "" + } + detector.bind(port, clusterName, group) + } + } + router.portConflictDetector = detector +} + // Update the groups after the backend list is updated. // called in the lock. func (router *ScoreBasedRouter) updateGroups() { @@ -254,6 +307,17 @@ func (router *ScoreBasedRouter) updateGroups() { } // If the labels were correctly set, we won't update its group even if the labels change. if backend.group != nil { + switch router.matchType { + case MatchClientCIDR, MatchProxyCIDR, MatchPort: + values := router.backendGroupValues(backend) + if !backend.group.EqualValues(values) { + router.logger.Warn("backend routing values changed, keep the existing group until it is removed", + zap.String("backend_id", backend.id), + zap.String("addr", backend.Addr()), + zap.Strings("current_values", values), + zap.Strings("group_values", backend.group.values)) + } + } continue } @@ -267,19 +331,19 @@ func (router *ScoreBasedRouter) updateGroups() { router.groups = append(router.groups, group) } group = router.groups[0] - case MatchClientCIDR, MatchProxyCIDR: - cidrs := backend.Cidr() - if len(cidrs) == 0 { + case MatchClientCIDR, MatchProxyCIDR, MatchPort: + values := router.backendGroupValues(backend) + if len(values) == 0 { break } for _, g := range router.groups { - if g.Intersect(cidrs) { + if g.Intersect(values) { group = g break } } if group == nil { - g, err := NewGroup(cidrs, router.bpCreator, router.matchType, router.logger) + g, err := NewGroup(values, router.bpCreator, router.matchType, router.logger) if err == nil { group = g router.groups = append(router.groups, group) @@ -287,13 +351,15 @@ func (router *ScoreBasedRouter) updateGroups() { // maybe too many logs, ignore the error now } } - if group != nil { - group.AddBackend(backend.id, backend) + if group == nil { + continue } + group.AddBackend(backend.id, backend) } for _, group := range router.groups { group.RefreshCidr() } + router.rebuildPortConflictDetector() } func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { diff --git a/pkg/balance/router/router_score_test.go b/pkg/balance/router/router_score_test.go index 9cb0b1e9..fd2a0877 100644 --- a/pkg/balance/router/router_score_test.go +++ b/pkg/balance/router/router_score_test.go @@ -8,7 +8,9 @@ import ( "math" "math/rand" "reflect" + "slices" "strconv" + "strings" "testing" "time" @@ -142,7 +144,11 @@ func (tester *routerTester) getBackendByIndex(index int) *backendWrapper { } func (tester *routerTester) simpleRoute(conn RedirectableConn) BackendInst { - selector := tester.router.GetBackendSelector(ClientInfo{}) + return tester.route(conn, ClientInfo{}) +} + +func (tester *routerTester) route(conn RedirectableConn, ci ClientInfo) BackendInst { + selector := tester.router.GetBackendSelector(ci) backend, err := selector.Next() if err != ErrNoBackend { require.NoError(tester.t, err) @@ -1184,6 +1190,446 @@ func TestGroupBackends(t *testing.T) { } } +func TestGroupBackendsByPort(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + tests := []struct { + addr string + labels map[string]string + groupCount int + backendCount int + port string + }{ + { + addr: "0", + labels: nil, + groupCount: 0, + backendCount: 1, + }, + { + addr: "1", + labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + groupCount: 1, + backendCount: 2, + port: "10080", + }, + { + addr: "2", + labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + groupCount: 1, + backendCount: 3, + port: "10080", + }, + { + addr: "3", + labels: map[string]string{config.TiProxyPortLabelName: "10081"}, + groupCount: 2, + backendCount: 4, + port: "10081", + }, + } + + for i, test := range tests { + bo.addBackend(test.addr, test.labels) + bo.notify(nil) + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + if len(router.groups) != test.groupCount { + return false + } + if len(router.backends) != test.backendCount { + return false + } + group := router.backends[test.addr].group + if test.port == "" { + return group == nil + } + return group != nil && slices.Equal(group.values, []string{test.port}) + }, 3*time.Second, 10*time.Millisecond, "test %d", i) + } +} + +func TestRouteAndRebalanceByPort(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + bp := &mockBalancePolicy{} + tester := newRouterTester(t, bp) + tester.router.matchType = MatchPort + bp.backendToRoute = func(backends []policy.BackendCtx) policy.BackendCtx { + if len(backends) == 0 { + return nil + } + return backends[0] + } + bp.backendsToBalance = func(backends []policy.BackendCtx) (from policy.BackendCtx, to policy.BackendCtx, balanceCount float64, reason string, logFields []zapcore.Field) { + if len(backends) < 2 { + return nil, nil, 0, "", nil + } + var busiest, idlest policy.BackendCtx + for _, backend := range backends { + if busiest == nil || backend.ConnCount() > busiest.ConnCount() { + busiest = backend + } + if idlest == nil || backend.ConnCount() < idlest.ConnCount() { + idlest = backend + } + } + if busiest == nil || idlest == nil || busiest == idlest { + return nil, nil, 0, "", nil + } + return busiest, idlest, 100, "conn", nil + } + tester.router.cfgGetter = newMockConfigGetter(cfg) + + tester.backends["1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "1", + Labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + }, + } + tester.backends["2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "2", + Labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + }, + } + tester.backends["3"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "3", + Labels: map[string]string{config.TiProxyPortLabelName: "10081"}, + }, + } + tester.notifyHealth() + + for range 10 { + conn := tester.createConn() + backend := tester.route(conn, ClientInfo{ListenerPort: "10080"}) + require.NotNil(t, backend) + conn.from = backend + tester.conns[conn.connID] = conn + } + for _, conn := range tester.conns { + require.Equal(t, "10080", tester.router.backends[conn.from.ID()].TiProxyPort()) + } + + tester.rebalance(10) + redirecting := 0 + for _, conn := range tester.conns { + if conn.to == nil || reflect.ValueOf(conn.to).IsNil() { + continue + } + redirecting++ + require.Equal(t, "10080", tester.router.backends[conn.to.ID()].TiProxyPort()) + require.NotEqual(t, "3", conn.to.Addr()) + } + require.Greater(t, redirecting, 0) +} + +func TestRouteByPortBlocksConflictingClusters(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + router := NewScoreBasedRouter(zap.NewNop()) + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackendWithCluster("a1", "cluster-a", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.addBackendWithCluster("b1", "cluster-b", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.notify(nil) + + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + return len(router.groups) == 2 && router.portConflictDetector != nil + }, 3*time.Second, 10*time.Millisecond) + + selector := router.GetBackendSelector(ClientInfo{ListenerPort: "10080"}) + _, err := selector.Next() + require.Error(t, err) + require.True(t, errors.Is(err, ErrPortConflict)) +} + +func TestRouteByPortRecoversAfterConflictIsRemoved(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + router := NewScoreBasedRouter(zap.NewNop()) + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackendWithCluster("a1", "cluster-a", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.addBackendWithCluster("b1", "cluster-b", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.notify(nil) + + require.Eventually(t, func() bool { + selector := router.GetBackendSelector(ClientInfo{ListenerPort: "10080"}) + _, err := selector.Next() + return errors.Is(err, ErrPortConflict) + }, 3*time.Second, 10*time.Millisecond) + + bo.healthLock.Lock() + delete(bo.healths, "b1") + bo.healthLock.Unlock() + bo.notify(nil) + + require.Eventually(t, func() bool { + selector := router.GetBackendSelector(ClientInfo{ListenerPort: "10080"}) + backend, err := selector.Next() + return err == nil && backend != nil && backend.ID() == "a1" + }, 3*time.Second, 10*time.Millisecond) +} + +func TestKeepExistingPortGroupWhenPortLabelChanges(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + lg, text := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackendWithCluster("backend-1", "cluster-a", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.notify(nil) + + var oldGroup *Group + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + oldGroup = router.backends["backend-1"].group + return oldGroup != nil && slices.Equal(oldGroup.values, []string{"cluster-a:10080"}) + }, 3*time.Second, 10*time.Millisecond) + + conn := newMockRedirectableConn(t, 1) + selector := router.GetBackendSelector(ClientInfo{ListenerPort: "10080"}) + backend, err := selector.Next() + require.NoError(t, err) + selector.Finish(conn, true) + conn.from = backend + + bo.healthLock.Lock() + bo.healths["backend-1"].ClusterName = "cluster-a" + bo.healthLock.Unlock() + bo.setLabels("backend-1", map[string]string{ + config.TiProxyPortLabelName: "10081", + }) + bo.notify(nil) + + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + return router.backends["backend-1"].group == oldGroup + }, 3*time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { + return strings.Contains(text.String(), "backend routing values changed, keep the existing group until it is removed") + }, 3*time.Second, 10*time.Millisecond) + + conn.Lock() + require.Equal(t, oldGroup, conn.receiver) + conn.Unlock() + + oldSelector := router.GetBackendSelector(ClientInfo{ListenerPort: "10080"}) + backend, err = oldSelector.Next() + require.NoError(t, err) + require.Equal(t, "backend-1", backend.ID()) + + newSelector := router.GetBackendSelector(ClientInfo{ListenerPort: "10081"}) + _, err = newSelector.Next() + require.ErrorIs(t, err, ErrNoBackend) +} + +func TestPortConflictGroupsStayClusterScoped(t *testing.T) { + tester := newRouterTester(t, nil) + tester.router.matchType = MatchPort + tester.backends["a1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-1:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["a2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-2:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["b1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-b-1:4000", + ClusterName: "cluster-b", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["b2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-b-2:4000", + ClusterName: "cluster-b", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.notifyHealth() + + groupA := findGroupByValues(t, tester.router, []string{"cluster-a:10080"}) + groupB := findGroupByValues(t, tester.router, []string{"cluster-b:10080"}) + require.NotSame(t, groupA, groupB) + for _, backend := range groupA.backends { + require.Equal(t, "cluster-a", backend.ClusterName()) + } + for _, backend := range groupB.backends { + require.Equal(t, "cluster-b", backend.ClusterName()) + } +} + +func TestPortConflictBlocksRoutingButAllowsIntraClusterRebalance(t *testing.T) { + bp := &mockBalancePolicy{} + tester := newRouterTester(t, bp) + tester.router.matchType = MatchPort + bp.backendsToBalance = func(backends []policy.BackendCtx) (from policy.BackendCtx, to policy.BackendCtx, balanceCount float64, reason string, logFields []zapcore.Field) { + if len(backends) < 2 { + return nil, nil, 0, "", nil + } + var busiest, idlest policy.BackendCtx + for _, backend := range backends { + if busiest == nil || backend.ConnCount() > busiest.ConnCount() { + busiest = backend + } + if idlest == nil || backend.ConnCount() < idlest.ConnCount() { + idlest = backend + } + } + if busiest == nil || idlest == nil || busiest == idlest { + return nil, nil, 0, "", nil + } + return busiest, idlest, 100, "conn", nil + } + + tester.backends["a1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-1:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["a2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-2:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["b1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-b-1:4000", + ClusterName: "cluster-b", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.notifyHealth() + + selector := tester.router.GetBackendSelector(ClientInfo{ListenerPort: "10080"}) + _, err := selector.Next() + require.Error(t, err) + require.True(t, errors.Is(err, ErrPortConflict)) + + groupA := findGroupByValues(t, tester.router, []string{"cluster-a:10080"}) + backendA1 := tester.router.backends["a1"] + for range 6 { + conn := tester.createConn() + groupA.onCreateConn(backendA1, conn, true) + conn.from = backendA1 + tester.conns[conn.connID] = conn + } + + groupA.lastRedirectTime = time.Time{} + groupA.Balance(context.Background()) + + redirecting := 0 + for _, conn := range tester.conns { + if conn.to == nil || reflect.ValueOf(conn.to).IsNil() { + continue + } + redirecting++ + require.Equal(t, "cluster-a", tester.router.backends[conn.to.ID()].ClusterName()) + require.Equal(t, "a2", conn.to.ID()) + } + require.Greater(t, redirecting, 0) +} + func TestRouteBackendsWithSameAddrDifferentIDs(t *testing.T) { tester := newRouterTester(t, nil) tester.router.matchType = MatchAll @@ -1215,3 +1661,22 @@ func TestRouteBackendsWithSameAddrDifferentIDs(t *testing.T) { require.Equal(t, "shared:4000", second.Addr()) require.NotEqual(t, first.ID(), second.ID()) } + +func findGroupByValues(t *testing.T, router *ScoreBasedRouter, values []string) *Group { + t.Helper() + router.Lock() + defer router.Unlock() + for _, group := range router.groups { + if group.matchType == MatchPort { + if slices.Equal(group.values, values) { + return group + } + continue + } + if group.EqualValues(values) { + return group + } + } + require.FailNow(t, "group not found", "values=%v", values) + return nil +} diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index faeb3318..7eb005ab 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -287,6 +287,14 @@ func (mgr *BackendConnManager) getBackendIO(ctx context.Context, cctx ConnContex ci.ClientAddr = mgr.clientIO.RemoteAddr() ci.ProxyAddr = mgr.clientIO.ProxyAddr() } + if addr, ok := cctx.Value(ConnContextKeyConnAddr).(string); ok { + _, port, splitErr := net.SplitHostPort(addr) + if splitErr != nil { + mgr.logger.Error("checking port failed", zap.String("listener_addr", addr), zap.Error(splitErr)) + } else { + ci.ListenerPort = port + } + } selector := r.GetBackendSelector(ci) startTime := time.Now() var addr string