Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 208 additions & 7 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket"
notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
oid-registry = "0.7.1"
once_cell = "1.13"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
Expand Down Expand Up @@ -173,6 +174,7 @@ rustc-hash = "2.1.1"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
rustls-split = "0.3"
scopeguard = "1.1"
sysinfo = "0.29.2"
sd-notify = "0.4.1"
Expand Down Expand Up @@ -235,6 +237,7 @@ whoami = "1.5.1"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
x509-parser = "0.16"

## TODO replace this with tracing
env_logger = "0.11"
Expand Down
12 changes: 12 additions & 0 deletions compute_tools/src/http/middleware/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ pub(in crate::http) struct Authorize {
impl Authorize {
pub fn new(compute_id: String, jwks: JwkSet) -> Self {
let mut validation = Validation::new(Algorithm::EdDSA);

// BEGIN HADRON
let use_rsa = jwks.keys.iter().any(|jwk| {
jwk.common
.key_algorithm
.is_some_and(|alg| alg == jsonwebtoken::jwk::KeyAlgorithm::RS256)
});
if use_rsa {
validation = Validation::new(Algorithm::RS256);
}
// END HADRON

validation.validate_exp = true;
// Unused by the control plane
validation.validate_nbf = false;
Expand Down
2 changes: 2 additions & 0 deletions control_plane/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true
x509-parser.workspace = true
rsa = "0.9"
6 changes: 5 additions & 1 deletion control_plane/src/bin/neon_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
// User (likely interactive) did not provide a description of the environment, give them the default
NeonLocalInitConf {
control_plane_api: Some(DEFAULT_PAGESERVER_CONTROL_PLANE_API.parse().unwrap()),
auth_token_type: AuthType::NeonJWT,
broker: NeonBroker {
listen_addr: Some(DEFAULT_BROKER_ADDR.parse().unwrap()),
listen_https_addr: None,
Expand Down Expand Up @@ -1585,7 +1586,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
assert!(!pageservers.is_empty());

let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
let auth_token = if matches!(
ps_conf.pg_auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
) {
let claims = Claims::new(Some(endpoint.tenant_id), Scope::Tenant);

Some(env.generate_auth_token(&claims)?)
Expand Down
87 changes: 70 additions & 17 deletions control_plane/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,8 @@
//! <other PostgreSQL files>
//! ```
//!
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};

use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
Expand All @@ -62,21 +52,31 @@ use compute_api::spec::{
};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, RSAKeyParameters, RSAKeyType,
};
use nix::sys::signal::{Signal, kill};
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey, traits::PublicKeyParts};
use safekeeper_api::PgMajorVersion;
use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::debug;
use url::Host;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::ShardStripeSize;
use x509_parser::parse_x509_certificate;

use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
Expand Down Expand Up @@ -155,23 +155,76 @@ impl ComputeControlPlane {
.unwrap_or(self.base_port)
}

// BEGIN HADRON

/// Extract SubjectPublicKeyInfo from a PEM that can be either a X509 certificate or a public key
fn extract_spki_from_pem(pem: &Pem) -> Result<Vec<u8>> {
if pem.tag() == "CERTIFICATE" {
// Handle X509 certificate
let (_, cert) = parse_x509_certificate(pem.contents())?;
let public_key = cert.public_key();
Ok(public_key.subject_public_key.data.to_vec())
} else {
// Handle public key directly
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
Ok(spki.subject_public_key.raw_bytes().to_vec())
}
}

/// Create RSA JWK from certificate PEM
fn create_rsa_jwk_from_cert(pem: &Pem, key_hash: &[u8]) -> Result<Jwk> {
let public_key = Self::extract_spki_from_pem(pem)?;

// Extract RSA parameters (n, e) from RSA public key DER data
let rsa_key = RsaPublicKey::from_pkcs1_der(&public_key)?;
let n = rsa_key.n().to_bytes_be();
let e = rsa_key.e().to_bytes_be();

Ok(Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::RS256),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
x509_sha256_fingerprint: None::<String>,
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
key_type: RSAKeyType::RSA,
n: URL_SAFE_NO_PAD.encode(n),
e: URL_SAFE_NO_PAD.encode(e),
}),
})
}

// END HADRON

/// Create a JSON Web Key Set. This ideally matches the way we create a JWKS
/// from the production control plane.
fn create_jwks_from_pem(pem: &Pem) -> Result<JwkSet> {
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
let public_key = spki.subject_public_key.raw_bytes();
let public_key = Self::extract_spki_from_pem(pem)?;

let mut hasher = Sha256::new();
hasher.update(public_key);
hasher.update(&public_key);
let key_hash = hasher.finalize();

// BEGIN HADRON
if pem.tag() == "CERTIFICATE" {
// Assume RSA if we are parsing keys from a certificate.
let jwk = Self::create_rsa_jwk_from_cert(pem, &key_hash)?;
return Ok(JwkSet { keys: vec![jwk] });
}
// END HADRON

Ok(JwkSet {
keys: vec![Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
Expand All @@ -180,7 +233,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
x: URL_SAFE_NO_PAD.encode(public_key),
}),
}],
})
Expand Down
10 changes: 10 additions & 0 deletions control_plane/src/endpoint_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::background_process::{self, start_process, stop_process};
use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
Expand All @@ -16,15 +17,22 @@ pub struct EndpointStorage {
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub auth_type: AuthType,
}

impl EndpointStorage {
pub fn from_env(env: &LocalEnv) -> EndpointStorage {
let auth_type = match env.token_auth_type {
AuthType::HadronJWT => AuthType::HadronJWT,
AuthType::NeonJWT | AuthType::Trust => AuthType::NeonJWT,
};

EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
auth_type,
}
}

Expand All @@ -46,12 +54,14 @@ impl EndpointStorage {
pemfile: Utf8PathBuf,
local_path: Utf8PathBuf,
r#type: String,
auth_type: AuthType,
}
let cfg = Cfg {
listen: self.listen_addr(),
pemfile: parent.join(self.pemfile.clone()),
local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR),
r#type: "LocalFs".to_string(),
auth_type: self.auth_type,
};
std::fs::create_dir_all(self.config_path().parent().unwrap())?;
std::fs::write(self.config_path(), serde_json::to_string(&cfg)?)
Expand Down
Loading
Loading