From 0ed92a692c690c96d2d129ebcc342d18af58fe05 Mon Sep 17 00:00:00 2001 From: Alex Bair Date: Mon, 12 Jan 2026 11:40:10 -0500 Subject: [PATCH] estuary-cdk: improve capture concurrency by making Task.checkpoint async When multiple CDK-basked captures run tasks concurrently, the synchronous Task.checkpoint() method blocks the entire asyncio event loop while writing buffered documents to stdout. This means when one task flushes its buffer, all other tasks are blocked from fetching data, processing records, or building up their own document buffers. We can make `Task.checkpoint` async in order to make flushing a buffer non-blocking to other tasks. Blocking I/O when flushing a buffer is now run with asyncio.to_thread, freeing the main event loop to run other coroutines while a buffer is being flushed. A module-level asyncio.Lock is used to serialize emissions to stdout. This ensures only a single task flushes its buffer at a time, preveing data from multiple buffers from being interleaved or corrupted while being written to stdout. With this change, while one task holds the lock and emits its buffer, other tasks can continue fetching data from the API, process/parse responses, and capture documents to build up their own buffers. When the emitting task releases the lock, another waiting task can begin flushing its buffer. This commit makes `Task.checkpoint()` async, which is a breaking change. All connectors calling `Task.checkpoint` must now await it. The connectors in the `connectors` repo affected by this are updated in this commit as well. --- estuary-cdk/estuary_cdk/capture/_emit.py | 26 ++++++++++++++++++ .../capture/base_capture_connector.py | 24 ++++++++--------- estuary-cdk/estuary_cdk/capture/common.py | 27 +++++++++++-------- estuary-cdk/estuary_cdk/capture/task.py | 14 +++++----- estuary-cdk/estuary_cdk/shim_airbyte_cdk.py | 6 ++--- .../source_apple_app_store/resources.py | 12 ++++----- .../resources.py | 12 ++++----- .../source_jira_native/resources.py | 8 +++--- .../source_sage_intacct/resources.py | 8 +++--- .../source_salesforce_native/resources.py | 8 +++--- .../source_stripe_native/priority_capture.py | 8 +++--- .../source_stripe_native/resources.py | 24 ++++++++--------- 12 files changed, 102 insertions(+), 75 deletions(-) create mode 100644 estuary-cdk/estuary_cdk/capture/_emit.py diff --git a/estuary-cdk/estuary_cdk/capture/_emit.py b/estuary-cdk/estuary_cdk/capture/_emit.py new file mode 100644 index 0000000000..34dc2fb3c5 --- /dev/null +++ b/estuary-cdk/estuary_cdk/capture/_emit.py @@ -0,0 +1,26 @@ +import asyncio +import shutil +from typing import BinaryIO + +# Global lock for serializing all emissions to stdout and prevent interleaving output. +_emit_lock = asyncio.Lock() + + +async def emit_bytes(data: bytes, output: BinaryIO) -> None: + async with _emit_lock: + await asyncio.to_thread(_write_bytes, data, output) + + +async def emit_from_buffer(buffer: BinaryIO, output: BinaryIO) -> None: + async with _emit_lock: + await asyncio.to_thread(_copy_buffer, buffer, output) + + +def _write_bytes(data: bytes, output: BinaryIO) -> None: + output.write(data) + output.flush() + + +def _copy_buffer(buffer: BinaryIO, output: BinaryIO) -> None: + shutil.copyfileobj(buffer, output) + output.flush() diff --git a/estuary-cdk/estuary_cdk/capture/base_capture_connector.py b/estuary-cdk/estuary_cdk/capture/base_capture_connector.py index 9050c1c24f..f7c46b2f46 100644 --- a/estuary-cdk/estuary_cdk/capture/base_capture_connector.py +++ b/estuary-cdk/estuary_cdk/capture/base_capture_connector.py @@ -24,6 +24,7 @@ from ..logger import FlowLogger from ..utils import format_error_message, sort_dict from . import Request, Response, Task, request, response +from ._emit import emit_bytes from .common import _ConnectorState @@ -81,20 +82,20 @@ async def handle( if spec := request.spec: response = await self.spec(log, spec) response.protocol = 3032023 - self._emit(Response(spec=response)) + await self._emit(Response(spec=response)) elif discover := request.discover: - self._emit(Response(discovered=await self.discover(log, discover))) + await self._emit(Response(discovered=await self.discover(log, discover))) elif validate := request.validate_: - self._emit(Response(validated=await self.validate(log, validate))) + await self._emit(Response(validated=await self.validate(log, validate))) elif apply := request.apply: - self._emit(Response(applied=await self.apply(log, apply))) + await self._emit(Response(applied=await self.apply(log, apply))) elif open := request.open: opened, capture = await self.open(log, open) - self._emit(Response(opened=opened)) + await self._emit(Response(opened=opened)) stopping = Task.Stopping(asyncio.Event()) @@ -140,23 +141,20 @@ async def periodic_stop() -> None: else: raise RuntimeError("malformed request", request) - def _emit( + async def _emit( self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState] ): - self.output.write( - response.model_dump_json(by_alias=True, exclude_unset=True).encode() - ) - self.output.write(b"\n") - self.output.flush() + data = response.model_dump_json(by_alias=True, exclude_unset=True).encode() + await emit_bytes(data + b"\n", self.output) - def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True): + async def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True): r = Response[Any, Any, GeneralConnectorState]( checkpoint=response.Checkpoint( state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch) ) ) - self._emit(r) + await self._emit(r) async def _encrypt_config( self, diff --git a/estuary-cdk/estuary_cdk/capture/common.py b/estuary-cdk/estuary_cdk/capture/common.py index b14b5befe6..b367a757cc 100644 --- a/estuary-cdk/estuary_cdk/capture/common.py +++ b/estuary-cdk/estuary_cdk/capture/common.py @@ -451,6 +451,8 @@ class FixedSchema: name: str key: list[str] model: type[_BaseDocument] | FixedSchema + # The open callback can be async or sync. + # Async is required when the callback needs to call task.checkpoint(). open: Callable[ [ CaptureBinding[_BaseResourceConfig], @@ -464,7 +466,7 @@ class FixedSchema: ] ], ], - None, + None | Awaitable[None], ] initial_state: _BaseResourceState initial_config: _BaseResourceConfig @@ -593,7 +595,7 @@ async def _run(task: Task): {"stateKey": stateKey}, ) backfill_requests.append(stateKey) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={stateKey: None}, backfillRequests={stateKey: None}, @@ -614,7 +616,7 @@ async def _run(task: Task): if state: if state.last_initialized is None: state.last_initialized = NOW - task.checkpoint( + await task.checkpoint( ConnectorState(bindingStateV1={binding.stateKey: state}) ) @@ -665,19 +667,22 @@ async def _run(task: Task): state.last_initialized = NOW # Checkpoint the binding's initialized state prior to any processing. - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) ) - resource.open( + result = resource.open( binding, index, state, task, resolved_bindings, ) + # Support both sync and async open callbacks + if inspect.iscoroutine(result): + await result if soonest_future_scheduled_initialization: @@ -910,7 +915,7 @@ async def _binding_snapshot_task( # Suppress all captured documents, as they're unchanged. task.reset() - task.checkpoint(connector_state) + await task.checkpoint(connector_state) async def _binding_backfill_task( @@ -1004,14 +1009,14 @@ def _initialize_connector_state(state: ResourceState.Backfill) -> ConnectorState state.next_page = item state_to_checkpoint = connector_state - task.checkpoint(state_to_checkpoint) + await task.checkpoint(state_to_checkpoint) done = False if done: break if subtask_id is not None: - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={ binding.stateKey: ResourceState(backfill={subtask_id: None}) @@ -1019,7 +1024,7 @@ def _initialize_connector_state(state: ResourceState.Backfill) -> ConnectorState ) ) else: - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: ResourceState(backfill=None)} ) @@ -1108,7 +1113,7 @@ async def _binding_incremental_task( "incremental task triggered backfill", {"subtask_id": subtask_id} ) task.stopping.event.set() - task.checkpoint( + await task.checkpoint( ConnectorState(backfillRequests={binding.stateKey: True}) ) return @@ -1137,7 +1142,7 @@ async def _binding_incremental_task( ) state.cursor = item - task.checkpoint(connector_state) + await task.checkpoint(connector_state) checkpoints += 1 pending = False diff --git a/estuary-cdk/estuary_cdk/capture/task.py b/estuary-cdk/estuary_cdk/capture/task.py index 7214c7e643..c67f85bcec 100644 --- a/estuary-cdk/estuary_cdk/capture/task.py +++ b/estuary-cdk/estuary_cdk/capture/task.py @@ -7,12 +7,12 @@ import base64 from logging import Logger import asyncio -import shutil import tempfile import traceback import xxhash from . import request, response +from ._emit import emit_from_buffer from ..flow import ( ConnectorSpec, ConnectorState, @@ -159,10 +159,9 @@ def sourced_schema(self, binding_index: int, schema: dict[str, Any]): self._buffer.write(b) self._buffer.write(b"\n") - def checkpoint(self, state: ConnectorState, merge_patch: bool = True): - """Emit previously-queued, captured documents follows by a checkpoint""" - - self._emit( + async def checkpoint(self, state: ConnectorState, merge_patch: bool = True): + """Emit previously-queued, captured documents followed by a checkpoint.""" + await self._emit( Response[Any, Any, ConnectorState]( checkpoint=response.Checkpoint( state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch) @@ -212,12 +211,11 @@ async def run_task(parent: Task): task.set_name(child_name) return task - def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]): + async def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]): self._buffer.write( response.model_dump_json(by_alias=True, exclude_unset=True).encode() ) self._buffer.write(b"\n") self._buffer.seek(0) - shutil.copyfileobj(self._buffer, self._output) - self._output.flush() + await emit_from_buffer(self._buffer, self._output) self.reset() diff --git a/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py b/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py index 9074bb0ecb..4e851a4b79 100644 --- a/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py +++ b/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py @@ -375,7 +375,7 @@ async def _run( binding_index=binding_idx, schema=schema, ) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) airbyte_catalog = ConfiguredAirbyteCatalog(streams=airbyte_streams) @@ -462,7 +462,7 @@ async def _run( entry[1].state = state_msg.dict() - task.checkpoint(connector_state, merge_patch=False) + await task.checkpoint(connector_state, merge_patch=False) elif trace := message.trace: if error := trace.error: @@ -517,5 +517,5 @@ async def _run( # Emit a final checkpoint before exiting. task.log.info("Emitting final checkpoint for sweep.") - task.checkpoint(connector_state, merge_patch=False) + await task.checkpoint(connector_state, merge_patch=False) return None diff --git a/source-apple-app-store/source_apple_app_store/resources.py b/source-apple-app-store/source_apple_app_store/resources.py index f98400727a..9253ac4514 100644 --- a/source-apple-app-store/source_apple_app_store/resources.py +++ b/source-apple-app-store/source_apple_app_store/resources.py @@ -76,7 +76,7 @@ def _create_initial_state(app_ids: list[str]) -> ResourceState: return initial_state -def _reconcile_connector_state( +async def _reconcile_connector_state( app_ids: list[str], binding: CaptureBinding[ResourceConfig], state: ResourceState, @@ -109,7 +109,7 @@ def _reconcile_connector_state( task.log.info( f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}." ) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) @@ -130,7 +130,7 @@ async def analytics_resources( initial_state = _create_initial_state(app_ids) - def open( + async def open( model: type[AppleAnalyticsRow], binding: CaptureBinding[ResourceConfig], binding_index: int, @@ -155,7 +155,7 @@ def open( model, ) - _reconcile_connector_state(app_ids, binding, state, initial_state, task) + await _reconcile_connector_state(app_ids, binding, state, initial_state, task) open_binding( binding, @@ -194,7 +194,7 @@ async def api_resources( initial_state = _create_initial_state(app_ids) - def open( + async def open( app_ids: list[str], fetch_changes_fn: ApiFetchChangesFn, fetch_page_fn: ApiFetchPageFn, @@ -219,7 +219,7 @@ def open( app_id, ) - _reconcile_connector_state(app_ids, binding, state, initial_state, task) + await _reconcile_connector_state(app_ids, binding, state, initial_state, task) open_binding( binding, diff --git a/source-facebook-marketing-native/source_facebook_marketing_native/resources.py b/source-facebook-marketing-native/source_facebook_marketing_native/resources.py index 3d0e30b31f..2c045ae800 100644 --- a/source-facebook-marketing-native/source_facebook_marketing_native/resources.py +++ b/source-facebook-marketing-native/source_facebook_marketing_native/resources.py @@ -143,7 +143,7 @@ def _create_initial_state(account_ids: str | list[str]) -> ResourceState: ) -def _reconcile_connector_state( +async def _reconcile_connector_state( account_ids: list[str], binding: CaptureBinding[ResourceConfig], state: ResourceState, @@ -176,7 +176,7 @@ def _reconcile_connector_state( task.log.info( f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}." ) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) @@ -280,7 +280,7 @@ def full_refresh_resource( snapshot_fn: FullRefreshFetchFn, use_sourced_schemas: bool = False, ) -> Resource: - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -291,7 +291,7 @@ def open( ): if use_sourced_schemas and hasattr(model, 'sourced_schema'): task.sourced_schema(binding_index, model.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) open_binding( binding, @@ -366,7 +366,7 @@ def incremental_resource( create_fetch_page_fn: Callable[[str], Callable], create_fetch_changes_fn: Callable[[str], Callable], ) -> Resource: - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -375,7 +375,7 @@ def open( ): assert len(accounts) > 0, "At least one account ID is required" - _reconcile_connector_state(accounts, binding, state, initial_state, task) + await _reconcile_connector_state(accounts, binding, state, initial_state, task) fetch_page_fns = {} fetch_changes_fns = {} diff --git a/source-jira-native/source_jira_native/resources.py b/source-jira-native/source_jira_native/resources.py index ddb50c6678..6116b684e3 100644 --- a/source-jira-native/source_jira_native/resources.py +++ b/source-jira-native/source_jira_native/resources.py @@ -338,7 +338,7 @@ def issues( log: Logger, http: HTTPMixin, config: EndpointConfig, timezone: ZoneInfo ) -> common.Resource: - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -355,7 +355,7 @@ def open( } state.inc = migrated_inc_state - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state} ), @@ -421,7 +421,7 @@ def issue_child_resources( log: Logger, http: HTTPMixin, config: EndpointConfig, timezone: ZoneInfo ) -> list[common.Resource]: - def open( + async def open( stream: type[IssueChildStream], binding: CaptureBinding[ResourceConfig], binding_index: int, @@ -439,7 +439,7 @@ def open( } state.inc = migrated_inc_state - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state} ), diff --git a/source-sage-intacct/source_sage_intacct/resources.py b/source-sage-intacct/source_sage_intacct/resources.py index 2d3088671f..eea5eb026d 100644 --- a/source-sage-intacct/source_sage_intacct/resources.py +++ b/source-sage-intacct/source_sage_intacct/resources.py @@ -55,7 +55,7 @@ async def incremental_resource( ) -> common.Resource: model = await sage.get_model(obj) - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -63,7 +63,7 @@ def open( all_bindings, ): task.sourced_schema(binding_index, model.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, @@ -125,7 +125,7 @@ def open( async def snapshot_resource(sage: Sage, obj: str) -> common.Resource: model = await sage.get_model(obj) - def open( + async def open( obj: str, binding: CaptureBinding[ResourceConfig], binding_index: int, @@ -134,7 +134,7 @@ def open( all_bindings, ): task.sourced_schema(binding_index, model.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, diff --git a/source-salesforce-native/source_salesforce_native/resources.py b/source-salesforce-native/source_salesforce_native/resources.py index 0440fd331c..f39fc5ee87 100644 --- a/source-salesforce-native/source_salesforce_native/resources.py +++ b/source-salesforce-native/source_salesforce_native/resources.py @@ -74,7 +74,7 @@ def full_refresh_resource( enable: bool, ) -> common.Resource: - def open( + async def open( binding: CaptureBinding[SalesforceResourceConfigWithSchedule], binding_index: int, state: ResourceState, @@ -89,7 +89,7 @@ def open( model_cls = create_salesforce_model(name, fields) task.sourced_schema(binding_index, model_cls.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, @@ -137,7 +137,7 @@ def incremental_resource( enable: bool, ) -> common.Resource: - def open( + async def open( binding: CaptureBinding[SalesforceResourceConfigWithSchedule], binding_index: int, state: ResourceState, @@ -152,7 +152,7 @@ def open( model_cls = create_salesforce_model(name, fields) task.sourced_schema(binding_index, model_cls.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, diff --git a/source-stripe-native/source_stripe_native/priority_capture.py b/source-stripe-native/source_stripe_native/priority_capture.py index cc8fc647d7..56a4287661 100644 --- a/source-stripe-native/source_stripe_native/priority_capture.py +++ b/source-stripe-native/source_stripe_native/priority_capture.py @@ -564,7 +564,7 @@ async def _binding_incremental_task_with_work_item( "incremental task triggered backfill", {"subtask_id": work_item.account_id} ) task.stopping.event.set() - task.checkpoint( + await task.checkpoint( ConnectorState(backfillRequests={binding.stateKey: True}) ) return @@ -593,7 +593,7 @@ async def _binding_incremental_task_with_work_item( ) state.cursor = item - task.checkpoint(connector_state) + await task.checkpoint(connector_state) checkpoints += 1 pending = False @@ -684,13 +684,13 @@ async def _binding_backfill_task_with_work_item( ) else: state.next_page = item - task.checkpoint(connector_state) + await task.checkpoint(connector_state) done = False if done: break - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={ binding.stateKey: ResourceState(backfill={work_item.account_id: None}) diff --git a/source-stripe-native/source_stripe_native/resources.py b/source-stripe-native/source_stripe_native/resources.py index 208c658956..30674570d6 100644 --- a/source-stripe-native/source_stripe_native/resources.py +++ b/source-stripe-native/source_stripe_native/resources.py @@ -105,7 +105,7 @@ async def _fetch_platform_account_id( return platform_account.id -def _reconcile_connector_state( +async def _reconcile_connector_state( account_ids: list[str], binding: CaptureBinding[ResourceConfig], state: ResourceState, @@ -146,7 +146,7 @@ def _reconcile_connector_state( task.log.info( f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}." ) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) @@ -325,7 +325,7 @@ def base_object( It requires a single, parent stream with a valid Event API Type """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -357,7 +357,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -416,7 +416,7 @@ def child_object( a valid Event API Type """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -450,7 +450,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -515,7 +515,7 @@ def split_child_object( in the API response. Meaning, the stream behaves like a non-chid stream incrementally. """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -548,7 +548,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -611,7 +611,7 @@ def usage_records( and requires special processing. """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -645,7 +645,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -708,7 +708,7 @@ def no_events_object( It works very similar to the base object, but without the use of the Events APi. """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -740,7 +740,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task )