Skip to content

Commit 218e3c9

Browse files
committed
add global connection pool for mysql
Signed-off-by: Rob Pickerill <[email protected]>
1 parent 5c52d03 commit 218e3c9

File tree

4 files changed

+499
-7
lines changed

4 files changed

+499
-7
lines changed

pkg/scalers/mysql_scaler.go

+105-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net"
88
"strings"
9+
"time"
910

1011
"github.com/go-logr/logr"
1112
"github.com/go-sql-driver/mysql"
@@ -16,6 +17,16 @@ import (
1617
kedautil "github.com/kedacore/keda/v2/pkg/util"
1718
)
1819

20+
var (
21+
// A map that holds MySQL connection pools, keyed by connection string
22+
connectionPools *kedautil.RefMap[string, *sql.DB]
23+
)
24+
25+
func init() {
26+
// Initialize the global connectionPools map
27+
connectionPools = kedautil.NewRefMap[string, *sql.DB]()
28+
}
29+
1930
type mySQLScaler struct {
2031
metricType v2.MetricTargetType
2132
metadata *mySQLMetadata
@@ -34,6 +45,12 @@ type mySQLMetadata struct {
3445
QueryValue float64 `keda:"name=queryValue, order=triggerMetadata"`
3546
ActivationQueryValue float64 `keda:"name=activationQueryValue, order=triggerMetadata, default=0"`
3647
MetricName string `keda:"name=metricName, order=triggerMetadata, optional"`
48+
49+
// Connection pool settings
50+
UseGlobalConnPools bool `keda:"name=useGlobalConnPools, order=triggerMetadata, optional"`
51+
MaxOpenConns int `keda:"name=maxOpenConns, order=triggerMetadata, optional"`
52+
MaxIdleConns int `keda:"name=maxIdleConns, order=triggerMetadata, optional"`
53+
ConnMaxIdleTime int `keda:"name=connMaxIdleTime, order=triggerMetadata, optional"` // seconds
3754
}
3855

3956
// NewMySQLScaler creates a new MySQL scaler
@@ -50,10 +67,19 @@ func NewMySQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
5067
return nil, fmt.Errorf("error parsing MySQL metadata: %w", err)
5168
}
5269

53-
conn, err := newMySQLConnection(meta, logger)
70+
// Create MySQL connection, if useGlobalConnPools is set to true, it will use
71+
// the global connection pool for the given connection string, otherwise it
72+
// will create a new local connection pool for the given connection string
73+
var conn *sql.DB
74+
if meta.UseGlobalConnPools {
75+
conn, err = getConnectionPool(meta, logger)
76+
} else {
77+
conn, err = newMySQLConnection(meta, logger)
78+
}
5479
if err != nil {
55-
return nil, fmt.Errorf("error establishing MySQL connection: %w", err)
80+
return nil, fmt.Errorf("error creating MySQL connection: %w", err)
5681
}
82+
5783
return &mySQLScaler{
5884
metricType: metricType,
5985
metadata: meta,
@@ -96,6 +122,40 @@ func metadataToConnectionStr(meta *mySQLMetadata) string {
96122
return connStr
97123
}
98124

125+
// getConnectionPool will check if the connection pool has already been
126+
// created for the given connection string and return it. If it has not
127+
// been created, it will create a new connection pool and store it in the
128+
// connectionPools map.
129+
func getConnectionPool(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) {
130+
connStr := metadataToConnectionStr(meta)
131+
// Try to load an existing pool and increment its reference count if found
132+
if pool, ok := connectionPools.Load(connStr); ok {
133+
err := connectionPools.AddRef(connStr)
134+
if err != nil {
135+
logger.Error(err, "Error increasing connection pool reference count")
136+
return nil, err
137+
}
138+
139+
return pool, nil
140+
}
141+
142+
// If pool does not exist, create a new one and store it in RefMap
143+
newPool, err := newMySQLConnection(meta, logger)
144+
if err != nil {
145+
return nil, err
146+
}
147+
err = connectionPools.Store(connStr, newPool, func(db *sql.DB) error {
148+
logger.Info("Closing MySQL connection pool", "connectionString", connStr)
149+
return db.Close()
150+
})
151+
if err != nil {
152+
logger.Error(err, "Error storing connection pool in RefMap")
153+
return nil, err
154+
}
155+
156+
return newPool, nil
157+
}
158+
99159
// newMySQLConnection creates MySQL db connection
100160
func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) {
101161
connStr := metadataToConnectionStr(meta)
@@ -104,14 +164,35 @@ func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error
104164
logger.Error(err, fmt.Sprintf("Found error when opening connection: %s", err))
105165
return nil, err
106166
}
167+
107168
err = db.Ping()
108169
if err != nil {
109170
logger.Error(err, fmt.Sprintf("Found error when pinging database: %s", err))
110171
return nil, err
111172
}
173+
174+
setConnectionPoolConfiguration(meta, db)
175+
112176
return db, nil
113177
}
114178

179+
// setConnectionPoolConfiguration configures the MySQL connection pool settings
180+
// based on the parameters provided in mySQLMetadata. If a setting is zero, it
181+
// is left at its default value.
182+
func setConnectionPoolConfiguration(meta *mySQLMetadata, db *sql.DB) {
183+
if meta.MaxOpenConns > 0 {
184+
db.SetMaxOpenConns(meta.MaxOpenConns)
185+
}
186+
187+
if meta.MaxIdleConns > 0 {
188+
db.SetMaxIdleConns(meta.MaxIdleConns)
189+
}
190+
191+
if meta.ConnMaxIdleTime > 0 {
192+
db.SetConnMaxIdleTime(time.Duration(meta.ConnMaxIdleTime) * time.Second)
193+
}
194+
}
195+
115196
// parseMySQLDbNameFromConnectionStr returns dbname from connection string
116197
// in it is not able to parse it, it returns "dbname" string
117198
func parseMySQLDbNameFromConnectionStr(connectionString string) string {
@@ -123,13 +204,30 @@ func parseMySQLDbNameFromConnectionStr(connectionString string) string {
123204
return "dbname"
124205
}
125206

126-
// Close disposes of MySQL connections
127-
func (s *mySQLScaler) Close(context.Context) error {
128-
err := s.connection.Close()
129-
if err != nil {
130-
s.logger.Error(err, "Error closing MySQL connection")
207+
// Close disposes of MySQL connections, closing either the global pool if used
208+
// or the local connection pool
209+
func (s *mySQLScaler) Close(ctx context.Context) error {
210+
if s.metadata.UseGlobalConnPools {
211+
if err := s.closeGlobalPool(ctx); err != nil {
212+
return fmt.Errorf("error closing MySQL connection: %w", err)
213+
}
214+
} else {
215+
if err := s.connection.Close(); err != nil {
216+
return fmt.Errorf("error closing MySQL connection: %w", err)
217+
}
218+
}
219+
220+
return nil
221+
}
222+
223+
// closeGlobalPool closes all MySQL connections in the global pool
224+
func (s *mySQLScaler) closeGlobalPool(_ context.Context) error {
225+
connStr := metadataToConnectionStr(s.metadata)
226+
if err := connectionPools.RemoveRef(connStr); err != nil {
227+
s.logger.Error(err, "Error decreasing connection pool reference count")
131228
return err
132229
}
230+
133231
return nil
134232
}
135233

pkg/scalers/mysql_scaler_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ var testMySQLMetadata = []parseMySQLMetadataTestData{
7575
resolvedEnv: map[string]string{},
7676
raisesError: true,
7777
},
78+
// use global pool
79+
{
80+
metadata: map[string]string{"query": "query", "queryValue": "12", "useGlobalConnPools": "true"},
81+
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
82+
resolvedEnv: testMySQLResolvedEnv,
83+
raisesError: false,
84+
},
85+
// use connection pool settings
86+
{
87+
metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10"},
88+
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
89+
resolvedEnv: testMySQLResolvedEnv,
90+
raisesError: false,
91+
},
92+
// use connection pool settings and global pool
93+
{
94+
metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10", "useGlobalConnPools": "true"},
95+
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
96+
resolvedEnv: testMySQLResolvedEnv,
97+
raisesError: false,
98+
},
7899
}
79100

80101
var mySQLMetricIdentifiers = []mySQLMetricIdentifier{

pkg/util/refmap.go

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package util
2+
3+
//nolint:depguard // sync/atomic
4+
import (
5+
"fmt"
6+
"sync"
7+
"sync/atomic"
8+
)
9+
10+
// refCountedValue manages a reference-counted value with a cleanup function.
11+
type refCountedValue[V any] struct {
12+
value V
13+
refCount atomic.Int64
14+
closeFunc func(V) error // Cleanup function to call when count reaches zero
15+
}
16+
17+
// Add increments the reference count.
18+
func (r *refCountedValue[V]) Add() {
19+
r.refCount.Add(1)
20+
}
21+
22+
// Remove decrements the reference count and invokes closeFunc if the count
23+
// reaches zero.
24+
func (r *refCountedValue[V]) Remove() error {
25+
if r.refCount.Add(-1) == 0 {
26+
return r.closeFunc(r.value)
27+
}
28+
29+
return nil
30+
}
31+
32+
// Value returns the underlying value.
33+
func (r *refCountedValue[V]) Value() V {
34+
return r.value
35+
}
36+
37+
// RefMap manages reference-counted items in a concurrent-safe map.
38+
type RefMap[K comparable, V any] struct {
39+
data map[K]*refCountedValue[V]
40+
mu sync.RWMutex
41+
}
42+
43+
// NewRefMap initializes a new RefMap. A RefMap is an atomic reference-counted
44+
// concurrent hashmap. The general usage pattern is to Store a value with a
45+
// close function, once the value is contained within the RefMap, it can be
46+
// accessed via the Load method. The AddRef method signals ownership of the
47+
// value and increments the reference count. The RemoveRef method decrements
48+
// the reference count. When the reference count reaches zero, the close
49+
// function is called and the value is removed from the map.
50+
func NewRefMap[K comparable, V any]() *RefMap[K, V] {
51+
return &RefMap[K, V]{
52+
data: make(map[K]*refCountedValue[V]),
53+
}
54+
}
55+
56+
// Store adds a new item with an initial reference count of 1 and a close
57+
// function.
58+
func (r *RefMap[K, V]) Store(key K, value V, closeFunc func(V) error) error {
59+
r.mu.Lock()
60+
defer r.mu.Unlock()
61+
62+
if _, exists := r.data[key]; exists {
63+
return fmt.Errorf("key already exists: %v", key)
64+
}
65+
66+
r.data[key] = &refCountedValue[V]{value: value, refCount: atomic.Int64{}, closeFunc: closeFunc}
67+
r.data[key].Add() // Set initial reference count to 1
68+
69+
return nil
70+
}
71+
72+
// Load retrieves a value by key without modifying the reference count,
73+
// returning the value and a boolean indicating if it was found. The reference
74+
// count not being modified means that a check for the existence of a key
75+
// can be performed with signalling ownership of the value. If the value is used
76+
// after this method, it is recommended to call AddRef to increment the
77+
// reference
78+
func (r *RefMap[K, V]) Load(key K) (V, bool) {
79+
r.mu.RLock()
80+
defer r.mu.RUnlock()
81+
82+
if refValue, found := r.data[key]; found {
83+
return refValue.Value(), true
84+
}
85+
var zero V
86+
87+
return zero, false
88+
}
89+
90+
// AddRef increments the reference count for a key if it exists. Ensure
91+
// to call RemoveRef when done with the value to prevent memory leaks.
92+
func (r *RefMap[K, V]) AddRef(key K) error {
93+
r.mu.RLock()
94+
defer r.mu.RUnlock()
95+
96+
refValue, found := r.data[key]
97+
if !found {
98+
return fmt.Errorf("key not found: %v", key)
99+
}
100+
101+
refValue.Add()
102+
return nil
103+
}
104+
105+
// RemoveRef decrements the reference count and deletes the entry if count
106+
// reaches zero.
107+
func (r *RefMap[K, V]) RemoveRef(key K) error {
108+
r.mu.Lock()
109+
defer r.mu.Unlock()
110+
111+
refValue, found := r.data[key]
112+
if !found {
113+
return fmt.Errorf("key not found: %v", key)
114+
}
115+
116+
err := refValue.Remove()
117+
118+
if refValue.refCount.Load() == 0 {
119+
delete(r.data, key)
120+
}
121+
122+
return err // returns the error from closeFunc
123+
}

0 commit comments

Comments
 (0)