26
26
logger = logging .getLogger ("django.server" )
27
27
28
28
SELF_HOSTED = settings .SELF_HOSTED
29
- STRIPE_SECRET_KEY = settings .STRIPE_SECRET_KEY
29
+ STRIPE_LIVE_SECRET_KEY = settings .STRIPE_LIVE_SECRET_KEY
30
+ STRIPE_TEST_SECRET_KEY = settings .STRIPE_TEST_SECRET_KEY
30
31
VITE_STRIPE_CLIENT = settings .VITE_STRIPE_CLIENT
31
32
VITE_API_URL = settings .VITE_API_URL
32
33
@@ -53,7 +54,7 @@ def working(self) -> bool:
53
54
pass
54
55
55
56
@abc .abstractmethod
56
- def update_payment_object_status (self , payment_object_id : str ):
57
+ def update_payment_object_status (self , organization , payment_object_id : str ):
57
58
"""This method will be called periodically when the status of a payment object needs to be updated. It should return the status of the payment object, which should be either paid or unpaid."""
58
59
pass
59
60
@@ -108,7 +109,8 @@ def initialize_settings(self, organization) -> None:
108
109
109
110
class StripeConnector (PaymentProvider ):
110
111
def __init__ (self ):
111
- self .secret_key = STRIPE_SECRET_KEY
112
+ self .live_secret_key = STRIPE_LIVE_SECRET_KEY
113
+ self .test_secret_key = STRIPE_TEST_SECRET_KEY
112
114
self .self_hosted = SELF_HOSTED
113
115
redirect_dict = {
114
116
"response_type" : "code" ,
@@ -123,7 +125,7 @@ def __init__(self):
123
125
self .redirect_url = ""
124
126
125
127
def working (self ) -> bool :
126
- return self .secret_key != "" and self .secret_key is not None
128
+ return self .live_secret_key is not None or self .test_secret_key is not None
127
129
128
130
def customer_connected (self , customer ) -> bool :
129
131
pp_ids = customer .integrations
@@ -133,18 +135,26 @@ def customer_connected(self, customer) -> bool:
133
135
134
136
def organization_connected (self , organization ) -> bool :
135
137
if self .self_hosted :
136
- return self .secret_key != "" and self .secret_key is not None
138
+ return self .live_secret_key is not None or self .test_secret_key is not None
137
139
else :
138
140
return (
139
- organization .payment_provider_ids .get (PAYMENT_PROVIDERS .STRIPE , "" )
140
- != ""
141
+ organization .payment_provider_ids .get (PAYMENT_PROVIDERS .STRIPE , None )
142
+ is not None
141
143
)
142
144
143
- def update_payment_object_status (self , payment_object_id ):
144
- from metering_billing .models import Invoice
145
+ def update_payment_object_status (self , organization , payment_object_id ):
146
+ from metering_billing .models import Invoice , Organization
145
147
146
- stripe .api_key = self .secret_key
147
- invoice = stripe .Invoice .retrieve (payment_object_id )
148
+ invoice_payload = {}
149
+ if not self .self_hosted :
150
+ invoice_payload ["stripe_account" ] = organization .payment_provider_ids .get (
151
+ PAYMENT_PROVIDERS .STRIPE
152
+ )
153
+ if organization .organization_type == Organization .OrganizationType .PRODUCTION :
154
+ stripe .api_key = self .live_secret_key
155
+ else :
156
+ stripe .api_key = self .test_secret_key
157
+ invoice = stripe .Invoice .retrieve (payment_object_id , ** invoice_payload )
148
158
if invoice .status == "paid" :
149
159
return Invoice .PaymentStatus .PAID
150
160
else :
@@ -154,15 +164,18 @@ def import_customers(self, organization):
154
164
"""
155
165
Imports customers from Stripe. If they already exist (by checking that either they already have their Stripe ID in our system, or seeing that they have the same email address), then we update the Stripe section of payment_providers dict to reflect new information. If they don't exist, we create them (not as a Lotus customer yet, just as a Stripe customer).
156
166
"""
157
- from metering_billing .models import Customer
167
+ from metering_billing .models import Customer , Organization
158
168
159
- stripe .api_key = self .secret_key
169
+ if organization .organization_type == Organization .OrganizationType .PRODUCTION :
170
+ stripe .api_key = self .live_secret_key
171
+ else :
172
+ stripe .api_key = self .test_secret_key
160
173
161
174
num_cust_added = 0
162
175
org_ppis = organization .payment_provider_ids
163
176
164
177
stripe_cust_kwargs = {}
165
- if org_ppis . get ( PAYMENT_PROVIDERS . STRIPE ) not in [ "" , None ] :
178
+ if not self . self_hosted :
166
179
# this is to get "on behalf" of someone
167
180
stripe_cust_kwargs ["stripe_account" ] = org_ppis .get (
168
181
PAYMENT_PROVIDERS .STRIPE
@@ -233,7 +246,12 @@ def import_customers(self, organization):
233
246
return num_cust_added
234
247
235
248
def import_payment_objects (self , organization ):
236
- stripe .api_key = self .secret_key
249
+ from metering_billing .models import Organization
250
+
251
+ if organization .organization_type == Organization .OrganizationType .PRODUCTION :
252
+ stripe .api_key = self .live_secret_key
253
+ else :
254
+ stripe .api_key = self .test_secret_key
237
255
imported_invoices = {}
238
256
for customer in organization .customers .all ():
239
257
if PAYMENT_PROVIDERS .STRIPE in customer .integrations :
@@ -244,9 +262,13 @@ def import_payment_objects(self, organization):
244
262
def _import_payment_objects_for_customer (self , customer ):
245
263
from metering_billing .models import Invoice
246
264
247
- stripe .api_key = self .secret_key
265
+ payload = {}
266
+ if not self .self_hosted :
267
+ payload ["stripe_account" ] = customer .organization .payment_provider_ids .get (
268
+ PAYMENT_PROVIDERS .STRIPE
269
+ )
248
270
invoices = stripe .Invoice .list (
249
- customer = customer .integrations [PAYMENT_PROVIDERS .STRIPE ]["id" ]
271
+ customer = customer .integrations [PAYMENT_PROVIDERS .STRIPE ]["id" ], ** payload
250
272
)
251
273
lotus_invoices = []
252
274
for stripe_invoice in invoices .auto_paging_iter ():
@@ -273,8 +295,15 @@ def _import_payment_objects_for_customer(self, customer):
273
295
return lotus_invoices
274
296
275
297
def create_customer (self , customer ):
276
- stripe .api_key = self .secret_key
277
- from metering_billing .models import OrganizationSetting
298
+ from metering_billing .models import Organization , OrganizationSetting
299
+
300
+ if (
301
+ customer .organization .organization_type
302
+ == Organization .OrganizationType .PRODUCTION
303
+ ):
304
+ stripe .api_key = self .live_secret_key
305
+ else :
306
+ stripe .api_key = self .test_secret_key
278
307
279
308
setting = OrganizationSetting .objects .get (
280
309
setting_name = ORGANIZATION_SETTING_NAMES .GENERATE_CUSTOMER_IN_STRIPE_AFTER_LOTUS ,
@@ -293,10 +322,10 @@ def create_customer(self, customer):
293
322
}
294
323
if not self .self_hosted :
295
324
org_stripe_acct = customer .organization .payment_provider_ids .get (
296
- PAYMENT_PROVIDERS .STRIPE , ""
325
+ PAYMENT_PROVIDERS .STRIPE , None
297
326
)
298
327
assert (
299
- org_stripe_acct != ""
328
+ org_stripe_acct is not None
300
329
), "Organization does not have a Stripe account ID"
301
330
customer_kwargs ["stripe_account" ] = org_stripe_acct
302
331
try :
@@ -318,7 +347,15 @@ def create_customer(self, customer):
318
347
)
319
348
320
349
def create_payment_object (self , invoice ) -> str :
321
- stripe .api_key = self .secret_key
350
+ from metering_billing .models import Organization
351
+
352
+ if (
353
+ invoice .organization .organization_type
354
+ == Organization .OrganizationType .PRODUCTION
355
+ ):
356
+ stripe .api_key = self .live_secret_key
357
+ else :
358
+ stripe .api_key = self .test_secret_key
322
359
# check everything works as expected + build invoice item
323
360
assert invoice .external_payment_obj_id is None
324
361
customer = invoice .customer
@@ -329,20 +366,17 @@ def create_payment_object(self, invoice) -> str:
329
366
invoice_kwargs = {
330
367
"auto_advance" : True ,
331
368
"customer" : stripe_customer_id ,
332
- # "automatic_tax": {
333
- # "enabled": True,
334
- # },
335
369
"description" : "Invoice from {}" .format (
336
370
customer .organization .organization_name
337
371
),
338
372
"currency" : invoice .currency .code .lower (),
339
373
}
340
374
if not self .self_hosted :
341
375
org_stripe_acct = customer .organization .payment_provider_ids .get (
342
- PAYMENT_PROVIDERS .STRIPE , ""
376
+ PAYMENT_PROVIDERS .STRIPE , None
343
377
)
344
378
assert (
345
- org_stripe_acct != ""
379
+ org_stripe_acct is not None
346
380
), "Organization does not have a Stripe account ID"
347
381
invoice_kwargs ["stripe_account" ] = org_stripe_acct
348
382
@@ -374,6 +408,8 @@ def create_payment_object(self, invoice) -> str:
374
408
"tax_behavior" : tax_behavior ,
375
409
"metadata" : metadata ,
376
410
}
411
+ if not self .self_hosted :
412
+ inv_dict ["stripe_account" ] = org_stripe_acct
377
413
stripe .InvoiceItem .create (** inv_dict )
378
414
stripe_invoice = stripe .Invoice .create (** invoice_kwargs )
379
415
return stripe_invoice .id
@@ -385,7 +421,12 @@ class StripePostRequestDataSerializer(serializers.Serializer):
385
421
return StripePostRequestDataSerializer
386
422
387
423
def handle_post (self , data , organization ) -> PaymentProviderPostResponseSerializer :
388
- stripe .api_key = self .secret_key
424
+ from metering_billing .models import Organization
425
+
426
+ if organization .organization_type == Organization .OrganizationType .PRODUCTION :
427
+ stripe .api_key = self .live_secret_key
428
+ else :
429
+ stripe .api_key = self .test_secret_key
389
430
response = stripe .OAuth .token (
390
431
grant_type = "authorization_code" ,
391
432
code = data ["authorization_code" ],
@@ -426,11 +467,15 @@ def transfer_subscriptions(
426
467
from metering_billing .models import (
427
468
Customer ,
428
469
ExternalPlanLink ,
470
+ Organization ,
429
471
Plan ,
430
472
SubscriptionRecord ,
431
473
)
432
474
433
- stripe .api_key = self .secret_key
475
+ if organization .organization_type == Organization .OrganizationType .PRODUCTION :
476
+ stripe .api_key = self .live_secret_key
477
+ else :
478
+ stripe .api_key = self .test_secret_key
434
479
435
480
org_ppis = organization .payment_provider_ids
436
481
stripe_cust_kwargs = {}
@@ -445,7 +490,7 @@ def transfer_subscriptions(
445
490
)
446
491
447
492
stripe_subscriptions = stripe .Subscription .search (
448
- query = "status:'active'" ,
493
+ query = "status:'active'" , ** stripe_cust_kwargs
449
494
)
450
495
plans_with_links = (
451
496
Plan .objects .filter (organization = organization , status = PLAN_STATUS .ACTIVE )
@@ -507,6 +552,7 @@ def transfer_subscriptions(
507
552
subscription .id ,
508
553
prorate = True ,
509
554
invoice_now = True ,
555
+ ** stripe_cust_kwargs ,
510
556
)
511
557
else :
512
558
validated_data ["start_date" ] = datetime .datetime .utcfromtimestamp (
@@ -515,6 +561,7 @@ def transfer_subscriptions(
515
561
sub = stripe .Subscription .modify (
516
562
subscription .id ,
517
563
cancel_at_period_end = True ,
564
+ ** stripe_cust_kwargs ,
518
565
)
519
566
ret_subs .append (sub )
520
567
SubscriptionRecord .objects .create (** validated_data )
0 commit comments