Skip to content

Commit 9cced6f

Browse files
manzil-infinity180jkjell
authored andcommitted
fix: mysql conn string requirement
Signed-off-by: Rahul Vishwakarma <rahulvs2809@gmail.com>
1 parent ee5b523 commit 9cced6f

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

pkg/metadatastorage/sqlstore/client.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package sqlstore
1616

1717
import (
1818
"fmt"
19+
"net/url"
1920
"strings"
2021
"time"
2122

@@ -57,6 +58,33 @@ func ClientWithConnMaxLifetime(connMaxLifetime time.Duration) ClientOption {
5758
}
5859
}
5960

61+
// ensureMySQLConnectionString ensures the connection string has the tcp protocol as required by the go-sql-driver
62+
func ensureMySQLConnectionString(connStr string) (string, error) {
63+
schema := "mysql://"
64+
65+
if strings.Contains(connStr, "@tcp(") {
66+
return connStr, nil
67+
}
68+
69+
// Add mysql:// prefix if not present. URL Parse will fail silently if a schema is not present
70+
if !strings.HasPrefix(connStr, schema) {
71+
connStr = schema + connStr
72+
}
73+
74+
// Parse the connection string as a URL
75+
u, err := url.Parse(connStr)
76+
if err != nil {
77+
return "", fmt.Errorf("invalid mysql connection string: %w", err)
78+
}
79+
80+
// Modify the host to include tcp
81+
u.Host = "tcp(" + u.Host + ")"
82+
83+
// Remove the mysql:// prefix from the final string
84+
result := strings.TrimPrefix(u.String(), schema)
85+
return result, nil
86+
}
87+
6088
// NewEntClient creates an ent client for use in the sqlmetadata store.
6189
// Valid backends are MYSQL and PSQL.
6290
func NewEntClient(sqlBackend string, connectionString string, opts ...ClientOption) (*ent.Client, error) {
@@ -73,6 +101,12 @@ func NewEntClient(sqlBackend string, connectionString string, opts ...ClientOpti
73101
var entDialect string
74102
upperSqlBackend := strings.ToUpper(sqlBackend)
75103
if strings.HasPrefix(upperSqlBackend, "MYSQL") {
104+
// Ensure the connection string has the tcp protocol as required by the go-sql-driver
105+
var err error
106+
connectionString, err = ensureMySQLConnectionString(connectionString)
107+
if err != nil {
108+
return nil, fmt.Errorf("could not ensure mysql connection string: %w", err)
109+
}
76110
dbConfig, err := mysql.ParseDSN(connectionString)
77111
if err != nil {
78112
return nil, fmt.Errorf("could not parse mysql connection string: %w", err)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright 2025 The Archivista Contributors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package sqlstore
16+
17+
import (
18+
"testing"
19+
20+
"github.com/stretchr/testify/assert"
21+
"github.com/stretchr/testify/require"
22+
)
23+
24+
func TestNewEntClient_MySQLConnectionStringError(t *testing.T) {
25+
tests := []struct {
26+
name string
27+
sqlBackend string
28+
connectionString string
29+
}{
30+
{
31+
name: "mysql with invalid URL that breaks url.Parse",
32+
sqlBackend: "MYSQL",
33+
connectionString: "user:pa%zzss@localhost:3306/dbname", // Invalid percent encoding
34+
},
35+
{
36+
name: "mysql with control characters",
37+
sqlBackend: "mysql",
38+
connectionString: "user:pass@local\x00host:3306/dbname", // Null byte in URL
39+
},
40+
{
41+
name: "mysql with invalid hex escape",
42+
sqlBackend: "MYSQL",
43+
connectionString: "user:pass%ZZword@localhost:3306/db", // Invalid hex in percent encoding
44+
},
45+
}
46+
47+
for _, tt := range tests {
48+
t.Run(tt.name, func(t *testing.T) {
49+
// This should trigger an error in ensureMySQLConnectionString
50+
// which will cover lines 104-109
51+
client, err := NewEntClient(tt.sqlBackend, tt.connectionString)
52+
53+
require.Error(t, err)
54+
assert.Nil(t, client)
55+
assert.Contains(t, err.Error(), "could not ensure mysql connection string")
56+
})
57+
}
58+
}
59+
60+
func TestEnsureMySQLConnectionString(t *testing.T) {
61+
tests := []struct {
62+
name string
63+
input string
64+
expected string
65+
expectError bool
66+
}{
67+
{
68+
name: "already has tcp protocol",
69+
input: "user:pass@tcp(localhost:3306)/dbname",
70+
expected: "user:pass@tcp(localhost:3306)/dbname",
71+
expectError: false,
72+
},
73+
{
74+
name: "needs tcp protocol",
75+
input: "user:pass@localhost:3306/dbname",
76+
expected: "user:pass@tcp(localhost:3306)/dbname",
77+
expectError: false,
78+
},
79+
{
80+
name: "with mysql:// prefix",
81+
input: "mysql://user:pass@localhost:3306/dbname",
82+
expected: "user:pass@tcp(localhost:3306)/dbname",
83+
expectError: false,
84+
},
85+
{
86+
name: "invalid url format",
87+
input: "invalid:url:format",
88+
expected: "",
89+
expectError: true,
90+
},
91+
{
92+
name: "with query parameters",
93+
input: "user:pass@localhost:3306/dbname?param=value",
94+
expected: "user:pass@tcp(localhost:3306)/dbname?param=value",
95+
expectError: false,
96+
},
97+
}
98+
99+
for _, tt := range tests {
100+
t.Run(tt.name, func(t *testing.T) {
101+
result, err := ensureMySQLConnectionString(tt.input)
102+
if tt.expectError {
103+
require.Error(t, err)
104+
assert.Empty(t, result)
105+
assert.Contains(t, err.Error(), "invalid mysql connection string")
106+
} else {
107+
require.NoError(t, err)
108+
assert.Equal(t, tt.expected, result)
109+
}
110+
})
111+
}
112+
}

0 commit comments

Comments
 (0)