Skip to content

Commit ee537af

Browse files
authored
feat: use create_concurrent_cursor_from_perpartition_cursor (#286)
Signed-off-by: Artem Inzhyyants <[email protected]>
1 parent dea2cc9 commit ee537af

File tree

8 files changed

+55
-38
lines changed

8 files changed

+55
-38
lines changed

airbyte_cdk/sources/declarative/async_job/job_orchestrator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,16 @@ def _is_breaking_exception(self, exception: Exception) -> bool:
482482
and exception.failure_type == FailureType.config_error
483483
)
484484

485-
def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
485+
def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]:
486486
"""
487-
Fetches records from the given partition's jobs.
487+
Fetches records from the given jobs.
488488
489489
Args:
490-
partition (AsyncPartition): The partition containing the jobs.
490+
async_jobs Iterable[AsyncJob]: The list of AsyncJobs.
491491
492492
Yields:
493493
Iterable[Mapping[str, Any]]: The fetched records from the jobs.
494494
"""
495-
for job in partition.jobs:
495+
for job in async_jobs:
496496
yield from self._job_repository.fetch_records(job)
497497
self._job_repository.delete(job)

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from airbyte_cdk.sources.declarative.extractors.record_filter import (
2020
ClientSideIncrementalRecordFilterDecorator,
2121
)
22+
from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor
2223
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
2324
from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import (
2425
PerPartitionWithGlobalCursor,
@@ -231,7 +232,7 @@ def _group_streams(
231232
):
232233
cursor = declarative_stream.retriever.stream_slicer.stream_slicer
233234

234-
if not isinstance(cursor, ConcurrentCursor):
235+
if not isinstance(cursor, ConcurrentCursor | ConcurrentPerPartitionCursor):
235236
# This should never happen since we instantiate ConcurrentCursor in
236237
# model_to_component_factory.py
237238
raise ValueError(

airbyte_cdk/sources/declarative/declarative_stream.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def read_records(
138138
"""
139139
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
140140
"""
141-
if stream_slice is None or stream_slice == {}:
141+
if stream_slice is None or (
142+
not isinstance(stream_slice, StreamSlice) and stream_slice == {}
143+
):
142144
# As the parameter is Optional, many would just call `read_records(sync_mode)` during testing without specifying the field
143145
# As part of the declarative model without custom components, this should never happen as the CDK would wire up a
144146
# SinglePartitionRouter that would create this StreamSlice properly

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1656,7 +1656,7 @@ def _build_stream_slicer_from_partition_router(
16561656
) -> Optional[PartitionRouter]:
16571657
if (
16581658
hasattr(model, "partition_router")
1659-
and isinstance(model, SimpleRetrieverModel)
1659+
and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel)
16601660
and model.partition_router
16611661
):
16621662
stream_slicer_model = model.partition_router
@@ -1690,6 +1690,31 @@ def _merge_stream_slicers(
16901690
stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config)
16911691

16921692
if model.incremental_sync and stream_slicer:
1693+
if model.retriever.type == "AsyncRetriever":
1694+
if model.incremental_sync.type != "DatetimeBasedCursor":
1695+
# We are currently in a transition to the Concurrent CDK and AsyncRetriever can only work with the support or unordered slices (for example, when we trigger reports for January and February, the report in February can be completed first). Once we have support for custom concurrent cursor or have a new implementation available in the CDK, we can enable more cursors here.
1696+
raise ValueError(
1697+
"AsyncRetriever with cursor other than DatetimeBasedCursor is not supported yet"
1698+
)
1699+
if stream_slicer:
1700+
return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing
1701+
state_manager=self._connector_state_manager,
1702+
model_type=DatetimeBasedCursorModel,
1703+
component_definition=model.incremental_sync.__dict__,
1704+
stream_name=model.name or "",
1705+
stream_namespace=None,
1706+
config=config or {},
1707+
stream_state={},
1708+
partition_router=stream_slicer,
1709+
)
1710+
return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing
1711+
model_type=DatetimeBasedCursorModel,
1712+
component_definition=model.incremental_sync.__dict__,
1713+
stream_name=model.name or "",
1714+
stream_namespace=None,
1715+
config=config or {},
1716+
)
1717+
16931718
incremental_sync_model = model.incremental_sync
16941719
if (
16951720
hasattr(incremental_sync_model, "global_substream_cursor")

airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from typing import Any, Callable, Iterable, Mapping, Optional
55

66
from airbyte_cdk.models import FailureType
7+
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
78
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
89
AsyncJobOrchestrator,
9-
AsyncPartition,
1010
)
1111
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
1212
SinglePartitionRouter,
@@ -42,12 +42,12 @@ def stream_slices(self) -> Iterable[StreamSlice]:
4242

4343
for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
4444
yield StreamSlice(
45-
partition=dict(completed_partition.stream_slice.partition)
46-
| {"partition": completed_partition},
45+
partition=dict(completed_partition.stream_slice.partition),
4746
cursor_slice=completed_partition.stream_slice.cursor_slice,
47+
extra_fields={"jobs": list(completed_partition.jobs)},
4848
)
4949

50-
def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
50+
def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]:
5151
"""
5252
This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should
5353
be responsible for. However, this was added in because the JobOrchestrator is required to
@@ -62,4 +62,4 @@ def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]
6262
failure_type=FailureType.system_error,
6363
)
6464

65-
return self._job_orchestrator.fetch_records(partition=partition)
65+
return self._job_orchestrator.fetch_records(async_jobs=async_jobs)

airbyte_cdk/sources/declarative/retrievers/async_retriever.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing_extensions import deprecated
88

9-
from airbyte_cdk.models import FailureType
9+
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
1010
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition
1111
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
1212
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
@@ -16,7 +16,6 @@
1616
from airbyte_cdk.sources.source import ExperimentalClassWarning
1717
from airbyte_cdk.sources.streams.core import StreamData
1818
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
19-
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
2019

2120

2221
@deprecated(
@@ -57,9 +56,9 @@ def _get_stream_state(self) -> StreamState:
5756

5857
return self.state
5958

60-
def _validate_and_get_stream_slice_partition(
59+
def _validate_and_get_stream_slice_jobs(
6160
self, stream_slice: Optional[StreamSlice] = None
62-
) -> AsyncPartition:
61+
) -> Iterable[AsyncJob]:
6362
"""
6463
Validates the stream_slice argument and returns the partition from it.
6564
@@ -73,12 +72,7 @@ def _validate_and_get_stream_slice_partition(
7372
AirbyteTracedException: If the stream_slice is not an instance of StreamSlice or if the partition is not present in the stream_slice.
7473
7574
"""
76-
if not isinstance(stream_slice, StreamSlice) or "partition" not in stream_slice.partition:
77-
raise AirbyteTracedException(
78-
message="Invalid arguments to AsyncRetriever.read_records: stream_slice is not optional. Please contact Airbyte Support",
79-
failure_type=FailureType.system_error,
80-
)
81-
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices
75+
return stream_slice.extra_fields.get("jobs", []) if stream_slice else []
8276

8377
def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
8478
return self.stream_slicer.stream_slices()
@@ -89,8 +83,8 @@ def read_records(
8983
stream_slice: Optional[StreamSlice] = None,
9084
) -> Iterable[StreamData]:
9185
stream_state: StreamState = self._get_stream_state()
92-
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
93-
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(partition)
86+
jobs: Iterable[AsyncJob] = self._validate_and_get_stream_slice_jobs(stream_slice)
87+
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(jobs)
9488

9589
yield from self.record_selector.filter_and_transform(
9690
all_data=records,

unit_tests/sources/declarative/async_job/test_job_orchestrator.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,8 @@ def test_when_fetch_records_then_yield_records_from_each_job(self) -> None:
174174
orchestrator = self._orchestrator([_A_STREAM_SLICE])
175175
first_job = _create_job()
176176
second_job = _create_job()
177-
partition = AsyncPartition([first_job, second_job], _A_STREAM_SLICE)
178177

179-
records = list(orchestrator.fetch_records(partition))
178+
records = list(orchestrator.fetch_records([first_job, second_job]))
180179

181180
assert len(records) == 2
182181
assert self._job_repository.fetch_records.mock_calls == [call(first_job), call(second_job)]

unit_tests/sources/declarative/partition_routers/test_async_job_partition_router.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def test_stream_slices_with_single_partition_router():
3535

3636
slices = list(partition_router.stream_slices())
3737
assert len(slices) == 1
38-
partition = slices[0].partition.get("partition")
39-
assert isinstance(partition, AsyncPartition)
40-
assert partition.stream_slice == StreamSlice(partition={}, cursor_slice={})
41-
assert partition.status == AsyncJobStatus.COMPLETED
38+
partition = slices[0]
39+
assert isinstance(partition, StreamSlice)
40+
assert partition == StreamSlice(partition={}, cursor_slice={})
41+
assert partition.extra_fields["jobs"][0].status() == AsyncJobStatus.COMPLETED
4242

43-
attempts_per_job = list(partition.jobs)
43+
attempts_per_job = list(partition.extra_fields["jobs"])
4444
assert len(attempts_per_job) == 1
4545
assert attempts_per_job[0].api_job_id() == "a_job_id"
4646
assert attempts_per_job[0].job_parameters() == StreamSlice(partition={}, cursor_slice={})
@@ -68,14 +68,10 @@ def test_stream_slices_with_parent_slicer():
6868
slices = list(partition_router.stream_slices())
6969
assert len(slices) == 3
7070
for i, partition in enumerate(slices):
71-
partition = partition.partition.get("partition")
72-
assert isinstance(partition, AsyncPartition)
73-
assert partition.stream_slice == StreamSlice(
74-
partition={"parent_id": str(i)}, cursor_slice={}
75-
)
76-
assert partition.status == AsyncJobStatus.COMPLETED
71+
assert isinstance(partition, StreamSlice)
72+
assert partition == StreamSlice(partition={"parent_id": str(i)}, cursor_slice={})
7773

78-
attempts_per_job = list(partition.jobs)
74+
attempts_per_job = list(partition.extra_fields["jobs"])
7975
assert len(attempts_per_job) == 1
8076
assert attempts_per_job[0].api_job_id() == "a_job_id"
8177
assert attempts_per_job[0].job_parameters() == StreamSlice(

0 commit comments

Comments
 (0)