Skip to content

Commit 456b0eb

Browse files
committed
fix(authorization): Fix authorization middleware to retry on 403
1 parent f4f7a48 commit 456b0eb

File tree

3 files changed

+76
-16
lines changed

3 files changed

+76
-16
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "bi"
3-
version = "0.0.28"
3+
version = "0.0.29"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

src/beyond_identity/api/common/middleware/authorization.rs

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ use crate::common::database::Database;
66
use crate::common::error::BiError;
77

88
use http::Extensions;
9+
use http::StatusCode;
910
use reqwest::{Request, Response};
1011
use reqwest_middleware::ClientWithMiddleware as Client;
11-
use reqwest_middleware::{ClientWithMiddleware, Middleware, Next, Result as MiddlewareResult};
12+
use reqwest_middleware::{
13+
ClientWithMiddleware, Error, Middleware, Next, Result as MiddlewareResult,
14+
};
1215
use serde::{Deserialize, Serialize};
1316
use std::time::{SystemTime, UNIX_EPOCH};
1417

@@ -43,16 +46,61 @@ impl Middleware for AuthorizationMiddleware {
4346
extensions: &mut Extensions,
4447
next: Next<'_>,
4548
) -> MiddlewareResult<Response> {
46-
let token = token(&self.db, &self.client, &self.tenant, &self.realm)
49+
let fetched_token = token(&self.db, &self.client, &self.tenant, &self.realm)
4750
.await
4851
.map_err(|e| reqwest_middleware::Error::Middleware(e.into()))?;
4952

5053
req.headers_mut().insert(
5154
reqwest::header::AUTHORIZATION,
52-
format!("Bearer {}", token).parse().unwrap(),
55+
format!("Bearer {}", fetched_token).parse().unwrap(),
5356
);
5457

55-
next.run(req, extensions).await
58+
let mut response = next
59+
.clone()
60+
.run(req.try_clone().unwrap(), extensions)
61+
.await?;
62+
63+
if response.status() == StatusCode::FORBIDDEN {
64+
log::debug!("Received 403 Forbidden, attempting to refresh token and retry request.");
65+
66+
// Invalidate the current token
67+
if let (Some(tenant), Some(realm)) = (&self.tenant, &self.realm) {
68+
self.db
69+
.delete_token(&tenant.id, &realm.id)
70+
.await
71+
.map_err(|e| {
72+
reqwest_middleware::Error::Middleware(
73+
BiError::StringError(e.to_string()).into(),
74+
)
75+
})?;
76+
}
77+
78+
// Fetch a new token
79+
let new_token = token(&self.db, &self.client, &self.tenant, &self.realm)
80+
.await
81+
.map_err(|e| reqwest_middleware::Error::Middleware(e.into()))?;
82+
83+
// Retry the request with the new token
84+
let mut new_req = req.try_clone().ok_or_else(|| {
85+
Error::Middleware(anyhow::anyhow!(
86+
"Request object is not clonable. Are you passing a streaming body?".to_string()
87+
))
88+
})?;
89+
new_req.headers_mut().insert(
90+
reqwest::header::AUTHORIZATION,
91+
format!("Bearer {}", new_token).parse().unwrap(),
92+
);
93+
94+
response = next.run(new_req, extensions).await?;
95+
96+
if response.status() == StatusCode::FORBIDDEN {
97+
log::error!(
98+
"Received 403 Forbidden after refreshing the token. This may indicate invalid credentials, insufficient permissions, or a server-side issue. Check the token, request headers, and server configuration."
99+
);
100+
}
101+
}
102+
103+
Ok(response)
56104
}
57105
}
58106

@@ -84,17 +132,11 @@ async fn token(
84132
.unwrap()
85133
.as_secs();
86134

87-
if token.expires_at >= 0 && (token.expires_at as u64) > current_time {
88-
log::debug!("Using stored bearer token for all requests");
89-
return Ok(token.access_token);
90-
}
91-
}
92-
93-
if let Some(token) = db.get_token(&tenant.id, &realm.id).await? {
94-
let current_time = SystemTime::now()
95-
.duration_since(UNIX_EPOCH)
96-
.unwrap()
97-
.as_secs();
135+
log::debug!(
136+
"Current time: {}, stored token expires at: {}",
137+
current_time,
138+
token.expires_at
139+
);
98140

99141
if token.expires_at >= 0 && (token.expires_at as u64) > current_time {
100142
log::debug!("Using stored bearer token for all requests");
@@ -143,6 +185,12 @@ async fn token(
143185
.as_secs();
144186
let expires_at = current_time + token_response.expires_in;
145187

188+
log::debug!(
189+
"Token expires in: {} seconds, setting expires_at to: {}",
190+
token_response.expires_in,
191+
expires_at
192+
);
193+
146194
let token = Token {
147195
access_token: token_response.access_token,
148196
expires_at: expires_at as i64,

src/common/database/database.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,18 @@ impl Database {
258258
Ok(())
259259
}
260260

261+
// Delete a token by tenant_id and realm_id
262+
pub async fn delete_token(&self, tenant_id: &str, realm_id: &str) -> Result<(), BiError> {
263+
query("DELETE FROM tokens WHERE tenant_id = ? AND realm_id = ?")
264+
.bind(tenant_id)
265+
.bind(realm_id)
266+
.execute(&self.pool)
267+
.await
268+
.map_err(|e| BiError::StringError(e.to_string()))?;
269+
270+
Ok(())
271+
}
272+
261273
// Get okta config from db
262274
pub async fn get_okta_config(&self) -> Result<Option<OktaConfig>, BiError> {
263275
self.get_config(OKTA_CONFIG_KEY).await

0 commit comments

Comments
 (0)