Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ pub struct RftClaims {
/// Allowed LoRA adapter name
#[serde(default)]
pub lora: Option<String>,
/// Alternate names that should resolve to the base `model`. Used when
/// the platform exposes a user-facing model identifier (e.g.
/// `sprints/Llama-3.2-1B-Instruct`) that differs from the canonical
/// HF path vLLM serves (`meta-llama/Llama-3.2-1B-Instruct`). A
/// request hitting an alias is authorized and the request body's
/// `model` field is rewritten to `self.model` before dispatch.
#[serde(default)]
pub model_aliases: Vec<String>,
}

/// Verifier for RS256 JWTs signed by the platform.
Expand Down Expand Up @@ -73,6 +81,10 @@ impl RftClaims {
}
}

if self.is_model_alias(requested) {
return true;
Comment on lines +84 to +85

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Require a base model before authorizing aliases

When a JWT contains model_aliases but omits model (or sets it empty, which this struct permits via #[serde(default)]), this branch authorizes the alias even though canonical_for_alias later returns None because there is no canonical base to rewrite to. In that misconfigured/lora-only token case the server forwards the alias unchanged on chat/completions/embeddings/responses, broadening the JWT scope instead of failing closed; alias authorization should be gated on a non-empty base model or handled as an error.

Useful? React with 👍 / 👎.

}

if let Some(lora) = self.lora.as_deref() {
// Empty lora claim must never authorize anything; an empty
// string is a prefix of every other string, which would let
Expand All @@ -96,19 +108,92 @@ impl RftClaims {

false
}

/// Whether `requested` is one of the alternate names declared in
/// `model_aliases`. Requires a non-empty base `model` claim: aliases
/// exist to rewrite to canonical and there is no canonical without
/// a base. Without this gate a JWT with `model_aliases` but no
/// `model` would authorize the alias and forward it unchanged to
/// vLLM, broadening scope rather than failing closed. Empty alias
/// entries are also ignored so a misconfigured claim never
/// authorizes the empty model.
fn is_model_alias(&self, requested: &str) -> bool {
if requested.is_empty() {
return false;
}
let Some(base) = self.model.as_deref() else {
return false;
};
if base.is_empty() {
return false;
}
self.model_aliases
.iter()
.any(|a| !a.is_empty() && a == requested)
}

/// If `requested` matched the JWT only via `model_aliases`, return
/// the canonical model name (`self.model`) so the request body can
/// be rewritten before forwarding to vLLM. Returns `None` if the
/// request already targets the base model or a LoRA — those need to
/// pass through unchanged so vLLM can dispatch the LoRA adapter.
///
/// LoRA-shadowing is the load-bearing case: a pathological JWT could
/// list the same string in both `lora` (or as a `<lora>-…` step
/// adapter) and `model_aliases`. Rewriting such a request to the
/// base model would silently swap a LoRA call for a base-model call,
/// so we treat LoRA matches as taking precedence over alias matches.
pub fn canonical_for_alias(&self, requested: &str) -> Option<String> {
if !self.is_model_alias(requested) {
return None;
}
let base = self.model.as_deref()?;
if base.is_empty() || base == requested {
return None;
}
if self.matches_lora(requested) {
return None;
}
Some(base.to_string())
Comment thread
cursor[bot] marked this conversation as resolved.
Comment on lines +146 to +157

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve LoRA names before alias rewrites

When a signed JWT contains a model_aliases entry that is also the allowed LoRA name (or a step-versioned LoRA alias), canonical_for_alias() returns the base model before checking whether requested is a LoRA match. The typed handlers then rewrite model to the base, so a request intended to use the run's adapter is forwarded as a base-model request instead, corrupting LoRA-scoped runs under that alias configuration; the helper's contract says LoRA targets must pass through unchanged.

Useful? React with 👍 / 👎.

}

/// Whether `requested` is authorized via the `lora` claim — either
/// an exact match or a `<lora>-<suffix>` step adapter. Mirrors the
/// LoRA branch of `allows_model` so `canonical_for_alias` can defer
/// to LoRA dispatch when both branches would otherwise authorize.
fn matches_lora(&self, requested: &str) -> bool {
let Some(lora) = self.lora.as_deref() else {
return false;
};
if lora.is_empty() {
return false;
}
if requested == lora {
return true;
}
if let Some(rest) = requested.strip_prefix(lora) {
return rest.starts_with('-');
}
false
}
}

#[cfg(test)]
mod tests {
use super::*;

fn claims(model: Option<&str>, lora: Option<&str>) -> RftClaims {
claims_with_aliases(model, lora, &[])
}

fn claims_with_aliases(model: Option<&str>, lora: Option<&str>, aliases: &[&str]) -> RftClaims {
RftClaims {
sub: "user".into(),
run_id: "abc".into(),
team_id: String::new(),
model: model.map(String::from),
lora: lora.map(String::from),
model_aliases: aliases.iter().map(|s| (*s).to_string()).collect(),
}
}

Expand Down Expand Up @@ -165,4 +250,91 @@ mod tests {
let c = claims(Some(""), Some("rft-abc"));
assert!(!c.allows_model(""));
}

#[test]
fn allows_model_alias() {
let c = claims_with_aliases(
Some("meta-llama/Llama-3.2-1B-Instruct"),
Some("rft-abc"),
&["sprints/Llama-3.2-1B-Instruct"],
);
assert!(c.allows_model("sprints/Llama-3.2-1B-Instruct"));
assert!(c.allows_model("meta-llama/Llama-3.2-1B-Instruct"));
assert!(!c.allows_model("other/model"));
}

#[test]
fn canonical_for_alias_rewrites_alias_only() {
let c = claims_with_aliases(
Some("meta-llama/Llama-3.2-1B-Instruct"),
Some("rft-abc"),
&["sprints/Llama-3.2-1B-Instruct"],
);
assert_eq!(
c.canonical_for_alias("sprints/Llama-3.2-1B-Instruct").as_deref(),
Some("meta-llama/Llama-3.2-1B-Instruct"),
);
// Base model and lora must NOT be rewritten — lora dispatch
// depends on the original name reaching vLLM.
assert_eq!(c.canonical_for_alias("meta-llama/Llama-3.2-1B-Instruct"), None);
assert_eq!(c.canonical_for_alias("rft-abc"), None);
assert_eq!(c.canonical_for_alias("rft-abc-step-42"), None);
assert_eq!(c.canonical_for_alias("unrelated"), None);
}

#[test]
fn empty_alias_entry_authorizes_nothing() {
let c = claims_with_aliases(Some("meta-llama/Llama-3.2-1B-Instruct"), Some("rft-abc"), &[""]);
assert!(!c.allows_model(""));
assert_eq!(c.canonical_for_alias(""), None);
}

#[test]
fn alias_matching_base_does_not_rewrite() {
// Pathological config: alias equals the base. Should still
// authorize, but not produce a rewrite (would be a no-op anyway).
let c = claims_with_aliases(Some("meta-llama/Llama-3.2-1B-Instruct"), None, &["meta-llama/Llama-3.2-1B-Instruct"]);
assert!(c.allows_model("meta-llama/Llama-3.2-1B-Instruct"));
assert_eq!(c.canonical_for_alias("meta-llama/Llama-3.2-1B-Instruct"), None);
}

#[test]
fn alias_matching_lora_does_not_rewrite() {
// Pathological config: an alias entry collides with the lora
// claim. Rewriting would silently swap the LoRA call for a
// base-model call, so LoRA matches take precedence.
let c = claims_with_aliases(
Some("meta-llama/Llama-3.2-1B-Instruct"),
Some("rft-abc"),
&["rft-abc"],
);
assert!(c.allows_model("rft-abc"));
assert_eq!(c.canonical_for_alias("rft-abc"), None);
}

#[test]
fn alias_matching_step_versioned_lora_does_not_rewrite() {
let c = claims_with_aliases(
Some("meta-llama/Llama-3.2-1B-Instruct"),
Some("rft-abc"),
&["rft-abc-step-42"],
);
assert!(c.allows_model("rft-abc-step-42"));
assert_eq!(c.canonical_for_alias("rft-abc-step-42"), None);
}

#[test]
fn alias_without_base_model_does_not_authorize() {
// Without a non-empty base model the router has no canonical
// name to rewrite to, so authorizing the alias would forward an
// arbitrary string to vLLM unchanged and broaden JWT scope.
// Fail closed instead.
let c_no_base = claims_with_aliases(None, Some("rft-abc"), &["sprints/Llama-3.2-1B-Instruct"]);
assert!(!c_no_base.allows_model("sprints/Llama-3.2-1B-Instruct"));
assert_eq!(c_no_base.canonical_for_alias("sprints/Llama-3.2-1B-Instruct"), None);

let c_empty_base = claims_with_aliases(Some(""), Some("rft-abc"), &["sprints/Llama-3.2-1B-Instruct"]);
assert!(!c_empty_base.allows_model("sprints/Llama-3.2-1B-Instruct"));
assert_eq!(c_empty_base.canonical_for_alias("sprints/Llama-3.2-1B-Instruct"), None);
}
}
27 changes: 25 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async fn transparent_proxy_handler(State(state): State<Arc<AppState>>, req: Requ
};

// Parse body as JSON
let body_json: serde_json::Value = if body_bytes.is_empty() {
let mut body_json: serde_json::Value = if body_bytes.is_empty() {
serde_json::Value::Null
} else {
match serde_json::from_slice(&body_bytes) {
Expand All @@ -193,8 +193,9 @@ async fn transparent_proxy_handler(State(state): State<Arc<AppState>>, req: Requ
.get("model")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
.map(str::to_string)
{
if !claims_ref.allows_model(model) {
if !claims_ref.allows_model(&model) {
warn!(
run_id = %claims_ref.run_id,
requested_model = %model,
Expand All @@ -209,6 +210,14 @@ async fn transparent_proxy_handler(State(state): State<Arc<AppState>>, req: Requ
)
.into_response();
}
if let Some(canonical) = claims_ref.canonical_for_alias(&model) {
if let Some(obj) = body_json.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(canonical),
);
}
}
}
if let Err(response) = enforce_no_lora_path_override_json(&claims, &body_json) {
return response;
Expand Down Expand Up @@ -445,6 +454,9 @@ fn pin_and_check_model(

let resolved = model.as_deref().unwrap_or("");
if claims.allows_model(resolved) {
if let Some(canonical) = claims.canonical_for_alias(resolved) {
*model = Some(canonical);
}
return Ok(());
}

Expand Down Expand Up @@ -500,6 +512,9 @@ fn pin_and_check_model_string(
}

if claims.allows_model(model) {
if let Some(canonical) = claims.canonical_for_alias(model) {
*model = canonical;
}
return Ok(());
}

Expand Down Expand Up @@ -558,6 +573,14 @@ fn pin_and_check_model_json(
};

if claims.allows_model(&resolved) {
if let Some(canonical) = claims.canonical_for_alias(&resolved) {
if let Some(obj) = body.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(canonical),
);
}
}
return Ok(());
}

Expand Down
Loading