Skip to content

Commit 341ee6d

Browse files
committed
source-stripe-native: speed up fetching connected account ids with concurrent worker system
The previous `_fetch_connected_account_ids` implementation was a single sequential paginator through GET `/v1/accounts`. For platforms with 36k+ connected accounts, this was taking over an hour. That's too slow since all connected account ids are fetched each time the capture starts up; we'd rather not spend an hour fetching connected account ids before capturing any data. This commit replaces that sequential paginator with a concurrent worker system modeled after `source-klaviyo-native`'s events backfill that partitions the time range into chunks using Stripe's `created[gte]`/`created[lte]` query parameters and has multiple workers paginate through their respective chunks in parallel. Workers detect dense time windows (chunks that take >30s to paginate) and, when idle workers are available, they subdivide the remaining unprocessed range into smaller chunks for other workers to pick up.
1 parent 44f3261 commit 341ee6d

File tree

3 files changed

+653
-27
lines changed

3 files changed

+653
-27
lines changed
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
import asyncio
2+
import time
3+
import traceback
4+
from asyncio import CancelledError
5+
from dataclasses import dataclass
6+
from datetime import UTC, datetime
7+
from logging import Logger
8+
9+
from estuary_cdk.http import HTTPSession
10+
from estuary_cdk.utils import format_error_message
11+
12+
from .api import API
13+
from .models import Accounts, ListResult
14+
15+
16+
# Stripe launched in 2011. No connected accounts can exist before this date.
17+
STRIPE_EPOCH = int(datetime(2011, 2, 1, tzinfo=UTC).timestamp())
18+
19+
20+
@dataclass
21+
class TimestampChunk:
22+
start: int
23+
end: int
24+
25+
26+
@dataclass
27+
class AccountFetchConfig:
28+
initial_chunks: int = 50
29+
num_workers: int = 10
30+
31+
# How many seconds a worker processes a chunk before it is considered
32+
# a dense time window and should be divided into smaller chunks.
33+
dense_chunk_threshold_seconds: int = 30
34+
35+
# Minimum chunk size in seconds (1 minute) to prevent over-subdivision.
36+
minimum_chunk_seconds: int = 60
37+
38+
work_queue_size: int = 100
39+
40+
page_limit: int = 100
41+
42+
43+
DEFAULT_CONFIG = AccountFetchConfig()
44+
45+
46+
def split_timestamp_range(
47+
start: int,
48+
end: int,
49+
num_chunks: int,
50+
minimum_chunk_seconds: int,
51+
) -> list[TimestampChunk]:
52+
"""Split a unix timestamp range into roughly equal-sized chunks.
53+
54+
If minimum_chunk_seconds would be violated, fewer chunks are produced.
55+
Returns a list of TimestampChunks covering the full range.
56+
"""
57+
if num_chunks <= 0:
58+
raise ValueError("num_chunks must be positive")
59+
60+
if start >= end:
61+
raise ValueError("start must be before end")
62+
63+
total_duration = end - start
64+
num_chunks = max(
65+
1,
66+
min(
67+
num_chunks,
68+
total_duration // minimum_chunk_seconds,
69+
),
70+
)
71+
chunk_duration = total_duration / num_chunks
72+
73+
chunks: list[TimestampChunk] = []
74+
current_start = start
75+
76+
for i in range(num_chunks):
77+
if i == num_chunks - 1:
78+
chunk_end = end
79+
else:
80+
chunk_end = int(current_start + chunk_duration)
81+
82+
chunks.append(TimestampChunk(start=current_start, end=chunk_end))
83+
current_start = chunk_end
84+
85+
return chunks
86+
87+
88+
class AccountWorkManager:
89+
def __init__(
90+
self,
91+
http: HTTPSession,
92+
log: Logger,
93+
config: AccountFetchConfig = DEFAULT_CONFIG,
94+
):
95+
self.http = http
96+
self.log = log
97+
self.config = config
98+
99+
self.work_queue: asyncio.Queue[TimestampChunk | None] = asyncio.Queue(maxsize=config.work_queue_size)
100+
101+
self._active_worker_count = 0
102+
self.first_worker_error: str | None = None
103+
104+
self.account_ids: set[str] = set()
105+
106+
def mark_worker_active(self) -> None:
107+
if self._active_worker_count >= self.config.num_workers:
108+
raise Exception(f"A worker attempted to mark itself active when the active worker count is {self._active_worker_count}.")
109+
self._active_worker_count += 1
110+
111+
def mark_worker_inactive(self) -> None:
112+
if self._active_worker_count <= 0:
113+
raise Exception(f"A worker attempted to mark itself inactive when the active worker count is {self._active_worker_count}.")
114+
self._active_worker_count -= 1
115+
116+
def are_active_workers(self) -> bool:
117+
return self._active_worker_count > 0
118+
119+
def has_idle_workers(self) -> bool:
120+
return self._active_worker_count < self.config.num_workers
121+
122+
async def fetch_account_ids(self, start: int, end: int) -> set[str]:
123+
initial_chunks = self._create_initial_chunks(start, end)
124+
for chunk in initial_chunks:
125+
self.work_queue.put_nowait(chunk)
126+
127+
# Purely diagnostic: logs task outcomes to aid debugging when a task
128+
# fails or is cancelled unexpectedly. Not used for control flow —
129+
# the TaskGroup handles exception propagation and cancellation.
130+
def callback(task: asyncio.Task):
131+
task_name = task.get_name()
132+
status: str = ""
133+
stack_trace: str | None = None
134+
135+
if task.cancelled():
136+
status = "cancelled"
137+
elif exc := task.exception():
138+
status = f"failed with exception {format_error_message(exc)}"
139+
if exc.__traceback__:
140+
stack_trace = "\nStack trace:\n" + "".join(traceback.format_list(traceback.extract_tb(exc.__traceback__)))
141+
else:
142+
status = "completed"
143+
144+
self.log.debug(f"Task {task_name} {status}.", {
145+
"first_worker_error": self.first_worker_error,
146+
"active_worker_count": self._active_worker_count,
147+
"stack_trace": stack_trace,
148+
})
149+
150+
self.log.debug("Starting concurrent account fetch workers.")
151+
async with asyncio.TaskGroup() as tg:
152+
for i in range(self.config.num_workers):
153+
worker_id = i + 1
154+
task = tg.create_task(
155+
account_chunk_worker(
156+
worker_id=worker_id,
157+
work_queue=self.work_queue,
158+
work_manager=self,
159+
http=self.http,
160+
log=self.log,
161+
config=self.config,
162+
),
163+
name=f"account_chunk_worker_{worker_id}"
164+
)
165+
task.add_done_callback(callback)
166+
167+
task = tg.create_task(
168+
self._shutdown_coordinator(),
169+
name="account_shutdown_coordinator"
170+
)
171+
task.add_done_callback(callback)
172+
173+
self.log.debug(f"Concurrent account fetch complete. Found {len(self.account_ids)} account IDs.")
174+
return self.account_ids
175+
176+
def _create_initial_chunks(self, start: int, end: int) -> list[TimestampChunk]:
177+
return split_timestamp_range(
178+
start=start,
179+
end=end,
180+
num_chunks=self.config.initial_chunks,
181+
minimum_chunk_seconds=self.config.minimum_chunk_seconds,
182+
)
183+
184+
async def _shutdown_coordinator(self) -> None:
185+
"""Wait for all work items to be processed, then signal workers to exit."""
186+
await self.work_queue.join()
187+
for _ in range(self.config.num_workers):
188+
self.work_queue.put_nowait(None)
189+
190+
191+
async def account_chunk_worker(
192+
worker_id: int,
193+
work_queue: asyncio.Queue[TimestampChunk | None],
194+
work_manager: AccountWorkManager,
195+
http: HTTPSession,
196+
log: Logger,
197+
config: AccountFetchConfig,
198+
) -> None:
199+
try:
200+
log.debug(f"Account worker {worker_id} started.")
201+
202+
while True:
203+
chunk = await work_queue.get()
204+
if chunk is None:
205+
work_queue.task_done()
206+
break
207+
208+
log.debug(f"Account worker {worker_id} working on chunk [{chunk.start}, {chunk.end}]")
209+
210+
work_manager.mark_worker_active()
211+
212+
try:
213+
url = f"{API}/accounts"
214+
params: dict[str, str | int] = {
215+
"limit": config.page_limit,
216+
"created[gte]": chunk.start,
217+
"created[lte]": chunk.end,
218+
}
219+
start_time = time.time()
220+
page_count = 0
221+
last_created: int | None = None
222+
# Only consider subdividing if the chunk is large enough that the resulting
223+
# sub-chunks would each still be meaningfully sized (above minimum_chunk_seconds).
224+
is_divisible = (chunk.end - chunk.start) >= (config.minimum_chunk_seconds * 3)
225+
is_dense_chunk = False
226+
227+
while True:
228+
response = ListResult[Accounts].model_validate_json(
229+
await http.request(log, url, params=params)
230+
)
231+
232+
for account in response.data:
233+
work_manager.account_ids.add(account.id)
234+
last_created = account.created
235+
236+
page_count += 1
237+
238+
if not response.has_more:
239+
break
240+
241+
# Stripe returns results in descending created order. After paginating
242+
# partway through chunk [S, E], the worker has processed accounts from E
243+
# down to last_created. The unprocessed range is [S, last_created].
244+
if (
245+
is_dense_chunk
246+
and last_created is not None
247+
and last_created > chunk.start
248+
and work_manager.has_idle_workers()
249+
):
250+
sub_chunks = split_timestamp_range(
251+
start=chunk.start,
252+
end=last_created,
253+
num_chunks=config.num_workers,
254+
minimum_chunk_seconds=config.minimum_chunk_seconds,
255+
)
256+
for sub_chunk in sub_chunks:
257+
work_queue.put_nowait(sub_chunk)
258+
break
259+
260+
# Check if this chunk is dense after each page.
261+
if (
262+
is_divisible
263+
and not is_dense_chunk
264+
):
265+
elapsed = time.time() - start_time
266+
if elapsed > config.dense_chunk_threshold_seconds:
267+
is_dense_chunk = True
268+
269+
params["starting_after"] = response.data[-1].id
270+
271+
log.debug(f"Account worker {worker_id} finished chunk. Pages: {page_count}, dense: {is_dense_chunk}")
272+
273+
# task_done() is called after any sub-chunks are enqueued to prevent
274+
# join() from returning before the sub-chunks are processed.
275+
work_queue.task_done()
276+
finally:
277+
work_manager.mark_worker_inactive()
278+
279+
log.debug(f"Account worker {worker_id} exited.")
280+
except CancelledError as e:
281+
if not work_manager.first_worker_error:
282+
msg = format_error_message(e)
283+
work_manager.first_worker_error = msg
284+
raise Exception(f"Account worker {worker_id} was unexpectedly cancelled: {msg}")
285+
else:
286+
raise e
287+
except BaseException as e:
288+
msg = format_error_message(e)
289+
if not work_manager.first_worker_error:
290+
work_manager.first_worker_error = msg
291+
292+
log.error(f"Account worker {worker_id} encountered an error.", {
293+
"exception": msg,
294+
})
295+
raise e
296+
297+
298+
async def fetch_connected_account_ids(
299+
http: HTTPSession,
300+
log: Logger,
301+
config: AccountFetchConfig = DEFAULT_CONFIG,
302+
) -> list[str]:
303+
"""Fetch all connected account IDs using concurrent workers.
304+
305+
Uses time-range partitioning with created[gte]/created[lte] to parallelize
306+
paginating through the accounts list endpoint.
307+
"""
308+
start = STRIPE_EPOCH
309+
end = int(time.time())
310+
311+
work_manager = AccountWorkManager(http=http, log=log, config=config)
312+
await work_manager.fetch_account_ids(start, end)
313+
314+
return list(work_manager.account_ids)

source-stripe-native/source_stripe_native/resources.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from estuary_cdk.http import HTTPError, HTTPMixin, HTTPSession, TokenSource
2222

23+
from .account_fetcher import fetch_connected_account_ids
2324
from .api import (
2425
API,
2526
fetch_backfill,
@@ -39,7 +40,6 @@
3940
Accounts,
4041
ConnectorState,
4142
EndpointConfig,
42-
ListResult,
4343
)
4444
from .priority_capture import (
4545
open_binding_with_priority_queue,
@@ -77,31 +77,6 @@ async def check_accessibility(
7777
return is_accessible
7878

7979

80-
async def _fetch_connected_account_ids(
81-
http: HTTPSession,
82-
log: Logger,
83-
) -> list[str]:
84-
account_ids: set[str] = set()
85-
86-
url = f"{API}/accounts"
87-
params: dict[str, str | int] = {"limit": 100}
88-
89-
while True:
90-
response = ListResult[Accounts].model_validate_json(
91-
await http.request(log, url, params=params)
92-
)
93-
94-
for account in response.data:
95-
account_ids.add(account.id)
96-
97-
if not response.has_more:
98-
break
99-
100-
params["starting_after"] = response.data[-1].id
101-
102-
return list(account_ids)
103-
104-
10580
async def _fetch_platform_account_id(
10681
http: HTTPSession,
10782
log: Logger,
@@ -189,7 +164,7 @@ async def all_resources(
189164
log.info(
190165
"Fetching connected account IDs. This may take multiple minutes if there are many connected accounts."
191166
)
192-
connected_account_ids = await _fetch_connected_account_ids(http, log)
167+
connected_account_ids = await fetch_connected_account_ids(http, log)
193168
log.info(
194169
f"Found {len(connected_account_ids)} connected account IDs.",
195170
{

0 commit comments

Comments
 (0)