diff --git a/.github/workflows/style-and-test-check.yml b/.github/workflows/style-and-test-check.yml index 0e84527..227155c 100644 --- a/.github/workflows/style-and-test-check.yml +++ b/.github/workflows/style-and-test-check.yml @@ -20,6 +20,11 @@ jobs: run: make test env: JINA_AUTH_TOKEN: ${{secrets.JINA_AUTH_TOKEN}} + - name: Run payment tests + run: JINA_HUBBLE_REGISTRY=https://apihubble.staging.jina.ai make test-payment + env: + STRIPE_SECRET_KEY: ${{secrets.STRIPE_SECRET_KEY}} + M2M_TOKEN: ${{secrets.M2M_TOKEN}} - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: diff --git a/Makefile b/Makefile index ffa0840..b77f50f 100644 --- a/Makefile +++ b/Makefile @@ -61,12 +61,14 @@ init: # ---------------------------------------------------------------- Test related targets -PYTEST_ARGS = --show-capture no --full-trace --verbose --cov hubble/ --cov-report term-missing --cov-report xml +PYTEST_ARGS = --show-capture no --full-trace --verbose --cov hubble/ --cov-report term-missing --cov-report xml ## Run tests test: - pytest $(PYTEST_ARGS) $(TESTS_PATH) + pytest $(PYTEST_ARGS) --ignore-glob=**/payment/* $(TESTS_PATH) +test-payment: + pytest $(PYTEST_ARGS) tests/unit/payment/ tests/integration/payment # ---------------------------------------------------------- Code style related targets diff --git a/hubble/payment/client.py b/hubble/payment/client.py index 6d4d193..99a54d4 100644 --- a/hubble/payment/client.py +++ b/hubble/payment/client.py @@ -6,29 +6,35 @@ class PaymentClient(PaymentBaseClient): def get_user_token(self, user_id) -> dict: + """Create a user session and token for a user. + + :param user_id: The _id of the user + :returns: Object + """ + return self.handle_request( url=self._base_url + PaymentEndpoints.get_user_token, data={'userId': user_id}, ) def get_authorized_jwt( - self, user_token: str, expiration_seconds: int = 15 * 60 * 1000 + self, token: str, expiration_seconds: int = 15 * 60 * 1000 ) -> dict: """Create a payment authorized JWT for user. - :param user_id: The _id of the user. - :param expiration_seconds: Number of seconds until the JWT expires. - :returns: Object. + :param user_id: The _id of the user + :param expiration_seconds: Number of seconds until the JWT expires + :returns: Object """ return self.handle_request( url=self._base_url + PaymentEndpoints.get_authorized_jwt, - data={'token': user_token, 'ttl': expiration_seconds}, + data={'token': token, 'ttl': expiration_seconds}, ) def verify_authorized_jwt(self, token: str) -> bool: """Verify if a token is a payment authorized JWT - :param token: User token. + :param token: User token :returns: Boolean (true if payment authorized, false otherwise) """ @@ -42,8 +48,8 @@ def verify_authorized_jwt(self, token: str) -> bool: def get_summary(self, token: str, app_id: str) -> object: """Get a list of a user's subscriptions and consumption for a given app. - :param token: User token. - :param app_id: ID of the application. + :param token: User token + :param app_id: ID of the application :returns: Object """ @@ -58,9 +64,9 @@ def report_usage( """Report usage for a given app. - :param token: User token. - :param app_id: ID of the application. - :param product_id: ID of the product. + :param token: User token + :param app_id: ID of the application + :param product_id: ID of the product :returns: Object """ diff --git a/hubble/payment/jwks.py b/hubble/payment/jwks.py index c43d0cc..5fd72f2 100644 --- a/hubble/payment/jwks.py +++ b/hubble/payment/jwks.py @@ -19,7 +19,7 @@ def get_keys(kid: str): def get_keys_from_config(): """Get cached JWK list from config file.""" jwks = config.get('jwks') - return jwks + return jwks if jwks is not None else [] @staticmethod def get_keys_from_hubble(): diff --git a/requirements-dev.txt b/requirements-dev.txt index 221fa67..ea4526b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,6 @@ pytest==7.0.0 pytest-asyncio==0.19.0 pytest-cov==3.0.0 pytest-mock==3.7.0 -mock==4.0.3 \ No newline at end of file +mock==4.0.3 +stripe==5.0.0 +python-dateutil==2.8.2 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 0897fdd..0d010f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,39 @@ +import os import tempfile import pytest +from hubble.payment.client import PaymentClient + +from .utils.stripe import StripeClient @pytest.fixture(autouse=True) def tmpfile(tmpdir): tmpfile = f'jina_test_{next(tempfile._get_candidate_names())}.db' return tmpdir / tmpfile + + +@pytest.fixture(scope='session') +def m2m_token(): + return os.environ.get('M2M_TOKEN', None) + + +# fixture for acquiring a 'cached' instance of StripeClient +@pytest.fixture(scope='session') +def stripe_client(): + api_key = os.environ.get('STRIPE_SECRET_KEY', None) + client = StripeClient(api_key) + yield client + client.cleanup() + + +@pytest.fixture() +def payment_client(m2m_token): + payment_client = PaymentClient(m2m_token=m2m_token) + yield payment_client + + +@pytest.fixture() +def user_token(payment_client, request): + user_token = payment_client.get_user_token(user_id=request.param)['data'] + yield user_token diff --git a/tests/integration/payment/test_non_paying_user.py b/tests/integration/payment/test_non_paying_user.py new file mode 100644 index 0000000..3672137 --- /dev/null +++ b/tests/integration/payment/test_non_paying_user.py @@ -0,0 +1,105 @@ +import time +from datetime import datetime + +import pytest +from dateutil.relativedelta import relativedelta +from mock import patch # noqa: F401 + +INTERNAL_APP_ID = 'hubble-sdk' +INTERNAL_PRODUCT_ID = 'hubble-sdk' + +PRICE_STRIPE_ID = 'price_1MTl37AkuPxeor9kLZxJ5lfd' + +NON_PAYING_USER_ID_1 = '63d75509234f12b36dbd8b36' +NON_PAYING_USER_EMAIL_1 = 'hubble_sdk_user_3@jina.ai' + +NON_PAYING_USER_ID_2 = '63d7551c234f12b36dbd8dca' +NON_PAYING_USER_EMAIL_2 = 'hubble_sdk_user_4@jina.ai' + + +@pytest.mark.parametrize('user_token', [NON_PAYING_USER_ID_1], indirect=True) +def test_get_summary(stripe_client, payment_client, user_token): + + # creating stripe customer for user + customer = stripe_client.get_customer(email=NON_PAYING_USER_EMAIL_1) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + expected_result = {'subscriptionItems': [], 'hasPaymentMethod': False} + assert summary['data'] == expected_result + + # creating subscription + stripe_client.create_subscription( + customer_id=customer['customer_id'], items=[PRICE_STRIPE_ID] + ) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + + expected_result = { + 'subscriptionItems': [ + { + 'internalAppId': INTERNAL_APP_ID, + 'internalProductId': INTERNAL_PRODUCT_ID, + 'usageQuantity': 0, + } + ], + 'hasPaymentMethod': False, + } + + assert summary['data'] == expected_result + + +@pytest.mark.parametrize('user_token', [NON_PAYING_USER_ID_2], indirect=True) +def test_submit_usage_report(stripe_client, payment_client, user_token): + + # try to submit a usage report + customer = stripe_client.get_customer(email=NON_PAYING_USER_EMAIL_2) + + stripe_client.create_subscription( + customer_id=customer['customer_id'], items=[PRICE_STRIPE_ID] + ) + + payment_client.report_usage( + token=user_token, + app_id=INTERNAL_APP_ID, + product_id=INTERNAL_PRODUCT_ID, + quantity=100, + ) + + # NOTE: sleeping to wait for the usage report to be processed + time.sleep(75) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + + expected_result = { + 'subscriptionItems': [ + { + 'internalAppId': INTERNAL_APP_ID, + 'internalProductId': INTERNAL_PRODUCT_ID, + 'usageQuantity': 100, + } + ], + 'hasPaymentMethod': False, + } + + assert summary['data'] == expected_result + + # advancing test clock by one month + # this will trigger a new subscription period + now = datetime.now() + later = now + relativedelta(days=+35) + stripe_client.advance_clock(test_clock_id=customer['test_clock_id'], date=later) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + + expected_result = {'subscriptionItems': [], 'hasPaymentMethod': False} + + assert summary['data'] == expected_result + + +@pytest.mark.parametrize( + 'user_token', [NON_PAYING_USER_ID_1, NON_PAYING_USER_ID_2], indirect=True +) +def test_get_authorized_jwt(payment_client, user_token): + jwt = payment_client.get_authorized_jwt(token=user_token)['data'] + is_authorized = payment_client.verify_authorized_jwt(token=jwt) + assert is_authorized is True diff --git a/tests/integration/payment/test_paying_user.py b/tests/integration/payment/test_paying_user.py new file mode 100644 index 0000000..96eee82 --- /dev/null +++ b/tests/integration/payment/test_paying_user.py @@ -0,0 +1,118 @@ +import time +from datetime import datetime + +import pytest +from dateutil.relativedelta import relativedelta +from mock import patch # noqa: F401 + +INTERNAL_APP_ID = 'hubble-sdk' +INTERNAL_PRODUCT_ID = 'hubble-sdk' + +PRICE_STRIPE_ID = 'price_1MTl37AkuPxeor9kLZxJ5lfd' + +PAYING_USER_ID_1 = '63d754c6234f12b36dbd827f' +PAYING_USER_EMAIL_1 = 'hubble_sdk_user_1@jina.ai' + +PAYING_USER_ID_2 = '63d754de234f12b36dbd8580' +PAYING_USER_EMAIL_2 = 'hubble_sdk_user_2@jina.ai' + + +@pytest.mark.parametrize('user_token', [PAYING_USER_ID_1], indirect=True) +def test_get_summary(stripe_client, payment_client, user_token): + + # creating stripe customer for user + customer = stripe_client.get_customer( + email=PAYING_USER_EMAIL_1, payment_method='pm_card_visa' + ) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + expected_result = {'subscriptionItems': [], 'hasPaymentMethod': True} + assert summary['data'] == expected_result + + # creating subscription + stripe_client.create_subscription( + customer_id=customer['customer_id'], items=[PRICE_STRIPE_ID] + ) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + + expected_result = { + 'subscriptionItems': [ + { + 'internalAppId': INTERNAL_APP_ID, + 'internalProductId': INTERNAL_PRODUCT_ID, + 'usageQuantity': 0, + } + ], + 'hasPaymentMethod': True, + } + + assert summary['data'] == expected_result + + +@pytest.mark.parametrize('user_token', [PAYING_USER_ID_2], indirect=True) +def test_submit_usage_report(stripe_client, payment_client, user_token): + + # try to submit a usage report + customer = stripe_client.get_customer( + email=PAYING_USER_EMAIL_2, payment_method='pm_card_visa' + ) + + stripe_client.create_subscription( + customer_id=customer['customer_id'], items=[PRICE_STRIPE_ID] + ) + + payment_client.report_usage( + token=user_token, + app_id=INTERNAL_APP_ID, + product_id=INTERNAL_PRODUCT_ID, + quantity=100, + ) + + # NOTE: sleeping to wait for the usage report to be processed + time.sleep(75) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + + expected_result = { + 'subscriptionItems': [ + { + 'internalAppId': INTERNAL_APP_ID, + 'internalProductId': INTERNAL_PRODUCT_ID, + 'usageQuantity': 100, + } + ], + 'hasPaymentMethod': True, + } + + assert summary['data'] == expected_result + + # advancing test clock by one month + # this will trigger a new subscription period + now = datetime.now() + later = now + relativedelta(days=+35) + stripe_client.advance_clock(test_clock_id=customer['test_clock_id'], date=later) + + summary = payment_client.get_summary(token=user_token, app_id=INTERNAL_APP_ID) + + expected_result = { + 'subscriptionItems': [ + { + 'internalAppId': INTERNAL_APP_ID, + 'internalProductId': INTERNAL_PRODUCT_ID, + 'usageQuantity': 0, + } + ], + 'hasPaymentMethod': True, + } + + assert summary['data'] == expected_result + + +@pytest.mark.parametrize( + 'user_token', [PAYING_USER_ID_1, PAYING_USER_ID_2], indirect=True +) +def test_get_authorized_jwt(payment_client, user_token): + jwt = payment_client.get_authorized_jwt(token=user_token)['data'] + is_authorized = payment_client.verify_authorized_jwt(token=jwt) + assert is_authorized is True diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/stripe.py b/tests/utils/stripe.py new file mode 100644 index 0000000..202f08f --- /dev/null +++ b/tests/utils/stripe.py @@ -0,0 +1,99 @@ +import time +from datetime import datetime +from typing import List + +DEFAULT_SLEEP_TIME = 10 + + +def wait(secs): + def decorator(func): + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + time.sleep(secs) + return result + + return wrapper + + return decorator + + +class StripeClient(object): + """Stripe utility functions for testing""" + + def __init__(self, stripe_key: str): + try: + import stripe + + self._stripe = stripe + self._stripe.api_key = stripe_key + except Exception: + raise Exception('Failed to initialize stripe, make sure it is installed.') + + self.cache = {} + + @wait(DEFAULT_SLEEP_TIME) + def create_clock(self): + now = datetime.now() + unix_time = now.timestamp() + unix_time = int(unix_time) + + return self._stripe.test_helpers.TestClock.create(frozen_time=unix_time) + + @wait(DEFAULT_SLEEP_TIME) + def advance_clock(self, test_clock_id: str, date: datetime): + # https://stripe.com/docs/api/test_clocks/advance?lang=python + unix_time = date.timestamp() + unix_time = int(unix_time) + + test_clock = self._stripe.test_helpers.TestClock.advance( + test_clock_id, frozen_time=unix_time + ) + + # waiting for test_clock to finish advancing + while test_clock['status'] == 'advancing': + time.sleep(0.5) + test_clock = self._stripe.test_helpers.TestClock.retrieve(test_clock_id) + + return test_clock + + @wait(DEFAULT_SLEEP_TIME) + def delete_clock(self, test_clock_id: str): + return self._stripe.test_helpers.TestClock.delete(test_clock_id) + + @wait(DEFAULT_SLEEP_TIME) + def create_customer(self, email: str, test_clock_id: str, payment_method=None): + return self._stripe.Customer.create( + email=email, test_clock=test_clock_id, payment_method=payment_method + ) + + @wait(DEFAULT_SLEEP_TIME) + def create_subscription(self, customer_id: str, items: List[str]): + # https://stripe.com/docs/api/subscriptions/create?lang=python + subscription_items = [{'price': item} for item in items] + + return self._stripe.Subscription.create( + customer=customer_id, items=subscription_items + ) + + def get_customer(self, email: str, payment_method: str = None): + + if email in self.cache: + return self.cache[email] + + test_clock = self.create_clock() + test_clock_id = test_clock['id'] + + customer = self.create_customer( + email=email, test_clock_id=test_clock_id, payment_method=payment_method + ) + customer_id = customer['id'] + + self.cache[email] = {'customer_id': customer_id, 'test_clock_id': test_clock_id} + + return self.cache[email] + + # TODO: finish this function + def cleanup(self): + for _, customer in self.cache.items(): + self.delete_clock(test_clock_id=customer['test_clock_id']) + self.cache = {}