Skip to content

Commit 443b31c

Browse files
committed
add a http -> https redirect and better port constants
1 parent daaadfb commit 443b31c

File tree

3 files changed

+100
-8
lines changed

3 files changed

+100
-8
lines changed

Cargo.lock

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ version = "0.1.0"
44
edition = "2024"
55

66
[features]
7-
https = ["dep:rustls"]
7+
https = ["dep:rustls", "dep:axum-extra"]
88

99
[dependencies]
1010
axum = "0.8.4"
11+
axum-extra = { version = "0.10.1", optional = true }
1112
axum-server = { version = "0.7.2", features = ["tls-rustls"] }
1213
dotenvy = "0.15.7"
1314
rustls = { version = "0.23.31", optional = true }

src/main.rs

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::net::{Ipv4Addr, SocketAddrV4};
1+
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
22

33
use crate::app::new_app;
44

@@ -8,15 +8,27 @@ compile_error!("Feature `https` must be enabled on release.");
88

99
mod app;
1010

11+
#[allow(dead_code)]
12+
struct Ports {
13+
http: u16,
14+
https: u16,
15+
}
16+
1117
#[cfg(not(debug_assertions))]
12-
const PORT: u16 = 443;
18+
const PORTS: Ports = Ports {
19+
http: 80,
20+
https: 443,
21+
};
1322
#[cfg(debug_assertions)]
14-
const PORT: u16 = 8080;
23+
const PORTS: Ports = Ports {
24+
http: 8080,
25+
https: 4430,
26+
};
1527

1628
#[cfg(debug_assertions)]
17-
const ADDR: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, PORT);
29+
const ADDR: Ipv4Addr = Ipv4Addr::LOCALHOST;
1830
#[cfg(not(debug_assertions))]
19-
const ADDR: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, PORT);
31+
const ADDR: Ipv4Addr = Ipv4Addr::UNSPECIFIED;
2032

2133
#[tokio::main]
2234
async fn main() {
@@ -31,7 +43,7 @@ async fn main() {
3143
async fn http_main() {
3244
let app = new_app();
3345

34-
axum_server::bind(std::net::SocketAddr::V4(ADDR))
46+
axum_server::bind(SocketAddr::V4(SocketAddrV4::new(ADDR, PORTS.http)))
3547
.serve(app.into_make_service())
3648
.await
3749
.unwrap();
@@ -49,11 +61,67 @@ async fn https_main() {
4961
.await
5062
.unwrap();
5163

64+
tokio::spawn(redirect_http_to_https(PORTS));
65+
5266
let app = new_app();
5367

5468
// run https server
55-
axum_server::bind_rustls(std::net::SocketAddr::V4(ADDR), config)
69+
axum_server::bind_rustls(SocketAddr::V4(SocketAddrV4::new(ADDR, PORTS.https)), config)
5670
.serve(app.into_make_service())
5771
.await
5872
.unwrap();
5973
}
74+
75+
#[cfg(feature = "https")]
76+
async fn redirect_http_to_https(ports: Ports) {
77+
use axum::{
78+
BoxError,
79+
handler::HandlerWithoutStateExt,
80+
http::{Uri, uri::Authority},
81+
};
82+
use axum_extra::extract::Host;
83+
84+
fn make_https(host: &str, uri: Uri, https_port: u16) -> Result<Uri, BoxError> {
85+
let mut parts = uri.into_parts();
86+
87+
parts.scheme = Some(axum::http::uri::Scheme::HTTPS);
88+
89+
if parts.path_and_query.is_none() {
90+
parts.path_and_query = Some("/".parse().unwrap());
91+
}
92+
93+
let authority: Authority = host.parse()?;
94+
let bare_host = match authority.port() {
95+
Some(port_struct) => authority
96+
.as_str()
97+
.strip_suffix(port_struct.as_str())
98+
.unwrap()
99+
.strip_suffix(':')
100+
.unwrap(), // if authority.port() is Some(port) then we can be sure authority ends with :{port}
101+
None => authority.as_str(),
102+
};
103+
104+
parts.authority = Some(format!("{bare_host}:{https_port}").parse()?);
105+
106+
Ok(Uri::from_parts(parts)?)
107+
}
108+
109+
let redirect = move |Host(host): Host, uri: Uri| async move {
110+
use axum::response::Redirect;
111+
112+
match make_https(&host, uri, ports.https) {
113+
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
114+
Err(_) => {
115+
use axum::http::StatusCode;
116+
117+
Err(StatusCode::BAD_REQUEST)
118+
}
119+
}
120+
};
121+
122+
let addr = SocketAddr::V4(SocketAddrV4::new(ADDR, ports.http));
123+
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
124+
axum::serve(listener, redirect.into_make_service())
125+
.await
126+
.unwrap();
127+
}

0 commit comments

Comments
 (0)