Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions onadata/libs/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jwt
from django_digest import HttpDigestAuthenticator
from multidb.pinning import use_master
from oauth2_provider.contrib.rest_framework import OAuth2Authentication
from oauth2_provider.models import AccessToken
from oauth2_provider.oauth2_validators import OAuth2Validator
from oauth2_provider.settings import oauth2_settings
Expand Down Expand Up @@ -430,3 +431,32 @@ def validate_bearer_token(self, token, scopes, request):
self._set_oauth2_error_on_request(request, access_token, scopes)

return False


class StrictOAuth2Authentication(OAuth2Authentication):
"""
OAuth2 authentication that raises AuthenticationFailed when a Bearer token
is provided but is invalid or expired.

The default OAuth2Authentication returns None for invalid tokens, which
causes DRF to silently fall through to other authentication classes or
allow anonymous access. This class ensures that if a Bearer token is
explicitly provided, it must be valid.
"""

def authenticate(self, request):
auth_header = get_authorization_header(request).split()

if not auth_header or auth_header[0].lower() != b"bearer":
return None

result = super().authenticate(request)

if result is None:
oauth2_error = getattr(request, "oauth2_error", {})
error_description = oauth2_error.get(
"error_description", "Invalid or expired token"
)
raise exceptions.AuthenticationFailed(_(error_description))

return result
68 changes: 68 additions & 0 deletions onadata/libs/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from onadata.libs.authentication import (
DigestAuthentication,
MasterReplicaOAuth2Validator,
StrictOAuth2Authentication,
TempTokenAuthentication,
TempTokenURLParameterAuthentication,
check_lockout,
Expand Down Expand Up @@ -201,3 +202,70 @@ def is_valid_mock(*args, **kwargs):
)
self.assertEqual(req.access_token, token)
self.assertEqual(req.user, token.user)


class TestStrictOAuth2Authentication(TestCase):
"""Test StrictOAuth2Authentication class."""

def setUp(self):
self.factory = APIRequestFactory()
self.auth = StrictOAuth2Authentication()

def test_returns_none_when_no_bearer_header(self):
"""Test returns None when no Bearer token is provided."""
request = self.factory.get("/")
result = self.auth.authenticate(request)
self.assertIsNone(result)

def test_returns_none_when_different_auth_scheme(self):
"""Test returns None when Authorization header uses different scheme."""
request = self.factory.get("/", HTTP_AUTHORIZATION="Token abc123")
result = self.auth.authenticate(request)
self.assertIsNone(result)

@patch("onadata.libs.authentication.OAuth2Authentication.authenticate")
def test_raises_auth_failed_when_bearer_token_invalid(self, mock_authenticate):
"""Test raises AuthenticationFailed when Bearer token is invalid."""
mock_authenticate.return_value = None
request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer invalid_token")
request.oauth2_error = {"error_description": "The access token is invalid"}

with self.assertRaises(AuthenticationFailed) as context:
self.auth.authenticate(request)

self.assertIn("access token is invalid", str(context.exception))

@patch("onadata.libs.authentication.OAuth2Authentication.authenticate")
def test_raises_auth_failed_when_bearer_token_expired(self, mock_authenticate):
"""Test raises AuthenticationFailed when Bearer token is expired."""
mock_authenticate.return_value = None
request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer expired_token")
request.oauth2_error = {"error_description": "The access token has expired"}

with self.assertRaises(AuthenticationFailed) as context:
self.auth.authenticate(request)

self.assertIn("access token has expired", str(context.exception))

@patch("onadata.libs.authentication.OAuth2Authentication.authenticate")
def test_raises_auth_failed_with_default_message(self, mock_authenticate):
"""Test raises AuthenticationFailed with default message when no error_description."""
mock_authenticate.return_value = None
request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer some_token")

with self.assertRaises(AuthenticationFailed) as context:
self.auth.authenticate(request)

self.assertIn("Invalid or expired token", str(context.exception))

@patch("onadata.libs.authentication.OAuth2Authentication.authenticate")
def test_returns_user_and_token_when_valid(self, mock_authenticate):
"""Test returns (user, token) tuple when Bearer token is valid."""
user = MagicMock()
token = MagicMock()
mock_authenticate.return_value = (user, token)
request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer valid_token")

result = self.auth.authenticate(request)

self.assertEqual(result, (user, token))
5 changes: 3 additions & 2 deletions onadata/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import sys
from importlib import reload

from celery.signals import after_setup_logger
from django.core.exceptions import SuspiciousOperation
from django.utils.log import AdminEmailHandler

from celery.signals import after_setup_logger

# setting default encoding to utf-8
if sys.version[0] == "2":
reload(sys)
Expand Down Expand Up @@ -296,7 +297,7 @@
"onadata.libs.authentication.DigestAuthentication",
"onadata.libs.authentication.TempTokenAuthentication",
"onadata.libs.authentication.EnketoTokenAuthentication",
"oauth2_provider.contrib.rest_framework.OAuth2Authentication",
"onadata.libs.authentication.StrictOAuth2Authentication",
"rest_framework.authentication.SessionAuthentication",
"rest_framework.authentication.TokenAuthentication",
),
Expand Down