diff --git a/README.md b/README.md index 1284523..4ed2592 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,19 @@ $ cat ~/.aws/config region = eu-central-1 ``` -Alternatively, you can use the following environment variables when starting postgres to configure the S3 client: +Alternatively, you can configure AWS credentials using session variables (GUCs) or environment variables: + +#### Session Variables (GUCs) - Highest Priority +You can set these within a PostgreSQL session: +```sql +SET pg_parquet.aws_access_key_id = 'AKIA...'; +SET pg_parquet.aws_secret_access_key = '...'; +SET pg_parquet.aws_session_token = '...'; -- Optional, for temporary credentials +SET pg_parquet.aws_region = 'us-east-1'; +SET pg_parquet.aws_endpoint_url = 'https://s3.amazonaws.com'; +``` + +#### Environment Variables - Second Priority - `AWS_ACCESS_KEY_ID`: the access key ID of the AWS account - `AWS_SECRET_ACCESS_KEY`: the secret access key of the AWS account - `AWS_SESSION_TOKEN`: the session token for the AWS account @@ -234,8 +246,9 @@ Alternatively, you can use the following environment variables when starting pos - `AWS_ALLOW_HTTP`: allows http endpoints **(only via environment variables)** Config source priority order is shown below: -1. Environment variables, -2. Config file. +1. Session variables (GUCs), +2. Environment variables, +3. Config file. Supported S3 uri formats are shown below: - s3:// \ / \ @@ -304,10 +317,36 @@ $ cat ~/.config/gcloud/application_default_credentials.json } ``` -Alternatively, you can use the following environment variables when starting postgres to configure the Google Cloud Storage client: +Alternatively, you can configure Google Cloud Storage credentials using session variables (GUCs) or environment variables: + +#### Session Variables (GUCs) - Highest Priority +You can set these within a PostgreSQL session: + +**For service account key (JSON string):** +```sql +-- Simple JSON (escape single quotes by doubling them) +SET pg_parquet.google_service_account_key = '{"type": "service_account", "project_id": "my-project"}'; + +-- Complex JSON with private key (escape single quotes) +SET pg_parquet.google_service_account_key = '{"type": "service_account", "project_id": "my-project", "private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n", "client_email": "service@my-project.iam.gserviceaccount.com"}'; +``` + +**For service account path:** +```sql +SET pg_parquet.google_service_account_path = '/path/to/service-account-key.json'; +``` + +**Note:** When setting JSON service account keys via GUC, make sure to escape any single quotes (`'`) by doubling them (`''`). + +#### Environment Variables - Second Priority - `GOOGLE_SERVICE_ACCOUNT_KEY`: json serialized service account key **(only via environment variables)** - `GOOGLE_SERVICE_ACCOUNT_PATH`: an alternative location for the config file **(only via environment variables)** +Config source priority order is shown below: +1. Session variables (GUCs), +2. Environment variables, +3. Default credentials file (`~/.config/gcloud/application_default_credentials.json`). + Supported Google Cloud Storage uri formats are shown below: - gs:// \ / \ diff --git a/src/lib.rs b/src/lib.rs index 88eff21..fbfea34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,30 @@ +use pgrx::pg_sys::AsPgCStr; use std::ffi::CStr; +use std::ffi::CString; use std::sync::LazyLock; use parquet_copy_hook::hook::{init_parquet_copy_hook, ENABLE_PARQUET_COPY_HOOK}; use parquet_copy_hook::pg_compat::MarkGUCPrefixReserved; -use pgrx::pg_sys::AsPgCStr; -use pgrx::{prelude::*, GucContext, GucFlags, GucRegistry}; +use pgrx::{prelude::*, GucContext, GucFlags, GucRegistry, GucSetting}; use tokio::runtime::Runtime; +// AWS Configuration GUCs +pub(crate) static AWS_ACCESS_KEY_ID: GucSetting> = + GucSetting::>::new(None); +pub(crate) static AWS_SECRET_ACCESS_KEY: GucSetting> = + GucSetting::>::new(None); +pub(crate) static AWS_SESSION_TOKEN: GucSetting> = + GucSetting::>::new(None); +pub(crate) static AWS_ENDPOINT_URL: GucSetting> = + GucSetting::>::new(None); +pub(crate) static AWS_REGION: GucSetting> = + GucSetting::>::new(None); + +pub(crate) static GOOGLE_SERVICE_ACCOUNT_KEY: GucSetting> = + GucSetting::>::new(None); +pub(crate) static GOOGLE_SERVICE_ACCOUNT_PATH: GucSetting> = + GucSetting::>::new(None); + mod arrow_parquet; mod object_store; mod parquet_copy_hook; @@ -49,7 +67,79 @@ pub extern "C-unwind" fn _PG_init() { &ENABLE_PARQUET_COPY_HOOK, GucContext::Userset, GucFlags::default(), - ) + ); + + // AWS Configuration GUCs + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.aws_access_key_id".as_pg_cstr()), + CStr::from_ptr("AWS Access Key ID for S3 authentication".as_pg_cstr()), + CStr::from_ptr( + "AWS Access Key ID used for authenticating with S3-compatible storage".as_pg_cstr(), + ), + &AWS_ACCESS_KEY_ID, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.aws_secret_access_key".as_pg_cstr()), + CStr::from_ptr("AWS Secret Access Key for S3 authentication".as_pg_cstr()), + CStr::from_ptr( + "AWS Secret Access Key used for authenticating with S3-compatible storage" + .as_pg_cstr(), + ), + &AWS_SECRET_ACCESS_KEY, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.aws_session_token".as_pg_cstr()), + CStr::from_ptr("AWS Session Token for S3 authentication".as_pg_cstr()), + CStr::from_ptr( + "AWS Session Token used for temporary credentials with S3-compatible storage" + .as_pg_cstr(), + ), + &AWS_SESSION_TOKEN, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.aws_endpoint_url".as_pg_cstr()), + CStr::from_ptr("AWS S3 Endpoint URL".as_pg_cstr()), + CStr::from_ptr("Custom endpoint URL for S3-compatible storage services".as_pg_cstr()), + &AWS_ENDPOINT_URL, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.aws_region".as_pg_cstr()), + CStr::from_ptr("AWS Region for S3 operations".as_pg_cstr()), + CStr::from_ptr("AWS region for S3 bucket operations".as_pg_cstr()), + &AWS_REGION, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.google_service_account_key".as_pg_cstr()), + CStr::from_ptr("Google Service Account Key JSON".as_pg_cstr()), + CStr::from_ptr("Google Cloud service account key used for authentication".as_pg_cstr()), + &GOOGLE_SERVICE_ACCOUNT_KEY, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + CStr::from_ptr("pg_parquet.google_service_account_path".as_pg_cstr()), + CStr::from_ptr("Google Service Account Key Path".as_pg_cstr()), + CStr::from_ptr("Path to Google Cloud service account key file".as_pg_cstr()), + &GOOGLE_SERVICE_ACCOUNT_PATH, + GucContext::Userset, + GucFlags::default(), + ); }; MarkGUCPrefixReserved("pg_parquet"); diff --git a/src/object_store/aws.rs b/src/object_store/aws.rs index 9a1541a..c8ddacb 100644 --- a/src/object_store/aws.rs +++ b/src/object_store/aws.rs @@ -5,7 +5,10 @@ use aws_credential_types::provider::ProvideCredentials; use object_store::aws::AmazonS3Builder; use url::Url; -use crate::PG_BACKEND_TOKIO_RUNTIME; +use crate::{ + AWS_ACCESS_KEY_ID, AWS_ENDPOINT_URL, AWS_REGION, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, + PG_BACKEND_TOKIO_RUNTIME, +}; use super::object_store_cache::ObjectStoreWithExpiration; @@ -115,7 +118,7 @@ struct AwsS3Config { } impl AwsS3Config { - // load reads the s3 config from the environment variables first and config files as fallback. + // load reads the s3 config from GUCs first, then environment variables, then config files as fallback. fn load() -> Self { let allow_http = if let Ok(allow_http) = std::env::var("AWS_ALLOW_HTTP") { allow_http.parse().unwrap_or(false) @@ -123,35 +126,65 @@ impl AwsS3Config { false }; - // first tries environment variables and then the config files - let sdk_config = PG_BACKEND_TOKIO_RUNTIME - .block_on(async { aws_config::defaults(BehaviorVersion::latest()).load().await }); - - let mut access_key_id = None; - let mut secret_access_key = None; - let mut session_token = None; + // Check GUC values first + let mut access_key_id = AWS_ACCESS_KEY_ID.get().map(|s| { + let key = s.to_string_lossy().to_string(); + key + }); + let mut secret_access_key = AWS_SECRET_ACCESS_KEY + .get() + .map(|s| s.to_string_lossy().to_string()); + let mut session_token = AWS_SESSION_TOKEN + .get() + .map(|s| s.to_string_lossy().to_string()); + let mut endpoint_url = AWS_ENDPOINT_URL + .get() + .map(|s| s.to_string_lossy().to_string()); + let mut region = AWS_REGION.get().map(|s| s.to_string_lossy().to_string()); + //ToDo: Add credential expiry handling when using session variables. let mut expire_at = None; - if let Some(credential_provider) = sdk_config.credentials_provider() { - let cred_res = PG_BACKEND_TOKIO_RUNTIME - .block_on(async { credential_provider.provide_credentials().await }); - - if let Ok(credentials) = cred_res { - access_key_id = Some(credentials.access_key_id().to_string()); - secret_access_key = Some(credentials.secret_access_key().to_string()); - session_token = credentials.session_token().map(|t| t.to_string()); - expire_at = credentials.expiry(); - } else { - pgrx::error!( - "failed to load aws credentials: {:?}", - cred_res.unwrap_err() - ); + // If GUCs are not set, fall back to environment variables and config files + if access_key_id.is_none() || secret_access_key.is_none() { + let sdk_config = PG_BACKEND_TOKIO_RUNTIME + .block_on(async { aws_config::defaults(BehaviorVersion::latest()).load().await }); + + if let Some(credential_provider) = sdk_config.credentials_provider() { + let cred_res = PG_BACKEND_TOKIO_RUNTIME + .block_on(async { credential_provider.provide_credentials().await }); + + if let Ok(credentials) = cred_res { + if access_key_id.is_none() { + let key = credentials.access_key_id().to_string(); + access_key_id = Some(key); + } + if secret_access_key.is_none() { + secret_access_key = Some(credentials.secret_access_key().to_string()); + } + if session_token.is_none() { + session_token = credentials.session_token().map(|t| t.to_string()); + } + expire_at = credentials.expiry(); + } else { + pgrx::error!( + "failed to load aws credentials: {:?}", + cred_res.unwrap_err() + ); + } } } - let endpoint_url = sdk_config.endpoint_url().map(|u| u.to_string()); + if region.is_none() { + let sdk_config = PG_BACKEND_TOKIO_RUNTIME + .block_on(async { aws_config::defaults(BehaviorVersion::latest()).load().await }); + region = sdk_config.region().map(|r| r.to_string()); + } - let region = sdk_config.region().map(|r| r.as_ref().to_string()); + if endpoint_url.is_none() { + if let Ok(env_endpoint) = std::env::var("AWS_ENDPOINT_URL") { + endpoint_url = Some(env_endpoint); + } + } Self { region, diff --git a/src/object_store/gcs.rs b/src/object_store/gcs.rs index e68b41a..8f91d34 100644 --- a/src/object_store/gcs.rs +++ b/src/object_store/gcs.rs @@ -3,6 +3,8 @@ use std::sync::Arc; use object_store::gcp::GoogleCloudStorageBuilder; use url::Url; +use crate::GOOGLE_SERVICE_ACCOUNT_KEY; + use super::object_store_cache::ObjectStoreWithExpiration; // create_gcs_object_store a GoogleCloudStorage object store from given uri. @@ -61,8 +63,15 @@ struct GoogleStorageConfig { impl GoogleStorageConfig { // load loads the Google Storage configuration from the environment. fn load() -> Self { + let mut key = GOOGLE_SERVICE_ACCOUNT_KEY + .get() + .map(|s| s.to_string_lossy().to_string()); + + if key.is_none() { + key = std::env::var("GOOGLE_SERVICE_ACCOUNT_KEY").ok(); + } Self { - service_account_key: std::env::var("GOOGLE_SERVICE_ACCOUNT_KEY").ok(), + service_account_key: key, service_account_path: std::env::var("GOOGLE_SERVICE_ACCOUNT_PATH").ok(), } } diff --git a/src/pgrx_tests/object_store.rs b/src/pgrx_tests/object_store.rs index 61d91cd..79098a7 100644 --- a/src/pgrx_tests/object_store.rs +++ b/src/pgrx_tests/object_store.rs @@ -717,6 +717,257 @@ mod tests { Spi::run("copy test_table from 's3://testbucket/dummy.csv';").unwrap(); } + #[pg_test] + fn test_s3_from_guc_credentials() { + object_store_cache_clear(); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + // Get credentials from environment variables + let access_key_id = + std::env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID not found"); + let secret_access_key = + std::env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY not found"); + let region = std::env::var("AWS_REGION").expect("AWS_REGION not found"); + let endpoint_url = std::env::var("AWS_ENDPOINT_URL").expect("AWS_ENDPOINT_URL not found"); + + // Remove environment variables to ensure GUCs are used + std::env::remove_var("AWS_ACCESS_KEY_ID"); + std::env::remove_var("AWS_SECRET_ACCESS_KEY"); + std::env::remove_var("AWS_REGION"); + std::env::remove_var("AWS_ENDPOINT_URL"); + + // Set credentials via GUCs + Spi::run(&format!( + "SET pg_parquet.aws_access_key_id = '{access_key_id}';" + )) + .unwrap(); + Spi::run(&format!( + "SET pg_parquet.aws_secret_access_key = '{secret_access_key}';" + )) + .unwrap(); + Spi::run(&format!("SET pg_parquet.aws_region = '{region}';")).unwrap(); + Spi::run(&format!( + "SET pg_parquet.aws_endpoint_url = '{endpoint_url}';" + )) + .unwrap(); + + let s3_uri = format!("s3://{test_bucket_name}/pg_parquet_test.parquet"); + + let test_table = TestTable::::new("int4".into()).with_uri(s3_uri); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + + // Restore environment variables + std::env::set_var("AWS_ACCESS_KEY_ID", access_key_id); + std::env::set_var("AWS_SECRET_ACCESS_KEY", secret_access_key); + std::env::set_var("AWS_REGION", region); + std::env::set_var("AWS_ENDPOINT_URL", endpoint_url); + } + + #[pg_test] + fn test_s3_guc_priority_over_env() { + object_store_cache_clear(); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + // Get real credentials + let real_access_key_id = + std::env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID not found"); + let real_secret_access_key = + std::env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY not found"); + let real_region = std::env::var("AWS_REGION").expect("AWS_REGION not found"); + let real_endpoint_url = + std::env::var("AWS_ENDPOINT_URL").expect("AWS_ENDPOINT_URL not found"); + + // Set wrong credentials in environment variables + std::env::set_var("AWS_ACCESS_KEY_ID", "wrong_access_key"); + std::env::set_var("AWS_SECRET_ACCESS_KEY", "wrong_secret_key"); + std::env::set_var("AWS_REGION", "wrong-region"); + std::env::set_var("AWS_ENDPOINT_URL", "http://wrong-endpoint"); + + // Set correct credentials via GUCs (should take priority) + Spi::run(&format!( + "SET pg_parquet.aws_access_key_id = '{real_access_key_id}';" + )) + .unwrap(); + Spi::run(&format!( + "SET pg_parquet.aws_secret_access_key = '{real_secret_access_key}';" + )) + .unwrap(); + Spi::run(&format!("SET pg_parquet.aws_region = '{real_region}';")).unwrap(); + Spi::run(&format!( + "SET pg_parquet.aws_endpoint_url = '{real_endpoint_url}';" + )) + .unwrap(); + + let s3_uri = format!("s3://{test_bucket_name}/pg_parquet_test.parquet"); + + let test_table = TestTable::::new("int4".into()).with_uri(s3_uri); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + + // Restore original environment variables + std::env::set_var("AWS_ACCESS_KEY_ID", real_access_key_id); + std::env::set_var("AWS_SECRET_ACCESS_KEY", real_secret_access_key); + std::env::set_var("AWS_REGION", real_region); + std::env::set_var("AWS_ENDPOINT_URL", real_endpoint_url); + } + + #[pg_test] + #[should_panic(expected = "403 Forbidden")] + fn test_s3_guc_wrong_credentials() { + object_store_cache_clear(); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + // Remove environment variables + std::env::remove_var("AWS_ACCESS_KEY_ID"); + std::env::remove_var("AWS_SECRET_ACCESS_KEY"); + std::env::remove_var("AWS_REGION"); + std::env::remove_var("AWS_ENDPOINT_URL"); + + // Set wrong credentials via GUCs + Spi::run("SET pg_parquet.aws_access_key_id = 'wrong_access_key';").unwrap(); + Spi::run("SET pg_parquet.aws_secret_access_key = 'wrong_secret_key';").unwrap(); + Spi::run("SET pg_parquet.aws_region = 'us-east-1';").unwrap(); + Spi::run("SET pg_parquet.aws_endpoint_url = 'https://s3.amazonaws.com';").unwrap(); + + let s3_uri = format!("s3://{test_bucket_name}/pg_parquet_test.parquet"); + + let copy_to_command = + format!("COPY (SELECT i FROM generate_series(1,10) i) TO '{s3_uri}';"); + Spi::run(copy_to_command.as_str()).unwrap(); + } + + #[pg_test] + fn test_gcs_from_guc_service_account_key() { + object_store_cache_clear(); + + let test_bucket_name: String = + std::env::var("GOOGLE_TEST_BUCKET").expect("GOOGLE_TEST_BUCKET not found"); + + // Get service account key from environment variable + let service_account_key = std::env::var("GOOGLE_SERVICE_ACCOUNT_KEY") + .expect("GOOGLE_SERVICE_ACCOUNT_KEY not found"); + + // Remove environment variable to ensure GUC is used + std::env::remove_var("GOOGLE_SERVICE_ACCOUNT_KEY"); + + // Set service account key via GUC + Spi::run(&format!( + "SET pg_parquet.google_service_account_key = '{service_account_key}';" + )) + .unwrap(); + + let gcs_uri = format!("gs://{test_bucket_name}/pg_parquet_test.parquet"); + + let test_table = TestTable::::new("int4".into()).with_uri(gcs_uri); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + + // Restore environment variable + std::env::set_var("GOOGLE_SERVICE_ACCOUNT_KEY", service_account_key); + } + + #[pg_test] + fn test_gcs_guc_priority_over_env() { + object_store_cache_clear(); + + let test_bucket_name: String = + std::env::var("GOOGLE_TEST_BUCKET").expect("GOOGLE_TEST_BUCKET not found"); + + // Get real service account key + let real_service_account_key = std::env::var("GOOGLE_SERVICE_ACCOUNT_KEY") + .expect("GOOGLE_SERVICE_ACCOUNT_KEY not found"); + + // Set wrong service account key in environment variable + std::env::set_var("GOOGLE_SERVICE_ACCOUNT_KEY", "wrong_service_account_key"); + + // Set correct service account key via GUC (should take priority) + Spi::run(&format!( + "SET pg_parquet.google_service_account_key = '{real_service_account_key}';" + )) + .unwrap(); + + let gcs_uri = format!("gs://{test_bucket_name}/pg_parquet_test.parquet"); + + let test_table = TestTable::::new("int4".into()).with_uri(gcs_uri); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + + // Restore original environment variable + std::env::set_var("GOOGLE_SERVICE_ACCOUNT_KEY", real_service_account_key); + } + + #[pg_test] + #[should_panic(expected = "404 Not Found")] + fn test_gcs_guc_wrong_credentials() { + object_store_cache_clear(); + + // Remove environment variables + std::env::remove_var("GOOGLE_SERVICE_ACCOUNT_KEY"); + std::env::remove_var("GOOGLE_SERVICE_ACCOUNT_PATH"); + + // Set wrong service account key via GUC + Spi::run("SET pg_parquet.google_service_account_key = '{\"type\": \"service_account\", \"project_id\": \"wrong-project\"}';").unwrap(); + + let gcs_uri = "gs://randombucketwhichdoesnotexist/pg_parquet_test.parquet"; + + let copy_to_command = + format!("COPY (SELECT i FROM generate_series(1,10) i) TO '{gcs_uri}';"); + Spi::run(copy_to_command.as_str()).unwrap(); + } + + #[pg_test] + fn test_gcs_guc_debug() { + // Test that GUCs are being read correctly + let test_key = "{\"type\": \"service_account\", \"project_id\": \"test-project\"}"; + let test_path = "/tmp/test-service-account.json"; + + // Set GUCs + Spi::run(&format!( + "SET pg_parquet.google_service_account_key = '{test_key}';" + )) + .unwrap(); + Spi::run(&format!( + "SET pg_parquet.google_service_account_path = '{test_path}';" + )) + .unwrap(); + + // Remove environment variables to ensure GUCs are used + std::env::remove_var("GOOGLE_SERVICE_ACCOUNT_KEY"); + std::env::remove_var("GOOGLE_SERVICE_ACCOUNT_PATH"); + + // Test that we can read the GUCs + let result = Spi::get_one::( + "SELECT current_setting('pg_parquet.google_service_account_key');", + ) + .unwrap(); + assert_eq!(result, Some(test_key.to_string())); + + let result = Spi::get_one::( + "SELECT current_setting('pg_parquet.google_service_account_path');", + ) + .unwrap(); + assert_eq!(result, Some(test_path.to_string())); + + // Restore environment variables + if let Ok(key) = std::env::var("GOOGLE_SERVICE_ACCOUNT_KEY") { + std::env::set_var("GOOGLE_SERVICE_ACCOUNT_KEY", key); + } + if let Ok(path) = std::env::var("GOOGLE_SERVICE_ACCOUNT_PATH") { + std::env::set_var("GOOGLE_SERVICE_ACCOUNT_PATH", path); + } + } + #[pg_test] #[cfg(not(rhel8))] fn test_object_store_cache() {