Skip to content

Commit ef9d5b0

Browse files
Fix using auto_paging_iter() with expand: [...] (#1434)
* deduplicate querystring using a pre-made url * fix tests
1 parent 43d0937 commit ef9d5b0

File tree

3 files changed

+94
-11
lines changed

3 files changed

+94
-11
lines changed

stripe/_api_requestor.py

+36-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
Unpack,
2121
)
2222
import uuid
23-
from urllib.parse import urlsplit, urlunsplit
23+
from urllib.parse import urlsplit, urlunsplit, parse_qs
2424

2525
# breaking circular dependency
2626
import stripe # noqa: IMP101
@@ -556,6 +556,35 @@ def _args_for_request_with_retries(
556556
url,
557557
)
558558

559+
params = params or {}
560+
if params and (method == "get" or method == "delete"):
561+
# if we're sending params in the querystring, then we have to make sure we're not
562+
# duplicating anything we got back from the server already (like in a list iterator)
563+
# so, we parse the querystring the server sends back so we can merge with what we (or the user) are trying to send
564+
existing_params = {}
565+
for k, v in parse_qs(urlsplit(url).query).items():
566+
# note: server sends back "expand[]" but users supply "expand", so we strip the brackets from the key name
567+
if k.endswith("[]"):
568+
existing_params[k[:-2]] = v
569+
else:
570+
# all querystrings are pulled out as lists.
571+
# We want to keep the querystrings that actually are lists, but flatten the ones that are single values
572+
existing_params[k] = v[0] if len(v) == 1 else v
573+
574+
# if a user is expanding something that wasn't expanded before, add (and deduplicate) it
575+
# this could theoretically work for other lists that we want to merge too, but that doesn't seem to be a use case
576+
# it never would have worked before, so I think we can start with `expand` and go from there
577+
if "expand" in existing_params and "expand" in params:
578+
params["expand"] = list( # type:ignore - this is a dict
579+
set([*existing_params["expand"], *params["expand"]])
580+
)
581+
582+
params = {
583+
**existing_params,
584+
# user_supplied params take precedence over server params
585+
**params,
586+
}
587+
559588
encoded_params = urlencode(list(_api_encode(params or {}, api_mode)))
560589

561590
# Don't use strict form encoding by changing the square bracket control
@@ -586,13 +615,13 @@ def _args_for_request_with_retries(
586615

587616
if method == "get" or method == "delete":
588617
if params:
589-
query = encoded_params
590-
scheme, netloc, path, base_query, fragment = urlsplit(abs_url)
618+
# if we're sending query params, we've already merged the incoming ones with the server's "url"
619+
# so we can overwrite the whole thing
620+
scheme, netloc, path, _, fragment = urlsplit(abs_url)
591621

592-
if base_query:
593-
query = "%s&%s" % (base_query, query)
594-
595-
abs_url = urlunsplit((scheme, netloc, path, query, fragment))
622+
abs_url = urlunsplit(
623+
(scheme, netloc, path, encoded_params, fragment)
624+
)
596625
post_data = None
597626
elif method == "post":
598627
if (

tests/api_resources/test_list_object.py

+53
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
import stripe
6+
from tests.http_client_mock import HTTPClientMock
67

78

89
class TestListObject(object):
@@ -439,6 +440,58 @@ def test_forwards_api_key_to_nested_resources(self, http_client_mock):
439440
)
440441
assert lo.data[0].api_key == "sk_test_iter_forwards_options"
441442

443+
def test_iter_with_params(self, http_client_mock: HTTPClientMock):
444+
http_client_mock.stub_request(
445+
"get",
446+
path="/v1/invoices/upcoming/lines",
447+
query_string="customer=cus_123&expand[0]=data.price&limit=1",
448+
rbody=json.dumps(
449+
{
450+
"object": "list",
451+
"data": [
452+
{
453+
"id": "prod_001",
454+
"object": "product",
455+
"price": {"object": "price", "id": "price_123"},
456+
}
457+
],
458+
"url": "/v1/invoices/upcoming/lines?customer=cus_123&expand%5B%5D=data.price",
459+
"has_more": True,
460+
}
461+
),
462+
)
463+
# second page
464+
http_client_mock.stub_request(
465+
"get",
466+
path="/v1/invoices/upcoming/lines",
467+
query_string="customer=cus_123&expand[0]=data.price&limit=1&starting_after=prod_001",
468+
rbody=json.dumps(
469+
{
470+
"object": "list",
471+
"data": [
472+
{
473+
"id": "prod_002",
474+
"object": "product",
475+
"price": {"object": "price", "id": "price_123"},
476+
}
477+
],
478+
"url": "/v1/invoices/upcoming/lines?customer=cus_123&expand%5B%5D=data.price",
479+
"has_more": False,
480+
}
481+
),
482+
)
483+
484+
lo = stripe.Invoice.upcoming_lines(
485+
api_key="sk_test_invoice_lines",
486+
customer="cus_123",
487+
expand=["data.price"],
488+
limit=1,
489+
)
490+
491+
seen = [item["id"] for item in lo.auto_paging_iter()]
492+
493+
assert seen == ["prod_001", "prod_002"]
494+
442495

443496
class TestAutoPagingAsync:
444497
@staticmethod

tests/test_api_requestor.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,17 @@ def test_ordereddict_encoding(self):
245245

246246
def test_url_construction(self, requestor, http_client_mock):
247247
CASES = (
248-
("%s?foo=bar" % stripe.api_base, "", {"foo": "bar"}),
249-
("%s?foo=bar" % stripe.api_base, "?", {"foo": "bar"}),
248+
(f"{stripe.api_base}?foo=bar", "", {"foo": "bar"}),
249+
(f"{stripe.api_base}?foo=bar", "?", {"foo": "bar"}),
250250
(stripe.api_base, "", {}),
251251
(
252-
"%s/%%20spaced?foo=bar%%24&baz=5" % stripe.api_base,
252+
f"{stripe.api_base}/%20spaced?baz=5&foo=bar%24",
253253
"/%20spaced?foo=bar%24",
254254
{"baz": "5"},
255255
),
256+
# duplicate query params keys should be deduped
256257
(
257-
"%s?foo=bar&foo=bar" % stripe.api_base,
258+
f"{stripe.api_base}?foo=bar",
258259
"?foo=bar",
259260
{"foo": "bar"},
260261
),

0 commit comments

Comments
 (0)