Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 0 additions & 433 deletions helium_iceberg/src/catalog.rs

This file was deleted.

171 changes: 171 additions & 0 deletions helium_iceberg/src/catalog/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//! OAuth2 authentication for the Iceberg REST catalog.

use crate::{option_hashmap, Error, Result, Settings};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;

/// Default token lifetime in seconds if not provided by the OAuth2 server.
const DEFAULT_TOKEN_LIFETIME_SECS: u64 = 3600; // 1 hour

/// Safety margin before token expiration to trigger proactive refresh.
const TOKEN_REFRESH_MARGIN: Duration = Duration::from_secs(60);

/// OAuth2 token response from the token endpoint.
#[derive(serde::Deserialize)]
struct TokenResponse {
access_token: String,
/// Token lifetime in seconds (standard OAuth2 response field).
expires_in: Option<u64>,
}

/// A cached OAuth2 token with expiration tracking.
pub(super) struct CachedToken {
pub(super) token: String,
pub(super) expires_at: Instant,
}

/// Authentication strategy for direct REST API calls.
#[derive(Clone)]
pub(crate) enum EndpointAuth {
None,
Token(String),
OAuth2 {
token_endpoint: String,
credential: String,
extra_params: HashMap<String, String>,
cached_token: Arc<Mutex<Option<CachedToken>>>,
},
}

impl EndpointAuth {
/// Build from settings, determining the auth strategy.
pub(crate) fn from_settings(settings: &Settings) -> Self {
if let Some(ref credential) = settings.auth.credential {
let token_endpoint = settings
.auth
.oauth2_server_uri
.clone()
.unwrap_or_else(|| format!("{}/v1/oauth/tokens", settings.catalog_uri));

let extra_params = option_hashmap! {
"scope" => settings.auth.scope,
"audience" => settings.auth.audience,
"resource" => settings.auth.resource,
};

Self::OAuth2 {
token_endpoint,
credential: credential.clone(),
extra_params,
cached_token: Arc::new(Mutex::new(None)),
}
} else if let Some(ref token) = settings.auth.token {
Self::Token(token.clone())
} else {
Self::None
}
}

/// Fetch a fresh OAuth2 token from the token endpoint.
///
/// Returns the token and its lifetime in seconds.
async fn fetch_token(
client: &reqwest::Client,
token_endpoint: &str,
credential: &str,
extra_params: &HashMap<String, String>,
) -> Result<(String, u64)> {
let (client_id, client_secret) = credential.split_once(':').unwrap_or((credential, ""));

let mut params = vec![
("grant_type", "client_credentials"),
("client_id", client_id),
("client_secret", client_secret),
];
let extra: Vec<(&str, &str)> = extra_params
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
params.extend(extra);

let response = client
.post(token_endpoint)
.form(&params)
.send()
.await
.map_err(|e| Error::Catalog(format!("OAuth2 token request failed: {e}")))?;

if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Error::Catalog(format!(
"OAuth2 token request returned {status}: {body}"
)));
}

let token_response: TokenResponse = response
.json()
.await
.map_err(|e| Error::Catalog(format!("failed to parse OAuth2 token response: {e}")))?;

let expires_in = token_response.expires_in.unwrap_or_else(|| {
tracing::warn!(
default_secs = DEFAULT_TOKEN_LIFETIME_SECS,
"no expires_in in token response, using default"
);
DEFAULT_TOKEN_LIFETIME_SECS
});

Ok((token_response.access_token, expires_in))
}

/// Get a valid token, using the cache if available and not expired.
///
/// Proactively refreshes the token if it's within the refresh margin of expiration.
pub(crate) async fn get_token(&self, client: &reqwest::Client) -> Result<Option<String>> {
match self {
Self::None => Ok(None),
Self::Token(token) => Ok(Some(token.clone())),
Self::OAuth2 {
token_endpoint,
credential,
extra_params,
cached_token,
} => {
let mut guard = cached_token.lock().await;

// Return cached token if it's still valid (with safety margin)
if let Some(ref cached) = *guard {
let now = Instant::now();
if now + TOKEN_REFRESH_MARGIN < cached.expires_at {
return Ok(Some(cached.token.clone()));
}
tracing::debug!("token expiring soon, proactively refreshing");
}

// Fetch new token
let (token, expires_in) =
Self::fetch_token(client, token_endpoint, credential, extra_params).await?;

let expires_at = Instant::now() + Duration::from_secs(expires_in);
*guard = Some(CachedToken {
token: token.clone(),
expires_at,
});

tracing::debug!(expires_in_secs = expires_in, "fetched new OAuth2 token");
Ok(Some(token))
}
}
}

/// Invalidate the cached token (called on 401 to trigger a refresh).
pub(crate) async fn invalidate(&self) {
if let Self::OAuth2 { cached_token, .. } = self {
let mut guard = cached_token.lock().await;
*guard = None;
}
}
}
212 changes: 212 additions & 0 deletions helium_iceberg/src/catalog/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
//! Iceberg REST catalog client with automatic authentication and retry handling.
//!
//! This module provides the [`Catalog`] type, a shareable wrapper around the
//! Iceberg REST catalog that handles:
//!
//! - OAuth2 token acquisition and caching
//! - Proactive token refresh before expiration
//! - Automatic retry on 401 authentication errors
//! - Direct REST API access for operations not supported by the iceberg crate

mod auth;
mod rest_endpoint;

use rest_endpoint::RestEndpoint;

use crate::error::IntoHeliumIcebergError;
use crate::{Error, IcebergTable, Result, Settings, TableCreator, TableDefinition};
use iceberg::table::Table;
use iceberg::{Catalog as IcebergCatalog, CatalogBuilder, NamespaceIdent, TableIdent};
use iceberg_catalog_rest::{
CommitTableRequest, RestCatalog, RestCatalogBuilder, REST_CATALOG_PROP_URI,
REST_CATALOG_PROP_WAREHOUSE,
};
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;

/// Check if an iceberg error indicates an authentication failure (401 Unauthorized).
fn is_iceberg_auth_error(error: &iceberg::Error) -> bool {
let msg = format!("{error:?}");
msg.contains("401") || msg.contains("Unauthorized") || msg.contains("unauthorized")
}

/// A shareable Iceberg REST catalog client with automatic retry on auth errors.
///
/// Since `RestCatalog` from the iceberg crate is not `Clone`, this struct wraps it
/// in an `Arc` to enable sharing a single catalog connection across multiple
/// components like `IcebergTable` and `TableCreator`.
///
/// Implements `iceberg::Catalog` with automatic 401 retry logic - when an operation
/// fails with an authentication error, the cached token is invalidated and the
/// operation is retried once.
///
/// Also provides direct REST API access for branch operations that cannot go
/// through the `Transaction` API due to `TableCommit::builder().build()` being
/// `pub(crate)` in iceberg 0.8.
#[derive(Clone)]
pub struct Catalog {
inner: Arc<RestCatalog>,
endpoint: RestEndpoint,
}

impl AsRef<RestCatalog> for Catalog {
fn as_ref(&self) -> &RestCatalog {
&self.inner
}
}

impl std::fmt::Debug for Catalog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Catalog")
.field("inner", &self.inner)
.field("endpoint", &self.endpoint)
.finish()
}
}

impl Catalog {
async fn with_auth<F, Fut, T>(&self, mut f: F) -> iceberg::Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = iceberg::Result<T>>,
{
match f().await {
Ok(result) => Ok(result),
Err(e) if is_iceberg_auth_error(&e) => {
tracing::warn!("auth error, invaliding token and retrying");
self.endpoint.auth.invalidate().await;
f().await
}
Err(e) => Err(e),
}
}

/// Connect to an Iceberg REST catalog using the provided settings.
pub async fn connect(settings: &Settings) -> Result<Self> {
let mut config = HashMap::new();
config.insert(
REST_CATALOG_PROP_URI.to_string(),
settings.catalog_uri.clone(),
);
if let Some(ref warehouse) = settings.warehouse {
config.insert(REST_CATALOG_PROP_WAREHOUSE.to_string(), warehouse.clone());
}
config.extend(settings.auth.props());
config.extend(settings.s3.props());
config.extend(settings.properties.clone());

let rest_catalog = RestCatalogBuilder::default()
.load(&settings.catalog_name, config)
.await
.map_err(Error::Iceberg)?;

let endpoint = RestEndpoint::resolve(settings).await?;

Ok(Self {
inner: Arc::new(rest_catalog),
endpoint,
})
}

/// Check if a table exists in the given namespace.
pub async fn table_exists(
&self,
namespace: impl Into<String>,
table_name: impl Into<String>,
) -> Result<bool> {
let namespace_ident = NamespaceIdent::new(namespace.into());
let table_ident = TableIdent::new(namespace_ident, table_name.into());

self.with_auth(|| self.inner.table_exists(&table_ident))
.await
.err_into()
}

/// Load a table from the catalog.
pub async fn load_table(
&self,
namespace: impl Into<String>,
table_name: impl Into<String>,
) -> Result<Table> {
let namespace = namespace.into();
let table_name = table_name.into();
let namespace_ident = NamespaceIdent::new(namespace.clone());
let table_ident = TableIdent::new(namespace_ident, table_name.clone());

self.with_auth(|| self.inner.load_table(&table_ident))
.await
.map_err(|e: iceberg::Error| match e.kind() {
iceberg::ErrorKind::DataInvalid => Error::TableNotFound {
namespace,
table: table_name,
},
_ => Error::Iceberg(e),
})
}

/// Create a namespace if it does not already exist.
pub async fn create_namespace_if_not_exists(&self, namespace: impl Into<String>) -> Result<()> {
let namespace_ident = NamespaceIdent::new(namespace.into());

let exists = self
.with_auth(|| self.inner.namespace_exists(&namespace_ident))
.await?;

if !exists {
self.with_auth(|| {
self.inner
.create_namespace(&namespace_ident, HashMap::new())
})
.await?;
}

Ok(())
}

/// Create a table if it does not already exist.
pub async fn create_table_if_not_exists<T>(
&self,
table_def: TableDefinition,
) -> Result<IcebergTable<T>> {
TableCreator::new(self.clone())
.create_table_if_not_exists(table_def)
.await
}

/// Send a commit table request directly to the REST catalog API.
///
/// This bypasses `TableCommit` (whose builder is `pub(crate)` in iceberg 0.8)
/// by constructing and sending a `CommitTableRequest` via HTTP POST.
pub(crate) async fn commit_table_request(&self, request: &CommitTableRequest) -> Result<()> {
self.endpoint.commit_table(request).await
}

/// List all namespaces in the catalog.
pub async fn list_namespaces(&self) -> Result<Vec<NamespaceIdent>> {
self.with_auth(|| self.inner.list_namespaces(None))
.await
.err_into()
}

/// List all namespaces under a parent namespace.
pub async fn list_namespaces_under(
&self,
parent: impl Into<String>,
) -> Result<Vec<NamespaceIdent>> {
let parent_ident = NamespaceIdent::new(parent.into());

self.with_auth(|| self.inner.list_namespaces(Some(&parent_ident)))
.await
.err_into()
}

/// List all tables in a namespace.
pub async fn list_tables(&self, namespace: impl Into<String>) -> Result<Vec<TableIdent>> {
let namespace_ident = NamespaceIdent::new(namespace.into());

self.with_auth(|| self.inner.list_tables(&namespace_ident))
.await
.err_into()
}
}
Loading
Loading