diff --git a/go.mod b/go.mod index d83a4fb7b..1f889b242 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( go.uber.org/mock v0.5.2 go.uber.org/ratelimit v0.2.0 go.uber.org/zap v1.27.0 + golang.org/x/net v0.48.0 google.golang.org/grpc v1.63.2 ) @@ -272,7 +273,6 @@ require ( golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect golang.org/x/mod v0.31.0 // indirect - golang.org/x/net v0.48.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect diff --git a/pkg/balance/observer/health_check.go b/pkg/balance/observer/health_check.go index 3988e375b..fa4ecf8ab 100644 --- a/pkg/balance/observer/health_check.go +++ b/pkg/balance/observer/health_check.go @@ -19,6 +19,11 @@ import ( "go.uber.org/zap" ) +type BackendNetwork interface { + HTTPClient(clusterName string) *http.Client + DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) +} + // HealthCheck is used to check the backends of one backend. One can pass a customized health check function to the observer. type HealthCheck interface { Check(ctx context.Context, info *BackendInfo, lastHealth *BackendHealth) *BackendHealth @@ -48,20 +53,44 @@ type security struct { type DefaultHealthCheck struct { cfg *config.HealthCheck logger *zap.Logger - httpCli *http.Client + network BackendNetwork } func NewDefaultHealthCheck(httpCli *http.Client, cfg *config.HealthCheck, logger *zap.Logger) *DefaultHealthCheck { - if httpCli == nil { - httpCli = http.NewHTTPClient(func() *tls.Config { return nil }) + return NewDefaultHealthCheckWithNetwork(newDefaultBackendNetwork(httpCli), cfg, logger) +} + +func NewDefaultHealthCheckWithNetwork(network BackendNetwork, cfg *config.HealthCheck, logger *zap.Logger) *DefaultHealthCheck { + if network == nil { + network = newDefaultBackendNetwork(nil) } return &DefaultHealthCheck{ - httpCli: httpCli, + network: network, cfg: cfg, logger: logger, } } +type defaultBackendNetwork struct { + httpCli *http.Client +} + +func newDefaultBackendNetwork(httpCli *http.Client) *defaultBackendNetwork { + if httpCli == nil { + httpCli = http.NewHTTPClient(func() *tls.Config { return nil }) + } + return &defaultBackendNetwork{httpCli: httpCli} +} + +func (n *defaultBackendNetwork) HTTPClient(string) *http.Client { + return n.httpCli +} + +func (n *defaultBackendNetwork) DialContext(ctx context.Context, network, addr, _ string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) +} + func (dhc *DefaultHealthCheck) Check(ctx context.Context, info *BackendInfo, lastBh *BackendHealth) *BackendHealth { bh := &BackendHealth{ BackendInfo: *info, @@ -96,10 +125,13 @@ func (dhc *DefaultHealthCheck) checkSqlPort(ctx context.Context, info *BackendIn return } addr := info.Addr + clusterName := info.ClusterName b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) err := http.ConnectWithRetry(func() error { startTime := time.Now() - conn, err := net.DialTimeout("tcp", addr, dhc.cfg.DialTimeout) + dialCtx, cancel := context.WithTimeout(ctx, dhc.cfg.DialTimeout) + conn, err := dhc.network.DialContext(dialCtx, "tcp", addr, clusterName) + cancel() setPingBackendMetrics(addr, startTime) if err != nil { return err @@ -134,7 +166,8 @@ func (dhc *DefaultHealthCheck) checkStatusPort(ctx context.Context, info *Backen addr := net.JoinHostPort(info.IP, strconv.Itoa(int(info.StatusPort))) b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) - resp, err := dhc.httpCli.Get(addr, statusPathSuffix, b, dhc.cfg.DialTimeout) + clusterName := info.ClusterName + resp, err := dhc.network.HTTPClient(clusterName).Get(addr, statusPathSuffix, b, dhc.cfg.DialTimeout) if err == nil { var respBody backendHttpStatusRespBody err = json.Unmarshal(resp, &respBody) @@ -176,7 +209,8 @@ func (dhc *DefaultHealthCheck) queryConfig(ctx context.Context, info *BackendInf b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) var resp []byte - if resp, err = dhc.httpCli.Get(addr, configPathSuffix, b, dhc.cfg.DialTimeout); err != nil { + clusterName := info.ClusterName + if resp, err = dhc.network.HTTPClient(clusterName).Get(addr, configPathSuffix, b, dhc.cfg.DialTimeout); err != nil { return } var respBody backendHttpConfigRespBody diff --git a/pkg/balance/observer/health_check_test.go b/pkg/balance/observer/health_check_test.go index 27e6fa394..054bef652 100644 --- a/pkg/balance/observer/health_check_test.go +++ b/pkg/balance/observer/health_check_test.go @@ -5,10 +5,12 @@ package observer import ( "context" + "crypto/tls" "encoding/json" "net" "net/http" "strings" + "sync" "sync/atomic" "testing" "time" @@ -17,6 +19,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/lib/util/waitgroup" "github.com/pingcap/tiproxy/pkg/testkit" + httputil "github.com/pingcap/tiproxy/pkg/util/http" "github.com/stretchr/testify/require" ) @@ -120,6 +123,59 @@ func TestSupportRedirection(t *testing.T) { require.False(t, health.SupportRedirection) } +func TestHealthCheckUsesClusterNetwork(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + cfg := newHealthCheckConfigForTest() + backend, info := newBackendServer(t) + defer backend.close() + backend.setServerVersion("1.0") + backend.setHasSigningCert(true) + info.ClusterName = "cluster-a" + + network := &mockBackendNetwork{ + httpCli: httputil.NewHTTPClient(func() *tls.Config { return nil }), + } + hc := NewDefaultHealthCheckWithNetwork(network, cfg, lg) + health := hc.Check(context.Background(), info, nil) + require.True(t, health.Healthy) + require.Contains(t, network.httpClusters(), "cluster-a") + require.Contains(t, network.dialClusters(), "cluster-a") +} + +type mockBackendNetwork struct { + httpCli *httputil.Client + mu sync.Mutex + https []string + dials []string +} + +func (n *mockBackendNetwork) HTTPClient(clusterName string) *httputil.Client { + n.mu.Lock() + n.https = append(n.https, clusterName) + n.mu.Unlock() + return n.httpCli +} + +func (n *mockBackendNetwork) DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) { + n.mu.Lock() + n.dials = append(n.dials, clusterName) + n.mu.Unlock() + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) +} + +func (n *mockBackendNetwork) httpClusters() []string { + n.mu.Lock() + defer n.mu.Unlock() + return append([]string(nil), n.https...) +} + +func (n *mockBackendNetwork) dialClusters() []string { + n.mu.Lock() + defer n.mu.Unlock() + return append([]string(nil), n.dials...) +} + type backendServer struct { t *testing.T sqlListener net.Listener diff --git a/pkg/balance/router/router.go b/pkg/balance/router/router.go index e83475115..9f62833ad 100644 --- a/pkg/balance/router/router.go +++ b/pkg/balance/router/router.go @@ -80,6 +80,7 @@ type BackendInst interface { Healthy() bool Local() bool Keyspace() string + ClusterName() string } // backendWrapper contains the connections on the backend. diff --git a/pkg/balance/router/router_static.go b/pkg/balance/router/router_static.go index 9eddddd4d..00385fdd7 100644 --- a/pkg/balance/router/router_static.go +++ b/pkg/balance/router/router_static.go @@ -82,6 +82,7 @@ func (r *StaticRouter) OnConnClosed(backendID, redirectingBackendID string, conn type StaticBackend struct { addr string keyspace string + cluster string healthy atomic.Bool } @@ -120,3 +121,7 @@ func (b *StaticBackend) Keyspace() string { func (b *StaticBackend) SetKeyspace(k string) { b.keyspace = k } + +func (b *StaticBackend) ClusterName() string { + return b.cluster +} diff --git a/pkg/manager/backendcluster/cluster.go b/pkg/manager/backendcluster/cluster.go index 611a967a0..748fd7484 100644 --- a/pkg/manager/backendcluster/cluster.go +++ b/pkg/manager/backendcluster/cluster.go @@ -6,13 +6,15 @@ package backendcluster import ( "context" "crypto/tls" + "net" "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/balance/metricsreader" "github.com/pingcap/tiproxy/pkg/manager/infosync" "github.com/pingcap/tiproxy/pkg/util/etcd" - "github.com/pingcap/tiproxy/pkg/util/http" + httputil "github.com/pingcap/tiproxy/pkg/util/http" + "github.com/pingcap/tiproxy/pkg/util/netutil" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" ) @@ -23,6 +25,8 @@ type Cluster struct { etcdCli *clientv3.Client infoSyncer *infosync.InfoSyncer metrics *metricsreader.ClusterReader + httpCli *httputil.Client + dialer *netutil.DNSDialer } func (c *Cluster) Config() config.BackendCluster { @@ -41,6 +45,14 @@ func (c *Cluster) GetPromInfo(ctx context.Context) (*infosync.PrometheusInfo, er return c.infoSyncer.GetPromInfo(ctx) } +func (c *Cluster) HTTPClient() *httputil.Client { + return c.httpCli +} + +func (c *Cluster) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return c.dialer.DialContext(ctx, network, addr) +} + func (c *Cluster) PreClose() { if c.metrics != nil { c.metrics.PreClose() @@ -69,10 +81,18 @@ func NewCluster( metricsQuerier *MetricsQuerier, ) (*Cluster, error) { clusterCfg = normalizeCluster(clusterCfg) - etcdCli, err := etcd.InitEtcdClientWithAddrs( + nameServers, err := config.ParseNSServers(clusterCfg.NSServers) + if err != nil { + return nil, err + } + dialer := netutil.NewDNSDialer(nameServers) + httpCli := httputil.NewHTTPClientWithDialContext(clusterTLS, dialer.DialContext) + + etcdCli, err := etcd.InitEtcdClientWithAddrsAndDialer( logger.With(zap.String("cluster", clusterCfg.Name)).Named("etcd"), clusterCfg.PDAddrs, clusterTLS(), + dialer, ) if err != nil { return nil, err @@ -91,13 +111,15 @@ func NewCluster( cfg: clusterCfg, etcdCli: etcdCli, infoSyncer: infoSyncer, + httpCli: httpCli, + dialer: dialer, } cluster.metrics = metricsreader.NewClusterReader( logger.With(zap.String("cluster", clusterCfg.Name)).Named("metrics"), clusterCfg.Name, cluster, cluster, - http.NewHTTPClient(clusterTLS), + httpCli, etcdCli, config.NewDefaultHealthCheckConfig(), cfgGetter, diff --git a/pkg/manager/backendcluster/manager.go b/pkg/manager/backendcluster/manager.go index 3ef69a537..f47a45c31 100644 --- a/pkg/manager/backendcluster/manager.go +++ b/pkg/manager/backendcluster/manager.go @@ -27,6 +27,7 @@ type Manager struct { wg waitgroup.WaitGroup cancel context.CancelFunc metrics *MetricsQuerier + network *NetworkRouter mu struct { sync.RWMutex @@ -41,6 +42,7 @@ func NewManager(lg *zap.Logger, clusterTLS func() *tls.Config) *Manager { } mgr.mu.clusters = make(map[string]*Cluster) mgr.metrics = NewMetricsQuerier(mgr) + mgr.network = NewNetworkRouter(mgr, clusterTLS) return mgr } @@ -164,6 +166,7 @@ func clusterReusable(cluster *Cluster, cfg config.BackendCluster) bool { left.PDAddrs == right.PDAddrs && slices.Equal(left.NSServers, right.NSServers) } + func (m *Manager) Snapshot() map[string]*Cluster { m.mu.RLock() snapshot := make(map[string]*Cluster, len(m.mu.clusters)) @@ -182,6 +185,10 @@ func (m *Manager) MetricsQuerier() *MetricsQuerier { return m.metrics } +func (m *Manager) NetworkRouter() *NetworkRouter { + return m.network +} + // PrimaryCluster returns the only configured cluster when the cluster count is exactly one. // It exists for features that are only well-defined in the single-cluster case, such as VIP. func (m *Manager) PrimaryCluster() *Cluster { diff --git a/pkg/manager/backendcluster/manager_test.go b/pkg/manager/backendcluster/manager_test.go index 504460d77..1799bf13f 100644 --- a/pkg/manager/backendcluster/manager_test.go +++ b/pkg/manager/backendcluster/manager_test.go @@ -7,6 +7,7 @@ import ( "context" "crypto/tls" "encoding/json" + "net" "path" "sync" "testing" @@ -15,6 +16,7 @@ import ( "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/pkg/manager/infosync" + "github.com/pingcap/tiproxy/pkg/testkit" "github.com/pingcap/tiproxy/pkg/util/etcd" "github.com/stretchr/testify/require" clientv3 "go.etcd.io/etcd/client/v3" @@ -122,6 +124,49 @@ func TestManagerDynamicClusterUpdate(t *testing.T) { }, 5*time.Second, 100*time.Millisecond) } +func TestManagerUsesClusterNameServersForPD(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + dnsA := testkit.StartDNSServer(t, map[string][]string{"pd-a.test": {"127.0.0.1"}}) + dnsB := testkit.StartDNSServer(t, map[string][]string{"pd-b.test": {"127.0.0.1"}}) + _, portA, err := net.SplitHostPort(clusterA.addr) + require.NoError(t, err) + _, portB, err := net.SplitHostPort(clusterB.addr) + require.NoError(t, err) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: net.JoinHostPort("pd-a.test", portA), NSServers: []string{dnsA.Addr()}}, + {Name: "cluster-b", PDAddrs: net.JoinHostPort("pd-b.test", portB), NSServers: []string{dnsB.Addr()}}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + cfgCh := make(chan *config.Config, 1) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, cfgCh)) + t.Cleanup(func() { + close(cfgCh) + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 2 { + return false + } + return topology[backendID("cluster-a", "10.0.0.1:4000")].ClusterName == "cluster-a" && + topology[backendID("cluster-b", "10.0.0.2:4000")].ClusterName == "cluster-b" + }, 5*time.Second, 100*time.Millisecond) + require.Greater(t, dnsA.QueryCount("pd-a.test"), 0) + require.Greater(t, dnsB.QueryCount("pd-b.test"), 0) +} + func TestManagerKeepsOldClusterWhenUpdateFails(t *testing.T) { clusterA := newManagerTestEtcdCluster(t) clusterB := newManagerTestEtcdCluster(t) @@ -172,6 +217,51 @@ func TestManagerKeepsOldClusterWhenUpdateFails(t *testing.T) { require.Contains(t, topology, backendID("cluster-a", "10.0.0.1:4000")) require.NotContains(t, topology, backendID("cluster-a", "10.0.0.2:4000")) } +func TestManagerUpdatesClusterNameServersForPD(t *testing.T) { + cluster := newManagerTestEtcdCluster(t) + t.Cleanup(func() { cluster.close(t) }) + + cluster.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + + dnsA := testkit.StartDNSServer(t, map[string][]string{"pd.test": {"127.0.0.1"}}) + dnsB := testkit.StartDNSServer(t, map[string][]string{"pd.test": {"127.0.0.1"}}) + _, port, err := net.SplitHostPort(cluster.addr) + require.NoError(t, err) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: net.JoinHostPort("pd.test", port), NSServers: []string{dnsA.Addr()}}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + cfgCh := make(chan *config.Config, 1) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, cfgCh)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + return dnsA.QueryCount("pd.test") > 0 + }, 5*time.Second, 100*time.Millisecond) + + originalCluster := mgr.Snapshot()["cluster-a"] + require.NotNil(t, originalCluster) + + nextCfg := cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: net.JoinHostPort("pd.test", port), NSServers: []string{dnsB.Addr()}}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + + require.Eventually(t, func() bool { + currentCluster := mgr.Snapshot()["cluster-a"] + return currentCluster != nil && + currentCluster != originalCluster && + dnsB.QueryCount("pd.test") > 0 + }, 5*time.Second, 100*time.Millisecond) +} func TestManagerKeepsDuplicateBackendAddrsAcrossClusters(t *testing.T) { clusterA := newManagerTestEtcdCluster(t) clusterB := newManagerTestEtcdCluster(t) @@ -271,7 +361,7 @@ type managerTestEtcdCluster struct { func newManagerTestEtcdCluster(t *testing.T) *managerTestEtcdCluster { lg, _ := logger.CreateLoggerForTest(t) - etcdSrv, err := etcd.CreateEtcdServer("0.0.0.0:0", t.TempDir(), lg) + etcdSrv, err := etcd.CreateEtcdServer("127.0.0.1:0", t.TempDir(), lg) require.NoError(t, err) addr := etcdSrv.Clients[0].Addr().String() cli, err := etcd.InitEtcdClientWithAddrs(lg, addr, nil) diff --git a/pkg/manager/backendcluster/network_router.go b/pkg/manager/backendcluster/network_router.go new file mode 100644 index 000000000..4dc173f4a --- /dev/null +++ b/pkg/manager/backendcluster/network_router.go @@ -0,0 +1,60 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "crypto/tls" + "net" + + "github.com/pingcap/tiproxy/lib/util/errors" + httputil "github.com/pingcap/tiproxy/pkg/util/http" + "github.com/pingcap/tiproxy/pkg/util/netutil" +) + +var ErrBackendClusterNotFound = errors.New("backend cluster not found") + +// NetworkRouter is a thin dispatch view over cluster-scoped dialers and HTTP clients. +// It does not own any cluster lifecycle by itself. +type NetworkRouter struct { + manager *Manager + clusterTLS func() *tls.Config + defaultDial *netutil.DNSDialer + defaultHTTP *httputil.Client +} + +func NewNetworkRouter(manager *Manager, clusterTLS func() *tls.Config) *NetworkRouter { + return &NetworkRouter{ + manager: manager, + clusterTLS: clusterTLS, + defaultDial: netutil.NewDNSDialer(nil), + defaultHTTP: httputil.NewHTTPClientWithDialContext(clusterTLS, nil), + } +} + +func (nr *NetworkRouter) missingClusterHTTPClient(clusterName string) *httputil.Client { + return httputil.NewHTTPClientWithDialContext(nr.clusterTLS, func(context.Context, string, string) (net.Conn, error) { + return nil, errors.Wrapf(ErrBackendClusterNotFound, "cluster %s", clusterName) + }) +} + +func (nr *NetworkRouter) HTTPClient(clusterName string) *httputil.Client { + if clusterName != "" { + if cluster := nr.manager.Snapshot()[clusterName]; cluster != nil { + return cluster.HTTPClient() + } + return nr.missingClusterHTTPClient(clusterName) + } + return nr.defaultHTTP +} + +func (nr *NetworkRouter) DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) { + if clusterName != "" { + if cluster := nr.manager.Snapshot()[clusterName]; cluster != nil { + return cluster.DialContext(ctx, network, addr) + } + return nil, errors.Wrapf(ErrBackendClusterNotFound, "cluster %s", clusterName) + } + return nr.defaultDial.DialContext(ctx, network, addr) +} diff --git a/pkg/manager/backendcluster/network_router_test.go b/pkg/manager/backendcluster/network_router_test.go new file mode 100644 index 000000000..593ef9bbb --- /dev/null +++ b/pkg/manager/backendcluster/network_router_test.go @@ -0,0 +1,58 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/stretchr/testify/require" +) + +func TestNetworkRouterDialContextRejectsMissingCluster(t *testing.T) { + router := NewNetworkRouter(&Manager{}, nilClusterTLS) + _, err := router.DialContext(context.Background(), "tcp", "127.0.0.1:80", "missing") + require.Error(t, err) + require.True(t, errors.Is(err, ErrBackendClusterNotFound)) +} + +func TestNetworkRouterHTTPClientRejectsMissingCluster(t *testing.T) { + router := NewNetworkRouter(&Manager{}, nilClusterTLS) + b := backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Millisecond), 0) + _, err := router.HTTPClient("missing").Get("127.0.0.1:80", "/status", b, time.Second) + require.Error(t, err) + require.True(t, errors.Is(err, ErrBackendClusterNotFound)) +} + +func TestNetworkRouterDialContextFallsBackWithoutClusterName(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ln.Close()) + }) + + accepted := make(chan struct{}, 1) + go func() { + conn, err := ln.Accept() + if err == nil { + accepted <- struct{}{} + _ = conn.Close() + } + }() + + router := NewNetworkRouter(&Manager{}, nilClusterTLS) + conn, err := router.DialContext(context.Background(), "tcp", ln.Addr().String(), "") + require.NoError(t, err) + require.NoError(t, conn.Close()) + + select { + case <-accepted: + case <-time.After(time.Second): + t.Fatal("listener was not reached through default dialer") + } +} diff --git a/pkg/manager/namespace/manager.go b/pkg/manager/namespace/manager.go index 8423daa2e..f5e8e467b 100644 --- a/pkg/manager/namespace/manager.go +++ b/pkg/manager/namespace/manager.go @@ -24,6 +24,7 @@ import ( ) type NamespaceManager interface { + SetBackendNetwork(backendNetwork observer.BackendNetwork) Init(logger *zap.Logger, nscs []*config.Namespace, tpFetcher observer.TopologyFetcher, promFetcher metricsreader.PromInfoFetcher, httpCli *http.Client, cfgMgr *mconfig.ConfigManager, metricsReader metricsreader.MetricsQuerier) error @@ -37,13 +38,14 @@ type NamespaceManager interface { type namespaceManager struct { sync.RWMutex - nsm map[string]*Namespace - tpFetcher observer.TopologyFetcher - promFetcher metricsreader.PromInfoFetcher - metricsReader metricsreader.MetricsQuerier - httpCli *http.Client - logger *zap.Logger - cfgMgr *mconfig.ConfigManager + nsm map[string]*Namespace + tpFetcher observer.TopologyFetcher + promFetcher metricsreader.PromInfoFetcher + metricsReader metricsreader.MetricsQuerier + httpCli *http.Client + backendNetwork observer.BackendNetwork + logger *zap.Logger + cfgMgr *mconfig.ConfigManager } func NewNamespaceManager() *namespaceManager { @@ -60,7 +62,7 @@ func (mgr *namespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, // init Router rt := router.NewScoreBasedRouter(logger.Named("router")) - hc := observer.NewDefaultHealthCheck(mgr.httpCli, healthCheckCfg, logger.Named("hc")) + hc := observer.NewDefaultHealthCheckWithNetwork(mgr.backendNetwork, healthCheckCfg, logger.Named("hc")) bo := observer.NewDefaultBackendObserver(logger.Named("observer"), healthCheckCfg, fetcher, hc, mgr.cfgMgr) bo.Start(context.Background()) bpCreator := func(lg *zap.Logger) policy.BalancePolicy { @@ -117,6 +119,12 @@ func (mgr *namespaceManager) Init(logger *zap.Logger, nscs []*config.Namespace, return mgr.CommitNamespaces(nscs, nil) } +func (mgr *namespaceManager) SetBackendNetwork(backendNetwork observer.BackendNetwork) { + mgr.Lock() + mgr.backendNetwork = backendNetwork + mgr.Unlock() +} + func (mgr *namespaceManager) GetNamespace(nm string) (*Namespace, bool) { mgr.RLock() defer mgr.RUnlock() diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 7eb005abe..c490817aa 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -91,6 +91,7 @@ type BCConfig struct { HealthyKeepAlive config.KeepAlive UnhealthyKeepAlive config.KeepAlive FromPublicEndpoints func(addr net.Addr) bool + DialContext func(ctx context.Context, backend router.BackendInst, addr string) (net.Conn, error) TickerInterval time.Duration CheckBackendInterval time.Duration DialTimeout time.Duration @@ -314,7 +315,9 @@ func (mgr *BackendConnManager) getBackendIO(ctx context.Context, cctx ConnContex var cn net.Conn addr = backend.Addr() - cn, err = net.DialTimeout("tcp", addr, mgr.config.DialTimeout) + dialCtx, cancel := context.WithTimeout(bctx, mgr.config.DialTimeout) + cn, err = mgr.dialBackend(dialCtx, backend, addr) + cancel() selector.Finish(mgr, err == nil) if err != nil { metrics.DialBackendFailCounter.WithLabelValues(addr).Inc() @@ -647,7 +650,9 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } var cn net.Conn - cn, rs.err = net.DialTimeout("tcp", (*backendInst).Addr(), mgr.config.DialTimeout) + dialCtx, cancel := context.WithTimeout(ctx, mgr.config.DialTimeout) + cn, rs.err = mgr.dialBackend(dialCtx, *backendInst, (*backendInst).Addr()) + cancel() if rs.err != nil { mgr.handshakeHandler.OnHandshake(mgr, (*backendInst).Addr(), rs.err, SrcBackendNetwork) return @@ -818,6 +823,14 @@ func (mgr *BackendConnManager) Value(key any) any { return v } +func (mgr *BackendConnManager) dialBackend(ctx context.Context, backend router.BackendInst, addr string) (net.Conn, error) { + if mgr.config.DialContext != nil { + return mgr.config.DialContext(ctx, backend, addr) + } + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", addr) +} + // Close releases all resources. func (mgr *BackendConnManager) Close() error { // BackendConnMgr may close even before connecting, so protect the members with a lock. diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 2532bd83f..c9c7a0581 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -80,6 +80,7 @@ func (mer *mockEventReceiver) checkEvent(t *testing.T, eventName int) event { type mockBackendInst struct { addr string keyspace string + cluster string healthy atomic.Bool local atomic.Bool } @@ -125,6 +126,10 @@ func (mbi *mockBackendInst) setKeyspace(k string) { mbi.keyspace = k } +func (mbi *mockBackendInst) ClusterName() string { + return mbi.cluster +} + type runner struct { client func(packetIO pnet.PacketIO) error proxy func(clientIO, backendIO pnet.PacketIO) error @@ -1022,11 +1027,14 @@ func TestGetBackendIO(t *testing.T) { mgr := NewBackendConnManager(lg, handler, &mockCapture{}, 0, &BCConfig{ConnectTimeout: time.Second}, nil) var wg waitgroup.WaitGroup for i := 0; i <= len(listeners); i++ { + acceptedCh := make(chan error, 1) wg.Run(func() { if i < len(listeners) { cn, err := listeners[i].Accept() - require.NoError(t, err) - require.NoError(t, cn.Close()) + if err == nil { + err = cn.Close() + } + acceptedCh <- err } }) io, err := mgr.getBackendIO(context.Background(), mgr, nil) @@ -1036,6 +1044,7 @@ func TestGetBackendIO(t *testing.T) { message := fmt.Sprintf("%d: %s, %+v\n", i, badAddrs, err) if i < len(listeners) { require.NoError(t, err, message) + require.NoError(t, <-acceptedCh, message) err = listeners[i].Close() require.NoError(t, err, message) } else { @@ -1047,6 +1056,45 @@ func TestGetBackendIO(t *testing.T) { } } +func TestGetBackendIOUsesBackendDialContext(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { require.NoError(t, listener.Close()) }() + + rt := router.NewStaticRouter([]string{"tidb-a.test:4000"}) + handler := &CustomHandshakeHandler{ + getRouter: func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + return rt, nil + }, + } + lg, _ := logger.CreateLoggerForTest(t) + var gotCluster, gotAddr string + mgr := NewBackendConnManager(lg, handler, &mockCapture{}, 0, &BCConfig{ + ConnectTimeout: time.Second, + DialContext: func(ctx context.Context, backendInst router.BackendInst, addr string) (net.Conn, error) { + gotCluster = backendInst.ClusterName() + gotAddr = addr + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", listener.Addr().String()) + }, + }, nil) + + acceptedCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err == nil { + err = conn.Close() + } + acceptedCh <- err + }() + io, err := mgr.getBackendIO(context.Background(), mgr, nil) + require.NoError(t, err) + require.NoError(t, io.Close()) + require.NoError(t, <-acceptedCh) + require.Empty(t, gotCluster) + require.Equal(t, "tidb-a.test:4000", gotAddr) +} + func TestBackendInactive(t *testing.T) { ts := newBackendMgrTester(t, func(config *testConfig) { config.proxyConfig.bcConfig.TickerInterval = time.Millisecond diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 59c89fcde..bafc3c3a4 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -12,6 +12,7 @@ import ( "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/pingcap/tiproxy/pkg/balance/router" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" "github.com/pingcap/tiproxy/pkg/metrics" @@ -40,6 +41,10 @@ type serverState struct { gracefulClose int // graceful-close-conn-timeout } +type BackendDialer interface { + DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) +} + type SQLServer struct { listeners []net.Listener addrs []string @@ -49,6 +54,7 @@ type SQLServer struct { hsHandler backend.HandshakeHandler cpt capture.Capture meter backend.Meter + dialer BackendDialer wg waitgroup.WaitGroup cancelFunc context.CancelFunc @@ -108,6 +114,10 @@ func (s *SQLServer) reset(cfg *config.Config) { s.mu.Unlock() } +func (s *SQLServer) SetBackendDialer(dialer BackendDialer) { + s.dialer = dialer +} + func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) { // Create another context because it still needs to run after graceful shutdown. ctx, s.cancelFunc = context.WithCancel(context.Background()) @@ -176,6 +186,13 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { UnhealthyKeepAlive: s.mu.unhealthyKeepAlive, ConnBufferSize: s.mu.connBufferSize, FromPublicEndpoints: s.fromPublicEndpoint, + DialContext: func(ctx context.Context, backendInst router.BackendInst, addr string) (net.Conn, error) { + if s.dialer != nil { + return s.dialer.DialContext(ctx, "tcp", addr, backendInst.ClusterName()) + } + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", addr) + }, }, s.meter) s.mu.clients[connID] = clientConn logger.Debug("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.Bool("require_backend_tls", s.mu.requireBackendTLS)) diff --git a/pkg/server/api/mock_test.go b/pkg/server/api/mock_test.go index d9081094e..9009c9a9d 100644 --- a/pkg/server/api/mock_test.go +++ b/pkg/server/api/mock_test.go @@ -29,6 +29,9 @@ func newMockNamespaceManager() *mockNamespaceManager { return mgr } +func (m *mockNamespaceManager) SetBackendNetwork(_ observer.BackendNetwork) { +} + func (m *mockNamespaceManager) Init(_ *zap.Logger, _ []*config.Namespace, _ observer.TopologyFetcher, _ metricsreader.PromInfoFetcher, _ *http.Client, _ *mconfig.ConfigManager, _ metricsreader.MetricsQuerier) error { return nil diff --git a/pkg/server/server.go b/pkg/server/server.go index c44dff580..0a783e086 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -119,6 +119,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup namespace manager { + srv.namespaceManager.SetBackendNetwork(srv.clusterManager.NetworkRouter()) nscs, nerr := srv.configManager.ListAllNamespace(ctx) if nerr != nil { err = nerr @@ -174,6 +175,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) if err != nil { return } + srv.proxy.SetBackendDialer(srv.clusterManager.NetworkRouter()) srv.proxy.Run(ctx, srv.configManager.WatchConfig()) } diff --git a/pkg/testkit/dns_server.go b/pkg/testkit/dns_server.go new file mode 100644 index 000000000..686dd15db --- /dev/null +++ b/pkg/testkit/dns_server.go @@ -0,0 +1,137 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package testkit + +import ( + "net" + "strings" + "sync" + "testing" + + "github.com/pingcap/tiproxy/lib/util/waitgroup" + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +type DNSServer struct { + conn *net.UDPConn + records map[string][]net.IP + mu sync.Mutex + queries map[string]int + wg waitgroup.WaitGroup +} + +func StartDNSServer(t *testing.T, records map[string][]string) *DNSServer { + t.Helper() + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + server := &DNSServer{ + conn: conn, + records: make(map[string][]net.IP, len(records)), + queries: make(map[string]int), + } + for name, ips := range records { + key := normalizeDNSName(name) + server.records[key] = make([]net.IP, 0, len(ips)) + for _, ip := range ips { + server.records[key] = append(server.records[key], net.ParseIP(ip)) + } + } + server.wg.Run(func() { + server.serve() + }) + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + return server +} + +func (s *DNSServer) Addr() string { + return s.conn.LocalAddr().String() +} + +func (s *DNSServer) QueryCount(name string) int { + s.mu.Lock() + defer s.mu.Unlock() + return s.queries[normalizeDNSName(name)] +} + +func (s *DNSServer) Close() error { + if s.conn != nil { + err := s.conn.Close() + s.wg.Wait() + return err + } + return nil +} + +func (s *DNSServer) serve() { + buf := make([]byte, 1500) + for { + n, addr, err := s.conn.ReadFromUDP(buf) + if err != nil { + return + } + resp, err := s.handleQuery(buf[:n]) + if err != nil { + continue + } + _, _ = s.conn.WriteToUDP(resp, addr) + } +} + +func (s *DNSServer) handleQuery(pkt []byte) ([]byte, error) { + var parser dnsmessage.Parser + header, err := parser.Start(pkt) + if err != nil { + return nil, err + } + question, err := parser.Question() + if err != nil { + return nil, err + } + name := normalizeDNSName(question.Name.String()) + s.mu.Lock() + s.queries[name]++ + s.mu.Unlock() + + respHeader := dnsmessage.Header{ + ID: header.ID, + Response: true, + RecursionAvailable: true, + } + builder := dnsmessage.NewBuilder(nil, respHeader) + builder.EnableCompression() + if err := builder.StartQuestions(); err != nil { + return nil, err + } + if err := builder.Question(question); err != nil { + return nil, err + } + if err := builder.StartAnswers(); err != nil { + return nil, err + } + for _, ip := range s.records[name] { + if ipv4 := ip.To4(); ipv4 != nil && question.Type == dnsmessage.TypeA { + resource := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: question.Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 60, + }, + Body: &dnsmessage.AResource{A: [4]byte(ipv4)}, + } + if err := builder.AResource(resource.Header, *resource.Body.(*dnsmessage.AResource)); err != nil { + return nil, err + } + } + } + return builder.Finish() +} + +func normalizeDNSName(name string) string { + return strings.TrimSuffix(strings.ToLower(name), ".") +} diff --git a/pkg/util/etcd/etcd.go b/pkg/util/etcd/etcd.go index e3393bf09..7e63527b2 100644 --- a/pkg/util/etcd/etcd.go +++ b/pkg/util/etcd/etcd.go @@ -7,6 +7,7 @@ import ( "context" "crypto/tls" "fmt" + "net" "net/url" "time" @@ -14,6 +15,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/lib/util/retry" "github.com/pingcap/tiproxy/pkg/manager/cert" + "github.com/pingcap/tiproxy/pkg/util/netutil" "go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/client/pkg/v3/transport" clientv3 "go.etcd.io/etcd/client/v3" @@ -36,31 +38,48 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertMa // InitEtcdClientWithAddrs initializes an etcd client that connects to PD ETCD servers. func InitEtcdClientWithAddrs(logger *zap.Logger, pdAddrs string, tlsConfig *tls.Config) (*clientv3.Client, error) { + return InitEtcdClientWithAddrsAndDialer(logger, pdAddrs, tlsConfig, nil) +} + +func InitEtcdClientWithAddrsAndDialer(logger *zap.Logger, pdAddrs string, tlsConfig *tls.Config, + dnsDialer *netutil.DNSDialer) (*clientv3.Client, error) { pdEndpoints := config.SplitAddrList(pdAddrs) logger.Info("connect ETCD servers", zap.Strings("addrs", pdEndpoints)) + dialOptions := []grpc.DialOption{ + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 10 * time.Second, + Timeout: 3 * time.Second, + }), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{ + BaseDelay: time.Second, + Multiplier: 1.1, + Jitter: 0.1, + MaxDelay: 3 * time.Second, + }, + MinConnectTimeout: 3 * time.Second, + }), + } + if dnsDialer != nil { + dialOptions = append(dialOptions, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return dnsDialer.DialContext(ctx, "tcp", addr) + })) + } etcdClient, err := clientv3.New(clientv3.Config{ Endpoints: pdEndpoints, TLS: tlsConfig, Logger: logger.Named("etcdcli"), AutoSyncInterval: 30 * time.Second, DialTimeout: 5 * time.Second, - DialOptions: []grpc.DialOption{ - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 10 * time.Second, - Timeout: 3 * time.Second, - }), - grpc.WithConnectParams(grpc.ConnectParams{ - Backoff: backoff.Config{ - BaseDelay: time.Second, - Multiplier: 1.1, - Jitter: 0.1, - MaxDelay: 3 * time.Second, - }, - MinConnectTimeout: 3 * time.Second, - }), - }, + DialOptions: dialOptions, }) - return etcdClient, errors.Wrapf(err, "init etcd client failed") + if err != nil { + return nil, errors.Wrapf(err, "init etcd client failed") + } + if err := syncEtcdClient(context.Background(), etcdClient); err != nil { + logger.Warn("sync ETCD member endpoints after init failed", zap.Error(err)) + } + return etcdClient, nil } func GetKVs(ctx context.Context, etcdCli *clientv3.Client, key string, opts []clientv3.OpOption, timeout, retryIntvl time.Duration, retryCnt uint64) ([]*mvccpb.KeyValue, error) { @@ -80,7 +99,15 @@ func GetKVs(ctx context.Context, etcdCli *clientv3.Client, key string, opts []cl // CreateEtcdServer creates an etcd server and is only used for testing. func CreateEtcdServer(addr, dir string, lg *zap.Logger) (*embed.Etcd, error) { - serverURL, err := url.Parse(fmt.Sprintf("http://%s", addr)) + listenAddr, advertiseAddr, err := allocEtcdServerAddr(addr) + if err != nil { + return nil, err + } + serverURL, err := url.Parse(fmt.Sprintf("http://%s", listenAddr)) + if err != nil { + return nil, err + } + advertiseURL, err := url.Parse(fmt.Sprintf("http://%s", advertiseAddr)) if err != nil { return nil, err } @@ -88,6 +115,9 @@ func CreateEtcdServer(addr, dir string, lg *zap.Logger) (*embed.Etcd, error) { cfg.Dir = dir cfg.ListenClientUrls = []url.URL{*serverURL} cfg.ListenPeerUrls = []url.URL{*serverURL} + cfg.AdvertiseClientUrls = []url.URL{*advertiseURL} + cfg.AdvertisePeerUrls = []url.URL{*advertiseURL} + cfg.InitialCluster = fmt.Sprintf("%s=%s", cfg.Name, advertiseURL.String()) cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(lg) cfg.LogLevel = "fatal" // Reuse port so that it can reboot with the same port immediately. @@ -103,6 +133,30 @@ func CreateEtcdServer(addr, dir string, lg *zap.Logger) (*embed.Etcd, error) { return etcd, err } +func allocEtcdServerAddr(addr string) (listenAddr, advertiseAddr string, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", "", err + } + if host == "" || host == "0.0.0.0" || host == "::" { + host = "127.0.0.1" + } + if port != "0" { + return net.JoinHostPort(host, port), net.JoinHostPort(host, port), nil + } + ln, err := net.Listen("tcp", net.JoinHostPort(host, "0")) + if err != nil { + return "", "", err + } + defer func() { + closeErr := ln.Close() + if err == nil && closeErr != nil { + err = closeErr + } + }() + return ln.Addr().String(), ln.Addr().String(), nil +} + func ConfigForEtcdTest(endpoint string) *config.Config { return &config.Config{ Proxy: config.ProxyServer{ @@ -114,3 +168,13 @@ func ConfigForEtcdTest(endpoint string) *config.Config { }, } } + +type etcdSyncer interface { + Sync(ctx context.Context) error +} + +func syncEtcdClient(ctx context.Context, cli etcdSyncer) error { + syncCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return errors.WithStack(cli.Sync(syncCtx)) +} diff --git a/pkg/util/etcd/etcd_test.go b/pkg/util/etcd/etcd_test.go index 1024abe5a..225efe8c6 100644 --- a/pkg/util/etcd/etcd_test.go +++ b/pkg/util/etcd/etcd_test.go @@ -58,3 +58,26 @@ func TestSplitAddrList(t *testing.T) { require.Equal(t, []string{"pd1:2379", "pd2:2379"}, config.SplitAddrList("pd1:2379, pd2:2379")) require.Equal(t, []string{"pd1:2379", "pd2:2379"}, config.SplitAddrList(" pd1:2379 , , pd2:2379 ")) } + +func TestSyncEtcdClient(t *testing.T) { + err := syncEtcdClient(context.Background(), &mockEtcdSyncer{}) + require.NoError(t, err) +} + +func TestSyncEtcdClientTimeout(t *testing.T) { + err := syncEtcdClient(context.Background(), &mockEtcdSyncer{block: true}) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +type mockEtcdSyncer struct { + block bool +} + +func (m *mockEtcdSyncer) Sync(ctx context.Context) error { + if !m.block { + return nil + } + <-ctx.Done() + return ctx.Err() +} diff --git a/pkg/util/http/http.go b/pkg/util/http/http.go index c4f5890f7..d9a2e150b 100644 --- a/pkg/util/http/http.go +++ b/pkg/util/http/http.go @@ -4,9 +4,11 @@ package http import ( + "context" "crypto/tls" "fmt" "io" + "net" "net/http" "time" @@ -21,11 +23,18 @@ type Client struct { } func NewHTTPClient(getTLSConfig func() *tls.Config) *Client { + return NewHTTPClientWithDialContext(getTLSConfig, nil) +} + +func NewHTTPClientWithDialContext(getTLSConfig func() *tls.Config, dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { // Since TLS config will hot reload, `TLSClientConfig` need update by `getTLSConfig()` // to obtain the latest TLS config. return &Client{ cli: &http.Client{ - Transport: &http.Transport{TLSClientConfig: getTLSConfig()}, + Transport: &http.Transport{ + TLSClientConfig: getTLSConfig(), + DialContext: dialContext, + }, }, getTLSConfig: getTLSConfig, } diff --git a/pkg/util/netutil/dns.go b/pkg/util/netutil/dns.go new file mode 100644 index 000000000..9131c2d1b --- /dev/null +++ b/pkg/util/netutil/dns.go @@ -0,0 +1,135 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package netutil + +import ( + "context" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/sync/singleflight" +) + +const defaultDNSCacheTTL = 5 * time.Second + +type dnsCacheEntry struct { + ips []net.IP + deadline time.Time +} + +// DNSDialer routes DNS lookups to configured name servers and caches lookup results briefly. +// If no name servers are configured, it falls back to the system resolver and dialer. +type DNSDialer struct { + cacheTTL time.Duration + nameServer []string + resolver *net.Resolver + dialer net.Dialer + nextServer atomic.Uint64 + lookupGroup singleflight.Group + mu struct { + sync.Mutex + cacheMap map[string]dnsCacheEntry + } +} + +func NewDNSDialer(nameServers []string) *DNSDialer { + d := &DNSDialer{ + cacheTTL: defaultDNSCacheTTL, + nameServer: append([]string(nil), nameServers...), + mu: struct { + sync.Mutex + cacheMap map[string]dnsCacheEntry + }{ + cacheMap: make(map[string]dnsCacheEntry), + }, + } + if len(nameServers) == 0 { + return d + } + d.resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { + server := d.nameServer[int(d.nextServer.Add(1)-1)%len(d.nameServer)] + return d.dialer.DialContext(ctx, network, server) + }, + } + return d +} + +func (d *DNSDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if d.resolver == nil { + return d.dialer.DialContext(ctx, network, addr) + } + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip != nil { + return d.dialer.DialContext(ctx, network, addr) + } + ips, err := d.lookupNetIP(ctx, host) + if err != nil { + return nil, err + } + var dialErr error + for _, ip := range ips { + conn, err := d.dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + if err == nil { + return conn, nil + } + dialErr = err + } + return nil, dialErr +} + +func (d *DNSDialer) lookupNetIP(ctx context.Context, host string) ([]net.IP, error) { + key := strings.TrimSuffix(strings.ToLower(host), ".") + if ips, ok := d.cachedIPs(key, time.Now()); ok { + return ips, nil + } + + resultCh := d.lookupGroup.DoChan(key, func() (any, error) { + now := time.Now() + if ips, ok := d.cachedIPs(key, now); ok { + return ips, nil + } + ips, err := d.resolver.LookupNetIP(ctx, "ip", host) + if err != nil { + return nil, err + } + ipList := make([]net.IP, 0, len(ips)) + for _, ip := range ips { + ipList = append(ipList, append(net.IP(nil), ip.AsSlice()...)) + } + d.mu.Lock() + d.mu.cacheMap[key] = dnsCacheEntry{ + ips: ipList, + deadline: now.Add(d.cacheTTL), + } + d.mu.Unlock() + return ipList, nil + }) + select { + case result := <-resultCh: + if result.Err != nil { + return nil, result.Err + } + return result.Val.([]net.IP), nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (d *DNSDialer) cachedIPs(key string, now time.Time) ([]net.IP, bool) { + d.mu.Lock() + defer d.mu.Unlock() + entry, ok := d.mu.cacheMap[key] + if !ok || !now.Before(entry.deadline) { + return nil, false + } + return entry.ips, true +} diff --git a/pkg/util/netutil/dns_test.go b/pkg/util/netutil/dns_test.go new file mode 100644 index 000000000..e7dd96305 --- /dev/null +++ b/pkg/util/netutil/dns_test.go @@ -0,0 +1,181 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package netutil + +import ( + "context" + "net" + "strconv" + "sync" + "testing" + "time" + + "github.com/pingcap/tiproxy/pkg/testkit" + "github.com/stretchr/testify/require" +) + +func TestDNSDialerUsesConfiguredNameServerAndCache(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + dns := testkit.StartDNSServer(t, map[string][]string{ + "tidb.test": {"127.0.0.1"}, + }) + + accepted := make(chan error, 2) + for range 2 { + go func() { + conn, err := listener.Accept() + if err != nil { + accepted <- err + return + } + accepted <- conn.Close() + }() + } + + dialer := NewDNSDialer([]string{dns.Addr()}) + dialer.dialer.Timeout = 100 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("tidb.test", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + queryCount := dns.QueryCount("tidb.test") + require.Greater(t, queryCount, 0) + + conn, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort("tidb.test", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.Equal(t, queryCount, dns.QueryCount("tidb.test")) + require.NoError(t, <-accepted) + require.NoError(t, <-accepted) +} + +func TestDNSDialerFallbackToSystemResolver(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + accepted := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + accepted <- err + return + } + accepted <- conn.Close() + }() + + dialer := NewDNSDialer(nil) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("localhost", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.NoError(t, <-accepted) +} + +func TestDNSDialerTriesAllResolvedIPs(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + dns := testkit.StartDNSServer(t, map[string][]string{ + "tidb.test": {"127.0.0.2", "127.0.0.1"}, + }) + + accepted := make(chan struct{}, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + if err := conn.Close(); err != nil { + acceptErr <- err + return + } + accepted <- struct{}{} + }() + + dialer := NewDNSDialer([]string{dns.Addr()}) + dialer.dialer.Timeout = 100 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("tidb.test", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + + select { + case <-accepted: + case <-time.After(time.Second): + t.Fatal("listener was not reached through resolved fallback IP") + } + select { + case err := <-acceptErr: + require.NoError(t, err) + default: + } +} + +func TestDNSDialerCoalescesConcurrentLookupsAfterCacheExpiry(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + dns := testkit.StartDNSServer(t, map[string][]string{ + "tidb.test": {"127.0.0.1"}, + }) + + dialer := NewDNSDialer([]string{dns.Addr()}) + dialer.cacheTTL = 20 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + accepted := make(chan error, 9) + for range cap(accepted) { + go func() { + conn, err := listener.Accept() + if err == nil { + err = conn.Close() + } + accepted <- err + }() + } + + targetAddr := net.JoinHostPort("tidb.test", strconv.Itoa(int(port))) + conn, err := dialer.DialContext(ctx, "tcp", targetAddr) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.NoError(t, <-accepted) + initialQueries := dns.QueryCount("tidb.test") + require.Greater(t, initialQueries, 0) + + time.Sleep(dialer.cacheTTL + 10*time.Millisecond) + + var wg sync.WaitGroup + errCh := make(chan error, 8) + for range cap(errCh) { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := dialer.DialContext(ctx, "tcp", targetAddr) + if err == nil { + err = conn.Close() + } + errCh <- err + }() + } + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + for range cap(errCh) { + require.NoError(t, <-accepted) + } + require.Equal(t, initialQueries*2, dns.QueryCount("tidb.test")) +}