From 9c91ee2d72c91e9f48be31779cb8f179b488813f Mon Sep 17 00:00:00 2001 From: JerrySentry <142266253+JerrySentry@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:39:21 -0500 Subject: [PATCH] Clean up user's expired login sessions (#1113) --- codecov_auth/tests/unit/views/test_base.py | 126 ++++++++++++++++++++- codecov_auth/views/base.py | 64 ++++++++--- 2 files changed, 172 insertions(+), 18 deletions(-) diff --git a/codecov_auth/tests/unit/views/test_base.py b/codecov_auth/tests/unit/views/test_base.py index 60e2333b3e..42f9dc8e08 100644 --- a/codecov_auth/tests/unit/views/test_base.py +++ b/codecov_auth/tests/unit/views/test_base.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -8,10 +8,15 @@ from django.http import HttpResponse from django.test import RequestFactory, TestCase, override_settings from freezegun import freeze_time -from shared.django_apps.codecov_auth.tests.factories import OwnerFactory, UserFactory +from shared.django_apps.codecov_auth.tests.factories import ( + OwnerFactory, + SessionFactory, + UserFactory, +) from shared.license import LicenseInformation -from codecov_auth.models import Owner, OwnerProfile +from codecov_auth.models import DjangoSession, Owner, OwnerProfile, Session +from codecov_auth.tests.factories import DjangoSessionFactory from codecov_auth.views.base import LoginMixin, StateMixin @@ -729,3 +734,118 @@ def test_login_authenticated_with_claimed_owner(self): # does not re-claim owner assert owner.user is not None assert owner.user != user + + @patch("services.refresh.RefreshService.trigger_refresh", lambda *args: None) + def test_login_owner_with_expired_login_session(self): + user = UserFactory() + owner = OwnerFactory(service="github", user=user) + + another_user = UserFactory() + another_owner = OwnerFactory(service="github", user=another_user) + + now = datetime.now(timezone.utc) + + # Create a session that will be deleted + to_be_deleted_1 = SessionFactory( + owner=owner, + type="login", + name="to_be_deleted", + lastseen="2021-01-01T00:00:00+00:00", + login_session=DjangoSessionFactory(expire_date=now - timedelta(days=1)), + ) + to_be_deleted_1_session_key = to_be_deleted_1.login_session.session_key + + # Create a session that will not be deleted because its not a login session + to_be_kept_1 = SessionFactory( + owner=owner, + type="api", + name="to_be_kept", + lastseen="2021-01-01T00:00:00+00:00", + login_session=DjangoSessionFactory(expire_date=now + timedelta(days=1)), + ) + + # Create a session that will not be deleted because it's not expired + to_be_kept_2 = SessionFactory( + owner=owner, + type="login", + name="to_be_kept", + lastseen="2021-01-01T00:00:00+00:00", + login_session=DjangoSessionFactory(expire_date=now + timedelta(days=1)), + ) + + # Create a session that will not be deleted because it's not the owner's session + to_be_kept_3 = SessionFactory( + owner=another_owner, + type="login", + name="to_be_kept", + lastseen="2021-01-01T00:00:00+00:00", + login_session=DjangoSessionFactory(expire_date=now - timedelta(seconds=1)), + ) + + assert ( + len(DjangoSession.objects.filter(session_key=to_be_deleted_1_session_key)) + == 1 + ) + assert ( + len( + DjangoSession.objects.filter( + session_key=to_be_kept_1.login_session.session_key + ) + ) + == 1 + ) + assert ( + len( + DjangoSession.objects.filter( + session_key=to_be_kept_2.login_session.session_key + ) + ) + == 1 + ) + assert ( + len( + DjangoSession.objects.filter( + session_key=to_be_kept_3.login_session.session_key + ) + ) + == 1 + ) + + self.request.user = user + self.mixin_instance.login_owner(owner, self.request, HttpResponse()) + owner.refresh_from_db() + + new_login_session = Session.objects.filter(name=None) + + assert len(new_login_session) == 1 + assert len(Session.objects.filter(name="to_be_deleted").all()) == 0 + assert len(Session.objects.filter(name="to_be_kept").all()) == 3 + + assert ( + len(DjangoSession.objects.filter(session_key=to_be_deleted_1_session_key)) + == 0 + ) + assert ( + len( + DjangoSession.objects.filter( + session_key=to_be_kept_1.login_session.session_key + ) + ) + == 1 + ) + assert ( + len( + DjangoSession.objects.filter( + session_key=to_be_kept_2.login_session.session_key + ) + ) + == 1 + ) + assert ( + len( + DjangoSession.objects.filter( + session_key=to_be_kept_3.login_session.session_key + ) + ) + == 1 + ) diff --git a/codecov_auth/views/base.py b/codecov_auth/views/base.py index 49497e0572..753f74f1ca 100644 --- a/codecov_auth/views/base.py +++ b/codecov_auth/views/base.py @@ -2,15 +2,18 @@ import re import uuid from functools import reduce +from typing import Any from urllib.parse import parse_qs, urlencode, urlparse from django.conf import settings from django.contrib.auth import login, logout from django.contrib.sessions.models import Session as DjangoSession from django.core.exceptions import PermissionDenied +from django.db import transaction from django.http.request import HttpRequest from django.http.response import HttpResponse from django.utils import timezone +from django.utils.timezone import now from shared.encryption.token import encode_token from shared.license import LICENSE_ERRORS_MESSAGES, get_current_license @@ -59,7 +62,7 @@ class StateMixin(object): """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.redis = get_redis_connection() super().__init__(*args, **kwargs) @@ -69,7 +72,7 @@ def _session_key(self) -> str: def _get_key_redis(self, state: str) -> str: return f"oauth-state-{state}" - def _is_matching_cors_domains(self, url_domain) -> bool: + def _is_matching_cors_domains(self, url_domain: str) -> bool: # make sure the domain is part of the CORS so that's a safe domain to # redirect to. if url_domain in settings.CORS_ALLOWED_ORIGINS: @@ -79,7 +82,7 @@ def _is_matching_cors_domains(self, url_domain) -> bool: return True return False - def _is_valid_redirection(self, to) -> bool: + def _is_valid_redirection(self, to: str) -> bool: # make sure the redirect url is from a domain we own try: url = urlparse(to) @@ -115,11 +118,11 @@ def generate_state(self) -> str: return state - def verify_state(self, state) -> bool: + def verify_state(self, state: str) -> bool: state_from_session = self.request.session.get(self._session_key(), None) return state_from_session and state == state_from_session - def get_redirection_url_from_state(self, state) -> (str, bool): + def get_redirection_url_from_state(self, state: str) -> tuple[str, bool]: cached_url = self.redis.get(self._get_key_redis(state)) if not cached_url: @@ -149,7 +152,7 @@ def get_redirection_url_from_state(self, state) -> (str, bool): # Return the final redirect URL to complete the login. return (cached_url.decode("utf-8"), True) - def remove_state(self, state, delay=0) -> None: + def remove_state(self, state: str, delay: int = 0) -> None: redirection_url, _ = self.get_redirection_url_from_state(state) if delay == 0: self.redis.delete(self._get_key_redis(state)) @@ -182,7 +185,7 @@ def modify_redirection_url_based_on_default_user_org( url += f"/{owner_profile.default_org.username}" return url - def get_or_create_org(self, single_organization): + def get_or_create_org(self, single_organization: dict) -> Owner: owner, was_created = Owner.objects.get_or_create( service=self.service, service_id=single_organization["id"], @@ -190,7 +193,9 @@ def get_or_create_org(self, single_organization): ) return owner - def login_owner(self, owner: Owner, request: HttpRequest, response: HttpResponse): + def login_owner( + self, owner: Owner, request: HttpRequest, response: HttpResponse + ) -> None: # if there's a currently authenticated user if request.user is not None and not request.user.is_anonymous: if owner.user is None: @@ -253,9 +258,11 @@ def login_owner(self, owner: Owner, request: HttpRequest, response: HttpResponse request.session["current_owner_id"] = owner.pk RefreshService().trigger_refresh(owner.ownerid, owner.username) + + self.delete_expired_sessions_and_django_sessions(owner) self.store_login_session(owner) - def get_and_modify_owner(self, user_dict, request) -> Owner: + def get_and_modify_owner(self, user_dict: dict, request: HttpRequest) -> Owner: user_orgs = user_dict["orgs"] formatted_orgs = [ dict(username=org["username"], id=str(org["id"])) for org in user_orgs @@ -298,7 +305,9 @@ def get_and_modify_owner(self, user_dict, request) -> Owner: return owner - def _check_enterprise_organizations_membership(self, user_dict, orgs): + def _check_enterprise_organizations_membership( + self, user_dict: dict, orgs: list[dict] + ) -> None: """Checks if a user belongs to the restricted organizations (or teams if GitHub) allowed in settings.""" if settings.IS_ENTERPRISE and get_config(self.service, "organizations"): orgs_in_settings = set(get_config(self.service, "organizations")) @@ -315,7 +324,7 @@ def _check_enterprise_organizations_membership(self, user_dict, orgs): "You must be a member of an allowed team in your organization." ) - def _check_user_count_limitations(self, login_data): + def _check_user_count_limitations(self, login_data: dict) -> None: if not settings.IS_ENTERPRISE: return license = get_current_license() @@ -339,7 +348,7 @@ def _check_user_count_limitations(self, login_data): owners_with_activated_users = Owner.objects.exclude( plan_activated_users__len=0 ).exclude(plan_activated_users__isnull=True) - all_distinct_actiaved_users = reduce( + all_distinct_actiaved_users: set[str] = reduce( lambda acc, curr: set(curr.plan_activated_users) | acc, owners_with_activated_users, set(), @@ -357,7 +366,9 @@ def _check_user_count_limitations(self, login_data): if users_on_service_count > license.number_allowed_users: raise PermissionDenied(LICENSE_ERRORS_MESSAGES["users-exceeded"]) - def _get_or_create_owner(self, user_dict, request): + def _get_or_create_owner( + self, user_dict: dict, request: HttpRequest + ) -> tuple[Owner, bool]: fields_to_update = ["oauth_token", "private_access", "updatestamp"] login_data = user_dict["user"] owner, was_created = Owner.objects.get_or_create( @@ -403,7 +414,7 @@ def _get_utm_params(self, params: dict) -> dict: # remove None values from the dict return {k: v for k, v in filtered_params.items() if v is not None} - def store_to_cookie_utm_tags(self, response) -> None: + def store_to_cookie_utm_tags(self, response: HttpResponse) -> None: if not settings.IS_ENTERPRISE: data = urlencode(self._get_utm_params(self.request.GET)) response.set_cookie( @@ -423,7 +434,7 @@ def retrieve_marketing_tags_from_cookie(self) -> dict: else: return {} - def store_login_session(self, owner: Owner): + def store_login_session(self, owner: Owner) -> None: # Store user's login session info after logging in http_x_forwarded_for = self.request.META.get("HTTP_X_FORWARDED_FOR") if http_x_forwarded_for: @@ -443,3 +454,26 @@ def store_login_session(self, owner: Owner): type=Session.SessionType.LOGIN, owner=owner, ) + + def delete_expired_sessions_and_django_sessions(self, owner: Owner) -> None: + """ + This function deletes expired login sessions for a given owner + """ + with transaction.atomic(): + # Get the primary keys of expired DjangoSessions for the given owner + expired_sessions = Session.objects.filter( + owner=owner, + type="login", + login_session__isnull=False, + login_session__expire_date__lt=now(), + ) + + # Delete the rows in the Session table using sessionid + Session.objects.filter( + sessionid__in=[es.sessionid for es in expired_sessions] + ).delete() + + # Delete the rows in the DjangoSession table using the extracted keys + DjangoSession.objects.filter( + session_key__in=[es.login_session for es in expired_sessions] + ).delete()