Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 183 additions & 4 deletions crates/catalog-unity/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,12 @@ impl UnityCatalogBuilder {
}
}

/// Returns the storage location and temporary token to be used with the
/// Unity Catalog table.
/// Returns the storage location and temporary token for the Unity Catalog table.
///
/// If storage options are provided, they override environment variables for authentication.
pub async fn get_uc_location_and_token(
table_uri: &str,
storage_options: Option<&HashMap<String, String>>,
) -> Result<(String, HashMap<String, String>), UnityCatalogError> {
let uri_parts: Vec<&str> = table_uri[5..].split('.').collect();
if uri_parts.len() != 3 {
Expand All @@ -551,7 +553,15 @@ impl UnityCatalogBuilder {
let database_name = uri_parts[1];
let table_name = uri_parts[2];

let unity_catalog = UnityCatalogBuilder::from_env().build()?;
let unity_catalog = if let Some(options) = storage_options {
let mut builder = UnityCatalogBuilder::from_env();
builder =
builder.try_with_options(options.iter().map(|(k, v)| (k.as_str(), v.as_str())))?;
builder.build()?
} else {
UnityCatalogBuilder::from_env().build()?
};

let storage_location = unity_catalog
.get_table_storage_location(Some(catalog_id.to_string()), database_name, table_name)
.await?;
Expand Down Expand Up @@ -845,7 +855,7 @@ impl ObjectStoreFactory for UnityCatalogFactory {
config: &StorageConfig,
) -> DeltaResult<(ObjectStoreRef, Path)> {
let (table_path, temp_creds) = UnityCatalogBuilder::execute_uc_future(
UnityCatalogBuilder::get_uc_location_and_token(table_uri.as_str()),
UnityCatalogBuilder::get_uc_location_and_token(table_uri.as_str(), Some(&config.raw)),
)??;

let mut storage_options = config.raw.clone();
Expand Down Expand Up @@ -936,6 +946,7 @@ mod tests {
use crate::UnityCatalogBuilder;
use deltalake_core::DataCatalog;
use httpmock::prelude::*;
use std::collections::HashMap;

#[tokio::test]
async fn test_unity_client() {
Expand Down Expand Up @@ -1003,4 +1014,172 @@ mod tests {
.unwrap();
assert!(storage_location.eq_ignore_ascii_case("string"));
}

#[test]
fn test_unitycatalogbuilder_with_storage_options() {
let mut storage_options = HashMap::new();
storage_options.insert(
"databricks_host".to_string(),
"https://test.databricks.com".to_string(),
);
storage_options.insert("databricks_token".to_string(), "test_token".to_string());

let builder = UnityCatalogBuilder::new()
.try_with_options(&storage_options)
.unwrap();

assert_eq!(
builder.workspace_url,
Some("https://test.databricks.com".to_string())
);
assert_eq!(builder.bearer_token, Some("test_token".to_string()));
}

#[test]
fn test_unitycatalogbuilder_client_credentials() {
let mut storage_options = HashMap::new();
storage_options.insert(
"databricks_host".to_string(),
"https://test.databricks.com".to_string(),
);
storage_options.insert("unity_client_id".to_string(), "test_client_id".to_string());
storage_options.insert("unity_client_secret".to_string(), "test_secret".to_string());
storage_options.insert("unity_authority_id".to_string(), "test_tenant".to_string());

let builder = UnityCatalogBuilder::new()
.try_with_options(&storage_options)
.unwrap();

assert_eq!(
builder.workspace_url,
Some("https://test.databricks.com".to_string())
);
assert_eq!(builder.client_id, Some("test_client_id".to_string()));
assert_eq!(builder.client_secret, Some("test_secret".to_string()));
assert_eq!(builder.authority_id, Some("test_tenant".to_string()));
}

#[test]
fn test_env_with_storage_options_override() {
std::env::set_var("DATABRICKS_HOST", "https://env.databricks.com");
std::env::set_var("DATABRICKS_TOKEN", "env_token");

let mut storage_options = HashMap::new();
storage_options.insert(
"databricks_host".to_string(),
"https://override.databricks.com".to_string(),
);

let builder = UnityCatalogBuilder::from_env()
.try_with_options(&storage_options)
.unwrap();

assert_eq!(
builder.workspace_url,
Some("https://override.databricks.com".to_string())
);
assert_eq!(builder.bearer_token, Some("env_token".to_string()));

std::env::remove_var("DATABRICKS_HOST");
std::env::remove_var("DATABRICKS_TOKEN");
}

#[test]
fn test_storage_options_key_variations() {
let test_cases = vec![
("databricks_host", "workspace_url"),
("unity_workspace_url", "workspace_url"),
("databricks_workspace_url", "workspace_url"),
("databricks_token", "bearer_token"),
("token", "bearer_token"),
("unity_client_id", "client_id"),
("databricks_client_id", "client_id"),
("client_id", "client_id"),
];

for (key, field) in test_cases {
let mut storage_options = HashMap::new();
let test_value = format!("test_value_for_{}", key);
storage_options.insert(key.to_string(), test_value.clone());

let result = UnityCatalogBuilder::new().try_with_options(&storage_options);
assert!(result.is_ok(), "Failed to parse key: {}", key);

let builder = result.unwrap();
match field {
"workspace_url" => assert_eq!(builder.workspace_url, Some(test_value)),
"bearer_token" => assert_eq!(builder.bearer_token, Some(test_value)),
"client_id" => assert_eq!(builder.client_id, Some(test_value)),
_ => {}
}
}
}

#[test]
fn test_invalid_config_key() {
let mut storage_options = HashMap::new();
storage_options.insert("invalid_key".to_string(), "test_value".to_string());

let result = UnityCatalogBuilder::new().try_with_options(&storage_options);
assert!(result.is_err());
}

#[test]
fn test_boolean_options() {
let test_cases = vec![
("true", true),
("false", false),
("1", true),
("0", false),
("yes", true),
("no", false),
];

for (value, expected) in test_cases {
let mut storage_options = HashMap::new();
storage_options.insert("unity_allow_http_url".to_string(), value.to_string());
storage_options.insert("unity_use_azure_cli".to_string(), value.to_string());

let builder = UnityCatalogBuilder::new()
.try_with_options(&storage_options)
.unwrap();

assert_eq!(
builder.allow_http_url, expected,
"Failed for value: {}",
value
);
assert_eq!(
builder.use_azure_cli, expected,
"Failed for value: {}",
value
);
}
}

#[tokio::test]
async fn test_invalid_table_uri() {
let test_cases = vec![
"uc://invalid",
"uc://",
"uc://catalog",
"uc://catalog.schema",
"uc://catalog.schema.table.extra",
"invalid://catalog.schema.table",
];

for uri in test_cases {
let result = UnityCatalogBuilder::get_uc_location_and_token(uri, None).await;
assert!(result.is_err(), "Expected error for URI: {}", uri);

if let Err(e) = result {
if uri.starts_with("uc://") && uri.len() > 5 {
assert!(matches!(
e,
crate::UnityCatalogError::InvalidTableURI { .. }
));
}
}
}
}
}
Loading