Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/corro-pg/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ impl codec::Decoder for PgWireMessageServerCodec {
api::PgWireConnectionState::AwaitingSslRequest => {}
api::PgWireConnectionState::AwaitingStartup => {
self.decode_context.awaiting_ssl = false;
self.decode_context.awaiting_startup = true;
}
_ => {
self.decode_context.awaiting_ssl = false;
self.decode_context.awaiting_startup = false;
}
}
Expand Down
59 changes: 34 additions & 25 deletions crates/corro-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,41 +598,50 @@ pub async fn start(
let negotiation =
ssl::negotiate_ssl(&mut tcp_socket, tls_acceptor.is_some()).await?;

let (mut framed, secured) = if matches!(negotiation, ssl::SslNegotiationType::None)
{
if ssl_required {
debug!("rejecting non-ssl connection");
return Ok(());
}
let (mut framed, secured, maybe_next_msg) =
if matches!(negotiation, ssl::SslNegotiationType::None(_)) {
if ssl_required {
debug!("rejecting non-ssl connection");
return Ok(());
}

(Either::Left(tcp_socket), false)
} else if let Some(tls) = tls_acceptor {
let tls_socket = tls.accept(tcp_socket.into_inner()).await?;
let maybe_next_msg = match negotiation {
ssl::SslNegotiationType::None(Some(msg)) => Some(msg),
_ => None,
};

if matches!(negotiation, ssl::SslNegotiationType::Direct) {
ssl::check_alpn_for_direct_ssl(&tls_socket)?;
}
(Either::Left(tcp_socket), false, maybe_next_msg)
} else if let Some(tls) = tls_acceptor {
let tls_socket = tls.accept(tcp_socket.into_inner()).await?;

let framed = Framed::new(
tokio::io::BufStream::new(tls_socket),
PgWireMessageServerCodec::new(codec::Client::new(local_addr, true)),
);
if matches!(negotiation, ssl::SslNegotiationType::Direct) {
ssl::check_alpn_for_direct_ssl(&tls_socket)?;
}

(Either::Right(framed), true)
} else {
trace!("received SSL connection attempt without a TLS acceptor configured");
return Ok(());
};
let framed = Framed::new(
tokio::io::BufStream::new(tls_socket),
PgWireMessageServerCodec::new(codec::Client::new(local_addr, true)),
);

(Either::Right(framed), true, None)
} else {
trace!("received SSL connection attempt without a TLS acceptor configured");
return Ok(());
};

trace!("SSL ? {secured}");

use crate::codec::SetState;
framed.set_state(pgwire::api::PgWireConnectionState::AwaitingStartup);

let msg = match framed.next().await {
Some(msg) => msg?,
trace!("maybe_next_msg: {maybe_next_msg:?}");
let msg = match maybe_next_msg {
Some(msg) => msg,
None => {
return Ok(());
framed.set_state(pgwire::api::PgWireConnectionState::AwaitingStartup);
match framed.next().await {
Some(msg) => msg?,
None => return Ok(()),
}
}
};

Expand Down
89 changes: 43 additions & 46 deletions crates/corro-pg/src/ssl.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use futures::{SinkExt, StreamExt};
use pgwire::messages::{
response::{GssEncResponse, SslResponse},
startup::{GssEncRequest, SslRequest},
PgWireBackendMessage,
PgWireBackendMessage, PgWireFrontendMessage,
};
use tokio::io::{AsyncBufReadExt, BufStream};
use tokio_util::codec::Framed;
Expand All @@ -12,7 +11,7 @@ use crate::utils::CountedTcpStream;
pub(super) enum SslNegotiationType {
Postgres,
Direct,
None,
None(Option<PgWireFrontendMessage>),
}

pub(super) async fn negotiate_ssl(
Expand All @@ -28,53 +27,51 @@ pub(super) async fn negotiate_ssl(
let mut gss_done = false;

loop {
let buf = socket.get_mut().fill_buf().await?;
let n = buf.len();
match socket.next().await {
Some(msg) => {
let msg = msg?;
match msg {
PgWireFrontendMessage::SslRequest(_) => {
if ssl_supported {
socket
.send(PgWireBackendMessage::SslResponse(SslResponse::Accept))
.await?;
return Ok(SslNegotiationType::Postgres);
} else {
socket
.send(PgWireBackendMessage::SslResponse(SslResponse::Refuse))
.await?;
ssl_done = true;

// already EOF
if n == 0 {
return Ok(SslNegotiationType::None);
}

if n >= 8 {
if SslRequest::is_ssl_request_packet(buf) {
// consume SslRequest
let _ = socket.next().await;
// ssl request
if ssl_supported {
socket
.send(PgWireBackendMessage::SslResponse(SslResponse::Accept))
.await?;
return Ok(SslNegotiationType::Postgres);
} else {
socket
.send(PgWireBackendMessage::SslResponse(SslResponse::Refuse))
.await?;
ssl_done = true;

if gss_done {
return Ok(SslNegotiationType::None);
} else {
// Continue to check for more requests (e.g., GssEncRequest after SSL refuse)
continue;
if gss_done {
return Ok(SslNegotiationType::None(None));
} else {
// Continue to check for more requests (e.g., GssEncRequest after SSL refuse)
continue;
}
}
}
}
} else if GssEncRequest::is_gss_enc_request_packet(buf) {
let _ = socket.next().await;
socket
.send(PgWireBackendMessage::GssEncResponse(GssEncResponse::Refuse))
.await?;
gss_done = true;
PgWireFrontendMessage::GssEncRequest(_) => {
let _ = socket.next().await;
socket
.send(PgWireBackendMessage::GssEncResponse(GssEncResponse::Refuse))
.await?;
gss_done = true;

if ssl_done {
return Ok(SslNegotiationType::None);
} else {
// Continue to check for more requests (e.g., SSL request after GSSAPI refuse)
continue;
if ssl_done {
return Ok(SslNegotiationType::None(None));
} else {
// Continue to check for more requests (e.g., SSL request after GSSAPI refuse)
continue;
}
}
msg => {
return Ok(SslNegotiationType::None(Some(msg)));
}
}
} else {
// startup or cancel
return Ok(SslNegotiationType::None);
}
None => {
return Ok(SslNegotiationType::None(None));
}
}
}
Expand Down
Loading