Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions internal/sources/oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"net/url"
"os"
"strings"

Expand Down Expand Up @@ -253,6 +254,72 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any, rea
return out, nil
}

func buildConnectStringBase(config Config) string {
if config.TnsAlias != "" {
return strings.TrimSpace(config.TnsAlias)
}
if config.ConnectionString != "" {
return strings.TrimSpace(config.ConnectionString)
}
if config.Port > 0 {
return fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
}
return fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
}

func buildGoOraConnString(config Config) (string, string) {
connectStringBase := buildConnectStringBase(config)
// Decode any pre-encoded credentials to avoid double-encoding, then encode safely for URL userinfo.
userInfo := url.UserPassword(
decodePercentEncodedUserInfo(config.User),
decodePercentEncodedUserInfo(config.Password),
).String()

baseConnStr, existingQuery, hasQuery := strings.Cut(connectStringBase, "?")
Comment thread
Deeven-Seru marked this conversation as resolved.
dsnBase := fmt.Sprintf("oracle://%s@%s", userInfo, baseConnStr)

q := url.Values{}
rawExistingQuery := ""
if hasQuery {
if parsedExisting, err := url.ParseQuery(existingQuery); err == nil {
q = parsedExisting
} else {
rawExistingQuery = existingQuery
}
}

trimmedWalletLocation := strings.TrimSpace(config.WalletLocation)
if rawExistingQuery != "" {
if trimmedWalletLocation == "" {
return fmt.Sprintf("%s?%s", dsnBase, rawExistingQuery), connectStringBase
}

walletQuery := url.Values{}
Comment thread
Deeven-Seru marked this conversation as resolved.
walletQuery.Set("ssl", "true")
walletQuery.Set("wallet", trimmedWalletLocation)
return fmt.Sprintf("%s?%s&%s", dsnBase, rawExistingQuery, walletQuery.Encode()), connectStringBase
}

if trimmedWalletLocation != "" {
q.Set("ssl", "true")
q.Set("wallet", trimmedWalletLocation)
}

if len(q) == 0 {
return dsnBase, connectStringBase
}

return fmt.Sprintf("%s?%s", dsnBase, q.Encode()), connectStringBase
}

func decodePercentEncodedUserInfo(value string) string {
decoded, err := url.PathUnescape(value)
if err != nil {
return value
}
return decoded
Comment thread
Deeven-Seru marked this conversation as resolved.
}

func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name)
Expand All @@ -279,42 +346,26 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
}()
}

var connectStringBase string
if config.TnsAlias != "" {
connectStringBase = strings.TrimSpace(config.TnsAlias)
} else if config.ConnectionString != "" {
connectStringBase = strings.TrimSpace(config.ConnectionString)
} else {
if config.Port > 0 {
connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
} else {
connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
}
}

var driverName string
var finalConnStr string
var connectStringBase string

if config.UseOCI {
// Use godror driver (requires OCI)
driverName = "godror"
connectStringBase = buildConnectStringBase(config)
finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
config.User, config.Password, connectStringBase)
logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase))
} else {
// Use go-ora driver (pure Go)
driverName = "oracle"

user := config.User
password := config.Password
finalConnStr, connectStringBase = buildGoOraConnString(config)

if hasWallet {
finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s",
user, password, connectStringBase, config.WalletLocation)
logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with wallet and serverString: %s\n", connectStringBase))
Comment thread
Deeven-Seru marked this conversation as resolved.
} else {
// Standard go-ora connection
finalConnStr = fmt.Sprintf("oracle://%s:%s@%s",
config.User, config.Password, connectStringBase)
logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase))
}
}
Expand Down
108 changes: 97 additions & 11 deletions internal/sources/oracle/oracle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright © 2025, Oracle and/or its affiliates.

package oracle_test
package oracle

import (
"context"
Expand All @@ -11,7 +11,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/testutils"
)

Expand All @@ -33,9 +32,9 @@ func TestParseFromYamlOracle(t *testing.T) {
useOCI: true
`,
want: map[string]sources.SourceConfig{
"my-oracle-cs": oracle.Config{
"my-oracle-cs": Config{
Name: "my-oracle-cs",
Type: oracle.SourceType,
Type: SourceType,
ConnectionString: "my-host:1521/XEPDB1",
User: "my_user",
Password: "my_pass",
Expand All @@ -56,9 +55,9 @@ func TestParseFromYamlOracle(t *testing.T) {
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-oracle-host": oracle.Config{
"my-oracle-host": Config{
Name: "my-oracle-host",
Type: oracle.SourceType,
Type: SourceType,
Host: "my-host",
Port: 1521,
ServiceName: "ORCLPDB",
Expand All @@ -81,9 +80,9 @@ func TestParseFromYamlOracle(t *testing.T) {
useOCI: true
`,
want: map[string]sources.SourceConfig{
"my-oracle-tns-oci": oracle.Config{
"my-oracle-tns-oci": Config{
Name: "my-oracle-tns-oci",
Type: oracle.SourceType,
Type: SourceType,
TnsAlias: "FINANCE_DB",
TnsAdmin: "/opt/oracle/network/admin",
User: "my_user",
Expand Down Expand Up @@ -205,10 +204,10 @@ func TestRunSQLExecutesDML(t *testing.T) {
}
defer db.Close()

src := &oracle.Source{
Config: oracle.Config{
src := &Source{
Config: Config{
Name: "test-dml-source",
Type: oracle.SourceType,
Type: SourceType,
User: "test-user",
},
DB: db,
Expand All @@ -225,3 +224,90 @@ func TestRunSQLExecutesDML(t *testing.T) {
"DML path may not have been executed")
}
}

func TestBuildGoOraConnString(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
user string
password string
connBase string
walletLocation string
want string
}{
{
name: "encodes credentials and wallet",
user: "user[client]",
password: "pa:ss@word",
connBase: "dbhost:1521/XEPDB1",
walletLocation: "/tmp/my wallet",
want: "oracle://user%5Bclient%5D:pa%3Ass%40word@dbhost:1521/XEPDB1?ssl=true&wallet=%2Ftmp%2Fmy+wallet",
},
{
name: "no wallet",
user: "scott",
password: "tiger",
connBase: "dbhost:1521/ORCL",
walletLocation: "",
want: "oracle://scott:tiger@dbhost:1521/ORCL",
},
{
name: "does not double encode percent encoded user",
user: "app_user%5BCLIENT_A%5D",
password: "secret",
connBase: "dbhost:1521/ORCL",
walletLocation: "",
want: "oracle://app_user%5BCLIENT_A%5D:secret@dbhost:1521/ORCL",
},
{
name: "uses trimmed wallet location",
user: "scott",
password: "tiger",
connBase: "dbhost:1521/ORCL",
walletLocation: " /tmp/wallet ",
want: "oracle://scott:tiger@dbhost:1521/ORCL?ssl=true&wallet=%2Ftmp%2Fwallet",
},
{
name: "preserves existing query without wallet",
user: "scott",
password: "tiger",
connBase: "dbhost:1521/ORCL?transport_connect_timeout=30",
walletLocation: "",
want: "oracle://scott:tiger@dbhost:1521/ORCL?transport_connect_timeout=30",
},
{
name: "merges existing query with wallet",
user: "scott",
password: "tiger",
connBase: "dbhost:1521/ORCL?foo=bar",
walletLocation: "/tmp/wallet",
want: "oracle://scott:tiger@dbhost:1521/ORCL?foo=bar&ssl=true&wallet=%2Ftmp%2Fwallet",
},
{
name: "preserves malformed existing query when appending wallet",
user: "scott",
password: "tiger",
connBase: "dbhost:1521/ORCL?already=%ZZ",
walletLocation: "/tmp/wallet",
want: "oracle://scott:tiger@dbhost:1521/ORCL?already=%ZZ&ssl=true&wallet=%2Ftmp%2Fwallet",
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
config := Config{
User: tc.user,
Password: tc.password,
ConnectionString: tc.connBase,
WalletLocation: tc.walletLocation,
}
got, _ := buildGoOraConnString(config)
if got != tc.want {
t.Fatalf("buildGoOraConnString() = %q, want %q", got, tc.want)
}
})
}
}