Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(file-mode-api): move file uploader to record selector level. #449

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,9 @@ def _group_streams(
# these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible,
# so we need to treat them as synchronous

file_uploader = None
if isinstance(declarative_stream, DeclarativeStream):
file_uploader = (
self._constructor.create_component(
model_type=FileUploader,
component_definition=name_to_stream_mapping[declarative_stream.name][
"file_uploader"
],
config=config,
)
if "file_uploader" in name_to_stream_mapping[declarative_stream.name]
else None
)
supports_file_transfer = (
"file_uploader" in name_to_stream_mapping[declarative_stream.name]
)

if (
isinstance(declarative_stream, DeclarativeStream)
Expand Down Expand Up @@ -288,7 +278,6 @@ def _group_streams(
declarative_stream.get_json_schema(),
retriever,
self.message_repository,
file_uploader,
),
stream_slicer=declarative_stream.retriever.stream_slicer,
)
Expand Down Expand Up @@ -319,7 +308,6 @@ def _group_streams(
declarative_stream.get_json_schema(),
retriever,
self.message_repository,
file_uploader,
),
stream_slicer=cursor,
)
Expand All @@ -339,7 +327,7 @@ def _group_streams(
else None,
logger=self.logger,
cursor=cursor,
supports_file_transfer=bool(file_uploader),
supports_file_transfer=supports_file_transfer,
)
)
elif (
Expand All @@ -351,7 +339,6 @@ def _group_streams(
declarative_stream.get_json_schema(),
declarative_stream.retriever,
self.message_repository,
file_uploader,
),
declarative_stream.retriever.stream_slicer,
)
Expand All @@ -372,7 +359,7 @@ def _group_streams(
cursor_field=None,
logger=self.logger,
cursor=final_state_cursor,
supports_file_transfer=bool(file_uploader),
supports_file_transfer=supports_file_transfer,
)
)
elif (
Expand Down Expand Up @@ -412,7 +399,6 @@ def _group_streams(
declarative_stream.get_json_schema(),
retriever,
self.message_repository,
file_uploader,
),
perpartition_cursor,
)
Expand All @@ -427,7 +413,7 @@ def _group_streams(
cursor_field=perpartition_cursor.cursor_field.cursor_field_key,
logger=self.logger,
cursor=perpartition_cursor,
supports_file_transfer=bool(file_uploader),
supports_file_transfer=supports_file_transfer,
)
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.sources.utils.transform import TypeTransformer
from airbyte_cdk.sources.declarative.retrievers.file_uploader import FileUploader


@dataclass
Expand All @@ -42,6 +43,7 @@ class RecordSelector(HttpSelector):
record_filter: Optional[RecordFilter] = None
transformations: List[RecordTransformation] = field(default_factory=lambda: [])
transform_before_filtering: bool = False
file_uploader: Optional[FileUploader] = None

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._parameters = parameters
Expand Down Expand Up @@ -117,7 +119,10 @@ def filter_and_transform(
transformed_filtered_data, schema=records_schema
)
for data in normalized_data:
yield Record(data=data, stream_name=self.name, associated_slice=stream_slice)
record = Record(data=data, stream_name=self.name, associated_slice=stream_slice)
if self.file_uploader:
self.file_uploader.upload(record)
yield record

def _normalize_by_schema(
self, records: Iterable[Mapping[str, Any]], schema: Optional[Mapping[str, Any]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,22 @@ class Config:
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


class FileUploader(BaseModel):
type: Literal["FileUploader"]
requester: Union[CustomRequester, HttpRequester] = Field(
...,
description="Requester component that describes how to prepare HTTP requests to send to the source API.",
)
download_target_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field(
...,
description="Responsible for fetching the url where the file is located. This is applied on each records and not on the HTTP response",
)
file_extractor: Optional[Union[CustomRecordExtractor, DpathExtractor]] = Field(
None,
description="Responsible for fetching the content of the file. If not defined, the assumption is that the whole response body is the file content",
)


class DeclarativeStream(BaseModel):
class Config:
extra = Extra.allow
Expand Down Expand Up @@ -2047,6 +2063,11 @@ class Config:
description="Array of state migrations to be applied on the input state",
title="State Migrations",
)
file_uploader: Optional[FileUploader] = Field(
None,
description="(experimental) Describes how to fetch a file",
title="File Uploader",
)
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


Expand Down Expand Up @@ -2278,22 +2299,6 @@ class StateDelegatingStream(BaseModel):
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


class FileUploader(BaseModel):
type: Literal["FileUploader"]
requester: Union[CustomRequester, HttpRequester] = Field(
...,
description="Requester component that describes how to prepare HTTP requests to send to the source API.",
)
download_target_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field(
...,
description="Responsible for fetching the url where the file is located. This is applied on each records and not on the HTTP response",
)
file_extractor: Optional[Union[CustomRecordExtractor, DpathExtractor]] = Field(
None,
description="Responsible for fetching the content of the file. If not defined, the assumption is that the whole response body is the file content",
)


class SimpleRetriever(BaseModel):
type: Literal["SimpleRetriever"]
record_selector: RecordSelector = Field(
Expand Down Expand Up @@ -2324,11 +2329,6 @@ class SimpleRetriever(BaseModel):
description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.",
title="Partition Router",
)
file_uploader: Optional[FileUploader] = Field(
None,
description="(experimental) Describes how to fetch a file",
title="File Uploader",
)
decoder: Optional[
Union[
CustomDecoder,
Expand Down Expand Up @@ -2485,6 +2485,7 @@ class DynamicDeclarativeStream(BaseModel):
DeclarativeSource1.update_forward_refs()
DeclarativeSource2.update_forward_refs()
SelectiveAuthenticator.update_forward_refs()
FileUploader.update_forward_refs()
DeclarativeStream.update_forward_refs()
SessionTokenAuthenticator.update_forward_refs()
DynamicSchemaLoader.update_forward_refs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,11 @@ def create_declarative_stream(
transformations.append(
self._create_component_from_model(model=transformation_model, config=config)
)
file_uploader = None
if model.file_uploader:
file_uploader = self._create_component_from_model(
model=model.file_uploader, config=config
)

retriever = self._create_component_from_model(
model=model.retriever,
Expand All @@ -1766,6 +1771,7 @@ def create_declarative_stream(
stop_condition_on_cursor=stop_condition_on_cursor,
client_side_incremental_sync=client_side_incremental_sync,
transformations=transformations,
file_uploader=file_uploader,
incremental_sync=model.incremental_sync,
)
cursor_field = model.incremental_sync.cursor_field if model.incremental_sync else None
Expand Down Expand Up @@ -2607,6 +2613,7 @@ def create_record_selector(
transformations: List[RecordTransformation] | None = None,
decoder: Decoder | None = None,
client_side_incremental_sync: Dict[str, Any] | None = None,
file_uploader: Optional[FileUploader] = None,
**kwargs: Any,
) -> RecordSelector:
extractor = self._create_component_from_model(
Expand Down Expand Up @@ -2644,6 +2651,7 @@ def create_record_selector(
config=config,
record_filter=record_filter,
transformations=transformations or [],
file_uploader=file_uploader,
schema_normalization=schema_normalization,
parameters=model.parameters or {},
transform_before_filtering=transform_before_filtering,
Expand Down Expand Up @@ -2701,6 +2709,7 @@ def create_simple_retriever(
stop_condition_on_cursor: bool = False,
client_side_incremental_sync: Optional[Dict[str, Any]] = None,
transformations: List[RecordTransformation],
file_uploader: Optional[FileUploader] = None,
incremental_sync: Optional[
Union[
IncrementingCountCursorModel, DatetimeBasedCursorModel, CustomIncrementalSyncModel
Expand All @@ -2723,6 +2732,7 @@ def create_simple_retriever(
decoder=decoder,
transformations=transformations,
client_side_incremental_sync=client_side_incremental_sync,
file_uploader=file_uploader,
)
url_base = (
model.requester.url_base
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
self._content_extractor = content_extractor

def upload(self, record: Record) -> None:
# TODO validate record shape - is the transformation applied at this point?
mocked_response = SafeResponse()
mocked_response.content = json.dumps(record.data).encode("utf-8")
download_target = list(self._download_target_extractor.extract_records(mocked_response))[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Iterable, Mapping, Optional

from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.declarative.retrievers.file_uploader import FileUploader
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
Expand All @@ -19,7 +18,6 @@ def __init__(
json_schema: Mapping[str, Any],
retriever: Retriever,
message_repository: MessageRepository,
file_uploader: Optional[FileUploader] = None,
) -> None:
"""
The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that our components are not
Expand All @@ -30,15 +28,13 @@ def __init__(
self._json_schema = json_schema
self._retriever = retriever
self._message_repository = message_repository
self._file_uploader = file_uploader

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
self._stream_name,
self._json_schema,
self._retriever,
self._message_repository,
self._file_uploader,
stream_slice,
)

Expand All @@ -50,14 +46,12 @@ def __init__(
json_schema: Mapping[str, Any],
retriever: Retriever,
message_repository: MessageRepository,
file_uploader: Optional[FileUploader],
stream_slice: StreamSlice,
):
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever
self._message_repository = message_repository
self._file_uploader = file_uploader
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)

Expand All @@ -73,8 +67,6 @@ def read(self) -> Iterable[Record]:
associated_slice=self._stream_slice,
)
)
if self._file_uploader:
self._file_uploader.upload(record)
yield record
else:
self._message_repository.emit_message(stream_data)
Expand Down
Loading