Skip to content

Commit 68610cb

Browse files
committed
*: add TLS config and cert reload for starter mode
1 parent c6056f8 commit 68610cb

6 files changed

Lines changed: 203 additions & 6 deletions

File tree

cmd/tidb-server/main.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ const (
127127
nmRepairMode = "repair-mode"
128128
nmRepairList = "repair-list"
129129
nmTempDir = "temp-dir"
130+
nmClusterCa = "cluster-ca"
131+
nmClusterCert = "cluster-cert"
132+
nmClusterKey = "cluster-key"
133+
nmSQLCA = "sql-ca"
134+
nmSQLCert = "sql-cert"
135+
nmSQLKey = "sql-key"
130136

131137
nmRedact = "redact"
132138

@@ -177,6 +183,12 @@ var (
177183
repairMode *bool
178184
repairList *string
179185
tempDir *string
186+
clusterCA *string
187+
clusterCert *string
188+
clusterKey *string
189+
sqlCA *string
190+
sqlCert *string
191+
sqlKey *string
180192

181193
// Log
182194
logLevel *string
@@ -238,6 +250,12 @@ func initFlagSet() *flag.FlagSet {
238250
repairMode = flagBoolean(fset, nmRepairMode, false, "enable admin repair mode")
239251
repairList = fset.String(nmRepairList, "", "admin repair table list")
240252
tempDir = fset.String(nmTempDir, config.DefTempDir, "tidb temporary directory")
253+
clusterCA = fset.String(nmClusterCa, "", "cluster CA file path")
254+
clusterCert = fset.String(nmClusterCert, "", "cluster cert file path")
255+
clusterKey = fset.String(nmClusterKey, "", "cluster key file path")
256+
sqlCA = fset.String(nmSQLCA, "", "SQL CA file path")
257+
sqlCert = fset.String(nmSQLCert, "", "SQL cert file path")
258+
sqlKey = fset.String(nmSQLKey, "", "SQL key file path")
241259

242260
// Log
243261
logLevel = fset.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
@@ -304,6 +322,8 @@ func main() {
304322
config.InitializeConfig(*configPath, *configCheck, *configStrict, overrideConfig, fset)
305323
if kerneltype.IsNextGen() {
306324
terror.MustNil(initDeployMode(config.GetGlobalConfig()))
325+
terror.MustNil(config.GetGlobalConfig().AdjustStarterConfig(deploymode.IsStarter()))
326+
config.StoreGlobalConfig(config.GetGlobalConfig())
307327
}
308328
if *version {
309329
mustInitVersions()
@@ -683,6 +703,29 @@ func overrideConfig(cfg *config.Config, fset *flag.FlagSet) {
683703
if actualFlags[nmTempDir] {
684704
cfg.TempDir = *tempDir
685705
}
706+
if cfg.DeployMode == deploymode.Starter {
707+
if actualFlags[nmClusterCa] {
708+
if *clusterCA != "" && (*clusterCert == "" || *clusterKey == "") {
709+
err = fmt.Errorf("cluster-ca requires both cluster-cert and cluster-key")
710+
terror.MustNil(err)
711+
}
712+
713+
cfg.Security.ClusterSSLCA = *clusterCA
714+
cfg.Security.ClusterSSLCert = *clusterCert
715+
cfg.Security.ClusterSSLKey = *clusterKey
716+
}
717+
718+
if actualFlags[nmSQLCA] {
719+
if *sqlCA != "" && (*sqlCert == "" || *sqlKey == "") {
720+
err = fmt.Errorf("sql-ca requires both sql-cert and sql-key")
721+
terror.MustNil(err)
722+
}
723+
724+
cfg.Security.SSLCA = *sqlCA
725+
cfg.Security.SSLCert = *sqlCert
726+
cfg.Security.SSLKey = *sqlKey
727+
}
728+
}
686729

687730
// Log
688731
if actualFlags[nmLogLevel] {

pkg/config/config.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ const (
107107
DefAuthTokenRefreshInterval = time.Hour
108108
// EnvVarKeyspaceName is the system env name for keyspace name.
109109
EnvVarKeyspaceName = "KEYSPACE_NAME"
110+
// EnvClusterCA is the system env name for cluster CA path.
111+
EnvClusterCA = "CLUSTER_CA"
112+
// EnvClusterCert is the system env name for cluster cert path.
113+
EnvClusterCert = "CLUSTER_CERT"
114+
// EnvClusterKey is the system env name for cluster key path.
115+
EnvClusterKey = "CLUSTER_KEY"
116+
// EnvSQLCA is the system env name for SQL CA path.
117+
EnvSQLCA = "SQL_CA"
118+
// EnvSQLCert is the system env name for SQL cert path.
119+
EnvSQLCert = "SQL_CERT"
120+
// EnvSQLKey is the system env name for SQL key path.
121+
EnvSQLKey = "SQL_KEY"
122+
// EnvPodIP is the system env name for pod IP.
123+
EnvPodIP = "POD_IP"
124+
// EnvNamespace is the system env name for namespace.
125+
EnvNamespace = "NAMESPACE"
110126
// MaxTokenLimit is the max token limit value.
111127
MaxTokenLimit = 1024 * 1024
112128
DefSchemaLease = 45 * time.Second
@@ -203,6 +219,8 @@ type Config struct {
203219
DeployMode deploymode.Mode `toml:"deploy-mode" json:"deploy-mode"`
204220
KeyspaceName string `toml:"keyspace-name" json:"keyspace-name"`
205221
TiKVWorkerURL string `toml:"tikv-worker-url" json:"tikv-worker-url"`
222+
TiKVAPIServiceAddr string `toml:"tikv-api-service-addr" json:"tikv-api-service-addr"`
223+
TiDBWorker tidbWorkerConfig `toml:"tidb-worker" json:"tidb-worker"`
206224
Log Log `toml:"log" json:"log"`
207225
Instance Instance `toml:"instance" json:"instance"`
208226
Security Security `toml:"security" json:"security"`
@@ -333,6 +351,10 @@ type Config struct {
333351
MeteringStorageURI string `toml:"metering-storage-uri" json:"metering-storage-uri"`
334352
}
335353

354+
type tidbWorkerConfig struct {
355+
APIServerAddr string `toml:"api-server-addr" json:"api-server-addr"`
356+
}
357+
336358
// RUV2Config is the configuration for RU v2 weight calculation.
337359
// The default values are experimentally fitted so they stay stable under the
338360
// same workload while remaining numerically aligned with RU v1.
@@ -1344,13 +1366,90 @@ func InitializeConfig(confPath string, configCheck, configStrict bool, enforceCm
13441366
fmt.Fprintln(os.Stderr, "invalid config", err)
13451367
os.Exit(1)
13461368
}
1369+
if err := cfg.AdjustStarterConfig(cfg.DeployMode == deploymode.Starter); err != nil {
1370+
fmt.Fprintln(os.Stderr, "invalid security env vars", err)
1371+
os.Exit(1)
1372+
}
13471373
if configCheck {
13481374
fmt.Println("config check successful")
13491375
os.Exit(0)
13501376
}
13511377
StoreGlobalConfig(cfg)
13521378
}
13531379

1380+
// AdjustStarterConfig applies starter-only security and service-address overrides.
1381+
func (c *Config) AdjustStarterConfig(isStarter bool) error {
1382+
if !isStarter {
1383+
return nil
1384+
}
1385+
if err := c.adjustSecurityConfig(); err != nil {
1386+
return err
1387+
}
1388+
c.adjustServiceAddr()
1389+
return nil
1390+
}
1391+
1392+
func trimScheme(addr string) string {
1393+
addr = strings.TrimPrefix(addr, "http://")
1394+
addr = strings.TrimPrefix(addr, "https://")
1395+
return addr
1396+
}
1397+
1398+
func (c *Config) adjustServiceAddr() {
1399+
scheme := "http://"
1400+
if len(c.Security.ClusterSSLCA) > 0 {
1401+
scheme = "https://"
1402+
}
1403+
if len(c.TiKVAPIServiceAddr) > 0 {
1404+
c.TiKVAPIServiceAddr = scheme + trimScheme(c.TiKVAPIServiceAddr)
1405+
}
1406+
if len(c.TiDBWorker.APIServerAddr) > 0 {
1407+
c.TiDBWorker.APIServerAddr = scheme + trimScheme(c.TiDBWorker.APIServerAddr)
1408+
}
1409+
}
1410+
1411+
func (c *Config) adjustSecurityConfig() error {
1412+
clusterCAPath := os.Getenv(EnvClusterCA)
1413+
clusterCertPath := os.Getenv(EnvClusterCert)
1414+
clusterKeyPath := os.Getenv(EnvClusterKey)
1415+
if len(clusterCAPath) > 0 && (len(clusterCertPath) == 0 || len(clusterKeyPath) == 0) {
1416+
return errors.New("both CLUSTER_CERT and CLUSTER_KEY must be set when CLUSTER_CA is set")
1417+
}
1418+
if len(clusterCAPath) > 0 {
1419+
c.Security.ClusterSSLCA = clusterCAPath
1420+
c.Security.ClusterSSLCert = clusterCertPath
1421+
c.Security.ClusterSSLKey = clusterKeyPath
1422+
}
1423+
1424+
sqlCAPath := os.Getenv(EnvSQLCA)
1425+
sqlCertPath := os.Getenv(EnvSQLCert)
1426+
sqlKeyPath := os.Getenv(EnvSQLKey)
1427+
if len(sqlCAPath) > 0 && (len(sqlCertPath) == 0 || len(sqlKeyPath) == 0) {
1428+
return errors.New("both SQL_CERT and SQL_KEY must be set when SQL_CA is set")
1429+
}
1430+
if len(sqlCAPath) > 0 {
1431+
c.Security.SSLCA = sqlCAPath
1432+
c.Security.SSLCert = sqlCertPath
1433+
c.Security.SSLKey = sqlKeyPath
1434+
}
1435+
1436+
podIP := os.Getenv(EnvPodIP)
1437+
namespace := os.Getenv(EnvNamespace)
1438+
if len(podIP) > 0 && len(namespace) > 0 {
1439+
c.AdvertiseAddress = podDNSName(podIP, namespace)
1440+
}
1441+
1442+
return nil
1443+
}
1444+
1445+
func podDNSName(podIP string, namespace string) string {
1446+
return fmt.Sprintf(
1447+
"%s.%s.pod.cluster.local",
1448+
strings.ReplaceAll(podIP, ".", "-"),
1449+
namespace,
1450+
)
1451+
}
1452+
13541453
// RemovedVariableCheck checks if the config file contains any items
13551454
// which have been removed. These will not take effect any more.
13561455
func (c *Config) RemovedVariableCheck(confFile string) error {

pkg/config/config_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,9 @@ func TestDeployModeConfig(t *testing.T) {
10531053
conf := NewConfig()
10541054
require.Equal(t, deploymode.Premium, conf.DeployMode)
10551055
require.NoError(t, conf.Valid())
1056+
conf.DeployMode = deploymode.Mode(100)
1057+
require.ErrorContains(t, conf.Valid(), "invalid deploy-mode")
1058+
conf.DeployMode = deploymode.Premium
10561059

10571060
storeDir := t.TempDir()
10581061
configFile := filepath.Join(storeDir, "config.toml")
@@ -1080,6 +1083,26 @@ func TestDeployModeConfig(t *testing.T) {
10801083
require.NoError(t, conf.Load(configFile))
10811084
require.Equal(t, deploymode.Starter, conf.DeployMode)
10821085
require.NoError(t, conf.Valid())
1086+
conf.TiKVAPIServiceAddr = "tikv-api.internal:20170"
1087+
conf.TiDBWorker.APIServerAddr = "tidb-worker.internal:10280"
1088+
t.Setenv(EnvClusterCA, "/tmp/cluster-ca.pem")
1089+
t.Setenv(EnvClusterCert, "/tmp/cluster-cert.pem")
1090+
t.Setenv(EnvClusterKey, "/tmp/cluster-key.pem")
1091+
t.Setenv(EnvSQLCA, "/tmp/sql-ca.pem")
1092+
t.Setenv(EnvSQLCert, "/tmp/sql-cert.pem")
1093+
t.Setenv(EnvSQLKey, "/tmp/sql-key.pem")
1094+
t.Setenv(EnvPodIP, "10.10.1.2")
1095+
t.Setenv(EnvNamespace, "tidb")
1096+
require.NoError(t, conf.AdjustStarterConfig(true))
1097+
require.Equal(t, "10-10-1-2.tidb.pod.cluster.local", conf.AdvertiseAddress)
1098+
require.Equal(t, "/tmp/cluster-ca.pem", conf.Security.ClusterSSLCA)
1099+
require.Equal(t, "/tmp/cluster-cert.pem", conf.Security.ClusterSSLCert)
1100+
require.Equal(t, "/tmp/cluster-key.pem", conf.Security.ClusterSSLKey)
1101+
require.Equal(t, "/tmp/sql-ca.pem", conf.Security.SSLCA)
1102+
require.Equal(t, "/tmp/sql-cert.pem", conf.Security.SSLCert)
1103+
require.Equal(t, "/tmp/sql-key.pem", conf.Security.SSLKey)
1104+
require.Equal(t, "https://tikv-api.internal:20170", conf.TiKVAPIServiceAddr)
1105+
require.Equal(t, "https://tidb-worker.internal:10280", conf.TiDBWorker.APIServerAddr)
10831106

10841107
require.NoError(t, os.WriteFile(configFile, []byte(`deploy-mode = "unknown"`), 0644))
10851108
conf = NewConfig()

pkg/server/stat.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,21 @@ func (s *Server) Stats(_ *variable.SessionVars) (map[string]any, error) {
5353

5454
tlsConfig := s.GetTLSConfig()
5555
if tlsConfig != nil {
56-
if len(tlsConfig.Certificates) == 1 {
56+
if tlsConfig.GetCertificate != nil {
57+
certs, err := tlsConfig.GetCertificate(nil)
58+
if err != nil {
59+
logutil.BgLogger().Error("Failed to get TLS certificates while acquiring server status", zap.Error(err))
60+
}
61+
if certs != nil && len(certs.Certificate) > 0 {
62+
pc, err := x509.ParseCertificate(certs.Certificate[0])
63+
if err != nil {
64+
logutil.BgLogger().Error("Failed to parse TLS certificates to get server status", zap.Error(err))
65+
} else {
66+
m[serverNotAfter] = pc.NotAfter.Format("Jan _2 15:04:05 2006 MST")
67+
m[serverNotBefore] = pc.NotBefore.Format("Jan _2 15:04:05 2006 MST")
68+
}
69+
}
70+
} else if len(tlsConfig.Certificates) == 1 {
5771
pc, err := x509.ParseCertificate(tlsConfig.Certificates[0].Certificate[0])
5872
if err != nil {
5973
logutil.BgLogger().Error("Failed to parse TLS certficates to get server status", zap.Error(err))

pkg/server/tests/tls/tls_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,9 @@ func TestReloadTLS(t *testing.T) {
461461

462462
// try reload a valid cert.
463463
tlsCfg := server.GetTLSConfig()
464-
cert, err := x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0])
464+
tlsCert, err := tlsCfg.GetCertificate(nil)
465+
require.NoError(t, err)
466+
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
465467
require.NoError(t, err)
466468
oldExpireTime := cert.NotAfter
467469
_, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload2.pem", "/tmp/server-cert-reload2.pem", func(c *x509.Certificate) {
@@ -485,7 +487,9 @@ func TestReloadTLS(t *testing.T) {
485487
require.NoError(t, err)
486488

487489
tlsCfg = server.GetTLSConfig()
488-
cert, err = x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0])
490+
tlsCert, err = tlsCfg.GetCertificate(nil)
491+
require.NoError(t, err)
492+
cert, err = x509.ParseCertificate(tlsCert.Certificate[0])
489493
require.NoError(t, err)
490494
newExpireTime := cert.NotAfter
491495
require.True(t, newExpireTime.After(oldExpireTime))

pkg/util/misc.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"strconv"
3636
"strings"
3737
"sync"
38+
"sync/atomic"
3839
"time"
3940

4041
"github.com/pingcap/errors"
@@ -399,13 +400,14 @@ func LoadTLSCertificates(ca, key, cert string, autoTLS bool, rsaKeySize int) (tl
399400
}
400401
}
401402

402-
var tlsCert tls.Certificate
403-
tlsCert, err = tls.LoadX509KeyPair(cert, key)
403+
certs, err := tls.LoadX509KeyPair(cert, key)
404404
if err != nil {
405405
logutil.BgLogger().Warn("load x509 failed", zap.Error(err))
406406
err = errors.Trace(err)
407407
return
408408
}
409+
cs := &atomic.Pointer[tls.Certificate]{}
410+
cs.Store(&certs)
409411

410412
requireTLS := tlsutil.RequireSecureTransport.Load()
411413

@@ -467,12 +469,24 @@ func LoadTLSCertificates(ca, key, cert string, autoTLS bool, rsaKeySize int) (tl
467469

468470
/* #nosec G402 */
469471
tlsConfig = &tls.Config{
470-
Certificates: []tls.Certificate{tlsCert},
471472
ClientCAs: certPool,
472473
ClientAuth: clientAuthPolicy,
473474
MinVersion: minTLSVersion,
474475
CipherSuites: cipherSuites,
475476
}
477+
tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
478+
certs, err := tls.LoadX509KeyPair(cert, key)
479+
if err != nil {
480+
logutil.BgLogger().Warn("could not load server certificate, using the old one", zap.Error(err))
481+
if old := cs.Load(); old != nil {
482+
return old, nil
483+
}
484+
return nil, nil
485+
}
486+
newCerts := &certs
487+
cs.Store(newCerts)
488+
return newCerts, nil
489+
}
476490
return
477491
}
478492

0 commit comments

Comments
 (0)