Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,37 @@ const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
/// A factory object for Postgres clients. This abstracts
/// details of client creation such as pooling.
#[async_trait]
pub trait ClientFactory: Default + Send + Sync + 'static {
pub trait ClientFactory: Send + Sync + 'static {
/// The type of client produced by `get_client`.
type Client: Client;
fn new(root_certificates: Vec<Vec<u8>>) -> Self;
/// Gets a client from the factory.
async fn get_client(&self, address: &str) -> Result<Self::Client>;
}

/// A `ClientFactory` that uses a connection pool per address.
pub struct PooledTokioClientFactory {
pools: moka::sync::Cache<String, deadpool_postgres::Pool>,
root_certificates: Vec<Vec<u8>>,
}

impl Default for PooledTokioClientFactory {
fn default() -> Self {
#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;

fn new(root_certificates: Vec<Vec<u8>>) -> Self {
Self {
pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
root_certificates,
}
}
}

#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;

async fn get_client(&self, address: &str) -> Result<Self::Client> {
let pool = self
.pools
.try_get_with_by_ref(address, || create_connection_pool(address))
.try_get_with_by_ref(address, || {
create_connection_pool(address, &self.root_certificates)
})
.map_err(ArcError)
.context("establishing PostgreSQL connection pool")?;

Expand All @@ -54,7 +57,10 @@ impl ClientFactory for PooledTokioClientFactory {
}

/// Creates a Postgres connection pool for the given address.
fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
fn create_connection_pool(
address: &str,
root_certificates: &[Vec<u8>],
) -> Result<deadpool_postgres::Pool> {
let config = address
.parse::<tokio_postgres::Config>()
.context("parsing Postgres connection string")?;
Expand All @@ -68,7 +74,11 @@ fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
let mgr = if config.get_ssl_mode() == SslMode::Disable {
deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
} else {
let builder = TlsConnector::builder();
let mut builder = TlsConnector::builder();
for cert_bytes in root_certificates {
builder.add_root_certificate(native_tls::Certificate::from_pem(cert_bytes)?);
}

let connector = MakeTlsConnector::new(builder.build()?);
deadpool_postgres::Manager::from_config(config, connector, mgr_config)
};
Expand Down
12 changes: 9 additions & 3 deletions crates/factor-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod client;
mod host;
pub mod runtime_config;
mod types;

use std::sync::Arc;
Expand All @@ -18,7 +19,7 @@ pub struct OutboundPgFactor<CF = crate::client::PooledTokioClientFactory> {
}

impl<CF: ClientFactory> Factor for OutboundPgFactor<CF> {
type RuntimeConfig = ();
type RuntimeConfig = runtime_config::RuntimeConfig;
type AppState = Arc<CF>;
type InstanceBuilder = InstanceState<CF>;

Expand All @@ -36,9 +37,14 @@ impl<CF: ClientFactory> Factor for OutboundPgFactor<CF> {

fn configure_app<T: RuntimeFactors>(
&self,
_ctx: ConfigureAppContext<T, Self>,
ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
Ok(Arc::new(CF::default()))
let certificates = match ctx.runtime_config() {
Some(rc) => rc.certificates.clone(),
None => vec![],
};
// let certificates = certificate_paths.iter().map(std::fs::read).collect::<Result<Vec<_>, _>>()?;
Ok(Arc::new(CF::new(certificates)))
}

fn prepare<T: RuntimeFactors>(
Expand Down
4 changes: 4 additions & 0 deletions crates/factor-outbound-pg/src/runtime_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[derive(Default)]
pub struct RuntimeConfig {
pub certificates: Vec<Vec<u8>>,
}
4 changes: 3 additions & 1 deletion crates/factor-outbound-pg/tests/factor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,15 @@ async fn exercise_query() -> anyhow::Result<()> {
}

// TODO: We can expand this mock to track calls and simulate return values
#[derive(Default)]
pub struct MockClientFactory {}
pub struct MockClient {}

#[async_trait]
impl ClientFactory for MockClientFactory {
type Client = MockClient;
fn new(_: Vec<Vec<u8>>) -> Self {
Self {}
}
async fn get_client(&self, _address: &str) -> Result<Self::Client> {
Ok(MockClient {})
}
Expand Down
20 changes: 17 additions & 3 deletions crates/runtime-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use spin_sqlite as sqlite;
use spin_trigger::cli::UserProvidedPath;
use toml::Value;

mod pg;
pub mod variables;

/// The default state directory for the trigger.
Expand Down Expand Up @@ -137,9 +138,13 @@ where
let outbound_networking = runtime_config_dir
.clone()
.map(OutboundNetworkingSpinRuntimeConfig::new);
let key_value_resolver = key_value_config_resolver(runtime_config_dir, state_dir.clone());
let key_value_resolver =
key_value_config_resolver(runtime_config_dir.clone(), state_dir.clone());
let sqlite_resolver = sqlite_config_resolver(state_dir.clone())
.context("failed to resolve sqlite runtime config")?;
let pg_resolver = pg::PgConfigResolver {
base_dir: runtime_config_dir.clone(),
};

let toml = toml_resolver.toml();
let log_dir = toml_resolver.log_dir()?;
Expand All @@ -150,6 +155,7 @@ where
&key_value_resolver,
outbound_networking.as_ref(),
&sqlite_resolver,
&pg_resolver,
);

// Note: all valid fields in the runtime config must have been referenced at
Expand Down Expand Up @@ -302,6 +308,7 @@ pub struct TomlRuntimeConfigSource<'a, 'b> {
key_value: &'a key_value::RuntimeConfigResolver,
outbound_networking: Option<&'a OutboundNetworkingSpinRuntimeConfig>,
sqlite: &'a sqlite::RuntimeConfigResolver,
pg_resolver: &'a pg::PgConfigResolver,
}

impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> {
Expand All @@ -310,12 +317,14 @@ impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> {
key_value: &'a key_value::RuntimeConfigResolver,
outbound_networking: Option<&'a OutboundNetworkingSpinRuntimeConfig>,
sqlite: &'a sqlite::RuntimeConfigResolver,
pg_resolver: &'a pg::PgConfigResolver,
) -> Self {
Self {
toml: toml_resolver,
key_value,
outbound_networking,
sqlite,
pg_resolver,
}
}
}
Expand Down Expand Up @@ -349,8 +358,13 @@ impl FactorRuntimeConfigSource<VariablesFactor> for TomlRuntimeConfigSource<'_,
}

impl FactorRuntimeConfigSource<OutboundPgFactor> for TomlRuntimeConfigSource<'_, '_> {
fn get_runtime_config(&mut self) -> anyhow::Result<Option<()>> {
Ok(None)
fn get_runtime_config(
&mut self,
) -> anyhow::Result<Option<<OutboundPgFactor as spin_factors::Factor>::RuntimeConfig>> {
Ok(Some(
self.pg_resolver
.runtime_config_from_toml(&self.toml.table)?,
))
}
}

Expand Down
49 changes: 49 additions & 0 deletions crates/runtime-config/src/pg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::path::PathBuf;

use serde::Deserialize;
use spin_factor_outbound_pg::runtime_config::RuntimeConfig;
use spin_factors::runtime_config::toml::GetTomlValue;

pub struct PgConfigResolver {
pub(crate) base_dir: Option<PathBuf>, // must have a value if any certs, but we need to deref it lazily
}

impl PgConfigResolver {
pub fn runtime_config_from_toml(
&self,
table: &impl GetTomlValue,
) -> anyhow::Result<RuntimeConfig> {
let Some(table) = table.get("postgres").and_then(|t| t.as_table()) else {
return Ok(Default::default());
};

let table: RuntimeConfigTable = RuntimeConfigTable::deserialize(table.clone())?;

let certificate_paths = table
.root_certificates
.iter()
.map(PathBuf::from)
.collect::<Vec<_>>();

let has_relative = certificate_paths.iter().any(|p| p.is_relative());

let certificate_paths = match (has_relative, self.base_dir.as_ref()) {
(false, _) => certificate_paths,
(true, None) => anyhow::bail!("the runtime config file contains relative certificate paths, but we could not determine the runtime config directory for them to be relative to"),
(true, Some(base)) => certificate_paths.into_iter().map(|p| base.join(p)).collect::<Vec<_>>(),
};

let certificates = certificate_paths
.iter()
.map(std::fs::read)
.collect::<Result<Vec<_>, _>>()?;

Ok(RuntimeConfig { certificates })
}
}

#[derive(Deserialize)]
struct RuntimeConfigTable {
#[serde(default)]
root_certificates: Vec<String>,
}
4 changes: 4 additions & 0 deletions tests/manual/pg-ssl-root-certs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
target/
.spin/
pg
postgres-ssl
Loading