diff --git a/packages/cubejs-api-gateway/src/sql-server.ts b/packages/cubejs-api-gateway/src/sql-server.ts index a8faa00cd8ae9..d09cf7437d5c9 100644 --- a/packages/cubejs-api-gateway/src/sql-server.ts +++ b/packages/cubejs-api-gateway/src/sql-server.ts @@ -31,6 +31,11 @@ export type SQLServerConstructorOptions = { gatewayPort?: number, }; +export type SqlAuthServiceAuthenticateRequest = { + protocol: string; + method: string; +}; + export class SQLServer { protected sqlInterfaceInstance: SqlInterfaceInstance | null = null; @@ -88,10 +93,14 @@ export class SQLServer { let { securityContext } = session; if (request.meta.changeUser && request.meta.changeUser !== session.user) { + const sqlAuthRequest: SqlAuthServiceAuthenticateRequest = { + protocol: request.meta.protocol, + method: 'password', + }; const canSwitch = session.superuser || await canSwitchSqlUser(session.user, request.meta.changeUser); if (canSwitch) { userForContext = request.meta.changeUser; - const current = await checkSqlAuth(request, userForContext, null); + const current = await checkSqlAuth({ ...request, ...sqlAuthRequest }, userForContext, null); securityContext = current.securityContext; } else { throw new Error( diff --git a/packages/cubejs-backend-native/src/auth.rs b/packages/cubejs-backend-native/src/auth.rs index ff297850bcda1..5a495588c6125 100644 --- a/packages/cubejs-backend-native/src/auth.rs +++ b/packages/cubejs-backend-native/src/auth.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use cubesql::{ di_service, - sql::{AuthContext, AuthenticateResponse, SqlAuthService}, + sql::{AuthContext, AuthenticateResponse, SqlAuthService, SqlAuthServiceAuthenticateRequest}, transport::LoadRequestMeta, CubeError, }; @@ -50,9 +50,28 @@ pub struct TransportRequest { pub meta: Option, } +#[derive(Debug, Serialize)] +pub struct TransportAuthRequest { + pub id: String, + pub meta: Option, + pub protocol: String, + pub method: String, +} + +impl From<(TransportRequest, SqlAuthServiceAuthenticateRequest)> for TransportAuthRequest { + fn from((t, a): (TransportRequest, SqlAuthServiceAuthenticateRequest)) -> Self { + Self { + id: t.id, + meta: t.meta, + protocol: a.protocol, + method: a.method, + } + } +} + #[derive(Debug, Serialize)] struct CheckSQLAuthTransportRequest { - request: TransportRequest, + request: TransportAuthRequest, user: Option, password: Option, } @@ -92,6 +111,7 @@ impl AuthContext for NativeSQLAuthContext { impl SqlAuthService for NodeBridgeAuthService { async fn authenticate( &self, + request: SqlAuthServiceAuthenticateRequest, user: Option, password: Option, ) -> Result { @@ -100,9 +120,11 @@ impl SqlAuthService for NodeBridgeAuthService { let request_id = Uuid::new_v4().to_string(); let extra = serde_json::to_string(&CheckSQLAuthTransportRequest { - request: TransportRequest { + request: TransportAuthRequest { id: format!("{}-span-1", request_id), meta: None, + protocol: request.protocol, + method: request.method, }, user: user.clone(), password: password.clone(), diff --git a/packages/cubejs-backend-native/test/sql.test.ts b/packages/cubejs-backend-native/test/sql.test.ts index 3a94f8713509b..4d99c43dedc61 100644 --- a/packages/cubejs-backend-native/test/sql.test.ts +++ b/packages/cubejs-backend-native/test/sql.test.ts @@ -200,6 +200,8 @@ describe('SQLInterface', () => { request: { id: expect.any(String), meta: null, + method: expect.any(String), + protocol: expect.any(String), }, user: user || null, password: @@ -257,6 +259,8 @@ describe('SQLInterface', () => { request: { id: expect.any(String), meta: null, + method: expect.any(String), + protocol: expect.any(String), }, user: 'allowed_user', password: 'password_for_allowed_user', diff --git a/packages/cubejs-testing/birdbox-fixtures/postgresql/single/sqlapi.js b/packages/cubejs-testing/birdbox-fixtures/postgresql/single/sqlapi.js index 28e9a2139ae09..adb41eb3bca06 100644 --- a/packages/cubejs-testing/birdbox-fixtures/postgresql/single/sqlapi.js +++ b/packages/cubejs-testing/birdbox-fixtures/postgresql/single/sqlapi.js @@ -13,6 +13,15 @@ module.exports = { return query; }, checkSqlAuth: async (req, user, password) => { + if (!req) { + throw new Error('Request is not defined'); + } + + const missing = ['protocol', 'method'].filter(key => !(key in req)); + if (missing.length) { + throw new Error(`Request object is missing required field(s): ${missing.join(', ')}`); + } + if (user === 'admin') { if (password && password !== 'admin_password') { throw new Error(`Password doesn't match for ${user}`); diff --git a/rust/cubesql/cubesql/src/compile/router.rs b/rust/cubesql/cubesql/src/compile/router.rs index 5aba3b8447967..9e013c202dbaf 100644 --- a/rust/cubesql/cubesql/src/compile/router.rs +++ b/rust/cubesql/cubesql/src/compile/router.rs @@ -12,6 +12,7 @@ use crate::{ DatabaseVariable, DatabaseVariablesToUpdate, }, sql::{ + auth_service::SqlAuthServiceAuthenticateRequest, dataframe, statement::{ ApproximateCountDistinctVisitor, CastReplacer, DateTokenNormalizeReplacer, @@ -447,12 +448,16 @@ impl QueryRouter { })? { self.state.set_user(Some(to_user.clone())); + let sql_auth_request = SqlAuthServiceAuthenticateRequest { + protocol: "postgres".to_string(), + method: "password".to_string(), + }; let authenticate_response = self .session_manager .server .auth // TODO do we want to send actual password here? - .authenticate(Some(to_user.clone()), None) + .authenticate(sql_auth_request, Some(to_user.clone()), None) .await .map_err(|e| { CompilationError::internal(format!("Error calling authenticate: {}", e)) @@ -562,11 +567,15 @@ impl QueryRouter { async fn reauthenticate_if_needed(&self) -> CompilationResult<()> { if self.state.is_auth_context_expired() { + let sql_auth_request = SqlAuthServiceAuthenticateRequest { + protocol: "postgres".to_string(), + method: "password".to_string(), + }; let authenticate_response = self .session_manager .server .auth - .authenticate(self.state.user(), None) + .authenticate(sql_auth_request, self.state.user(), None) .await .map_err(|e| { CompilationError::fatal(format!( diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 9a9315dc54223..0536ba812f1a1 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -8,9 +8,10 @@ use crate::{ }, config::{ConfigObj, ConfigObjImpl}, sql::{ - compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe, - pg_auth_service::PostgresAuthServiceDefaultImpl, AuthContextRef, AuthenticateResponse, - HttpAuthContext, ServerManager, Session, SessionManager, SqlAuthService, + auth_service::SqlAuthServiceAuthenticateRequest, compiler_cache::CompilerCacheImpl, + dataframe::batches_to_dataframe, pg_auth_service::PostgresAuthServiceDefaultImpl, + AuthContextRef, AuthenticateResponse, HttpAuthContext, ServerManager, Session, + SessionManager, SqlAuthService, }, transport::{ CubeMeta, CubeMetaDimension, CubeMetaJoin, CubeMetaMeasure, CubeMetaSegment, @@ -750,6 +751,7 @@ pub fn get_test_auth() -> Arc { impl SqlAuthService for TestSqlAuth { async fn authenticate( &self, + _request: SqlAuthServiceAuthenticateRequest, _user: Option, password: Option, ) -> Result { diff --git a/rust/cubesql/cubesql/src/sql/auth_service.rs b/rust/cubesql/cubesql/src/sql/auth_service.rs index f29002f862d24..3550716db9759 100644 --- a/rust/cubesql/cubesql/src/sql/auth_service.rs +++ b/rust/cubesql/cubesql/src/sql/auth_service.rs @@ -2,6 +2,7 @@ use std::{any::Any, env, fmt::Debug, sync::Arc}; use crate::CubeError; use async_trait::async_trait; +use serde::{Deserialize, Serialize}; use serde_json::Value; // We cannot use generic here. It's why there is this trait @@ -43,10 +44,17 @@ pub struct AuthenticateResponse { pub skip_password_check: bool, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SqlAuthServiceAuthenticateRequest { + pub protocol: String, + pub method: String, +} + #[async_trait] pub trait SqlAuthService: Send + Sync + Debug { async fn authenticate( &self, + request: SqlAuthServiceAuthenticateRequest, user: Option, password: Option, ) -> Result; @@ -61,6 +69,7 @@ crate::di_service!(SqlAuthDefaultImpl, [SqlAuthService]); impl SqlAuthService for SqlAuthDefaultImpl { async fn authenticate( &self, + _request: SqlAuthServiceAuthenticateRequest, _user: Option, password: Option, ) -> Result { diff --git a/rust/cubesql/cubesql/src/sql/mod.rs b/rust/cubesql/cubesql/src/sql/mod.rs index 776b13db15e78..6cd813d0f90f4 100644 --- a/rust/cubesql/cubesql/src/sql/mod.rs +++ b/rust/cubesql/cubesql/src/sql/mod.rs @@ -13,7 +13,7 @@ pub(crate) mod types; // Public API pub use auth_service::{ AuthContext, AuthContextRef, AuthenticateResponse, HttpAuthContext, SqlAuthDefaultImpl, - SqlAuthService, + SqlAuthService, SqlAuthServiceAuthenticateRequest, }; pub use database_variables::postgres::session_vars::CUBESQL_PENALIZE_POST_PROCESSING_VAR; pub use postgres::*; diff --git a/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs b/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs index 50698f18bb938..240046fd50da1 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, fmt::Debug, sync::Arc}; use async_trait::async_trait; use crate::{ + sql::auth_service::SqlAuthServiceAuthenticateRequest, sql::{AuthContextRef, SqlAuthService}, CubeError, }; @@ -74,8 +75,16 @@ impl PostgresAuthService for PostgresAuthServiceDefaultImpl { } let user = parameters.get("user").unwrap().clone(); + let sql_auth_request = SqlAuthServiceAuthenticateRequest { + protocol: "postgres".to_string(), + method: "password".to_string(), + }; let authenticate_response = service - .authenticate(Some(user.clone()), Some(password_message.password.clone())) + .authenticate( + sql_auth_request, + Some(user.clone()), + Some(password_message.password.clone()), + ) .await; let auth_fail = || {