diff --git a/estuary-cdk/estuary_cdk/capture/_emit.py b/estuary-cdk/estuary_cdk/capture/_emit.py new file mode 100644 index 0000000000..34dc2fb3c5 --- /dev/null +++ b/estuary-cdk/estuary_cdk/capture/_emit.py @@ -0,0 +1,26 @@ +import asyncio +import shutil +from typing import BinaryIO + +# Global lock for serializing all emissions to stdout and prevent interleaving output. +_emit_lock = asyncio.Lock() + + +async def emit_bytes(data: bytes, output: BinaryIO) -> None: + async with _emit_lock: + await asyncio.to_thread(_write_bytes, data, output) + + +async def emit_from_buffer(buffer: BinaryIO, output: BinaryIO) -> None: + async with _emit_lock: + await asyncio.to_thread(_copy_buffer, buffer, output) + + +def _write_bytes(data: bytes, output: BinaryIO) -> None: + output.write(data) + output.flush() + + +def _copy_buffer(buffer: BinaryIO, output: BinaryIO) -> None: + shutil.copyfileobj(buffer, output) + output.flush() diff --git a/estuary-cdk/estuary_cdk/capture/base_capture_connector.py b/estuary-cdk/estuary_cdk/capture/base_capture_connector.py index 9050c1c24f..f7c46b2f46 100644 --- a/estuary-cdk/estuary_cdk/capture/base_capture_connector.py +++ b/estuary-cdk/estuary_cdk/capture/base_capture_connector.py @@ -24,6 +24,7 @@ from ..logger import FlowLogger from ..utils import format_error_message, sort_dict from . import Request, Response, Task, request, response +from ._emit import emit_bytes from .common import _ConnectorState @@ -81,20 +82,20 @@ async def handle( if spec := request.spec: response = await self.spec(log, spec) response.protocol = 3032023 - self._emit(Response(spec=response)) + await self._emit(Response(spec=response)) elif discover := request.discover: - self._emit(Response(discovered=await self.discover(log, discover))) + await self._emit(Response(discovered=await self.discover(log, discover))) elif validate := request.validate_: - self._emit(Response(validated=await self.validate(log, validate))) + await self._emit(Response(validated=await self.validate(log, validate))) elif apply := request.apply: - self._emit(Response(applied=await self.apply(log, apply))) + await self._emit(Response(applied=await self.apply(log, apply))) elif open := request.open: opened, capture = await self.open(log, open) - self._emit(Response(opened=opened)) + await self._emit(Response(opened=opened)) stopping = Task.Stopping(asyncio.Event()) @@ -140,23 +141,20 @@ async def periodic_stop() -> None: else: raise RuntimeError("malformed request", request) - def _emit( + async def _emit( self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState] ): - self.output.write( - response.model_dump_json(by_alias=True, exclude_unset=True).encode() - ) - self.output.write(b"\n") - self.output.flush() + data = response.model_dump_json(by_alias=True, exclude_unset=True).encode() + await emit_bytes(data + b"\n", self.output) - def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True): + async def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True): r = Response[Any, Any, GeneralConnectorState]( checkpoint=response.Checkpoint( state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch) ) ) - self._emit(r) + await self._emit(r) async def _encrypt_config( self, diff --git a/estuary-cdk/estuary_cdk/capture/common.py b/estuary-cdk/estuary_cdk/capture/common.py index b14b5befe6..b367a757cc 100644 --- a/estuary-cdk/estuary_cdk/capture/common.py +++ b/estuary-cdk/estuary_cdk/capture/common.py @@ -451,6 +451,8 @@ class FixedSchema: name: str key: list[str] model: type[_BaseDocument] | FixedSchema + # The open callback can be async or sync. + # Async is required when the callback needs to call task.checkpoint(). open: Callable[ [ CaptureBinding[_BaseResourceConfig], @@ -464,7 +466,7 @@ class FixedSchema: ] ], ], - None, + None | Awaitable[None], ] initial_state: _BaseResourceState initial_config: _BaseResourceConfig @@ -593,7 +595,7 @@ async def _run(task: Task): {"stateKey": stateKey}, ) backfill_requests.append(stateKey) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={stateKey: None}, backfillRequests={stateKey: None}, @@ -614,7 +616,7 @@ async def _run(task: Task): if state: if state.last_initialized is None: state.last_initialized = NOW - task.checkpoint( + await task.checkpoint( ConnectorState(bindingStateV1={binding.stateKey: state}) ) @@ -665,19 +667,22 @@ async def _run(task: Task): state.last_initialized = NOW # Checkpoint the binding's initialized state prior to any processing. - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) ) - resource.open( + result = resource.open( binding, index, state, task, resolved_bindings, ) + # Support both sync and async open callbacks + if inspect.iscoroutine(result): + await result if soonest_future_scheduled_initialization: @@ -910,7 +915,7 @@ async def _binding_snapshot_task( # Suppress all captured documents, as they're unchanged. task.reset() - task.checkpoint(connector_state) + await task.checkpoint(connector_state) async def _binding_backfill_task( @@ -1004,14 +1009,14 @@ def _initialize_connector_state(state: ResourceState.Backfill) -> ConnectorState state.next_page = item state_to_checkpoint = connector_state - task.checkpoint(state_to_checkpoint) + await task.checkpoint(state_to_checkpoint) done = False if done: break if subtask_id is not None: - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={ binding.stateKey: ResourceState(backfill={subtask_id: None}) @@ -1019,7 +1024,7 @@ def _initialize_connector_state(state: ResourceState.Backfill) -> ConnectorState ) ) else: - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: ResourceState(backfill=None)} ) @@ -1108,7 +1113,7 @@ async def _binding_incremental_task( "incremental task triggered backfill", {"subtask_id": subtask_id} ) task.stopping.event.set() - task.checkpoint( + await task.checkpoint( ConnectorState(backfillRequests={binding.stateKey: True}) ) return @@ -1137,7 +1142,7 @@ async def _binding_incremental_task( ) state.cursor = item - task.checkpoint(connector_state) + await task.checkpoint(connector_state) checkpoints += 1 pending = False diff --git a/estuary-cdk/estuary_cdk/capture/task.py b/estuary-cdk/estuary_cdk/capture/task.py index 7214c7e643..c67f85bcec 100644 --- a/estuary-cdk/estuary_cdk/capture/task.py +++ b/estuary-cdk/estuary_cdk/capture/task.py @@ -7,12 +7,12 @@ import base64 from logging import Logger import asyncio -import shutil import tempfile import traceback import xxhash from . import request, response +from ._emit import emit_from_buffer from ..flow import ( ConnectorSpec, ConnectorState, @@ -159,10 +159,9 @@ def sourced_schema(self, binding_index: int, schema: dict[str, Any]): self._buffer.write(b) self._buffer.write(b"\n") - def checkpoint(self, state: ConnectorState, merge_patch: bool = True): - """Emit previously-queued, captured documents follows by a checkpoint""" - - self._emit( + async def checkpoint(self, state: ConnectorState, merge_patch: bool = True): + """Emit previously-queued, captured documents followed by a checkpoint.""" + await self._emit( Response[Any, Any, ConnectorState]( checkpoint=response.Checkpoint( state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch) @@ -212,12 +211,11 @@ async def run_task(parent: Task): task.set_name(child_name) return task - def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]): + async def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]): self._buffer.write( response.model_dump_json(by_alias=True, exclude_unset=True).encode() ) self._buffer.write(b"\n") self._buffer.seek(0) - shutil.copyfileobj(self._buffer, self._output) - self._output.flush() + await emit_from_buffer(self._buffer, self._output) self.reset() diff --git a/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py b/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py index 9074bb0ecb..4e851a4b79 100644 --- a/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py +++ b/estuary-cdk/estuary_cdk/shim_airbyte_cdk.py @@ -375,7 +375,7 @@ async def _run( binding_index=binding_idx, schema=schema, ) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) airbyte_catalog = ConfiguredAirbyteCatalog(streams=airbyte_streams) @@ -462,7 +462,7 @@ async def _run( entry[1].state = state_msg.dict() - task.checkpoint(connector_state, merge_patch=False) + await task.checkpoint(connector_state, merge_patch=False) elif trace := message.trace: if error := trace.error: @@ -517,5 +517,5 @@ async def _run( # Emit a final checkpoint before exiting. task.log.info("Emitting final checkpoint for sweep.") - task.checkpoint(connector_state, merge_patch=False) + await task.checkpoint(connector_state, merge_patch=False) return None diff --git a/source-apple-app-store/source_apple_app_store/resources.py b/source-apple-app-store/source_apple_app_store/resources.py index f98400727a..9253ac4514 100644 --- a/source-apple-app-store/source_apple_app_store/resources.py +++ b/source-apple-app-store/source_apple_app_store/resources.py @@ -76,7 +76,7 @@ def _create_initial_state(app_ids: list[str]) -> ResourceState: return initial_state -def _reconcile_connector_state( +async def _reconcile_connector_state( app_ids: list[str], binding: CaptureBinding[ResourceConfig], state: ResourceState, @@ -109,7 +109,7 @@ def _reconcile_connector_state( task.log.info( f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}." ) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) @@ -130,7 +130,7 @@ async def analytics_resources( initial_state = _create_initial_state(app_ids) - def open( + async def open( model: type[AppleAnalyticsRow], binding: CaptureBinding[ResourceConfig], binding_index: int, @@ -155,7 +155,7 @@ def open( model, ) - _reconcile_connector_state(app_ids, binding, state, initial_state, task) + await _reconcile_connector_state(app_ids, binding, state, initial_state, task) open_binding( binding, @@ -194,7 +194,7 @@ async def api_resources( initial_state = _create_initial_state(app_ids) - def open( + async def open( app_ids: list[str], fetch_changes_fn: ApiFetchChangesFn, fetch_page_fn: ApiFetchPageFn, @@ -219,7 +219,7 @@ def open( app_id, ) - _reconcile_connector_state(app_ids, binding, state, initial_state, task) + await _reconcile_connector_state(app_ids, binding, state, initial_state, task) open_binding( binding, diff --git a/source-facebook-marketing-native/source_facebook_marketing_native/resources.py b/source-facebook-marketing-native/source_facebook_marketing_native/resources.py index 3d0e30b31f..2c045ae800 100644 --- a/source-facebook-marketing-native/source_facebook_marketing_native/resources.py +++ b/source-facebook-marketing-native/source_facebook_marketing_native/resources.py @@ -143,7 +143,7 @@ def _create_initial_state(account_ids: str | list[str]) -> ResourceState: ) -def _reconcile_connector_state( +async def _reconcile_connector_state( account_ids: list[str], binding: CaptureBinding[ResourceConfig], state: ResourceState, @@ -176,7 +176,7 @@ def _reconcile_connector_state( task.log.info( f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}." ) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) @@ -280,7 +280,7 @@ def full_refresh_resource( snapshot_fn: FullRefreshFetchFn, use_sourced_schemas: bool = False, ) -> Resource: - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -291,7 +291,7 @@ def open( ): if use_sourced_schemas and hasattr(model, 'sourced_schema'): task.sourced_schema(binding_index, model.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) open_binding( binding, @@ -366,7 +366,7 @@ def incremental_resource( create_fetch_page_fn: Callable[[str], Callable], create_fetch_changes_fn: Callable[[str], Callable], ) -> Resource: - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -375,7 +375,7 @@ def open( ): assert len(accounts) > 0, "At least one account ID is required" - _reconcile_connector_state(accounts, binding, state, initial_state, task) + await _reconcile_connector_state(accounts, binding, state, initial_state, task) fetch_page_fns = {} fetch_changes_fns = {} diff --git a/source-jira-native/source_jira_native/resources.py b/source-jira-native/source_jira_native/resources.py index ddb50c6678..6116b684e3 100644 --- a/source-jira-native/source_jira_native/resources.py +++ b/source-jira-native/source_jira_native/resources.py @@ -338,7 +338,7 @@ def issues( log: Logger, http: HTTPMixin, config: EndpointConfig, timezone: ZoneInfo ) -> common.Resource: - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -355,7 +355,7 @@ def open( } state.inc = migrated_inc_state - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state} ), @@ -421,7 +421,7 @@ def issue_child_resources( log: Logger, http: HTTPMixin, config: EndpointConfig, timezone: ZoneInfo ) -> list[common.Resource]: - def open( + async def open( stream: type[IssueChildStream], binding: CaptureBinding[ResourceConfig], binding_index: int, @@ -439,7 +439,7 @@ def open( } state.inc = migrated_inc_state - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state} ), diff --git a/source-sage-intacct/source_sage_intacct/resources.py b/source-sage-intacct/source_sage_intacct/resources.py index 2d3088671f..eea5eb026d 100644 --- a/source-sage-intacct/source_sage_intacct/resources.py +++ b/source-sage-intacct/source_sage_intacct/resources.py @@ -55,7 +55,7 @@ async def incremental_resource( ) -> common.Resource: model = await sage.get_model(obj) - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -63,7 +63,7 @@ def open( all_bindings, ): task.sourced_schema(binding_index, model.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, @@ -125,7 +125,7 @@ def open( async def snapshot_resource(sage: Sage, obj: str) -> common.Resource: model = await sage.get_model(obj) - def open( + async def open( obj: str, binding: CaptureBinding[ResourceConfig], binding_index: int, @@ -134,7 +134,7 @@ def open( all_bindings, ): task.sourced_schema(binding_index, model.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, diff --git a/source-salesforce-native/source_salesforce_native/resources.py b/source-salesforce-native/source_salesforce_native/resources.py index 0440fd331c..f39fc5ee87 100644 --- a/source-salesforce-native/source_salesforce_native/resources.py +++ b/source-salesforce-native/source_salesforce_native/resources.py @@ -74,7 +74,7 @@ def full_refresh_resource( enable: bool, ) -> common.Resource: - def open( + async def open( binding: CaptureBinding[SalesforceResourceConfigWithSchedule], binding_index: int, state: ResourceState, @@ -89,7 +89,7 @@ def open( model_cls = create_salesforce_model(name, fields) task.sourced_schema(binding_index, model_cls.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, @@ -137,7 +137,7 @@ def incremental_resource( enable: bool, ) -> common.Resource: - def open( + async def open( binding: CaptureBinding[SalesforceResourceConfigWithSchedule], binding_index: int, state: ResourceState, @@ -152,7 +152,7 @@ def open( model_cls = create_salesforce_model(name, fields) task.sourced_schema(binding_index, model_cls.sourced_schema()) - task.checkpoint(state=ConnectorState()) + await task.checkpoint(state=ConnectorState()) common.open_binding( binding, diff --git a/source-stripe-native/source_stripe_native/priority_capture.py b/source-stripe-native/source_stripe_native/priority_capture.py index cc8fc647d7..56a4287661 100644 --- a/source-stripe-native/source_stripe_native/priority_capture.py +++ b/source-stripe-native/source_stripe_native/priority_capture.py @@ -564,7 +564,7 @@ async def _binding_incremental_task_with_work_item( "incremental task triggered backfill", {"subtask_id": work_item.account_id} ) task.stopping.event.set() - task.checkpoint( + await task.checkpoint( ConnectorState(backfillRequests={binding.stateKey: True}) ) return @@ -593,7 +593,7 @@ async def _binding_incremental_task_with_work_item( ) state.cursor = item - task.checkpoint(connector_state) + await task.checkpoint(connector_state) checkpoints += 1 pending = False @@ -684,13 +684,13 @@ async def _binding_backfill_task_with_work_item( ) else: state.next_page = item - task.checkpoint(connector_state) + await task.checkpoint(connector_state) done = False if done: break - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={ binding.stateKey: ResourceState(backfill={work_item.account_id: None}) diff --git a/source-stripe-native/source_stripe_native/resources.py b/source-stripe-native/source_stripe_native/resources.py index 208c658956..30674570d6 100644 --- a/source-stripe-native/source_stripe_native/resources.py +++ b/source-stripe-native/source_stripe_native/resources.py @@ -105,7 +105,7 @@ async def _fetch_platform_account_id( return platform_account.id -def _reconcile_connector_state( +async def _reconcile_connector_state( account_ids: list[str], binding: CaptureBinding[ResourceConfig], state: ResourceState, @@ -146,7 +146,7 @@ def _reconcile_connector_state( task.log.info( f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}." ) - task.checkpoint( + await task.checkpoint( ConnectorState( bindingStateV1={binding.stateKey: state}, ) @@ -325,7 +325,7 @@ def base_object( It requires a single, parent stream with a valid Event API Type """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -357,7 +357,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -416,7 +416,7 @@ def child_object( a valid Event API Type """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -450,7 +450,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -515,7 +515,7 @@ def split_child_object( in the API response. Meaning, the stream behaves like a non-chid stream incrementally. """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -548,7 +548,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -611,7 +611,7 @@ def usage_records( and requires special processing. """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -645,7 +645,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task ) @@ -708,7 +708,7 @@ def no_events_object( It works very similar to the base object, but without the use of the Events APi. """ - def open( + async def open( binding: CaptureBinding[ResourceConfig], binding_index: int, state: ResourceState, @@ -740,7 +740,7 @@ def open( fetch_page=fetch_page_fns, ) else: - _reconcile_connector_state( + await _reconcile_connector_state( all_account_ids, binding, state, initial_state, task )