Skip to content

Commit fc7714a

Browse files
update test account flow for stripe integration (#521)
1 parent a56dc94 commit fc7714a

13 files changed

+130
-61
lines changed

.github/workflows/django-postgres.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ jobs:
4949
DJANGO_SETTINGS_MODULE: "lotus.settings"
5050
PYTHONPATH: "."
5151
SECRET_KEY: ${{ secrets.SECRET_KEY }}
52-
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
52+
STRIPE_LIVE_SECRET_KEY: ${{ secrets.STRIPE_LIVE_SECRET_KEY }}
53+
STRIPE_TEST_SECRET_KEY: ${{ secrets.STRIPE_TEST_SECRET_KEY }}
5354
DEBUG: False
5455
KAFKA_URL: "localhost:9092"
5556
PYTHONDONTWRITEBYTECODE: 1

.github/workflows/postman_workflow.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ jobs:
4444
DJANGO_SETTINGS_MODULE: "lotus.settings"
4545
PYTHONPATH: "."
4646
SECRET_KEY: ${{ secrets.SECRET_KEY }}
47-
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
47+
STRIPE_LIVE_SECRET_KEY: ${{ secrets.STRIPE_LIVE_SECRET_KEY }}
48+
STRIPE_TEST_SECRET_KEY: ${{ secrets.STRIPE_TEST_SECRET_KEY }}
4849
DEBUG: False
4950
KAFKA_URL: "localhost:9092"
5051
PYTHONDONTWRITEBYTECODE: 1

backend/lotus/settings.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@
6363
PRODUCT_ANALYTICS_OPT_IN = config("PRODUCT_ANALYTICS_OPT_IN", default=True, cast=bool)
6464
PRODUCT_ANALYTICS_OPT_IN = True if not SELF_HOSTED else PRODUCT_ANALYTICS_OPT_IN
6565
# Stripe required
66-
STRIPE_SECRET_KEY = config("STRIPE_SECRET_KEY", default="")
66+
STRIPE_LIVE_SECRET_KEY = config("STRIPE_LIVE_SECRET_KEY", default=None)
67+
if STRIPE_LIVE_SECRET_KEY is None:
68+
STRIPE_LIVE_SECRET_KEY = config("STRIPE_SECRET_KEY", default=None)
69+
STRIPE_TEST_SECRET_KEY = config("STRIPE_TEST_SECRET_KEY", default=None)
6770
STRIPE_WEBHOOK_SECRET = config("STRIPE_WEBHOOK_SECRET", default="whsec_")
6871
# Webhooks for Svix
6972
SVIX_API_KEY = config("SVIX_API_KEY", default="")

backend/metering_billing/models.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@
1717
from django.db.models.constraints import CheckConstraint, UniqueConstraint
1818
from django.db.models.functions import Cast, Coalesce
1919
from django.utils.translation import gettext_lazy as _
20-
from rest_framework_api_key.models import AbstractAPIKey
21-
from simple_history.models import HistoricalRecords
22-
from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate
23-
from svix.internal.openapi_client.models.http_error import HttpError
24-
from svix.internal.openapi_client.models.http_validation_error import (
25-
HTTPValidationError,
26-
)
27-
2820
from metering_billing.exceptions.exceptions import (
2921
ExternalConnectionFailure,
3022
ExternalConnectionInvalid,
@@ -77,6 +69,13 @@
7769
WEBHOOK_TRIGGER_EVENTS,
7870
)
7971
from metering_billing.webhooks import invoice_paid_webhook, usage_alert_webhook
72+
from rest_framework_api_key.models import AbstractAPIKey
73+
from simple_history.models import HistoricalRecords
74+
from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate
75+
from svix.internal.openapi_client.models.http_error import HttpError
76+
from svix.internal.openapi_client.models.http_validation_error import (
77+
HTTPValidationError,
78+
)
8079

8180
logger = logging.getLogger("django.server")
8281
META = settings.META

backend/metering_billing/payment_providers.py

+77-30
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
logger = logging.getLogger("django.server")
2727

2828
SELF_HOSTED = settings.SELF_HOSTED
29-
STRIPE_SECRET_KEY = settings.STRIPE_SECRET_KEY
29+
STRIPE_LIVE_SECRET_KEY = settings.STRIPE_LIVE_SECRET_KEY
30+
STRIPE_TEST_SECRET_KEY = settings.STRIPE_TEST_SECRET_KEY
3031
VITE_STRIPE_CLIENT = settings.VITE_STRIPE_CLIENT
3132
VITE_API_URL = settings.VITE_API_URL
3233

@@ -53,7 +54,7 @@ def working(self) -> bool:
5354
pass
5455

5556
@abc.abstractmethod
56-
def update_payment_object_status(self, payment_object_id: str):
57+
def update_payment_object_status(self, organization, payment_object_id: str):
5758
"""This method will be called periodically when the status of a payment object needs to be updated. It should return the status of the payment object, which should be either paid or unpaid."""
5859
pass
5960

@@ -108,7 +109,8 @@ def initialize_settings(self, organization) -> None:
108109

109110
class StripeConnector(PaymentProvider):
110111
def __init__(self):
111-
self.secret_key = STRIPE_SECRET_KEY
112+
self.live_secret_key = STRIPE_LIVE_SECRET_KEY
113+
self.test_secret_key = STRIPE_TEST_SECRET_KEY
112114
self.self_hosted = SELF_HOSTED
113115
redirect_dict = {
114116
"response_type": "code",
@@ -123,7 +125,7 @@ def __init__(self):
123125
self.redirect_url = ""
124126

125127
def working(self) -> bool:
126-
return self.secret_key != "" and self.secret_key is not None
128+
return self.live_secret_key is not None or self.test_secret_key is not None
127129

128130
def customer_connected(self, customer) -> bool:
129131
pp_ids = customer.integrations
@@ -133,18 +135,26 @@ def customer_connected(self, customer) -> bool:
133135

134136
def organization_connected(self, organization) -> bool:
135137
if self.self_hosted:
136-
return self.secret_key != "" and self.secret_key is not None
138+
return self.live_secret_key is not None or self.test_secret_key is not None
137139
else:
138140
return (
139-
organization.payment_provider_ids.get(PAYMENT_PROVIDERS.STRIPE, "")
140-
!= ""
141+
organization.payment_provider_ids.get(PAYMENT_PROVIDERS.STRIPE, None)
142+
is not None
141143
)
142144

143-
def update_payment_object_status(self, payment_object_id):
144-
from metering_billing.models import Invoice
145+
def update_payment_object_status(self, organization, payment_object_id):
146+
from metering_billing.models import Invoice, Organization
145147

146-
stripe.api_key = self.secret_key
147-
invoice = stripe.Invoice.retrieve(payment_object_id)
148+
invoice_payload = {}
149+
if not self.self_hosted:
150+
invoice_payload["stripe_account"] = organization.payment_provider_ids.get(
151+
PAYMENT_PROVIDERS.STRIPE
152+
)
153+
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
154+
stripe.api_key = self.live_secret_key
155+
else:
156+
stripe.api_key = self.test_secret_key
157+
invoice = stripe.Invoice.retrieve(payment_object_id, **invoice_payload)
148158
if invoice.status == "paid":
149159
return Invoice.PaymentStatus.PAID
150160
else:
@@ -154,15 +164,18 @@ def import_customers(self, organization):
154164
"""
155165
Imports customers from Stripe. If they already exist (by checking that either they already have their Stripe ID in our system, or seeing that they have the same email address), then we update the Stripe section of payment_providers dict to reflect new information. If they don't exist, we create them (not as a Lotus customer yet, just as a Stripe customer).
156166
"""
157-
from metering_billing.models import Customer
167+
from metering_billing.models import Customer, Organization
158168

159-
stripe.api_key = self.secret_key
169+
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
170+
stripe.api_key = self.live_secret_key
171+
else:
172+
stripe.api_key = self.test_secret_key
160173

161174
num_cust_added = 0
162175
org_ppis = organization.payment_provider_ids
163176

164177
stripe_cust_kwargs = {}
165-
if org_ppis.get(PAYMENT_PROVIDERS.STRIPE) not in ["", None]:
178+
if not self.self_hosted:
166179
# this is to get "on behalf" of someone
167180
stripe_cust_kwargs["stripe_account"] = org_ppis.get(
168181
PAYMENT_PROVIDERS.STRIPE
@@ -233,7 +246,12 @@ def import_customers(self, organization):
233246
return num_cust_added
234247

235248
def import_payment_objects(self, organization):
236-
stripe.api_key = self.secret_key
249+
from metering_billing.models import Organization
250+
251+
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
252+
stripe.api_key = self.live_secret_key
253+
else:
254+
stripe.api_key = self.test_secret_key
237255
imported_invoices = {}
238256
for customer in organization.customers.all():
239257
if PAYMENT_PROVIDERS.STRIPE in customer.integrations:
@@ -244,9 +262,13 @@ def import_payment_objects(self, organization):
244262
def _import_payment_objects_for_customer(self, customer):
245263
from metering_billing.models import Invoice
246264

247-
stripe.api_key = self.secret_key
265+
payload = {}
266+
if not self.self_hosted:
267+
payload["stripe_account"] = customer.organization.payment_provider_ids.get(
268+
PAYMENT_PROVIDERS.STRIPE
269+
)
248270
invoices = stripe.Invoice.list(
249-
customer=customer.integrations[PAYMENT_PROVIDERS.STRIPE]["id"]
271+
customer=customer.integrations[PAYMENT_PROVIDERS.STRIPE]["id"], **payload
250272
)
251273
lotus_invoices = []
252274
for stripe_invoice in invoices.auto_paging_iter():
@@ -273,8 +295,15 @@ def _import_payment_objects_for_customer(self, customer):
273295
return lotus_invoices
274296

275297
def create_customer(self, customer):
276-
stripe.api_key = self.secret_key
277-
from metering_billing.models import OrganizationSetting
298+
from metering_billing.models import Organization, OrganizationSetting
299+
300+
if (
301+
customer.organization.organization_type
302+
== Organization.OrganizationType.PRODUCTION
303+
):
304+
stripe.api_key = self.live_secret_key
305+
else:
306+
stripe.api_key = self.test_secret_key
278307

279308
setting = OrganizationSetting.objects.get(
280309
setting_name=ORGANIZATION_SETTING_NAMES.GENERATE_CUSTOMER_IN_STRIPE_AFTER_LOTUS,
@@ -293,10 +322,10 @@ def create_customer(self, customer):
293322
}
294323
if not self.self_hosted:
295324
org_stripe_acct = customer.organization.payment_provider_ids.get(
296-
PAYMENT_PROVIDERS.STRIPE, ""
325+
PAYMENT_PROVIDERS.STRIPE, None
297326
)
298327
assert (
299-
org_stripe_acct != ""
328+
org_stripe_acct is not None
300329
), "Organization does not have a Stripe account ID"
301330
customer_kwargs["stripe_account"] = org_stripe_acct
302331
try:
@@ -318,7 +347,15 @@ def create_customer(self, customer):
318347
)
319348

320349
def create_payment_object(self, invoice) -> str:
321-
stripe.api_key = self.secret_key
350+
from metering_billing.models import Organization
351+
352+
if (
353+
invoice.organization.organization_type
354+
== Organization.OrganizationType.PRODUCTION
355+
):
356+
stripe.api_key = self.live_secret_key
357+
else:
358+
stripe.api_key = self.test_secret_key
322359
# check everything works as expected + build invoice item
323360
assert invoice.external_payment_obj_id is None
324361
customer = invoice.customer
@@ -329,20 +366,17 @@ def create_payment_object(self, invoice) -> str:
329366
invoice_kwargs = {
330367
"auto_advance": True,
331368
"customer": stripe_customer_id,
332-
# "automatic_tax": {
333-
# "enabled": True,
334-
# },
335369
"description": "Invoice from {}".format(
336370
customer.organization.organization_name
337371
),
338372
"currency": invoice.currency.code.lower(),
339373
}
340374
if not self.self_hosted:
341375
org_stripe_acct = customer.organization.payment_provider_ids.get(
342-
PAYMENT_PROVIDERS.STRIPE, ""
376+
PAYMENT_PROVIDERS.STRIPE, None
343377
)
344378
assert (
345-
org_stripe_acct != ""
379+
org_stripe_acct is not None
346380
), "Organization does not have a Stripe account ID"
347381
invoice_kwargs["stripe_account"] = org_stripe_acct
348382

@@ -374,6 +408,8 @@ def create_payment_object(self, invoice) -> str:
374408
"tax_behavior": tax_behavior,
375409
"metadata": metadata,
376410
}
411+
if not self.self_hosted:
412+
inv_dict["stripe_account"] = org_stripe_acct
377413
stripe.InvoiceItem.create(**inv_dict)
378414
stripe_invoice = stripe.Invoice.create(**invoice_kwargs)
379415
return stripe_invoice.id
@@ -385,7 +421,12 @@ class StripePostRequestDataSerializer(serializers.Serializer):
385421
return StripePostRequestDataSerializer
386422

387423
def handle_post(self, data, organization) -> PaymentProviderPostResponseSerializer:
388-
stripe.api_key = self.secret_key
424+
from metering_billing.models import Organization
425+
426+
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
427+
stripe.api_key = self.live_secret_key
428+
else:
429+
stripe.api_key = self.test_secret_key
389430
response = stripe.OAuth.token(
390431
grant_type="authorization_code",
391432
code=data["authorization_code"],
@@ -426,11 +467,15 @@ def transfer_subscriptions(
426467
from metering_billing.models import (
427468
Customer,
428469
ExternalPlanLink,
470+
Organization,
429471
Plan,
430472
SubscriptionRecord,
431473
)
432474

433-
stripe.api_key = self.secret_key
475+
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
476+
stripe.api_key = self.live_secret_key
477+
else:
478+
stripe.api_key = self.test_secret_key
434479

435480
org_ppis = organization.payment_provider_ids
436481
stripe_cust_kwargs = {}
@@ -445,7 +490,7 @@ def transfer_subscriptions(
445490
)
446491

447492
stripe_subscriptions = stripe.Subscription.search(
448-
query="status:'active'",
493+
query="status:'active'", **stripe_cust_kwargs
449494
)
450495
plans_with_links = (
451496
Plan.objects.filter(organization=organization, status=PLAN_STATUS.ACTIVE)
@@ -507,6 +552,7 @@ def transfer_subscriptions(
507552
subscription.id,
508553
prorate=True,
509554
invoice_now=True,
555+
**stripe_cust_kwargs,
510556
)
511557
else:
512558
validated_data["start_date"] = datetime.datetime.utcfromtimestamp(
@@ -515,6 +561,7 @@ def transfer_subscriptions(
515561
sub = stripe.Subscription.modify(
516562
subscription.id,
517563
cancel_at_period_end=True,
564+
**stripe_cust_kwargs,
518565
)
519566
ret_subs.append(sub)
520567
SubscriptionRecord.objects.create(**validated_data)

backend/metering_billing/tasks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from dateutil.relativedelta import relativedelta
77
from django.conf import settings
88
from django.db.models import Q
9-
109
from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP
1110
from metering_billing.serializers.backtest_serializers import (
1211
AllSubstitutionResultsSerializer,
@@ -143,8 +142,9 @@ def update_invoice_status():
143142
for incomplete_invoice in incomplete_invoices:
144143
pp = incomplete_invoice.external_payment_obj_type
145144
if pp in PAYMENT_PROVIDER_MAP and PAYMENT_PROVIDER_MAP[pp].working():
145+
organization = incomplete_invoice.organization
146146
status = PAYMENT_PROVIDER_MAP[pp].update_payment_object_status(
147-
incomplete_invoice.external_payment_obj_id
147+
organization, incomplete_invoice.external_payment_obj_id
148148
)
149149
if status == Invoice.PaymentStatus.PAID:
150150
incomplete_invoice.payment_status = Invoice.PaymentStatus.PAID

backend/metering_billing/tests/conftest.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import posthog
44
import pytest
5-
from model_bakery import baker
6-
75
from metering_billing.utils import now_utc
86
from metering_billing.utils.enums import (
97
FLAT_FEE_BILLING_TYPE,
@@ -13,6 +11,7 @@
1311
PRODUCT_STATUS,
1412
USAGE_BILLING_FREQUENCY,
1513
)
14+
from model_bakery import baker
1615

1716

1817
@pytest.fixture(autouse=True)
@@ -39,12 +38,12 @@ def use_dummy_cache_backend(settings):
3938
def turn_off_stripe_connection():
4039
from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP
4140

42-
sk = PAYMENT_PROVIDER_MAP["stripe"].secret_key
43-
PAYMENT_PROVIDER_MAP["stripe"].secret_key = None
41+
sk = PAYMENT_PROVIDER_MAP["stripe"].test_secret_key
42+
PAYMENT_PROVIDER_MAP["stripe"].test_secret_key = None
4443

4544
yield
4645

47-
PAYMENT_PROVIDER_MAP["stripe"].secret_key = sk
46+
PAYMENT_PROVIDER_MAP["stripe"].test_secret_key = sk
4847

4948

5049
@pytest.fixture

0 commit comments

Comments
 (0)