diff --git a/requirements-test.txt b/requirements-test.txt index 285accd388..d85a5429a0 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,7 +5,6 @@ time-machine==2.13.0 # for typing mypy==1.1.1 types-python-dateutil==2.8.19.14 -types-python-jose==3.3.0 types-pyyaml==6.0.12.20240808 types-requests==2.32.0.20240907 types-setuptools==74.1.0.20240907 diff --git a/requirements.txt b/requirements.txt index f10d7be28e..542e63984f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ google-cloud-storage==2.18.0 googleapis-common-protos==1.63.2 google-api-core==2.19.1 honcho==1.1.0 -python-jose[cryptography]==3.3.0 jsonschema==4.23.0 fastjsonschema==2.16.2 packaging==24.1 @@ -21,6 +20,7 @@ parsimonious==0.10.0 progressbar2==4.2.0 protobuf==5.27.3 proto-plus==1.24.0 +pyjwt[crypto]==2.10.1 pytest==7.1.3 pytest-cov==4.1.0 pytest-watch==4.2.0 diff --git a/snuba/admin/jwt.py b/snuba/admin/jwt.py index d5d61b7b43..7ef602dbf8 100644 --- a/snuba/admin/jwt.py +++ b/snuba/admin/jwt.py @@ -1,5 +1,6 @@ from typing import Any, Optional +import jwt import requests from snuba import settings @@ -32,7 +33,11 @@ def validate_assertion(assertion: str) -> AdminUser: If not, an exception will be raised """ - from jose import jwt - - info = jwt.decode(assertion, _certs(), algorithms=["ES256"], audience=_audience()) + info = jwt.decode( + assertion, + _certs(), + algorithms=["ES256"], + audience=_audience(), + options={"verify_aud": True}, + ) return AdminUser(email=info["email"], id=info["sub"]) diff --git a/tests/admin/test_authorization.py b/tests/admin/test_authorization.py index ea73fdd769..89e7d75341 100644 --- a/tests/admin/test_authorization.py +++ b/tests/admin/test_authorization.py @@ -9,13 +9,6 @@ from snuba.admin.auth_roles import ROLES -@pytest.fixture -def admin_api() -> FlaskClient: - from snuba.admin.views import application - - return application.test_client() - - @pytest.mark.redis_db def test_tools(admin_api: FlaskClient) -> None: response = admin_api.get("/tools") @@ -28,9 +21,7 @@ def test_tools(admin_api: FlaskClient) -> None: @pytest.mark.redis_db @patch("snuba.admin.auth.DEFAULT_ROLES", [ROLES["ProductTools"]]) -def test_product_tools_role( - admin_api: FlaskClient, -) -> None: +def test_product_tools_role(admin_api: FlaskClient) -> None: response = admin_api.get("/tools") assert response.status_code == 200 data = json.loads(response.data) diff --git a/tests/admin/test_jwt.py b/tests/admin/test_jwt.py new file mode 100644 index 0000000000..f0507d10ab --- /dev/null +++ b/tests/admin/test_jwt.py @@ -0,0 +1,104 @@ +from datetime import datetime, timedelta, timezone +from typing import Tuple +from unittest.mock import MagicMock + +import jwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec + +from snuba.admin.jwt import validate_assertion +from snuba.admin.user import AdminUser +from tests.conftest import TEST_AUDIENCE, TEST_EMAIL, TEST_SUB, create_test_token + + +def test_validate_assertion_success( + mock_settings: MagicMock, + mock_certs: MagicMock, + key_pair: Tuple[bytes, bytes], +) -> None: + """Test that a valid JWT token is properly validated and returns an AdminUser""" + private_key, _ = key_pair + token = create_test_token(private_key) + user = validate_assertion(token) + assert isinstance(user, AdminUser) + assert user.email == TEST_EMAIL + assert user.id == TEST_SUB + + +def test_validate_assertion_invalid_signature( + mock_settings: MagicMock, + mock_certs: MagicMock, +) -> None: + """Test that an invalid signature raises an exception""" + # Create a different key pair for signing + different_private_key = ec.generate_private_key(ec.SECP256K1()).private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + token = create_test_token(different_private_key) + with pytest.raises(jwt.InvalidSignatureError): + validate_assertion(token) + + +def test_validate_assertion_expired( + mock_settings: MagicMock, + mock_certs: MagicMock, + key_pair: Tuple[bytes, bytes], +) -> None: + """Test that an expired token raises an exception""" + private_key, _ = key_pair + token = create_test_token(private_key, exp_delta=timedelta(hours=-1)) + with pytest.raises(jwt.ExpiredSignatureError): + validate_assertion(token) + + +def test_validate_assertion_invalid_audience( + mock_settings: MagicMock, + mock_certs: MagicMock, + key_pair: Tuple[bytes, bytes], +) -> None: + """Test that a token with wrong audience raises an exception""" + private_key, _ = key_pair + token = create_test_token(private_key, audience="wrong-audience") + with pytest.raises(jwt.InvalidAudienceError): + validate_assertion(token) + + +def test_validate_assertion_missing_email( + mock_settings: MagicMock, + mock_certs: MagicMock, + key_pair: Tuple[bytes, bytes], +) -> None: + """Test that a token without email claim raises an exception""" + private_key, _ = key_pair + now = datetime.now(timezone.utc) + payload = { + "sub": TEST_SUB, + "aud": TEST_AUDIENCE, + "exp": now + timedelta(hours=1), + "iat": now, + } + token = jwt.encode(payload, private_key.decode("utf-8"), algorithm="ES256") + with pytest.raises(KeyError): + validate_assertion(token) + + +def test_validate_assertion_missing_sub( + mock_settings: MagicMock, + mock_certs: MagicMock, + key_pair: Tuple[bytes, bytes], +) -> None: + """Test that a token without sub claim raises an exception""" + private_key, _ = key_pair + now = datetime.now(timezone.utc) + payload = { + "email": TEST_EMAIL, + "aud": TEST_AUDIENCE, + "exp": now + timedelta(hours=1), + "iat": now, + } + token = jwt.encode(payload, private_key.decode("utf-8"), algorithm="ES256") + with pytest.raises(KeyError): + validate_assertion(token) diff --git a/tests/conftest.py b/tests/conftest.py index a913e7376a..4313f37106 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,14 @@ import json import traceback +from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, Generator, List, Sequence, Tuple, Union +from unittest.mock import MagicMock, patch +import jwt import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec +from flask.testing import FlaskClient from snuba_sdk.legacy import json_to_snql from snuba import settings, state @@ -48,9 +54,11 @@ def create_databases() -> None: storage_sets=cluster["storage_sets"], single_node=cluster["single_node"], cluster_name=cluster["cluster_name"] if "cluster_name" in cluster else None, - distributed_cluster_name=cluster["distributed_cluster_name"] - if "distributed_cluster_name" in cluster - else None, + distributed_cluster_name=( + cluster["distributed_cluster_name"] + if "distributed_cluster_name" in cluster + else None + ), ) database_name = cluster["database"] @@ -299,3 +307,79 @@ def set_config(key: str, value: Any) -> None: def disable_query_cache(snuba_set_config: SnubaSetConfig, redis_db: None) -> None: snuba_set_config("use_cache", False) snuba_set_config("use_readthrough_query_cache", 0) + + +# JWT Test constants +TEST_AUDIENCE = "test-audience" +TEST_EMAIL = "test@example.com" +TEST_SUB = "12345" + + +@pytest.fixture +def key_pair() -> tuple[bytes, bytes]: + """Generate a fresh ES256 key pair for each test run""" + private_key = ec.generate_private_key(ec.SECP256K1()) + public_key = private_key.public_key() + + # Convert to PEM format + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return private_pem, public_pem + + +@pytest.fixture +def mock_settings() -> Generator[MagicMock, None, None]: + with patch("snuba.admin.jwt.settings") as mock_settings: + mock_settings.ADMIN_AUTH_JWT_AUDIENCE = TEST_AUDIENCE + yield mock_settings + + +@pytest.fixture +def mock_certs(key_pair: tuple[bytes, bytes]) -> Generator[MagicMock, None, None]: + _, public_key = key_pair + with patch("snuba.admin.jwt._certs") as mock_certs: + mock_certs.return_value = public_key.decode("utf-8") + yield mock_certs + + +@pytest.fixture +def admin_api( + mock_settings: MagicMock, mock_certs: MagicMock, key_pair: tuple[bytes, bytes] +) -> FlaskClient: + """Returns a Flask test client with valid JWT authentication""" + with patch("snuba.admin.auth.settings.ADMIN_AUTH_PROVIDER", "IAP"): + from snuba.admin.views import application + + private_key, _ = key_pair + token = create_test_token(private_key) + client = application.test_client() + client.environ_base["HTTP_X_GOOG_IAP_JWT_ASSERTION"] = token + return client + + +def create_test_token( + private_key: bytes, + *, + email: str = TEST_EMAIL, + sub: str = TEST_SUB, + audience: str = TEST_AUDIENCE, + exp_delta: timedelta = timedelta(hours=1), +) -> str: + """Helper function to create a test JWT token""" + now = datetime.now(timezone.utc) + payload = { + "email": email, + "sub": sub, + "aud": audience, + "exp": now + exp_delta, + "iat": now, + } + return jwt.encode(payload, private_key.decode("utf-8"), algorithm="ES256")