Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions estuary-cdk/estuary_cdk/capture/_emit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio
import shutil
from typing import BinaryIO

# Global lock for serializing all emissions to stdout and prevent interleaving output.
_emit_lock = asyncio.Lock()


async def emit_bytes(data: bytes, output: BinaryIO) -> None:
async with _emit_lock:
await asyncio.to_thread(_write_bytes, data, output)


async def emit_from_buffer(buffer: BinaryIO, output: BinaryIO) -> None:
async with _emit_lock:
await asyncio.to_thread(_copy_buffer, buffer, output)


def _write_bytes(data: bytes, output: BinaryIO) -> None:
output.write(data)
output.flush()


def _copy_buffer(buffer: BinaryIO, output: BinaryIO) -> None:
shutil.copyfileobj(buffer, output)
output.flush()
24 changes: 11 additions & 13 deletions estuary-cdk/estuary_cdk/capture/base_capture_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..logger import FlowLogger
from ..utils import format_error_message, sort_dict
from . import Request, Response, Task, request, response
from ._emit import emit_bytes
from .common import _ConnectorState


Expand Down Expand Up @@ -81,20 +82,20 @@ async def handle(
if spec := request.spec:
response = await self.spec(log, spec)
response.protocol = 3032023
self._emit(Response(spec=response))
await self._emit(Response(spec=response))

elif discover := request.discover:
self._emit(Response(discovered=await self.discover(log, discover)))
await self._emit(Response(discovered=await self.discover(log, discover)))

elif validate := request.validate_:
self._emit(Response(validated=await self.validate(log, validate)))
await self._emit(Response(validated=await self.validate(log, validate)))

elif apply := request.apply:
self._emit(Response(applied=await self.apply(log, apply)))
await self._emit(Response(applied=await self.apply(log, apply)))

elif open := request.open:
opened, capture = await self.open(log, open)
self._emit(Response(opened=opened))
await self._emit(Response(opened=opened))

stopping = Task.Stopping(asyncio.Event())

Expand Down Expand Up @@ -140,23 +141,20 @@ async def periodic_stop() -> None:
else:
raise RuntimeError("malformed request", request)

def _emit(
async def _emit(
self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState]
):
self.output.write(
response.model_dump_json(by_alias=True, exclude_unset=True).encode()
)
self.output.write(b"\n")
self.output.flush()
data = response.model_dump_json(by_alias=True, exclude_unset=True).encode()
await emit_bytes(data + b"\n", self.output)

def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True):
async def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True):
r = Response[Any, Any, GeneralConnectorState](
checkpoint=response.Checkpoint(
state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch)
)
)

self._emit(r)
await self._emit(r)

async def _encrypt_config(
self,
Expand Down
27 changes: 16 additions & 11 deletions estuary-cdk/estuary_cdk/capture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ class FixedSchema:
name: str
key: list[str]
model: type[_BaseDocument] | FixedSchema
# The open callback can be async or sync.
# Async is required when the callback needs to call task.checkpoint().
open: Callable[
[
CaptureBinding[_BaseResourceConfig],
Expand All @@ -464,7 +466,7 @@ class FixedSchema:
]
],
],
None,
None | Awaitable[None],
]
initial_state: _BaseResourceState
initial_config: _BaseResourceConfig
Expand Down Expand Up @@ -593,7 +595,7 @@ async def _run(task: Task):
{"stateKey": stateKey},
)
backfill_requests.append(stateKey)
task.checkpoint(
await task.checkpoint(
ConnectorState(
bindingStateV1={stateKey: None},
backfillRequests={stateKey: None},
Expand All @@ -614,7 +616,7 @@ async def _run(task: Task):
if state:
if state.last_initialized is None:
state.last_initialized = NOW
task.checkpoint(
await task.checkpoint(
ConnectorState(bindingStateV1={binding.stateKey: state})
)

Expand Down Expand Up @@ -665,19 +667,22 @@ async def _run(task: Task):
state.last_initialized = NOW

# Checkpoint the binding's initialized state prior to any processing.
task.checkpoint(
await task.checkpoint(
ConnectorState(
bindingStateV1={binding.stateKey: state},
)
)

resource.open(
result = resource.open(
binding,
index,
state,
task,
resolved_bindings,
)
# Support both sync and async open callbacks
if inspect.iscoroutine(result):
await result


if soonest_future_scheduled_initialization:
Expand Down Expand Up @@ -910,7 +915,7 @@ async def _binding_snapshot_task(
# Suppress all captured documents, as they're unchanged.
task.reset()

task.checkpoint(connector_state)
await task.checkpoint(connector_state)


async def _binding_backfill_task(
Expand Down Expand Up @@ -1004,22 +1009,22 @@ def _initialize_connector_state(state: ResourceState.Backfill) -> ConnectorState
state.next_page = item
state_to_checkpoint = connector_state

task.checkpoint(state_to_checkpoint)
await task.checkpoint(state_to_checkpoint)
done = False

if done:
break

if subtask_id is not None:
task.checkpoint(
await task.checkpoint(
ConnectorState(
bindingStateV1={
binding.stateKey: ResourceState(backfill={subtask_id: None})
}
)
)
else:
task.checkpoint(
await task.checkpoint(
ConnectorState(
bindingStateV1={binding.stateKey: ResourceState(backfill=None)}
)
Expand Down Expand Up @@ -1108,7 +1113,7 @@ async def _binding_incremental_task(
"incremental task triggered backfill", {"subtask_id": subtask_id}
)
task.stopping.event.set()
task.checkpoint(
await task.checkpoint(
ConnectorState(backfillRequests={binding.stateKey: True})
)
return
Expand Down Expand Up @@ -1137,7 +1142,7 @@ async def _binding_incremental_task(
)

state.cursor = item
task.checkpoint(connector_state)
await task.checkpoint(connector_state)
checkpoints += 1
pending = False

Expand Down
14 changes: 6 additions & 8 deletions estuary-cdk/estuary_cdk/capture/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import base64
from logging import Logger
import asyncio
import shutil
import tempfile
import traceback
import xxhash

from . import request, response
from ._emit import emit_from_buffer
from ..flow import (
ConnectorSpec,
ConnectorState,
Expand Down Expand Up @@ -159,10 +159,9 @@ def sourced_schema(self, binding_index: int, schema: dict[str, Any]):
self._buffer.write(b)
self._buffer.write(b"\n")

def checkpoint(self, state: ConnectorState, merge_patch: bool = True):
"""Emit previously-queued, captured documents follows by a checkpoint"""

self._emit(
async def checkpoint(self, state: ConnectorState, merge_patch: bool = True):
"""Emit previously-queued, captured documents followed by a checkpoint."""
await self._emit(
Response[Any, Any, ConnectorState](
checkpoint=response.Checkpoint(
state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch)
Expand Down Expand Up @@ -212,12 +211,11 @@ async def run_task(parent: Task):
task.set_name(child_name)
return task

def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]):
async def _emit(self, response: Response[EndpointConfig, ResourceConfig, ConnectorState]):
self._buffer.write(
response.model_dump_json(by_alias=True, exclude_unset=True).encode()
)
self._buffer.write(b"\n")
self._buffer.seek(0)
shutil.copyfileobj(self._buffer, self._output)
self._output.flush()
await emit_from_buffer(self._buffer, self._output)
self.reset()
6 changes: 3 additions & 3 deletions estuary-cdk/estuary_cdk/shim_airbyte_cdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ async def _run(
binding_index=binding_idx,
schema=schema,
)
task.checkpoint(state=ConnectorState())
await task.checkpoint(state=ConnectorState())

airbyte_catalog = ConfiguredAirbyteCatalog(streams=airbyte_streams)

Expand Down Expand Up @@ -462,7 +462,7 @@ async def _run(

entry[1].state = state_msg.dict()

task.checkpoint(connector_state, merge_patch=False)
await task.checkpoint(connector_state, merge_patch=False)

elif trace := message.trace:
if error := trace.error:
Expand Down Expand Up @@ -517,5 +517,5 @@ async def _run(

# Emit a final checkpoint before exiting.
task.log.info("Emitting final checkpoint for sweep.")
task.checkpoint(connector_state, merge_patch=False)
await task.checkpoint(connector_state, merge_patch=False)
return None
12 changes: 6 additions & 6 deletions source-apple-app-store/source_apple_app_store/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _create_initial_state(app_ids: list[str]) -> ResourceState:
return initial_state


def _reconcile_connector_state(
async def _reconcile_connector_state(
app_ids: list[str],
binding: CaptureBinding[ResourceConfig],
state: ResourceState,
Expand Down Expand Up @@ -109,7 +109,7 @@ def _reconcile_connector_state(
task.log.info(
f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}."
)
task.checkpoint(
await task.checkpoint(
ConnectorState(
bindingStateV1={binding.stateKey: state},
)
Expand All @@ -130,7 +130,7 @@ async def analytics_resources(

initial_state = _create_initial_state(app_ids)

def open(
async def open(
model: type[AppleAnalyticsRow],
binding: CaptureBinding[ResourceConfig],
binding_index: int,
Expand All @@ -155,7 +155,7 @@ def open(
model,
)

_reconcile_connector_state(app_ids, binding, state, initial_state, task)
await _reconcile_connector_state(app_ids, binding, state, initial_state, task)

open_binding(
binding,
Expand Down Expand Up @@ -194,7 +194,7 @@ async def api_resources(

initial_state = _create_initial_state(app_ids)

def open(
async def open(
app_ids: list[str],
fetch_changes_fn: ApiFetchChangesFn,
fetch_page_fn: ApiFetchPageFn,
Expand All @@ -219,7 +219,7 @@ def open(
app_id,
)

_reconcile_connector_state(app_ids, binding, state, initial_state, task)
await _reconcile_connector_state(app_ids, binding, state, initial_state, task)

open_binding(
binding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _create_initial_state(account_ids: str | list[str]) -> ResourceState:
)


def _reconcile_connector_state(
async def _reconcile_connector_state(
account_ids: list[str],
binding: CaptureBinding[ResourceConfig],
state: ResourceState,
Expand Down Expand Up @@ -176,7 +176,7 @@ def _reconcile_connector_state(
task.log.info(
f"Checkpointing state to ensure any new state is persisted for {binding.stateKey}."
)
task.checkpoint(
await task.checkpoint(
ConnectorState(
bindingStateV1={binding.stateKey: state},
)
Expand Down Expand Up @@ -280,7 +280,7 @@ def full_refresh_resource(
snapshot_fn: FullRefreshFetchFn,
use_sourced_schemas: bool = False,
) -> Resource:
def open(
async def open(
binding: CaptureBinding[ResourceConfig],
binding_index: int,
state: ResourceState,
Expand All @@ -291,7 +291,7 @@ def open(
):
if use_sourced_schemas and hasattr(model, 'sourced_schema'):
task.sourced_schema(binding_index, model.sourced_schema())
task.checkpoint(state=ConnectorState())
await task.checkpoint(state=ConnectorState())

open_binding(
binding,
Expand Down Expand Up @@ -366,7 +366,7 @@ def incremental_resource(
create_fetch_page_fn: Callable[[str], Callable],
create_fetch_changes_fn: Callable[[str], Callable],
) -> Resource:
def open(
async def open(
binding: CaptureBinding[ResourceConfig],
binding_index: int,
state: ResourceState,
Expand All @@ -375,7 +375,7 @@ def open(
):
assert len(accounts) > 0, "At least one account ID is required"

_reconcile_connector_state(accounts, binding, state, initial_state, task)
await _reconcile_connector_state(accounts, binding, state, initial_state, task)

fetch_page_fns = {}
fetch_changes_fns = {}
Expand Down
Loading
Loading