Skip to content

Commit 0ed92a6

Browse files
committed
estuary-cdk: improve capture concurrency by making Task.checkpoint async
When multiple CDK-basked captures run tasks concurrently, the synchronous Task.checkpoint() method blocks the entire asyncio event loop while writing buffered documents to stdout. This means when one task flushes its buffer, all other tasks are blocked from fetching data, processing records, or building up their own document buffers. We can make `Task.checkpoint` async in order to make flushing a buffer non-blocking to other tasks. Blocking I/O when flushing a buffer is now run with asyncio.to_thread, freeing the main event loop to run other coroutines while a buffer is being flushed. A module-level asyncio.Lock is used to serialize emissions to stdout. This ensures only a single task flushes its buffer at a time, preveing data from multiple buffers from being interleaved or corrupted while being written to stdout. With this change, while one task holds the lock and emits its buffer, other tasks can continue fetching data from the API, process/parse responses, and capture documents to build up their own buffers. When the emitting task releases the lock, another waiting task can begin flushing its buffer. This commit makes `Task.checkpoint()` async, which is a breaking change. All connectors calling `Task.checkpoint` must now await it. The connectors in the `connectors` repo affected by this are updated in this commit as well.
1 parent ecfd938 commit 0ed92a6

File tree

12 files changed

+102
-75
lines changed

12 files changed

+102
-75
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import asyncio
2+
import shutil
3+
from typing import BinaryIO
4+
5+
# Global lock for serializing all emissions to stdout and prevent interleaving output.
6+
_emit_lock = asyncio.Lock()
7+
8+
9+
async def emit_bytes(data: bytes, output: BinaryIO) -> None:
10+
async with _emit_lock:
11+
await asyncio.to_thread(_write_bytes, data, output)
12+
13+
14+
async def emit_from_buffer(buffer: BinaryIO, output: BinaryIO) -> None:
15+
async with _emit_lock:
16+
await asyncio.to_thread(_copy_buffer, buffer, output)
17+
18+
19+
def _write_bytes(data: bytes, output: BinaryIO) -> None:
20+
output.write(data)
21+
output.flush()
22+
23+
24+
def _copy_buffer(buffer: BinaryIO, output: BinaryIO) -> None:
25+
shutil.copyfileobj(buffer, output)
26+
output.flush()

estuary-cdk/estuary_cdk/capture/base_capture_connector.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..logger import FlowLogger
2525
from ..utils import format_error_message, sort_dict
2626
from . import Request, Response, Task, request, response
27+
from ._emit import emit_bytes
2728
from .common import _ConnectorState
2829

2930

@@ -81,20 +82,20 @@ async def handle(
8182
if spec := request.spec:
8283
response = await self.spec(log, spec)
8384
response.protocol = 3032023
84-
self._emit(Response(spec=response))
85+
await self._emit(Response(spec=response))
8586

8687
elif discover := request.discover:
87-
self._emit(Response(discovered=await self.discover(log, discover)))
88+
await self._emit(Response(discovered=await self.discover(log, discover)))
8889

8990
elif validate := request.validate_:
90-
self._emit(Response(validated=await self.validate(log, validate)))
91+
await self._emit(Response(validated=await self.validate(log, validate)))
9192

9293
elif apply := request.apply:
93-
self._emit(Response(applied=await self.apply(log, apply)))
94+
await self._emit(Response(applied=await self.apply(log, apply)))
9495

9596
elif open := request.open:
9697
opened, capture = await self.open(log, open)
97-
self._emit(Response(opened=opened))
98+
await self._emit(Response(opened=opened))
9899

99100
stopping = Task.Stopping(asyncio.Event())
100101

@@ -140,23 +141,20 @@ async def periodic_stop() -> None:
140141
else:
141142
raise RuntimeError("malformed request", request)
142143

143-
def _emit(
144+
async def _emit(
144145
self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState]
145146
):
146-
self.output.write(
147-
response.model_dump_json(by_alias=True, exclude_unset=True).encode()
148-
)
149-
self.output.write(b"\n")
150-
self.output.flush()
147+
data = response.model_dump_json(by_alias=True, exclude_unset=True).encode()
148+
await emit_bytes(data + b"\n", self.output)
151149

152-
def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True):
150+
async def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True):
153151
r = Response[Any, Any, GeneralConnectorState](
154152
checkpoint=response.Checkpoint(
155153
state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch)
156154
)
157155
)
158156

159-
self._emit(r)
157+
await self._emit(r)
160158

161159
async def _encrypt_config(
162160
self,

estuary-cdk/estuary_cdk/capture/common.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,8 @@ class FixedSchema:
451451
name: str
452452
key: list[str]
453453
model: type[_BaseDocument] | FixedSchema
454+
# The open callback can be async or sync.
455+
# Async is required when the callback needs to call task.checkpoint().
454456
open: Callable[
455457
[
456458
CaptureBinding[_BaseResourceConfig],
@@ -464,7 +466,7 @@ class FixedSchema:
464466
]
465467
],
466468
],
467-
None,
469+
None | Awaitable[None],
468470
]
469471
initial_state: _BaseResourceState
470472
initial_config: _BaseResourceConfig
@@ -593,7 +595,7 @@ async def _run(task: Task):
593595
{"stateKey": stateKey},
594596
)
595597
backfill_requests.append(stateKey)
596-
task.checkpoint(
598+
await task.checkpoint(
597599
ConnectorState(
598600
bindingStateV1={stateKey: None},
599601
backfillRequests={stateKey: None},
@@ -614,7 +616,7 @@ async def _run(task: Task):
614616
if state:
615617
if state.last_initialized is None:
616618
state.last_initialized = NOW
617-
task.checkpoint(
619+
await task.checkpoint(
618620
ConnectorState(bindingStateV1={binding.stateKey: state})
619621
)
620622

@@ -665,19 +667,22 @@ async def _run(task: Task):
665667
state.last_initialized = NOW
666668

667669
# Checkpoint the binding's initialized state prior to any processing.
668-
task.checkpoint(
670+
await task.checkpoint(
669671
ConnectorState(
670672
bindingStateV1={binding.stateKey: state},
671673
)
672674
)
673675

674-
resource.open(
676+
result = resource.open(
675677
binding,
676678
index,
677679
state,
678680
task,
679681
resolved_bindings,
680682
)
683+
# Support both sync and async open callbacks
684+
if inspect.iscoroutine(result):
685+
await result
681686

682687

683688
if soonest_future_scheduled_initialization:
@@ -910,7 +915,7 @@ async def _binding_snapshot_task(
910915
# Suppress all captured documents, as they're unchanged.
911916
task.reset()
912917

913-
task.checkpoint(connector_state)
918+
await task.checkpoint(connector_state)
914919

915920

916921
async def _binding_backfill_task(
@@ -1004,22 +1009,22 @@ def _initialize_connector_state(state: ResourceState.Backfill) -> ConnectorState
10041009
state.next_page = item
10051010
state_to_checkpoint = connector_state
10061011

1007-
task.checkpoint(state_to_checkpoint)
1012+
await task.checkpoint(state_to_checkpoint)
10081013
done = False
10091014

10101015
if done:
10111016
break
10121017

10131018
if subtask_id is not None:
1014-
task.checkpoint(
1019+
await task.checkpoint(
10151020
ConnectorState(
10161021
bindingStateV1={
10171022
binding.stateKey: ResourceState(backfill={subtask_id: None})
10181023
}
10191024
)
10201025
)
10211026
else:
1022-
task.checkpoint(
1027+
await task.checkpoint(
10231028
ConnectorState(
10241029
bindingStateV1={binding.stateKey: ResourceState(backfill=None)}
10251030
)
@@ -1108,7 +1113,7 @@ async def _binding_incremental_task(
11081113
"incremental task triggered backfill", {"subtask_id": subtask_id}
11091114
)
11101115
task.stopping.event.set()
1111-
task.checkpoint(
1116+
await task.checkpoint(
11121117
ConnectorState(backfillRequests={binding.stateKey: True})
11131118
)
11141119
return
@@ -1137,7 +1142,7 @@ async def _binding_incremental_task(
11371142
)
11381143

11391144
state.cursor = item
1140-
task.checkpoint(connector_state)
1145+
await task.checkpoint(connector_state)
11411146
checkpoints += 1
11421147
pending = False
11431148

estuary-cdk/estuary_cdk/capture/task.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import base64
88
from logging import Logger
99
import asyncio
10-
import shutil
1110
import tempfile
1211
import traceback
1312
import xxhash
1413

1514
from . import request, response
15+
from ._emit import emit_from_buffer
1616
from ..flow import (
1717
ConnectorSpec,
1818
ConnectorState,
@@ -159,10 +159,9 @@ def sourced_schema(self, binding_index: int, schema: dict[str, Any]):
159159
self._buffer.write(b)
160160
self._buffer.write(b"\n")
161161

162-
def checkpoint(self, state: ConnectorState, merge_patch: bool = True):
163-
"""Emit previously-queued, captured documents follows by a checkpoint"""
164-
165-
self._emit(
162+
async def checkpoint(self, state: ConnectorState, merge_patch: bool = True):
163+
"""Emit previously-queued, captured documents followed by a checkpoint."""
164+
await self._emit(
166165
Response[Any, Any, ConnectorState](
167166
checkpoint=response.Checkpoint(
168167
state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch)
@@ -212,12 +211,11 @@ async def run_task(parent: Task):
212211
task.set_name(child_name)
213212
return task
214213

215-
def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]):
214+
async def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]):
216215
self._buffer.write(
217216
response.model_dump_json(by_alias=True, exclude_unset=True).encode()
218217
)
219218
self._buffer.write(b"\n")
220219
self._buffer.seek(0)
221-
shutil.copyfileobj(self._buffer, self._output)
222-
self._output.flush()
220+
await emit_from_buffer(self._buffer, self._output)
223221
self.reset()

estuary-cdk/estuary_cdk/shim_airbyte_cdk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ async def _run(
375375
binding_index=binding_idx,
376376
schema=schema,
377377
)
378-
task.checkpoint(state=ConnectorState())
378+
await task.checkpoint(state=ConnectorState())
379379

380380
airbyte_catalog = ConfiguredAirbyteCatalog(streams=airbyte_streams)
381381

@@ -462,7 +462,7 @@ async def _run(
462462

463463
entry[1].state = state_msg.dict()
464464

465-
task.checkpoint(connector_state, merge_patch=False)
465+
await task.checkpoint(connector_state, merge_patch=False)
466466

467467
elif trace := message.trace:
468468
if error := trace.error:
@@ -517,5 +517,5 @@ async def _run(
517517

518518
# Emit a final checkpoint before exiting.
519519
task.log.info("Emitting final checkpoint for sweep.")
520-
task.checkpoint(connector_state, merge_patch=False)
520+
await task.checkpoint(connector_state, merge_patch=False)
521521
return None

source-apple-app-store/source_apple_app_store/resources.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _create_initial_state(app_ids: list[str]) -> ResourceState:
7676
return initial_state
7777

7878

79-
def _reconcile_connector_state(
79+
async def _reconcile_connector_state(
8080
app_ids: list[str],
8181
binding: CaptureBinding[ResourceConfig],
8282
state: ResourceState,
@@ -109,7 +109,7 @@ def _reconcile_connector_state(
109109
task.log.info(
110110
f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}."
111111
)
112-
task.checkpoint(
112+
await task.checkpoint(
113113
ConnectorState(
114114
bindingStateV1={binding.stateKey: state},
115115
)
@@ -130,7 +130,7 @@ async def analytics_resources(
130130

131131
initial_state = _create_initial_state(app_ids)
132132

133-
def open(
133+
async def open(
134134
model: type[AppleAnalyticsRow],
135135
binding: CaptureBinding[ResourceConfig],
136136
binding_index: int,
@@ -155,7 +155,7 @@ def open(
155155
model,
156156
)
157157

158-
_reconcile_connector_state(app_ids, binding, state, initial_state, task)
158+
await _reconcile_connector_state(app_ids, binding, state, initial_state, task)
159159

160160
open_binding(
161161
binding,
@@ -194,7 +194,7 @@ async def api_resources(
194194

195195
initial_state = _create_initial_state(app_ids)
196196

197-
def open(
197+
async def open(
198198
app_ids: list[str],
199199
fetch_changes_fn: ApiFetchChangesFn,
200200
fetch_page_fn: ApiFetchPageFn,
@@ -219,7 +219,7 @@ def open(
219219
app_id,
220220
)
221221

222-
_reconcile_connector_state(app_ids, binding, state, initial_state, task)
222+
await _reconcile_connector_state(app_ids, binding, state, initial_state, task)
223223

224224
open_binding(
225225
binding,

source-facebook-marketing-native/source_facebook_marketing_native/resources.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _create_initial_state(account_ids: str | list[str]) -> ResourceState:
143143
)
144144

145145

146-
def _reconcile_connector_state(
146+
async def _reconcile_connector_state(
147147
account_ids: list[str],
148148
binding: CaptureBinding[ResourceConfig],
149149
state: ResourceState,
@@ -176,7 +176,7 @@ def _reconcile_connector_state(
176176
task.log.info(
177177
f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}."
178178
)
179-
task.checkpoint(
179+
await task.checkpoint(
180180
ConnectorState(
181181
bindingStateV1={binding.stateKey: state},
182182
)
@@ -280,7 +280,7 @@ def full_refresh_resource(
280280
snapshot_fn: FullRefreshFetchFn,
281281
use_sourced_schemas: bool = False,
282282
) -> Resource:
283-
def open(
283+
async def open(
284284
binding: CaptureBinding[ResourceConfig],
285285
binding_index: int,
286286
state: ResourceState,
@@ -291,7 +291,7 @@ def open(
291291
):
292292
if use_sourced_schemas and hasattr(model, 'sourced_schema'):
293293
task.sourced_schema(binding_index, model.sourced_schema())
294-
task.checkpoint(state=ConnectorState())
294+
await task.checkpoint(state=ConnectorState())
295295

296296
open_binding(
297297
binding,
@@ -366,7 +366,7 @@ def incremental_resource(
366366
create_fetch_page_fn: Callable[[str], Callable],
367367
create_fetch_changes_fn: Callable[[str], Callable],
368368
) -> Resource:
369-
def open(
369+
async def open(
370370
binding: CaptureBinding[ResourceConfig],
371371
binding_index: int,
372372
state: ResourceState,
@@ -375,7 +375,7 @@ def open(
375375
):
376376
assert len(accounts) > 0, "At least one account ID is required"
377377

378-
_reconcile_connector_state(accounts, binding, state, initial_state, task)
378+
await _reconcile_connector_state(accounts, binding, state, initial_state, task)
379379

380380
fetch_page_fns = {}
381381
fetch_changes_fns = {}

0 commit comments

Comments
 (0)