Skip to content

Commit 61a0f38

Browse files
committed
move sso logic to postgresé
1 parent 6c023a3 commit 61a0f38

35 files changed

Lines changed: 822 additions & 252 deletions

File tree

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""add sso access
2+
3+
Revision ID: f94a77198389
4+
Revises: 5138304df5f5
5+
Create Date: 2026-06-26 12:57:42.469421
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = 'f94a77198389'
16+
down_revision: Union[str, None] = '5138304df5f5'
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
"""Upgrade schema."""
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table('sso_policy_rule',
25+
sa.Column('id', sa.Integer(), nullable=False),
26+
sa.Column('type', sa.Enum('EMAIL', 'ORGANIZATION', 'ROLE', name='ssoaccessruletype'), nullable=False),
27+
sa.Column('value', sa.String(), nullable=True),
28+
sa.Column('role_id', sa.Integer(), nullable=True),
29+
sa.Column('organization_id', sa.Integer(), nullable=True),
30+
sa.Column('created', sa.DateTime(timezone=True), nullable=False),
31+
sa.Column('updated', sa.DateTime(timezone=True), nullable=False),
32+
sa.ForeignKeyConstraint(['organization_id'], ['organization.id'], ondelete='CASCADE'),
33+
sa.ForeignKeyConstraint(['role_id'], ['role.id'], ondelete='CASCADE'),
34+
sa.PrimaryKeyConstraint('id'),
35+
sa.UniqueConstraint('type', 'value', name='unique_sso_policy_type_value')
36+
)
37+
# ### end Alembic commands ###
38+
39+
40+
def downgrade() -> None:
41+
"""Downgrade schema."""
42+
# ### commands auto generated by Alembic - please adjust! ###
43+
op.drop_table('sso_policy_rule')
44+
op.execute("DROP TYPE IF EXISTS ssoaccessruletype")
45+
# ### end Alembic commands ###

api/dependencies.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from redis.asyncio import Redis
88
from sqlalchemy.ext.asyncio import AsyncSession
99

10-
from api.domain.auth import AuthSsoSessionValidator
10+
from api.domain.auth import AuthSsoSessionValidator, SsoPolicyRepository
1111
from api.domain.key import KeyEncoder, KeyRepository
1212
from api.domain.model import ModelEnvironmentalImpactsComputer, ModelTokenizer
1313
from api.domain.organization import OrganizationRepository
@@ -37,6 +37,7 @@
3737
PostgresProviderRepository,
3838
PostgresRolesRepository,
3939
PostgresRouterRepository,
40+
PostgresSsoPolicyRepository,
4041
PostgresUserRepository,
4142
)
4243
from api.infrastructure.redis import RedisProviderLoadBalancer, RedisProviderMetricsLogger, RedisRouterRateLimiter
@@ -52,6 +53,7 @@
5253
)
5354
from api.use_cases.admin.roles import CreateRoleUseCase, DeleteRoleUseCase, GetRolesUseCase, GetRoleUseCase, UpdateRoleUseCase
5455
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase
56+
from api.use_cases.admin.sso import GetSsoPolicyUseCase, UpdateSsoPolicyUseCase
5557
from api.use_cases.admin.users import CreateUserUseCase, DeleteUserUseCase, GetOneUserUseCase, GetUsersUseCase
5658
from api.use_cases.auth import AuthLoginUseCase, AuthSsoLoginUseCase
5759
from api.use_cases.health import GetHealthModelsUseCase
@@ -163,6 +165,10 @@ def _router_repository(session: AsyncSession) -> PostgresRouterRepository:
163165
return PostgresRouterRepository(postgres_session=session, app_title=configuration.settings.app_title)
164166

165167

168+
def _sso_policy_repository(session: AsyncSession) -> SsoPolicyRepository:
169+
return PostgresSsoPolicyRepository(postgres_session=session)
170+
171+
166172
def _limit_repository(session: AsyncSession) -> LimitRepository:
167173
return PostgresLimitRepository(postgres_session=session)
168174

@@ -195,11 +201,10 @@ def auth_sso_login_use_case_factory(
195201
) -> AuthSsoLoginUseCase:
196202
return AuthSsoLoginUseCase(
197203
key_repository=_key_repository(key_encoder=key_encoder, session=postgres_session),
198-
organization_repository=_organization_repository(session=postgres_session),
199204
user_repository=_user_repository(session=postgres_session),
205+
sso_policy_repository=_sso_policy_repository(session=postgres_session),
200206
auth_sso_session_validator=_auth_sso_session_validator(),
201207
auth_login_type=configuration.settings.auth_login_type,
202-
auth_sso_default_role_id=configuration.settings.auth_sso_default_role_id,
203208
auth_login_session_duration=configuration.settings.auth_login_session_duration,
204209
)
205210

@@ -234,23 +239,39 @@ def get_model_use_case_factory(postgres_session: AsyncSession = Depends(get_post
234239
return GetModelUseCase(router_repository=_router_repository(postgres_session))
235240

236241

237-
# user use cases
238-
def create_user_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateUserUseCase:
239-
return CreateUserUseCase(user_repository=_user_repository(postgres_session), user_password_encoder=_user_password_encoder())
242+
# provider use cases
243+
def create_provider_use_case_factory(
244+
postgres_session: AsyncSession = Depends(get_postgres_session),
245+
provider_client: ProviderClient = Depends(_provider_client),
246+
) -> CreateProviderUseCase:
247+
return CreateProviderUseCase(
248+
router_repository=_router_repository(postgres_session),
249+
provider_repository=_provider_repository(postgres_session),
250+
provider_gateway=_provider_gateway(provider_client=provider_client, provider_adapter_builder=HttpProviderAdapterBuilder()),
251+
)
240252

241253

242-
def get_one_user_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneUserUseCase:
243-
return GetOneUserUseCase(user_repository=_user_repository(postgres_session))
254+
def update_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> UpdateProviderUseCase:
255+
return UpdateProviderUseCase(
256+
router_repository=_router_repository(postgres_session),
257+
provider_repository=_provider_repository(postgres_session),
258+
)
244259

245260

246-
def get_users_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetUsersUseCase:
247-
return GetUsersUseCase(user_repository=_user_repository(postgres_session))
261+
def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> DeleteProviderUseCase:
262+
return DeleteProviderUseCase(
263+
provider_repository=_provider_repository(postgres_session),
264+
)
248265

249266

250-
def delete_user_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> DeleteUserUseCase:
251-
return DeleteUserUseCase(
252-
user_repository=_user_repository(postgres_session),
253-
router_repository=_router_repository(postgres_session),
267+
def get_one_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneProviderUseCase:
268+
return GetOneProviderUseCase(
269+
provider_repository=_provider_repository(postgres_session),
270+
)
271+
272+
273+
def get_providers_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetProvidersUseCase:
274+
return GetProvidersUseCase(
254275
provider_repository=_provider_repository(postgres_session),
255276
)
256277

@@ -345,38 +366,33 @@ def update_router_use_case_factory(postgres_session: AsyncSession = Depends(get_
345366
)
346367

347368

348-
# provider use cases
349-
def create_provider_use_case_factory(
350-
postgres_session: AsyncSession = Depends(get_postgres_session),
351-
provider_client: ProviderClient = Depends(_provider_client),
352-
) -> CreateProviderUseCase:
353-
return CreateProviderUseCase(
354-
router_repository=_router_repository(postgres_session),
355-
provider_repository=_provider_repository(postgres_session),
356-
provider_gateway=_provider_gateway(provider_client=provider_client, provider_adapter_builder=HttpProviderAdapterBuilder()),
369+
# sso policy use cases
370+
def get_sso_policy_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetSsoPolicyUseCase:
371+
return GetSsoPolicyUseCase(
372+
sso_policy_repository=_sso_policy_repository(session=postgres_session),
357373
)
358374

359375

360-
def update_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> UpdateProviderUseCase:
361-
return UpdateProviderUseCase(
362-
router_repository=_router_repository(postgres_session),
363-
provider_repository=_provider_repository(postgres_session),
364-
)
376+
def update_sso_policy_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> UpdateSsoPolicyUseCase:
377+
return UpdateSsoPolicyUseCase(sso_policy_repository=_sso_policy_repository(session=postgres_session))
365378

366379

367-
def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> DeleteProviderUseCase:
368-
return DeleteProviderUseCase(
369-
provider_repository=_provider_repository(postgres_session),
370-
)
380+
# user use cases
381+
def create_user_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateUserUseCase:
382+
return CreateUserUseCase(user_repository=_user_repository(postgres_session), user_password_encoder=_user_password_encoder())
371383

372384

373-
def get_one_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneProviderUseCase:
374-
return GetOneProviderUseCase(
375-
provider_repository=_provider_repository(postgres_session),
376-
)
385+
def get_one_user_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneUserUseCase:
386+
return GetOneUserUseCase(user_repository=_user_repository(postgres_session))
377387

378388

379-
def get_providers_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetProvidersUseCase:
380-
return GetProvidersUseCase(
389+
def get_users_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetUsersUseCase:
390+
return GetUsersUseCase(user_repository=_user_repository(postgres_session))
391+
392+
393+
def delete_user_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> DeleteUserUseCase:
394+
return DeleteUserUseCase(
395+
user_repository=_user_repository(postgres_session),
396+
router_repository=_router_repository(postgres_session),
381397
provider_repository=_provider_repository(postgres_session),
382398
)

api/domain/auth/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from ._authssosessionvalidator import AuthSsoSessionValidator, SsoSessionClaims
1+
from ._authssosessionvalidator import AuthSsoSessionValidator
2+
from ._ssopolicyrepository import SsoPolicyRepository
3+
from .entities import SsoPolicy
24

3-
__all__ = ["AuthSsoSessionValidator", "SsoSessionClaims"]
5+
__all__ = ["AuthSsoSessionValidator", "SsoPolicy", "SsoPolicyRepository"]
Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
32

4-
from api.domain.auth.errors import InvalidOidcTokenError, SsoProviderNotAvailableError
5-
6-
7-
@dataclass(frozen=True)
8-
class SsoSessionClaims:
9-
email: str
10-
user: str | None = None
3+
from api.domain.auth.errors import SsoInvalidSessionError, SsoProviderNotAvailableError
114

125

136
class AuthSsoSessionValidator(ABC):
147
@abstractmethod
15-
async def validate_session(self, session_cookie: str) -> SsoSessionClaims | InvalidOidcTokenError | SsoProviderNotAvailableError:
8+
async def validate_session(self, session_cookie: str) -> str | SsoInvalidSessionError | SsoProviderNotAvailableError:
9+
"""
10+
Validate the session cookie by calling the authentication service and return the email of the user.
11+
"""
1612
pass
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
3+
from api.domain.auth.entities import NewSsoPolicy, SsoPolicy
4+
from api.domain.auth.errors import SsoPolicyRuleAlreadyExistsError
5+
from api.domain.organization.errors import OrganizationNotFoundError
6+
from api.domain.role.errors import RoleNotFoundError
7+
8+
9+
class SsoPolicyRepository(ABC):
10+
@abstractmethod
11+
async def get_policy(self) -> SsoPolicy:
12+
pass
13+
14+
@abstractmethod
15+
async def replace_policy(
16+
self, policy: NewSsoPolicy
17+
) -> SsoPolicy | RoleNotFoundError | OrganizationNotFoundError | SsoPolicyRuleAlreadyExistsError:
18+
pass

api/domain/auth/entities.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from datetime import datetime
12
from enum import StrEnum
3+
from typing import Literal
24

3-
from pydantic import BaseModel, ConfigDict, Field
5+
from pydantic import BaseModel, Field
46

57

68
class SsoAccessRuleType(StrEnum):
@@ -9,25 +11,91 @@ class SsoAccessRuleType(StrEnum):
911
ROLE = "role"
1012

1113

12-
class SsoAccessRule(BaseModel):
13-
model_config = ConfigDict(from_attributes=True)
14+
class NewSsoPolicyEmailRule(BaseModel):
15+
type: Literal[SsoAccessRuleType.EMAIL]
16+
value: str
17+
18+
19+
class NewSsoPolicyOrganizationRule(BaseModel):
20+
type: Literal[SsoAccessRuleType.ORGANIZATION]
21+
value: str | None = None
22+
organization_id: int
23+
24+
25+
class NewSsoPolicyRoleRule(BaseModel):
26+
type: Literal[SsoAccessRuleType.ROLE]
27+
value: str | None = None
28+
role_id: int
1429

15-
id: int | None = None
30+
31+
type NewSsoPolicyRule = NewSsoPolicyEmailRule | NewSsoPolicyOrganizationRule | NewSsoPolicyRoleRule
32+
33+
34+
class NewSsoPolicy(BaseModel):
35+
rules: list[NewSsoPolicyRule] = Field(default_factory=list)
36+
37+
38+
class SsoPolicyRuleBase(BaseModel):
39+
id: int
1640
type: SsoAccessRuleType
41+
created: datetime
42+
updated: datetime
43+
44+
45+
class SsoPolicyEmailRule(SsoPolicyRuleBase):
46+
type: Literal[SsoAccessRuleType.EMAIL]
1747
value: str
1848

1949

20-
class SsoRoleMapping(BaseModel):
21-
model_config = ConfigDict(from_attributes=True)
50+
class SsoPolicyOrganizationRule(SsoPolicyRuleBase):
51+
type: Literal[SsoAccessRuleType.ORGANIZATION]
52+
value: str | None = None
53+
organization_id: int
54+
2255

23-
id: int | None = None
24-
organization_name: str
25-
oidc_role_name: str
56+
class SsoPolicyRoleRule(SsoPolicyRuleBase):
57+
type: Literal[SsoAccessRuleType.ROLE]
58+
value: str | None = None
2659
role_id: int
2760

2861

62+
type SsoPolicyRule = SsoPolicyEmailRule | SsoPolicyOrganizationRule | SsoPolicyRoleRule
63+
64+
2965
class SsoPolicy(BaseModel):
30-
allowed_emails: list[str] = Field(default_factory=list)
31-
allowed_organizations: list[str] = Field(default_factory=list)
32-
allowed_roles: list[str] = Field(default_factory=list)
33-
role_mappings: list[SsoRoleMapping] = Field(default_factory=list)
66+
rules: list[SsoPolicyRule] = Field(default_factory=list)
67+
68+
def is_allowed(self, email: str, organization: str | None, roles: list[str]) -> bool:
69+
checks: list[bool] = []
70+
allowed_emails = [rule.value for rule in self.rules if rule.type == SsoAccessRuleType.EMAIL]
71+
allowed_organizations = [rule.value for rule in self.rules if rule.type == SsoAccessRuleType.ORGANIZATION and rule.value is not None]
72+
allowed_roles = [rule.value for rule in self.rules if rule.type == SsoAccessRuleType.ROLE and rule.value is not None]
73+
74+
if allowed_roles:
75+
checks.append(any(role in allowed_roles for role in roles))
76+
77+
if allowed_emails:
78+
checks.append(any(email.endswith(allowed_email) for allowed_email in allowed_emails))
79+
80+
if allowed_organizations:
81+
checks.append(organization is not None and organization in allowed_organizations)
82+
83+
return True if not checks else any(checks)
84+
85+
def get_matching_organization_rule(self, organization: str | None) -> SsoPolicyOrganizationRule | None:
86+
for rule in self.rules:
87+
if rule.type != SsoAccessRuleType.ORGANIZATION:
88+
continue
89+
if organization is None and rule.value is None: # default organization rule
90+
return rule
91+
if organization is not None and organization == rule.value:
92+
return rule
93+
94+
def get_matching_role_rule(self, roles: list[str]) -> SsoPolicyRoleRule | None:
95+
for rule in self.rules:
96+
if rule.type != SsoAccessRuleType.ROLE:
97+
continue
98+
if not roles and rule.value is None: # default role rule
99+
return rule
100+
if roles and rule.value in roles:
101+
return rule

0 commit comments

Comments
 (0)