diff --git a/openleadr-vtn/src/jwt.rs b/openleadr-vtn/src/jwt.rs index a287ea8..e3d8638 100644 --- a/openleadr-vtn/src/jwt.rs +++ b/openleadr-vtn/src/jwt.rs @@ -19,7 +19,7 @@ use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation}; use openleadr_wire::ven::VenId; use tracing::trace; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::env; pub struct JwtManager { @@ -75,6 +75,21 @@ pub enum AlgorithmDef { EdDSA, } +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[cfg_attr(test, derive(PartialOrd, Ord))] +pub enum Scope { + #[serde(rename = "read_all")] + ReadAll, + #[serde(rename = "write_programs")] + WritePrograms, + #[serde(rename = "write_reports")] + WriteReports, + #[serde(rename = "write_events")] + WriteEvents, + #[serde(rename = "write_vens")] + WriteVens, +} + #[derive(Debug, PartialEq, Eq, Clone, Deserialize)] struct RsaKey { kty: OAuthKeyType, @@ -129,6 +144,116 @@ pub(crate) struct Claims { pub(crate) roles: Vec, } +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +struct InitialClaims { + exp: usize, + nbf: usize, + sub: String, + roles: Option>, + scope: Option, +} + +#[derive(Clone, Debug, serde::Serialize)] +struct Scopes { + scopes: Vec, +} + +impl<'de> Deserialize<'de> for Scopes { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s: &str = Deserialize::deserialize(deserializer)?; + let parts = s.split(" "); + + let mut scopes: Vec = Vec::new(); + for part in parts { + match part { + "read_all" => scopes.push(Scope::ReadAll), + "write_vens" => scopes.push(Scope::WriteVens), + "write_programs" => scopes.push(Scope::WritePrograms), + "write_events" => scopes.push(Scope::WriteEvents), + "write_reports" => scopes.push(Scope::WriteReports), + _ => { + trace!("Unknown scope encountered: {:?}", part); + } + } + } + + Ok(Scopes { scopes }) + } +} + +impl InitialClaims { + fn map_scope_to_roles(&self) -> Vec { + let mut roles: Vec = Vec::new(); + match &self.scope { + None => roles, + Some(s) => { + // If ReadAll && WriteVens -> VenManager + if s.scopes.iter().any(|r| r == &Scope::ReadAll) + && s.scopes.iter().any(|r| r == &Scope::WriteVens) + { + roles.push(AuthRole::VenManager); + } + + // If ReadAll && WriteReports -> Ven("anonymous") + if s.scopes.iter().any(|r| r == &Scope::ReadAll) + && s.scopes.iter().any(|r| r == &Scope::WriteReports) + { + roles.push(AuthRole::VEN(VenId::new("anonymous").unwrap())); + } + + // If ReadAll && WritePrograms && WriteEvents -> AnyBusiness + if s.scopes.iter().any(|r| r == &Scope::ReadAll) + && s.scopes.iter().any(|r| r == &Scope::WritePrograms) + && s.scopes.iter().any(|r| r == &Scope::WriteEvents) + { + roles.push(AuthRole::AnyBusiness); + } + + roles + } + } + } +} + +impl TryFrom for Claims { + type Error = ResponseOAuthError; + + fn try_from(initial: InitialClaims) -> Result { + match initial.roles { + // when roles are empty, check scopes + // and map these to our roles + None => { + if initial.scope.is_none() { + return Err(OAuthError::new(OAuthErrorType::InvalidGrant) + .with_description( + "Token must contain valid roles or a valid scope".to_string(), + ) + .into()); + } + + Ok(Claims { + roles: initial.map_scope_to_roles(), + exp: initial.exp, + nbf: initial.nbf, + sub: initial.sub, + }) + } + + // otherwise ignore scope and use + // the given roles + Some(roles) => Ok(Claims { + roles, + exp: initial.exp, + nbf: initial.nbf, + sub: initial.sub, + }), + } + } +} + #[cfg(test)] #[cfg(feature = "live-db-test")] impl Claims { @@ -278,15 +403,15 @@ impl JwtManager { match &self.decoding_key { Some(key) => { - let token_data = jsonwebtoken::decode::(token, key, &self.validation)?; - Ok(token_data.claims) + let token_data = + jsonwebtoken::decode::(token, key, &self.validation)?; + token_data.claims.try_into() } None => { // Fetch server keys let keys = self.fetch_keys().await; - // Try multiple keys; if fail then try to fetch new keys if keys.is_empty() { return Err(OAuthError::new(OAuthErrorType::NoAvailableKeys) .with_description( @@ -296,8 +421,11 @@ impl JwtManager { } for decoding_key in keys { - let signature_data = - jsonwebtoken::decode::(token, &decoding_key, &signature_validation); + let signature_data = jsonwebtoken::decode::( + token, + &decoding_key, + &signature_validation, + ); match signature_data { // If signature is correct, validate claims @@ -500,3 +628,164 @@ where Ok(VenManagerUser(user)) } } + +#[cfg(test)] +mod tests { + + use crate::jwt::{AuthRole, Claims, InitialClaims, Scope, Scopes, VenId}; + + #[test] + fn test_no_roles_no_scope_into_claims() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: None, + scope: None, + }; + + let claims: Result = initial.try_into(); + assert!(claims.is_err()); + } + + #[test] + fn test_initial_roles_into_claims() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: Some(vec![AuthRole::AnyBusiness, AuthRole::VenManager]), + scope: None, + }; + + let claims: Result = initial.clone().try_into(); + assert!(claims.is_ok()); + + let values = claims.unwrap(); + assert_eq!(values.exp, initial.exp); + assert_eq!(values.nbf, initial.nbf); + assert_eq!(values.sub, initial.sub); + assert_eq!( + values.roles, + vec![AuthRole::AnyBusiness, AuthRole::VenManager] + ); + } + + #[test] + fn test_scope_ignored_if_roles_present() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: Some(vec![AuthRole::AnyBusiness]), + scope: Some(Scopes { + scopes: vec![Scope::ReadAll, Scope::WriteVens], + }), + }; + + let claims: Result = initial.clone().try_into(); + assert!(claims.is_ok()); + + let values = claims.unwrap(); + assert_eq!(values.roles, vec![AuthRole::AnyBusiness]); + } + + #[test] + fn test_scope_into_any_business_role() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: None, + scope: Some(Scopes { + scopes: vec![Scope::ReadAll, Scope::WritePrograms, Scope::WriteEvents], + }), + }; + + let claims: Result = initial.clone().try_into(); + assert!(claims.is_ok()); + + let values = claims.unwrap(); + assert_eq!(values.exp, initial.exp); + assert_eq!(values.nbf, initial.nbf); + assert_eq!(values.sub, initial.sub); + assert_eq!(values.roles, vec![AuthRole::AnyBusiness]); + } + + #[test] + fn test_scope_into_ven_manager_role() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: None, + scope: Some(Scopes { + scopes: vec![Scope::ReadAll, Scope::WriteVens], + }), + }; + + let claims: Result = initial.clone().try_into(); + assert!(claims.is_ok()); + + let values = claims.unwrap(); + assert_eq!(values.exp, initial.exp); + assert_eq!(values.nbf, initial.nbf); + assert_eq!(values.sub, initial.sub); + assert_eq!(values.roles, vec![AuthRole::VenManager]); + } + + #[test] + fn test_scope_into_anonymous_ven_role() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: None, + scope: Some(Scopes { + scopes: vec![Scope::ReadAll, Scope::WriteReports], + }), + }; + + let claims: Result = initial.clone().try_into(); + assert!(claims.is_ok()); + + let values = claims.unwrap(); + assert_eq!(values.exp, initial.exp); + assert_eq!(values.nbf, initial.nbf); + assert_eq!(values.sub, initial.sub); + assert_eq!( + values.roles, + vec![AuthRole::VEN(VenId::new("anonymous").unwrap())] + ); + } + + #[test] + fn test_scope_into_multiple_roles() { + let initial = InitialClaims { + exp: 10, + nbf: 10, + sub: "test".to_string(), + roles: None, + scope: Some(Scopes { + scopes: vec![ + Scope::ReadAll, + Scope::WriteVens, + Scope::WritePrograms, + Scope::WriteEvents, + ], + }), + }; + + let claims: Result = initial.clone().try_into(); + assert!(claims.is_ok()); + + let values = claims.unwrap(); + assert_eq!(values.exp, initial.exp); + assert_eq!(values.nbf, initial.nbf); + assert_eq!(values.sub, initial.sub); + assert_eq!( + values.roles, + vec![AuthRole::VenManager, AuthRole::AnyBusiness] + ); + } +} diff --git a/openleadr-wire/src/oauth.rs b/openleadr-wire/src/oauth.rs index e5b4c9d..9480bec 100644 --- a/openleadr-wire/src/oauth.rs +++ b/openleadr-wire/src/oauth.rs @@ -4,7 +4,7 @@ pub enum OAuthErrorType { OAuthNotEnabled, InvalidRequest, InvalidClient, - // InvalidGrant, + InvalidGrant, // UnauthorizedClient, UnsupportedGrantType, // InvalidScope,