Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion drivers/mysql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ Add MySql credentials in following format in `config.json` file. [More details.]
"update_method": {
"initial_wait_time": 10
},
"tls_skip_verify": true,
"jdbc_url_params": {
"timeout": "10s"
},
"ssl": {
"mode": "disable"
},
"max_threads":10,
"backoff_retry_count": 2,
"ssh_config":{
Expand Down
89 changes: 79 additions & 10 deletions drivers/mysql/internal/config.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
package driver

import (
"crypto/tls"
"crypto/x509"
"fmt"
"maps"
"strings"

"github.com/go-sql-driver/mysql"

"github.com/datazip-inc/olake/constants"
"github.com/datazip-inc/olake/utils"
"github.com/datazip-inc/olake/utils/logger"
)

// Config represents the configuration for connecting to a MySQL database
type Config struct {
Host string `json:"hosts"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database"`
Port int `json:"port"`
TLSSkipVerify bool `json:"tls_skip_verify"` // Add this field
UpdateMethod interface{} `json:"update_method"`
MaxThreads int `json:"max_threads"`
RetryCount int `json:"backoff_retry_count"`
SSHConfig *utils.SSHConfig `json:"ssh_config"`
Host string `json:"hosts"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database"`
Port int `json:"port"`
JDBCURLParams map[string]string `json:"jdbc_url_params"`
SSLConfiguration *utils.SSLConfig `json:"ssl"`
UpdateMethod interface{} `json:"update_method"`
MaxThreads int `json:"max_threads"`
RetryCount int `json:"backoff_retry_count"`
SSHConfig *utils.SSHConfig `json:"ssh_config"`
}

type CDC struct {
Expand Down Expand Up @@ -49,9 +54,66 @@ func (c *Config) URI() string {
AllowNativePasswords: true,
}

if c.SSLConfiguration != nil {
switch c.SSLConfiguration.Mode {
case utils.SSLModeDisable:
cfg.TLSConfig = "false"
case utils.SSLModeRequire:
cfg.TLSConfig = "true"
case utils.SSLModeVerifyCA, utils.SSLModeVerifyFull:
tlsConfig, err := c.buildTLSConfig()
if err != nil {
logger.Warnf("Failed to build TLS config, falling back to skip-verify: %v", err)
cfg.TLSConfig = "skip-verify"
} else if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
logger.Warnf("Failed to register TLS config, falling back to skip-verify: %v", err)
cfg.TLSConfig = "skip-verify"
} else {
cfg.TLSConfig = "custom"
}
}
}

if len(c.JDBCURLParams) > 0 {
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
maps.Copy(cfg.Params, c.JDBCURLParams)
}

return cfg.FormatDSN()
}

// buildTLSConfig builds a custom TLS configuration for certificate-based SSL
func (c *Config) buildTLSConfig() (*tls.Config, error) {
rootCertPool := x509.NewCertPool()

if c.SSLConfiguration.ServerCA != "" {
if ok := rootCertPool.AppendCertsFromPEM([]byte(c.SSLConfiguration.ServerCA)); !ok {
return nil, fmt.Errorf("failed to append CA certificate")
}
}

tlsConfig := &tls.Config{
RootCAs: rootCertPool,
MinVersion: tls.VersionTLS12,
}

if c.SSLConfiguration.Mode == utils.SSLModeVerifyCA {
tlsConfig.InsecureSkipVerify = true
}

if c.SSLConfiguration.ClientCert != "" && c.SSLConfiguration.ClientKey != "" {
cert, err := tls.X509KeyPair([]byte(c.SSLConfiguration.ClientCert), []byte(c.SSLConfiguration.ClientKey))
if err != nil {
return nil, fmt.Errorf("failed to load client certificate and key: %s", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}

return tlsConfig, nil
}

// Validate checks the configuration for any missing or invalid fields
func (c *Config) Validate() error {
if c.Host == "" {
Expand Down Expand Up @@ -88,5 +150,12 @@ func (c *Config) Validate() error {
c.RetryCount = constants.DefaultRetryCount // Reasonable default for retries
}

// Validate SSL configuration if provided
if c.SSLConfiguration != nil {
if err := c.SSLConfiguration.Validate(); err != nil {
return fmt.Errorf("failed to validate SSL config: %s", err)
}
}

return utils.Validate(c)
}
175 changes: 175 additions & 0 deletions drivers/mysql/internal/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package driver

import (
"strings"
"testing"

"github.com/datazip-inc/olake/utils"
)

func TestConfig_URI_WithJDBCParams(t *testing.T) {
config := &Config{
Host: "localhost",
Port: 3306,
Username: "testuser",
Password: "testpass",
Database: "testdb",
JDBCURLParams: map[string]string{
"charset": "utf8mb4",
"parseTime": "true",
"loc": "Local",
},
}

uri := config.URI()

// Check that JDBC params are included in the URI
if !strings.Contains(uri, "charset=utf8mb4") {
t.Errorf("Expected charset parameter in URI, got: %s", uri)
}
if !strings.Contains(uri, "parseTime=true") {
t.Errorf("Expected parseTime parameter in URI, got: %s", uri)
}
if !strings.Contains(uri, "loc=Local") {
t.Errorf("Expected loc parameter in URI, got: %s", uri)
}
}

func TestConfig_URI_WithSSLDisabled(t *testing.T) {
config := &Config{
Host: "localhost",
Port: 3306,
Username: "testuser",
Password: "testpass",
Database: "testdb",
SSLConfiguration: &utils.SSLConfig{
Mode: utils.SSLModeDisable,
},
}

uri := config.URI()

// Check that TLS is disabled
if !strings.Contains(uri, "tls=false") {
t.Errorf("Expected tls=false in URI, got: %s", uri)
}
}

func TestConfig_URI_WithSSLRequired(t *testing.T) {
config := &Config{
Host: "localhost",
Port: 3306,
Username: "testuser",
Password: "testpass",
Database: "testdb",
SSLConfiguration: &utils.SSLConfig{
Mode: utils.SSLModeRequire,
},
}

uri := config.URI()

// Check that TLS is enabled
if !strings.Contains(uri, "tls=true") {
t.Errorf("Expected tls=true in URI, got: %s", uri)
}
}

func TestConfig_Validate_WithSSLConfig(t *testing.T) {
tests := []struct {
name string
config *Config
expectErr bool
}{
{
name: "Valid SSL config with disable mode",
config: &Config{
Host: "localhost",
Port: 3306,
Username: "testuser",
Password: "testpass",
Database: "testdb",
SSLConfiguration: &utils.SSLConfig{
Mode: utils.SSLModeDisable,
},
},
expectErr: false,
},
{
name: "Valid SSL config with require mode",
config: &Config{
Host: "localhost",
Port: 3306,
Username: "testuser",
Password: "testpass",
Database: "testdb",
SSLConfiguration: &utils.SSLConfig{
Mode: utils.SSLModeRequire,
},
},
expectErr: false,
},
{
name: "Invalid SSL config - verify-ca without certificates",
config: &Config{
Host: "localhost",
Port: 3306,
Username: "testuser",
Password: "testpass",
Database: "testdb",
SSLConfiguration: &utils.SSLConfig{
Mode: utils.SSLModeVerifyCA,
},
},
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.expectErr && err == nil {
t.Errorf("Expected error but got none")
}
if !tt.expectErr && err != nil {
t.Errorf("Expected no error but got: %v", err)
}
})
}
}

func TestConfig_URI_CombinedParams(t *testing.T) {
config := &Config{
Host: "mysql.example.com",
Port: 3306,
Username: "appuser",
Password: "securepass",
Database: "appdb",
JDBCURLParams: map[string]string{
"charset": "utf8mb4",
"parseTime": "true",
"timeout": "10s",
"readTimeout": "30s",
"writeTimeout": "30s",
},
SSLConfiguration: &utils.SSLConfig{
Mode: utils.SSLModeRequire,
},
}

uri := config.URI()

// Verify both JDBC params and SSL config are present
if !strings.Contains(uri, "charset=utf8mb4") {
t.Errorf("Expected charset parameter in URI")
}
if !strings.Contains(uri, "tls=true") {
t.Errorf("Expected TLS enabled in URI")
}
if !strings.Contains(uri, "mysql.example.com:3306") {
t.Errorf("Expected correct host and port in URI")
}
if !strings.Contains(uri, "appdb") {
t.Errorf("Expected database name in URI")
}
}
Loading
Loading