diff --git a/internal/sources/oracle/oracle.go b/internal/sources/oracle/oracle.go index 80730362e69d..68a92b5616a4 100644 --- a/internal/sources/oracle/oracle.go +++ b/internal/sources/oracle/oracle.go @@ -6,6 +6,7 @@ import ( "database/sql" "encoding/json" "fmt" + "net/url" "os" "strings" @@ -253,6 +254,72 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any, rea return out, nil } +func buildConnectStringBase(config Config) string { + if config.TnsAlias != "" { + return strings.TrimSpace(config.TnsAlias) + } + if config.ConnectionString != "" { + return strings.TrimSpace(config.ConnectionString) + } + if config.Port > 0 { + return fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) + } + return fmt.Sprintf("%s/%s", config.Host, config.ServiceName) +} + +func buildGoOraConnString(config Config) (string, string) { + connectStringBase := buildConnectStringBase(config) + // Decode any pre-encoded credentials to avoid double-encoding, then encode safely for URL userinfo. + userInfo := url.UserPassword( + decodePercentEncodedUserInfo(config.User), + decodePercentEncodedUserInfo(config.Password), + ).String() + + baseConnStr, existingQuery, hasQuery := strings.Cut(connectStringBase, "?") + dsnBase := fmt.Sprintf("oracle://%s@%s", userInfo, baseConnStr) + + q := url.Values{} + rawExistingQuery := "" + if hasQuery { + if parsedExisting, err := url.ParseQuery(existingQuery); err == nil { + q = parsedExisting + } else { + rawExistingQuery = existingQuery + } + } + + trimmedWalletLocation := strings.TrimSpace(config.WalletLocation) + if rawExistingQuery != "" { + if trimmedWalletLocation == "" { + return fmt.Sprintf("%s?%s", dsnBase, rawExistingQuery), connectStringBase + } + + walletQuery := url.Values{} + walletQuery.Set("ssl", "true") + walletQuery.Set("wallet", trimmedWalletLocation) + return fmt.Sprintf("%s?%s&%s", dsnBase, rawExistingQuery, walletQuery.Encode()), connectStringBase + } + + if trimmedWalletLocation != "" { + q.Set("ssl", "true") + q.Set("wallet", trimmedWalletLocation) + } + + if len(q) == 0 { + return dsnBase, connectStringBase + } + + return fmt.Sprintf("%s?%s", dsnBase, q.Encode()), connectStringBase +} + +func decodePercentEncodedUserInfo(value string) string { + decoded, err := url.PathUnescape(value) + if err != nil { + return value + } + return decoded +} + func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name) @@ -279,25 +346,14 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi }() } - var connectStringBase string - if config.TnsAlias != "" { - connectStringBase = strings.TrimSpace(config.TnsAlias) - } else if config.ConnectionString != "" { - connectStringBase = strings.TrimSpace(config.ConnectionString) - } else { - if config.Port > 0 { - connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) - } else { - connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) - } - } - var driverName string var finalConnStr string + var connectStringBase string if config.UseOCI { // Use godror driver (requires OCI) driverName = "godror" + connectStringBase = buildConnectStringBase(config) finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, config.User, config.Password, connectStringBase) logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase)) @@ -305,16 +361,11 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi // Use go-ora driver (pure Go) driverName = "oracle" - user := config.User - password := config.Password + finalConnStr, connectStringBase = buildGoOraConnString(config) if hasWallet { - finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s", - user, password, connectStringBase, config.WalletLocation) + logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with wallet and serverString: %s\n", connectStringBase)) } else { - // Standard go-ora connection - finalConnStr = fmt.Sprintf("oracle://%s:%s@%s", - config.User, config.Password, connectStringBase) logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase)) } } diff --git a/internal/sources/oracle/oracle_test.go b/internal/sources/oracle/oracle_test.go index 68b3dfbd6c59..bebbe2485d62 100644 --- a/internal/sources/oracle/oracle_test.go +++ b/internal/sources/oracle/oracle_test.go @@ -1,6 +1,6 @@ // Copyright © 2025, Oracle and/or its affiliates. -package oracle_test +package oracle import ( "context" @@ -11,7 +11,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oracle" "github.com/googleapis/genai-toolbox/internal/testutils" ) @@ -33,9 +32,9 @@ func TestParseFromYamlOracle(t *testing.T) { useOCI: true `, want: map[string]sources.SourceConfig{ - "my-oracle-cs": oracle.Config{ + "my-oracle-cs": Config{ Name: "my-oracle-cs", - Type: oracle.SourceType, + Type: SourceType, ConnectionString: "my-host:1521/XEPDB1", User: "my_user", Password: "my_pass", @@ -56,9 +55,9 @@ func TestParseFromYamlOracle(t *testing.T) { password: my_pass `, want: map[string]sources.SourceConfig{ - "my-oracle-host": oracle.Config{ + "my-oracle-host": Config{ Name: "my-oracle-host", - Type: oracle.SourceType, + Type: SourceType, Host: "my-host", Port: 1521, ServiceName: "ORCLPDB", @@ -81,9 +80,9 @@ func TestParseFromYamlOracle(t *testing.T) { useOCI: true `, want: map[string]sources.SourceConfig{ - "my-oracle-tns-oci": oracle.Config{ + "my-oracle-tns-oci": Config{ Name: "my-oracle-tns-oci", - Type: oracle.SourceType, + Type: SourceType, TnsAlias: "FINANCE_DB", TnsAdmin: "/opt/oracle/network/admin", User: "my_user", @@ -205,10 +204,10 @@ func TestRunSQLExecutesDML(t *testing.T) { } defer db.Close() - src := &oracle.Source{ - Config: oracle.Config{ + src := &Source{ + Config: Config{ Name: "test-dml-source", - Type: oracle.SourceType, + Type: SourceType, User: "test-user", }, DB: db, @@ -225,3 +224,90 @@ func TestRunSQLExecutesDML(t *testing.T) { "DML path may not have been executed") } } + +func TestBuildGoOraConnString(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + user string + password string + connBase string + walletLocation string + want string + }{ + { + name: "encodes credentials and wallet", + user: "user[client]", + password: "pa:ss@word", + connBase: "dbhost:1521/XEPDB1", + walletLocation: "/tmp/my wallet", + want: "oracle://user%5Bclient%5D:pa%3Ass%40word@dbhost:1521/XEPDB1?ssl=true&wallet=%2Ftmp%2Fmy+wallet", + }, + { + name: "no wallet", + user: "scott", + password: "tiger", + connBase: "dbhost:1521/ORCL", + walletLocation: "", + want: "oracle://scott:tiger@dbhost:1521/ORCL", + }, + { + name: "does not double encode percent encoded user", + user: "app_user%5BCLIENT_A%5D", + password: "secret", + connBase: "dbhost:1521/ORCL", + walletLocation: "", + want: "oracle://app_user%5BCLIENT_A%5D:secret@dbhost:1521/ORCL", + }, + { + name: "uses trimmed wallet location", + user: "scott", + password: "tiger", + connBase: "dbhost:1521/ORCL", + walletLocation: " /tmp/wallet ", + want: "oracle://scott:tiger@dbhost:1521/ORCL?ssl=true&wallet=%2Ftmp%2Fwallet", + }, + { + name: "preserves existing query without wallet", + user: "scott", + password: "tiger", + connBase: "dbhost:1521/ORCL?transport_connect_timeout=30", + walletLocation: "", + want: "oracle://scott:tiger@dbhost:1521/ORCL?transport_connect_timeout=30", + }, + { + name: "merges existing query with wallet", + user: "scott", + password: "tiger", + connBase: "dbhost:1521/ORCL?foo=bar", + walletLocation: "/tmp/wallet", + want: "oracle://scott:tiger@dbhost:1521/ORCL?foo=bar&ssl=true&wallet=%2Ftmp%2Fwallet", + }, + { + name: "preserves malformed existing query when appending wallet", + user: "scott", + password: "tiger", + connBase: "dbhost:1521/ORCL?already=%ZZ", + walletLocation: "/tmp/wallet", + want: "oracle://scott:tiger@dbhost:1521/ORCL?already=%ZZ&ssl=true&wallet=%2Ftmp%2Fwallet", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + config := Config{ + User: tc.user, + Password: tc.password, + ConnectionString: tc.connBase, + WalletLocation: tc.walletLocation, + } + got, _ := buildGoOraConnString(config) + if got != tc.want { + t.Fatalf("buildGoOraConnString() = %q, want %q", got, tc.want) + } + }) + } +}