Skip to content
Open
35 changes: 29 additions & 6 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ default = ["tracing"]

async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"]
async-stream = [] # unused, remove before the next breaking-change release
file-stream = ["dep:tokio-util", "tokio-util?/io", "dep:tokio", "tokio?/fs", "tokio?/io-util"]
file-stream = [
"dep:tokio-util",
"tokio-util?/io",
"dep:tokio",
"tokio?/fs",
"tokio?/io-util",
]
attachment = ["dep:tracing"]
error-response = ["dep:tracing", "tracing/std"]
cookie = ["dep:cookie"]
Expand All @@ -37,10 +43,19 @@ json-lines = [
multipart = ["dep:multer", "dep:fastrand"]
protobuf = ["dep:prost"]
scheme = []
query = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"]
query = [
"dep:form_urlencoded",
"dep:serde_html_form",
"dep:serde_path_to_error",
]
tracing = ["axum-core/tracing", "axum/tracing"]
typed-header = ["dep:headers"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
typed-routing = [
"dep:axum-macros",
"dep:percent-encoding",
"dep:serde_html_form",
"dep:form_urlencoded",
]

# Enabled by docs.rs because it uses all-features
__private_docs = [
Expand All @@ -52,7 +67,9 @@ __private_docs = [
axum = { path = "../axum", version = "0.8.3", default-features = false, features = ["original-uri"] }
axum-core = { path = "../axum-core", version = "0.5.2" }
bytes = "1.1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
futures-util = { version = "0.3", default-features = false, features = [
"alloc",
] }
http = "1.0.0"
http-body = "1.0.0"
http-body-util = "0.1.0"
Expand All @@ -66,7 +83,9 @@ tower-service = "0.3"

# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.5.0", optional = true }
cookie = { package = "cookie", version = "0.18.0", features = ["percent-encode"], optional = true }
cookie = { package = "cookie", version = "0.18.0", features = [
"percent-encode",
], optional = true }
fastrand = { version = "2.1.0", optional = true }
form_urlencoded = { version = "1.1.0", optional = true }
headers = { version = "0.4.0", optional = true }
Expand All @@ -86,7 +105,11 @@ typed-json = { version = "0.1.1", optional = true }
axum = { path = "../axum", features = ["macros", "__private"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = "1.0.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
reqwest = { version = "0.12", default-features = false, features = [
"json",
"stream",
"multipart",
] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.71"
tokio = { version = "1.14", features = ["full"] }
Expand Down
188 changes: 188 additions & 0 deletions axum-extra/src/extract/cookie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,82 @@ impl CookieJar {
pub fn iter(&self) -> impl Iterator<Item = &'_ Cookie<'static>> {
self.jar.iter()
}

/// Add a cookie with the specified prefix to the jar.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::{CookieJar, Cookie};
/// use cookie::prefix::{Host, Secure};
///
/// async fn handler(jar: CookieJar) -> CookieJar {
/// // Add a cookie with the "__Host-" prefix
/// let with_host = jar.clone().add_prefixed(Host, Cookie::new("session_id", "value"));
///
/// // Add a cookie with the "__Secure-" prefix
/// let _with_secure = jar.add_prefixed(Secure, Cookie::new("auth", "token"));
///
/// with_host
/// }
/// ```
#[must_use]
pub fn add_prefixed<P: cookie::prefix::Prefix>(
mut self,
prefix: P,
cookie: Cookie<'static>,
) -> Self {
let mut prefixed_jar = self.jar.prefixed_mut(prefix);
prefixed_jar.add(cookie);
self
}

/// Get a signed cookie with the specified prefix from the jar.
///
/// If the cookie exists and its signature is valid, it is returned with its original name
/// (without the prefix) and plaintext value.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::{CookieJar, Cookie};
/// use cookie::prefix::{Host, Secure};
///
/// async fn handler(jar: CookieJar) {
/// if let Some(cookie) = jar.get_prefixed(cookie::prefix::Host, "session_id") {
/// let value = cookie.value();
/// }
/// }
/// ```
pub fn get_prefixed<P: cookie::prefix::Prefix>(
&self,
prefix: P,
name: &str,
) -> Option<Cookie<'static>> {
let prefixed_jar = self.jar.prefixed(prefix);
prefixed_jar.get(name)
}

/// Remove a cookie with the specified prefix from the jar.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::CookieJar;
/// use cookie::prefix::{Host, Secure};
///
/// async fn handler(jar: CookieJar) -> CookieJar {
/// // Remove a cookie with the "__Host-" prefix
/// jar.remove_prefixed(Host, "session_id")
/// }
/// ```
#[must_use]
pub fn remove_prefixed<P, S>(mut self, prefix: P, name: S) -> Self
where
P: cookie::prefix::Prefix,
S: Into<String>,
{
let mut prefixed_jar = self.jar.prefixed_mut(prefix);
prefixed_jar.remove(name.into());
self
}
}

impl IntoResponseParts for CookieJar {
Expand Down Expand Up @@ -232,6 +308,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
mod tests {
use super::*;
use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use cookie::prefix::Host;
use http_body_util::BodyExt;
use tower::ServiceExt;

Expand Down Expand Up @@ -268,6 +345,19 @@ mod tests {
.await
.unwrap();
let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
println!("Set Cookie value: {}", cookie_value);

assert!(cookie_value.starts_with("key="));

// For signed/private cookies, verify that the plaintext value is not directly visible
// (only for signed and private jars, not for the regular CookieJar)
if std::any::type_name::<$jar>().contains("Private")
|| std::any::type_name::<$jar>().contains("Signed")
{
assert!(!cookie_value.contains("key=value"));
} else {
assert!(cookie_value.contains("key=value"));
}

let res = app
.clone()
Expand Down Expand Up @@ -302,17 +392,115 @@ mod tests {
};
}

macro_rules! cookie_prefixed_test {
($name:ident, $jar:ty) => {
#[tokio::test]
async fn $name() {
async fn set_cookie_prefixed(jar: $jar) -> impl IntoResponse {
jar.add_prefixed(Host, Cookie::new("key", "value"))
}

async fn get_cookie_prefixed(jar: $jar) -> impl IntoResponse {
jar.get_prefixed(Host, "key").unwrap().value().to_owned()
}

async fn remove_cookie_prefixed(jar: $jar) -> impl IntoResponse {
jar.remove_prefixed(Host, "key")
}

let state = AppState {
key: Key::generate(),
custom_key: CustomKey(Key::generate()),
};

let app = Router::new()
.route("/set", get(set_cookie_prefixed))
.route("/get", get(get_cookie_prefixed))
.route("/remove", get(remove_cookie_prefixed))
.with_state(state);

let res = app
.clone()
.oneshot(Request::builder().uri("/set").body(Body::empty()).unwrap())
.await
.unwrap();
let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
println!("Set Cookie value: {}", cookie_value);
assert!(cookie_value.contains("__Host-key"));

// For signed/private cookies, verify that the plaintext value is not directly visible
// (only for signed and private jars, not for the regular CookieJar)
if std::any::type_name::<$jar>().contains("Private")
|| std::any::type_name::<$jar>().contains("Signed")
{
assert!(!cookie_value.contains("key=value"));
} else {
assert!(cookie_value.contains("key=value"));
}

// Extract just the cookie part (before the first semicolon)
// Set-Cookie: __Host-key=value; Secure; Path=/ -> __Host-key=value
let cookie_header_value = cookie_value.split(';').next().unwrap().trim();
println!("Using Cookie header value: {}", cookie_header_value);

let res = app
.clone()
.oneshot(
Request::builder()
.uri("/get")
.header("cookie", cookie_header_value)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = body_text(res).await;
assert_eq!(body, "value");

let res = app
.clone()
.oneshot(
Request::builder()
.uri("/remove")
.header("cookie", cookie_value)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(res.headers()["set-cookie"]
.to_str()
.unwrap()
.contains("__Host-key=;"));
}
};
}

cookie_test!(plaintext_cookies, CookieJar);

#[cfg(feature = "cookie-signed")]
cookie_test!(signed_cookies, SignedCookieJar);
#[cfg(feature = "cookie-signed")]
cookie_prefixed_test!(signed_cookies_prefixed, SignedCookieJar);
#[cfg(feature = "cookie-signed")]
cookie_test!(signed_cookies_with_custom_key, SignedCookieJar<CustomKey>);
#[cfg(feature = "cookie-signed")]
cookie_prefixed_test!(
signed_cookies_prefixed_with_custom_key,
SignedCookieJar<CustomKey>
);

#[cfg(feature = "cookie-private")]
cookie_test!(private_cookies, PrivateCookieJar);
#[cfg(feature = "cookie-private")]
cookie_prefixed_test!(private_cookies_prefixed, PrivateCookieJar);
#[cfg(feature = "cookie-private")]
cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
#[cfg(feature = "cookie-private")]
cookie_prefixed_test!(
private_cookies_prefixed_with_custom_key,
PrivateCookieJar<CustomKey>
);

#[derive(Clone)]
struct AppState {
Expand Down
82 changes: 81 additions & 1 deletion axum-extra/src/extract/cookie/private.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl<K> PrivateCookieJar<K> {
/// Authenticates and decrypts `cookie`, returning the plaintext version if decryption succeeds
/// or `None` otherwise.
pub fn decrypt(&self, cookie: Cookie<'static>) -> Option<Cookie<'static>> {
self.private_jar().decrypt(cookie)
self.private_jar().decrypt(cookie.clone())
}

/// Get an iterator over all cookies in the jar.
Expand All @@ -267,6 +267,86 @@ impl<K> PrivateCookieJar<K> {
fn private_jar_mut(&mut self) -> PrivateJar<&'_ mut cookie::CookieJar> {
self.jar.private_mut(&self.key)
}
/// Add a signed cookie with the specified prefix to the jar.
///
/// The cookie's value will be signed using the jar's key, and the prefix will determine the
/// cookie's name and attributes (e.g., `Secure`, `Path=/` for `__Host-`).
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie};
/// use cookie::prefix::Host;
///
/// async fn handler(jar: PrivateCookieJar) -> PrivateCookieJar {
/// jar.add_prefixed(Host, Cookie::new("session_id", "value"))
/// }
/// ```
#[must_use]
pub fn add_prefixed<P: cookie::prefix::Prefix>(
self,
_prefix: P,
cookie: Cookie<'static>,
) -> Self {
let mut jar = self.jar;
jar.remove(Cookie::new(cookie.name().to_owned(), ""));

let prefixed_name = format!("{}{}", P::PREFIX, cookie.name());
let mut new_cookie = cookie;
new_cookie.set_name(prefixed_name);
jar.private_mut(&self.key).add(new_cookie);

Self {
jar,
key: self.key,
_marker: self._marker,
}
}
/// Get a signed cookie with the specified prefix from the jar.
///
/// If the cookie exists and its signature is valid, it is returned with its original name
/// (without the prefix) and plaintext value.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::PrivateCookieJar;
///
/// async fn handler(jar: PrivateCookieJar) {
/// if let Some(cookie) = jar.get_prefixed(cookie::prefix::Host, "session_id") {
/// let value = cookie.value();
/// }
/// }
/// ```
pub fn get_prefixed<P: cookie::prefix::Prefix>(
&self,
_prefix: P,
name: &str,
) -> Option<Cookie<'static>> {
let prefixed_name = format!("{}{name}", P::PREFIX);
self.jar
.get(&prefixed_name)
.and_then(|c| self.decrypt(c.clone()))
}
/// Remove a signed cookie with the specified prefix from the jar.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::PrivateCookieJar;
/// use cookie::prefix::Host;
///
/// async fn handler(jar: PrivateCookieJar) -> PrivateCookieJar {
/// jar.remove_prefixed(Host, "session_id")
/// }
/// ```
#[must_use]
pub fn remove_prefixed<P, S>(mut self, prefix: P, name: S) -> Self
where
P: cookie::prefix::Prefix,
S: Into<String>,
{
let mut prefixed_jar = self.jar.prefixed_mut(prefix);
prefixed_jar.remove(name.into());
self
}
}

impl<K> IntoResponseParts for PrivateCookieJar<K> {
Expand Down
Loading
Loading