From 21dc68c98e5ed2d1b9122b880cc3e931074a6b2c Mon Sep 17 00:00:00 2001 From: eric-wang-1990 Date: Thu, 6 Mar 2025 00:49:21 +0000 Subject: [PATCH] add oauth access token --- .../src/Drivers/Apache/Spark/SparkAuthType.cs | 4 ++++ .../Drivers/Apache/Spark/SparkHttpConnection.cs | 17 +++++++++++++++-- .../src/Drivers/Apache/Spark/SparkParameters.cs | 4 ++++ csharp/test/Drivers/Apache/Spark/DriverTests.cs | 7 ++++++- .../Apache/Spark/SparkTestConfiguration.cs | 3 +++ .../Apache/Spark/SparkTestEnvironment.cs | 4 ++++ 6 files changed, 36 insertions(+), 3 deletions(-) diff --git a/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs index 83a78a788b..93c8706224 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs @@ -24,6 +24,7 @@ internal enum SparkAuthType UsernameOnly, Basic, Token, + OAuth, Empty = int.MaxValue, } @@ -49,6 +50,9 @@ internal static bool TryParse(string? authType, out SparkAuthType authTypeValue) case SparkAuthTypeConstants.Token: authTypeValue = SparkAuthType.Token; return true; + case SparkAuthTypeConstants.OAuth: + authTypeValue = SparkAuthType.OAuth; + return true; default: authTypeValue = SparkAuthType.Invalid; return false; diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index dc0f4b44a4..b95262fe63 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -51,6 +51,7 @@ protected override void ValidateAuthentication() Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + Properties.TryGetValue(SparkParameters.AccessToken, out string? access_token); bool isValidAuthType = SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); switch (authTypeValue) { @@ -80,6 +81,13 @@ protected override void ValidateAuthentication() $"Parameters must include valid authentiation settings. Please provide either '{SparkParameters.Token}'; or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", nameof(Properties)); break; + + case SparkAuthType.OAuth: + if (string.IsNullOrWhiteSpace(access_token)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.OAuth}' but parameter '{SparkParameters.AccessToken}' is not set. Please provide a value for '{SparkParameters.AccessToken}'.", + nameof(Properties)); + break; default: throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); } @@ -140,12 +148,13 @@ protected override TTransport CreateTransport() Properties.TryGetValue(SparkParameters.AuthType, out string? authType); bool isValidAuthType = SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); Properties.TryGetValue(SparkParameters.Token, out string? token); + Properties.TryGetValue(SparkParameters.AccessToken, out string? access_token); Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(AdbcOptions.Uri, out string? uri); Uri baseAddress = GetBaseAddress(uri, hostName, path, port); - AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password, access_token); HttpClientHandler httpClientHandler = NewHttpClientHandler(); HttpClient httpClient = new(httpClientHandler); @@ -185,7 +194,7 @@ private HttpClientHandler NewHttpClientHandler() return httpClientHandler; } - private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(SparkAuthType authType, string? token, string? username, string? password) + private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(SparkAuthType authType, string? token, string? username, string? password, string? access_token) { if (!string.IsNullOrEmpty(token) && (authType == SparkAuthType.Empty || authType == SparkAuthType.Token)) { @@ -199,6 +208,10 @@ private HttpClientHandler NewHttpClientHandler() { return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:"))); } + else if (!string.IsNullOrEmpty(access_token) && authType == SparkAuthType.OAuth) + { + return new AuthenticationHeaderValue(BearerAuthenticationScheme, access_token); + } else if (authType == SparkAuthType.None) { return null; diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs index 6cb96dd5f1..d3bd9be6d5 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs @@ -26,6 +26,9 @@ public static class SparkParameters public const string Port = "adbc.spark.port"; public const string Path = "adbc.spark.path"; public const string Token = "adbc.spark.token"; + + // access_token is required when authType is oauth + public const string AccessToken = "adbc.spark.access_token"; public const string AuthType = "adbc.spark.auth_type"; public const string Type = "adbc.spark.type"; public const string DataTypeConv = "adbc.spark.data_type_conv"; @@ -39,6 +42,7 @@ public static class SparkAuthTypeConstants public const string UsernameOnly = "username_only"; public const string Basic = "basic"; public const string Token = "token"; + public const string OAuth = "oauth"; } public static class SparkServerTypeConstants diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs b/csharp/test/Drivers/Apache/Spark/DriverTests.cs index 03699c7a70..ec49f38a0d 100644 --- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs +++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs @@ -84,19 +84,24 @@ public override void CanDetectInvalidAuthentication() Dictionary parameters = GetDriverParameters(TestConfiguration); bool hasToken = parameters.TryGetValue(SparkParameters.Token, out var token) && !string.IsNullOrEmpty(token); + bool hasAccessToken = parameters.TryGetValue(SparkParameters.Token, out var access_token) && !string.IsNullOrEmpty(access_token); bool hasUsername = parameters.TryGetValue(AdbcOptions.Username, out var username) && !string.IsNullOrEmpty(username); bool hasPassword = parameters.TryGetValue(AdbcOptions.Password, out var password) && !string.IsNullOrEmpty(password); if (hasToken) { parameters[SparkParameters.Token] = "invalid-token"; } + else if (hasAccessToken) + { + parameters[SparkParameters.AccessToken] = "invalid-access-token"; + } else if (hasUsername && hasPassword) { parameters[AdbcOptions.Password] = "invalid-password"; } else { - Assert.Fail($"Unexpected configuration. Must provide '{SparkParameters.Token}' or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); + Assert.Fail($"Unexpected configuration. Must provide '{SparkParameters.Token}' or '{SparkParameters.AccessToken}' or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); } AdbcDatabase database = driver.Open(parameters); diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs b/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs index 5ada5abeb3..abf4400bc1 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs @@ -25,5 +25,8 @@ public class SparkTestConfiguration : ApacheTestConfiguration [JsonPropertyName("token"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string Token { get; set; } = string.Empty; + [JsonPropertyName("access_token"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string AccessToken { get; set; } = string.Empty; + } } diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs index 16a5501118..7fcddae9ce 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs @@ -77,6 +77,10 @@ public override Dictionary GetDriverParameters(SparkTestConfigur { parameters.Add(SparkParameters.Token, testConfiguration.Token!); } + if (!string.IsNullOrEmpty(testConfiguration.AccessToken)) + { + parameters.Add(SparkParameters.AccessToken, testConfiguration.AccessToken); + } if (!string.IsNullOrEmpty(testConfiguration.Username)) { parameters.Add(AdbcOptions.Username, testConfiguration.Username!);