diff --git a/docs/source/driver/flight_sql.rst b/docs/source/driver/flight_sql.rst index 983a6162c7..d3a588be57 100644 --- a/docs/source/driver/flight_sql.rst +++ b/docs/source/driver/flight_sql.rst @@ -159,6 +159,12 @@ few optional authentication schemes: header will then be sent back as the ``authorization`` header on all future requests. +- OAuth 2.0 authentication flows. + + The client provides :ref:`configurations ` to allow client application to obtain access + tokens from an authorization server. The obtained token is then used + on the ``authorization`` header on all future requests. + Bulk Ingestion -------------- @@ -246,10 +252,67 @@ to :c:struct:`AdbcDatabase`, :c:struct:`AdbcConnection`, and Add the header ``
`` to outgoing requests with the given value. - Python: :attr:`adbc_driver_flightsql.ConnectionOptions.RPC_CALL_HEADER_PREFIX` + Python: :attr:`adbc_driver_flightsql.ConnectionOptions.RPC_CALL_HEADER_PREFIX` .. warning:: Header names must be in all lowercase. + +OAuth 2.0 Options +----------------------- +.. _oauth-configurations: + +Supported configurations to obtain tokens using OAuth 2.0 authentication flows. + +``adbc.flight.sql.oauth.flow`` + Specifies the OAuth 2.0 flow type to use. Possible values: ``client_credentials``, ``token_exchange`` + +``adbc.flight.sql.oauth.client_id`` + Unique identifier issued to the client application by the authorization server + +``adbc.flight.sql.oauth.client_secret`` + Secret associated to the client_id. Used to authenticate the client application to the authorization server + +``adbc.flight.sql.oauth.token_uri`` + The endpoint URL where the client application requests tokens from the authorization server + +``adbc.flight.sql.oauth.scope`` + Space-separated list of permissions that the client is requesting access to (e.g ``"read.all offline_access"``) + +``adbc.flight.sql.oauth.exchange.subject_token`` + The security token that the client application wants to exchange + +``adbc.flight.sql.oauth.exchange.subject_token_type`` + Identifier for the type of the subject token. + Check list below for supported token types. + +``adbc.flight.sql.oauth.exchange.actor_token`` + A security token that represents the identity of the acting party + +``adbc.flight.sql.oauth.exchange.actor_token_type`` + Identifier for the type of the actor token. + Check list below for supported token types. +``adbc.flight.sql.oauth.exchange.aud`` + The intended audience for the requested security token + +``adbc.flight.sql.oauth.exchange.resource`` + The resource server where the client intends to use the requested security token + +``adbc.flight.sql.oauth.exchange.scope`` + Specific permissions requested for the new token + +``adbc.flight.sql.oauth.exchange.requested_token_type`` + The type of token the client wants to receive in exchange. + Check list below for supported token types. + + +Supported token types: + - ``urn:ietf:params:oauth:token-type:access_token`` + - ``urn:ietf:params:oauth:token-type:refresh_token`` + - ``urn:ietf:params:oauth:token-type:id_token`` + - ``urn:ietf:params:oauth:token-type:saml1`` + - ``urn:ietf:params:oauth:token-type:saml2`` + - ``urn:ietf:params:oauth:token-type:jwt`` + Distributed Result Sets ----------------------- diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index c8a59d72ae..f2cd0060ce 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -20,11 +20,21 @@ package flightsql_test import ( + "bytes" "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/json" + "encoding/pem" "errors" "fmt" + "math/big" "net" + "net/http" + "net/http/httptest" "net/textproto" "os" "strconv" @@ -50,6 +60,7 @@ import ( "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" @@ -69,16 +80,14 @@ type ServerBasedTests struct { } func (suite *ServerBasedTests) DoSetupSuite(srv flightsql.Server, srvMiddleware []flight.ServerMiddleware, dbArgs map[string]string, dialOpts ...grpc.DialOption) { - suite.s = flight.NewServerWithMiddleware(srvMiddleware) - suite.s.RegisterFlightService(flightsql.NewFlightServer(srv)) - suite.Require().NoError(suite.s.Init("localhost:0")) - suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill) - go func() { - _ = suite.s.Serve() - }() + suite.setupFlightServer(srv, srvMiddleware) - uri := "grpc+tcp://" + suite.s.Addr().String() + suite.setupDatabase(dbArgs, dialOpts...) +} + +func (suite *ServerBasedTests) setupDatabase(dbArgs map[string]string, dialOpts ...grpc.DialOption) { var err error + uri := "grpc+tcp://" + suite.s.Addr().String() args := map[string]string{ "uri": uri, @@ -88,6 +97,16 @@ func (suite *ServerBasedTests) DoSetupSuite(srv flightsql.Server, srvMiddleware suite.Require().NoError(err) } +func (suite *ServerBasedTests) setupFlightServer(srv flightsql.Server, srvMiddleware []flight.ServerMiddleware, srvOpts ...grpc.ServerOption) { + suite.s = flight.NewServerWithMiddleware(srvMiddleware, srvOpts...) + suite.s.RegisterFlightService(flightsql.NewFlightServer(srv)) + suite.Require().NoError(suite.s.Init("localhost:0")) + suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill) + go func() { + _ = suite.s.Serve() + }() +} + func (suite *ServerBasedTests) SetupTest() { var err error suite.cnxn, err = suite.db.Open(context.Background()) @@ -104,6 +123,59 @@ func (suite *ServerBasedTests) TearDownSuite() { suite.s.Shutdown() } +func (suite *ServerBasedTests) generateCertOption() grpc.ServerOption { + // Generate a self-signed certificate in-process for testing + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + suite.Require().NoError(err) + certTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Unit Tests Incorporated"}, + }, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + certDer, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &privKey.PublicKey, privKey) + suite.Require().NoError(err) + buffer := &bytes.Buffer{} + suite.Require().NoError(pem.Encode(buffer, &pem.Block{Type: "CERTIFICATE", Bytes: certDer})) + certBytes := make([]byte, buffer.Len()) + copy(certBytes, buffer.Bytes()) + buffer.Reset() + suite.Require().NoError(pem.Encode(buffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)})) + keyBytes := make([]byte, buffer.Len()) + copy(keyBytes, buffer.Bytes()) + + cert, err := tls.X509KeyPair(certBytes, keyBytes) + suite.Require().NoError(err) + + suite.Require().NoError(err) + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + tlsCreds := credentials.NewTLS(tlsConfig) + + return grpc.Creds(tlsCreds) +} + +func (suite *ServerBasedTests) openAndExecuteQuery(query string) { + var err error + suite.cnxn, err = suite.db.Open(context.Background()) + suite.Require().NoError(err) + defer suite.cnxn.Close() + + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer stmt.Close() + + suite.Require().NoError(stmt.SetSqlQuery(query)) + reader, _, err := stmt.ExecuteQuery(context.Background()) + suite.NoError(err) + defer reader.Release() +} + // ---- Tests -------------------- func TestAuthn(t *testing.T) { @@ -150,6 +222,10 @@ func TestGetObjects(t *testing.T) { suite.Run(t, &GetObjectsTests{}) } +func TestOauth(t *testing.T) { + suite.Run(t, &OAuthTests{}) +} + // ---- AuthN Tests -------------------- type AuthnTestServer struct { @@ -230,23 +306,288 @@ type AuthnTests struct { } func (suite *AuthnTests) SetupSuite() { - suite.DoSetupSuite(&AuthnTestServer{}, []flight.ServerMiddleware{ + suite.setupFlightServer(&AuthnTestServer{}, []flight.ServerMiddleware{ {Stream: authnTestStream, Unary: authnTestUnary}, - }, map[string]string{ - driver.OptionAuthorizationHeader: "Bearer initial", }) } +func (suite *AuthnTests) SetupTest() { + suite.setupDatabase(map[string]string{ + "uri": "grpc+tcp://" + suite.s.Addr().String(), + }) +} + +func (suite *AuthnTests) TearDownTest() { + suite.NoError(suite.db.Close()) + suite.db = nil +} + +func (suite *AuthnTests) TearDownSuite() { + suite.s.Shutdown() +} + func (suite *AuthnTests) TestBearerTokenUpdated() { + err := suite.db.SetOptions(map[string]string{ + driver.OptionAuthorizationHeader: "Bearer initial", + }) + suite.Require().NoError(err) + // apache/arrow-adbc#584: when setting the auth header directly, the client should use any updated token value from the server if given - stmt, err := suite.cnxn.NewStatement() + + suite.openAndExecuteQuery("a-query") +} + +type OAuthTests struct { + ServerBasedTests + + oauthServer *httptest.Server + mockOAuthServer *MockOAuthServer +} + +// MockOAuthServer simulates an OAuth 2.0 server for testing +type MockOAuthServer struct { + // Track calls to validate server behavior + clientCredentialsCalls int + tokenExchangeCalls int +} + +func (m *MockOAuthServer) handleTokenRequest(w http.ResponseWriter, r *http.Request) { + // Parse the form to get the request parameters + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + grantType := r.FormValue("grant_type") + + switch grantType { + case "client_credentials": + m.clientCredentialsCalls++ + // Validate client credentials + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + + if clientID == "test-client" && clientSecret == "test-secret" { + // Return a valid token response + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "access_token": "test-client-token", + "token_type": "bearer", + "expires_in": 3600 + }`)) + + return + } + + case "urn:ietf:params:oauth:grant-type:token-exchange": + m.tokenExchangeCalls++ + // Validate token exchange parameters + subjectToken := r.FormValue("subject_token") + subjectTokenType := r.FormValue("subject_token_type") + + if subjectToken == "test-subject-token" && + subjectTokenType == "urn:ietf:params:oauth:token-type:jwt" { + // Return a valid token response + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "access_token": "test-exchanged-token", + "token_type": "bearer", + "expires_in": 3600 + }`)) + return + } + } + + // Default: return error for invalid request + http.Error(w, "Invalid request", http.StatusBadRequest) +} + +func oauthTestUnary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.InvalidArgument, "Could not get metadata") + } + auth := md.Get("authorization") + if len(auth) == 0 { + return nil, status.Error(codes.Unauthenticated, "No token") + } else if auth[0] != "Bearer test-exchanged-token" && auth[0] != "Bearer test-client-token" { + return nil, status.Error(codes.Unauthenticated, "Invalid token for unary call: "+auth[0]) + } + + md.Set("authorization", "Bearer final") + ctx = metadata.NewOutgoingContext(ctx, md) + return handler(ctx, req) +} + +func (suite *OAuthTests) SetupSuite() { + suite.mockOAuthServer = &MockOAuthServer{} + suite.oauthServer = httptest.NewServer(http.HandlerFunc(suite.mockOAuthServer.handleTokenRequest)) + + suite.setupFlightServer(&AuthnTestServer{}, []flight.ServerMiddleware{ + {Unary: oauthTestUnary}, + }, suite.generateCertOption()) +} + +func (suite *OAuthTests) TearDownSuite() { + suite.oauthServer.Close() + suite.s.Shutdown() +} + +func (suite *OAuthTests) SetupTest() { + suite.setupDatabase(map[string]string{ + "uri": "grpc+tls://" + suite.s.Addr().String(), + }) +} + +func (suite *OAuthTests) TearDownTest() { + suite.NoError(suite.db.Close()) + suite.db = nil +} + +func (suite *OAuthTests) TestTokenExchangeFlow() { + err := suite.db.SetOptions(map[string]string{ + driver.OptionKeyOauthFlow: driver.TokenExchange, + driver.OptionKeySubjectToken: "test-subject-token", + driver.OptionKeySubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + driver.OptionSSLSkipVerify: adbc.OptionValueEnabled, + }) suite.Require().NoError(err) - defer stmt.Close() - suite.Require().NoError(stmt.SetSqlQuery("timeout")) - reader, _, err := stmt.ExecuteQuery(context.Background()) - suite.NoError(err) - defer reader.Release() + suite.openAndExecuteQuery("a-query") + suite.Equal(1, suite.mockOAuthServer.tokenExchangeCalls, "Token exchange flow should be called once") +} + +func (suite *OAuthTests) TestClientCredentialsFlow() { + err := suite.db.SetOptions(map[string]string{ + driver.OptionKeyOauthFlow: driver.ClientCredentials, + driver.OptionKeyClientId: "test-client", + driver.OptionKeyClientSecret: "test-secret", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + driver.OptionSSLSkipVerify: adbc.OptionValueEnabled, + }) + suite.Require().NoError(err) + + suite.cnxn, err = suite.db.Open(context.Background()) + suite.Require().NoError(err) + defer suite.cnxn.Close() + + suite.openAndExecuteQuery("a-query") + // golang/oauth2 tries to call the token endpoint sending the client credentials in the authentication header, + // if it fails, it retries sending the client credentials in the request body. + // See https://code.google.com/p/goauth2/issues/detail?id=31 for background. + suite.Equal(2, suite.mockOAuthServer.clientCredentialsCalls, "Client credentials flow should be called once") +} + +func (suite *OAuthTests) TestFailOauthWithTokenSet() { + err := suite.db.SetOptions(map[string]string{ + driver.OptionAuthorizationHeader: "Bearer test-client-token", + driver.OptionKeyOauthFlow: driver.ClientCredentials, + driver.OptionKeyClientId: "test-client", + driver.OptionKeyClientSecret: "test-secret", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + }) + suite.Error(err, "Expected error for missing parameters") + suite.Contains(err.Error(), "Authentication conflict: Use either Authorization header OR username/password parameter") +} + +func (suite *OAuthTests) TestMissingRequiredParamsTokenExchange() { + testCases := []struct { + name string + options map[string]string + expectedErrorMsg string + }{ + { + name: "Missing token", + options: map[string]string{ + driver.OptionKeyOauthFlow: driver.TokenExchange, + driver.OptionKeySubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + }, + expectedErrorMsg: "token exchange grant requires adbc.flight.sql.oauth.exchange.subject_token", + }, + { + name: "Missing subject token type", + options: map[string]string{ + driver.OptionKeyOauthFlow: driver.TokenExchange, + driver.OptionKeySubjectToken: "test-subject-token", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + }, + expectedErrorMsg: "token exchange grant requires adbc.flight.sql.oauth.exchange.subject_token_type", + }, + { + name: "Missing token URI", + options: map[string]string{ + driver.OptionKeyOauthFlow: driver.TokenExchange, + driver.OptionKeySubjectToken: "test-subject-token", + driver.OptionKeySubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", + }, + expectedErrorMsg: "token exchange grant requires adbc.flight.sql.oauth.token_uri", + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + // We need to set options with the driver's SetOptions method + err := suite.db.SetOptions(tc.options) + suite.Error(err, "Expected error for missing parameters") + suite.Contains(err.Error(), tc.expectedErrorMsg) + }) + } +} +func (suite *OAuthTests) TestMissingRequiredParamsClientCredentials() { + testCases := []struct { + name string + options map[string]string + expectedErrorMsg string + }{ + { + name: "Missing client ID", + options: map[string]string{ + driver.OptionKeyOauthFlow: driver.ClientCredentials, + driver.OptionKeyClientSecret: "test-secret", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + }, + expectedErrorMsg: "client credentials grant requires adbc.flight.sql.oauth.client_id", + }, + { + name: "Missing client secret", + options: map[string]string{ + driver.OptionKeyOauthFlow: driver.ClientCredentials, + driver.OptionKeyClientId: "test-client", + driver.OptionKeyTokenURI: suite.oauthServer.URL, + }, + expectedErrorMsg: "client credentials grant requires adbc.flight.sql.oauth.client_secret", + }, + { + name: "Missing token URI", + options: map[string]string{ + driver.OptionKeyOauthFlow: driver.ClientCredentials, + driver.OptionKeyClientId: "test-client", + driver.OptionKeyClientSecret: "test-secret", + }, + expectedErrorMsg: "client credentials grant requires adbc.flight.sql.oauth.token_uri", + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + // We need to set options with the driver's SetOptions method + err := suite.db.SetOptions(tc.options) + suite.Error(err, "Expected error for missing parameters") + suite.Contains(err.Error(), tc.expectedErrorMsg) + }) + } +} + +func (suite *OAuthTests) TestInvalidOAuthFlow() { + err := suite.db.SetOptions(map[string]string{ + driver.OptionKeyOauthFlow: "invalid-flow", + driver.OptionKeySubjectToken: "test-token", + }) + + suite.Error(err, "Expected error for invalid OAuth flow") + suite.Contains(err.Error(), "Not Implemented: oauth flow not implemented: invalid-flow") } // ---- Grpc Dialer Options Tests -------------- diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index bbbcbbf061..e45eb4d5da 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -68,6 +68,7 @@ type databaseImpl struct { enableCookies bool options map[string]string userDialOpts []grpc.DialOption + oauthToken credentials.PerRPCCredentials } func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { @@ -146,10 +147,12 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { delete(cnOptions, OptionAuthorizationHeader) } + const authConflictError = "Authentication conflict: Use either Authorization header OR username/password parameter" + if u, ok := cnOptions[adbc.OptionKeyUsername]; ok { if d.hdrs.Len() > 0 { return adbc.Error{ - Msg: "Authorization header already provided, do not provide user/pass also", + Msg: authConflictError, Code: adbc.StatusInvalidArgument, } } @@ -160,7 +163,7 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { if p, ok := cnOptions[adbc.OptionKeyPassword]; ok { if d.hdrs.Len() > 0 { return adbc.Error{ - Msg: "Authorization header already provided, do not provide user/pass also", + Msg: authConflictError, Code: adbc.StatusInvalidArgument, } } @@ -168,6 +171,33 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { delete(cnOptions, adbc.OptionKeyPassword) } + if flow, ok := cnOptions[OptionKeyOauthFlow]; ok { + if d.hdrs.Len() > 0 { + return adbc.Error{ + Msg: authConflictError, + Code: adbc.StatusInvalidArgument, + } + } + + var err error + switch flow { + case ClientCredentials: + d.oauthToken, err = newClientCredentials(cnOptions) + case TokenExchange: + d.oauthToken, err = newTokenExchangeFlow(cnOptions) + default: + return adbc.Error{ + Msg: fmt.Sprintf("oauth flow not implemented: %s", flow), + Code: adbc.StatusNotImplemented, + } + } + + if err != nil { + return err + } + delete(cnOptions, OptionKeyOauthFlow) + } + var err error if tv, ok := cnOptions[OptionTimeoutFetch]; ok { if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil { @@ -374,6 +404,10 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL Driver "+driverVersion)) dialOpts = append(dialOpts, d.userDialOpts...) + if d.oauthToken != nil { + dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(d.oauthToken)) + } + d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) if err != nil { @@ -384,24 +418,30 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl } cl.Alloc = d.Alloc + // Authorization header is already set, continue if len(authMiddle.hdrs.Get("authorization")) > 0 { d.Logger.DebugContext(ctx, "reusing auth token", "location", loc) - } else { - if d.user != "" || d.pass != "" { - var header, trailer metadata.MD - ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout) - if err != nil { - return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken") - } + return cl, nil + } - if md, ok := metadata.FromOutgoingContext(ctx); ok { - authMiddle.mutex.Lock() - defer authMiddle.mutex.Unlock() - authMiddle.hdrs.Set("authorization", md.Get("Authorization")[0]) - } + var authValue string + + if d.user != "" || d.pass != "" { + var header, trailer metadata.MD + ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout) + if err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken") + } + + if md, ok := metadata.FromOutgoingContext(ctx); ok { + authValue = md.Get("Authorization")[0] } } + if authValue != "" { + authMiddle.SetHeader(authValue) + } + return cl, nil } @@ -526,3 +566,9 @@ func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md metadata. b.hdrs.Set("authorization", headers...) } } + +func (b *bearerAuthMiddleware) SetHeader(authValue string) { + b.mutex.Lock() + defer b.mutex.Unlock() + b.hdrs.Set("authorization", authValue) +} diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index 9e517c7167..ff1e74bb5f 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -66,6 +66,23 @@ const ( OptionStringListSessionOptionPrefix = "adbc.flight.sql.session.optionstringlist." OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info" infoDriverName = "ADBC Flight SQL Driver - Go" + + // Oauth2 options + OptionKeyOauthFlow = "adbc.flight.sql.oauth.flow" + OptionKeyAuthURI = "adbc.flight.sql.oauth.auth_uri" + OptionKeyTokenURI = "adbc.flight.sql.oauth.token_uri" + OptionKeyRedirectURI = "adbc.flight.sql.oauth.redirect_uri" + OptionKeyScope = "adbc.flight.sql.oauth.scope" + OptionKeyClientId = "adbc.flight.sql.oauth.client_id" + OptionKeyClientSecret = "adbc.flight.sql.oauth.client_secret" + OptionKeySubjectToken = "adbc.flight.sql.oauth.exchange.subject_token" + OptionKeySubjectTokenType = "adbc.flight.sql.oauth.exchange.subject_token_type" + OptionKeyActorToken = "adbc.flight.sql.oauth.exchange.actor_token" + OptionKeyActorTokenType = "adbc.flight.sql.oauth.exchange.actor_token_type" + OptionKeyReqTokenType = "adbc.flight.sql.oauth.exchange.requested_token_type" + OptionKeyExchangeScope = "adbc.flight.sql.oauth.exchange.scope" + OptionKeyExchangeAud = "adbc.flight.sql.oauth.exchange.aud" + OptionKeyExchangeResource = "adbc.flight.sql.oauth.exchange.resource" ) var errNoTransactionSupport = adbc.Error{ diff --git a/go/adbc/driver/flightsql/flightsql_oauth.go b/go/adbc/driver/flightsql/flightsql_oauth.go new file mode 100644 index 0000000000..707590a0df --- /dev/null +++ b/go/adbc/driver/flightsql/flightsql_oauth.go @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package flightsql + +import ( + "context" + "fmt" + + "golang.org/x/oauth2" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/oauth" +) + +const ( + ClientCredentials = "client_credentials" + TokenExchange = "token_exchange" +) + +type oAuthOption struct { + isRequired bool + oAuthKey string +} + +var ( + clientCredentialsParams = map[string]oAuthOption{ + OptionKeyClientId: {true, "client_id"}, + OptionKeyClientSecret: {true, "client_secret"}, + OptionKeyTokenURI: {true, "token_uri"}, + OptionKeyScope: {false, "scope"}, + } + + tokenExchangParams = map[string]oAuthOption{ + OptionKeySubjectToken: {true, "subject_token"}, + OptionKeySubjectTokenType: {true, "subject_token_type"}, + OptionKeyReqTokenType: {false, "requested_token_type"}, + OptionKeyExchangeAud: {false, "audience"}, + OptionKeyExchangeResource: {false, "resource"}, + OptionKeyExchangeScope: {false, "scope"}, + } +) + +func parseOAuthOptions(options map[string]string, paramMap map[string]oAuthOption, flowName string) (map[string]string, error) { + params := map[string]string{} + + for key, param := range paramMap { + if value, ok := options[key]; ok { + params[key] = value + delete(options, key) + } else if param.isRequired { + return nil, fmt.Errorf("%s grant requires %s", flowName, key) + } + } + + return params, nil +} + +func exchangeToken(conf *oauth2.Config, codeOptions []oauth2.AuthCodeOption) (credentials.PerRPCCredentials, error) { + ctx := context.Background() + tok, err := conf.Exchange(ctx, "", codeOptions...) + if err != nil { + return nil, err + } + return &oauth.TokenSource{TokenSource: conf.TokenSource(ctx, tok)}, nil +} + +func newClientCredentials(options map[string]string) (credentials.PerRPCCredentials, error) { + codeOptions := []oauth2.AuthCodeOption{ + // Required value for client credentials requests as specified in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2 + oauth2.SetAuthURLParam("grant_type", "client_credentials"), + } + + params, err := parseOAuthOptions(options, clientCredentialsParams, "client credentials") + if err != nil { + return nil, err + } + + conf := &oauth2.Config{ + ClientID: params[OptionKeyClientId], + ClientSecret: params[OptionKeyClientSecret], + Endpoint: oauth2.Endpoint{ + TokenURL: params[OptionKeyTokenURI], + }, + } + + if scopes, ok := params[OptionKeyScope]; ok { + conf.Scopes = []string{scopes} + } + + return exchangeToken(conf, codeOptions) +} + +func newTokenExchangeFlow(options map[string]string) (credentials.PerRPCCredentials, error) { + tokenURI, ok := options[OptionKeyTokenURI] + if !ok { + return nil, fmt.Errorf("token exchange grant requires %s", OptionKeyTokenURI) + } + delete(options, OptionKeyTokenURI) + + conf := &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: tokenURI, + }, + } + + codeOptions := []oauth2.AuthCodeOption{ + // Required value for token exchange requests as specified in https://datatracker.ietf.org/doc/html/rfc8693#name-request + oauth2.SetAuthURLParam("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"), + } + + params, err := parseOAuthOptions(options, tokenExchangParams, "token exchange") + if err != nil { + return nil, err + } + + for key, param := range tokenExchangParams { + if value, ok := params[key]; ok { + codeOptions = append(codeOptions, oauth2.SetAuthURLParam(param.oAuthKey, value)) + } + } + + // actor token and actor token type are optional + // but if one is present, the other must be present + if actor, ok := options[OptionKeyActorToken]; ok { + codeOptions = append(codeOptions, oauth2.SetAuthURLParam("actor_token", actor)) + delete(options, OptionKeyActorToken) + if actorTokenType, ok := options[OptionKeyActorTokenType]; ok { + codeOptions = append(codeOptions, oauth2.SetAuthURLParam("actor_token_type", actorTokenType)) + delete(options, OptionKeyActorTokenType) + } else { + return nil, fmt.Errorf("token exchange grant requires %s when %s is provided", + OptionKeyActorTokenType, OptionKeyActorToken) + } + } + + return exchangeToken(conf, codeOptions) +}