diff --git a/airbyte/__init__.py b/airbyte/__init__.py index 4d453732..3a1e1b01 100644 --- a/airbyte/__init__.py +++ b/airbyte/__init__.py @@ -131,6 +131,7 @@ from airbyte.datasets import CachedDataset from airbyte.destinations.base import Destination from airbyte.destinations.util import get_destination +from airbyte.lakes import GCSLakeStorage, LakeStorage, S3LakeStorage from airbyte.records import StreamRecord from airbyte.results import ReadResult, WriteResult from airbyte.secrets import SecretSourceEnum, get_secret @@ -154,6 +155,7 @@ documents, exceptions, # noqa: ICN001 # No 'exc' alias for top-level module experimental, + lakes, logs, mcp, records, @@ -175,6 +177,7 @@ "documents", "exceptions", "experimental", + "lakes", "logs", "mcp", "records", @@ -195,7 +198,10 @@ "CachedDataset", "Destination", "DuckDBCache", + "GCSLakeStorage", + "LakeStorage", "ReadResult", + "S3LakeStorage", "SecretSourceEnum", "Source", "StreamRecord", diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index 12d4a3ad..4e70ebc4 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -12,9 +12,8 @@ from pydantic import Field, PrivateAttr from sqlalchemy import text -from airbyte_protocol.models import ConfiguredAirbyteCatalog - from airbyte import constants +from airbyte._util.text_util import generate_ulid from airbyte._writers.base import AirbyteWriterInterface from airbyte.caches._catalog_backend import CatalogBackendBase, SqlCatalogBackend from airbyte.caches._state_backend import SqlStateBackend @@ -23,6 +22,7 @@ from airbyte.shared.catalog_providers import CatalogProvider from airbyte.shared.sql_processor import SqlConfig from airbyte.shared.state_writers import StdOutStateWriter +from airbyte_protocol.models import ConfiguredAirbyteCatalog if TYPE_CHECKING: @@ -30,6 +30,7 @@ from airbyte._message_iterators import AirbyteMessageIterator from airbyte.caches._state_backend_base import StateBackendBase + from airbyte.lakes import FastLoadResult, FastUnloadResult, LakeStorage from airbyte.progress import ProgressTracker from airbyte.shared.sql_processor import SqlProcessorBase from airbyte.shared.state_providers import StateProviderBase @@ -38,7 +39,10 @@ from airbyte.strategies import WriteStrategy -class CacheBase(SqlConfig, AirbyteWriterInterface): +DEFAULT_LAKE_STORE_OUTPUT_PREFIX: str = "airbyte/lake/output/{stream_name}/batch-{batch_id}/" + + +class CacheBase(SqlConfig, AirbyteWriterInterface): # noqa: PLR0904 """Base configuration for a cache. Caches inherit from the matching `SqlConfig` class, which provides the SQL config settings @@ -74,6 +78,7 @@ def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # "configuration." ) + @final def __init__(self, **data: Any) -> None: # noqa: ANN401 """Initialize the cache and backends.""" super().__init__(**data) @@ -107,6 +112,7 @@ def __init__(self, **data: Any) -> None: # noqa: ANN401 temp_file_cleanup=self.cleanup, ) + @final @property def config_hash(self) -> str | None: """Return a hash of the cache configuration. @@ -115,6 +121,7 @@ def config_hash(self) -> str | None: """ return super(SqlConfig, self).config_hash + @final def execute_sql(self, sql: str | list[str]) -> None: """Execute one or more SQL statements against the cache's SQL backend. @@ -145,6 +152,7 @@ def processor(self) -> SqlProcessorBase: """Return the SQL processor instance.""" return self._read_processor + @final def get_record_processor( self, source_name: str, @@ -178,6 +186,7 @@ def get_record_processor( # Read methods: + @final def get_records( self, stream_name: str, @@ -251,6 +260,7 @@ def __bool__(self) -> bool: """ return True + @final def get_state_provider( self, source_name: str, @@ -266,6 +276,7 @@ def get_state_provider( destination_name=destination_name, ) + @final def get_state_writer( self, source_name: str, @@ -281,6 +292,7 @@ def get_state_writer( destination_name=destination_name, ) + @final def register_source( self, source_name: str, @@ -294,6 +306,7 @@ def register_source( incoming_stream_names=stream_names, ) + @final def create_source_tables( self, source: Source, @@ -330,20 +343,24 @@ def create_source_tables( create_if_missing=True, ) + @final def __getitem__(self, stream: str) -> CachedDataset: """Return a dataset by stream name.""" return self.streams[stream] + @final def __contains__(self, stream: str) -> bool: """Return whether a stream is in the cache.""" return stream in (self._catalog_backend.stream_names) + @final def __iter__( # type: ignore [override] # Overriding Pydantic model method self, ) -> Iterator[tuple[str, Any]]: """Iterate over the streams in the cache.""" return ((name, dataset) for name, dataset in self.streams.items()) + @final def _write_airbyte_message_stream( self, stdin: IO[str] | AirbyteMessageIterator, @@ -365,3 +382,232 @@ def _write_airbyte_message_stream( progress_tracker=progress_tracker, ) progress_tracker.log_cache_processing_complete() + + @final + def _resolve_lake_store_path( + self, + lake_store_prefix: str, + stream_name: str | None = None, + batch_id: str | None = None, + ) -> str: + """Resolve the lake path prefix. + + The string is interpolated with "{stream_name}" and "{batch_id}" if requested. + + If `stream_name` is requested but not provided, it raises a ValueError. + If `batch_id` is requested but not provided, it defaults to a generated ULID. + """ + if lake_store_prefix is None: + raise ValueError( + "lake_store_prefix must be provided. Use DEFAULT_LAKE_STORE_OUTPUT_PREFIX if needed." + ) + + if "{stream_name}" in lake_store_prefix: + if stream_name is not None: + lake_store_prefix = lake_store_prefix.format(stream_name=stream_name) + else: + raise ValueError( + "stream_name must be provided when lake_store_prefix contains {stream_name}." + ) + + if "{batch_id}" in lake_store_prefix: + batch_id = batch_id or generate_ulid() + lake_store_prefix = lake_store_prefix.format( + batch_id=batch_id, + ) + + return lake_store_prefix + + @final + def fast_unload_streams( + self, + lake_store: LakeStorage, + *, + lake_store_prefix: str = DEFAULT_LAKE_STORE_OUTPUT_PREFIX, + streams: list[str] | Literal["*"] | None = None, + ) -> list[FastUnloadResult]: + """Unload the cache to a lake store. + + We dump data directly to parquet files in the lake store. + + Args: + streams: The streams to unload. If None, unload all streams. + lake_store: The lake store to unload to. If None, use the default lake store. + """ + stream_names: list[str] + if streams == "*" or streams is None: + stream_names = self._catalog_backend.stream_names + elif isinstance(streams, list): + stream_names = streams + else: + raise ValueError( + f"Invalid streams argument: {streams}. Must be '*' or a list of stream names." + ) + + return [ + self.fast_unload_stream( + stream_name=stream_name, + lake_store=lake_store, + lake_store_prefix=lake_store_prefix, + ) + for stream_name in stream_names + ] + + @final + def fast_unload_stream( + self, + lake_store: LakeStorage, + *, + lake_store_prefix: str = DEFAULT_LAKE_STORE_OUTPUT_PREFIX, + stream_name: str, + **kwargs, + ) -> FastUnloadResult: + """Unload a single stream to the lake store. + + This generic implementation delegates to `fast_unload_table()` + which subclasses should override for database-specific fast operations. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + """ + sql_table = self.streams[stream_name].to_sql_table() + table_name = sql_table.name + + # Raises NotImplementedError if subclass does not implement this method: + return self.fast_unload_table( + lake_store=lake_store, + lake_store_prefix=lake_store_prefix, + stream_name=stream_name, + table_name=table_name, + **kwargs, + ) + + def fast_unload_table( + self, + table_name: str, + lake_store: LakeStorage, + *, + lake_store_prefix: str = DEFAULT_LAKE_STORE_OUTPUT_PREFIX, + db_name: str | None = None, + schema_name: str | None = None, + stream_name: str | None = None, + ) -> FastUnloadResult: + """Fast-unload a specific table to the designated lake storage. + + Subclasses should override this method to implement fast unloads. + + Subclasses should also ensure that the `lake_store_prefix` is resolved + using the `_resolve_lake_store_path` method. E.g.: + ```python + lake_store_prefix = self._resolve_lake_store_path( + lake_store_prefix=lake_store_prefix, + stream_name=stream_name, + ) + ``` + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + """ + raise NotImplementedError + + @final + def fast_load_streams( + self, + lake_store: LakeStorage, + *, + lake_store_prefix: str, + streams: list[str], + zero_copy: bool = False, + ) -> None: + """Unload the cache to a lake store. + + We dump data directly to parquet files in the lake store. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + """ + for stream_name in streams: + self.fast_load_stream( + stream_name=stream_name, + lake_store=lake_store, + lake_store_prefix=lake_store_prefix or stream_name, + zero_copy=zero_copy, + ) + + @final + def fast_load_stream( + self, + lake_store: LakeStorage, + *, + stream_name: str, + lake_store_prefix: str, + zero_copy: bool = False, + ) -> FastLoadResult: + """Load a single stream from the lake store using fast native LOAD operations. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + """ + sql_table = self.streams[stream_name].to_sql_table() + table_name = sql_table.name + + if zero_copy: + raise NotImplementedError("Zero-copy loading is not yet supported in Snowflake.") + + return self.fast_load_table( + table_name=table_name, + lake_store=lake_store, + lake_store_prefix=lake_store_prefix, + zero_copy=zero_copy, + ) + + def fast_load_table( + self, + table_name: str, + lake_store: LakeStorage, + lake_store_prefix: str, + *, + db_name: str | None = None, + schema_name: str | None = None, + zero_copy: bool = False, + ) -> FastLoadResult: + """Fast-load a specific table from the designated lake storage. + + Subclasses should override this method to implement fast loads. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + """ + raise NotImplementedError + + @final + def fast_load_stream_from_unload_result( + self, + stream_name: str, + unload_result: FastUnloadResult, + *, + zero_copy: bool = False, + ) -> FastLoadResult: + """Load the result of a fast unload operation.""" + return self.fast_load_stream( + stream_name=stream_name, + lake_store=unload_result.lake_store, + lake_store_prefix=unload_result.lake_store_prefix, + zero_copy=zero_copy, + ) + + @final + def fast_load_table_from_unload_result( + self, + table_name: str, + unload_result: FastUnloadResult, + *, + zero_copy: bool = False, + ) -> FastLoadResult: + """Load the result of a fast unload operation.""" + return self.fast_load_table( + table_name=table_name, + lake_store=unload_result.lake_store, + lake_store_prefix=unload_result.lake_store_prefix, + zero_copy=zero_copy, + ) diff --git a/airbyte/caches/bigquery.py b/airbyte/caches/bigquery.py index a6aaf71e..8b9d8282 100644 --- a/airbyte/caches/bigquery.py +++ b/airbyte/caches/bigquery.py @@ -20,18 +20,22 @@ from typing import TYPE_CHECKING, ClassVar, NoReturn from airbyte_api.models import DestinationBigquery +from typing_extensions import override from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor from airbyte.caches.base import ( + DEFAULT_LAKE_STORE_OUTPUT_PREFIX, CacheBase, ) from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE from airbyte.destinations._translate_cache_to_dest import ( bigquery_cache_to_destination_configuration, ) +from airbyte.lakes import FastLoadResult, FastUnloadResult, GCSLakeStorage if TYPE_CHECKING: + from airbyte.lakes import LakeStorage from airbyte.shared.sql_processor import SqlProcessorBase @@ -63,6 +67,97 @@ def get_arrow_dataset( "Please consider using a different cache implementation for these functionalities." ) + @override + def fast_unload_table( + self, + table_name: str, + lake_store: LakeStorage, + *, + lake_store_prefix: str = DEFAULT_LAKE_STORE_OUTPUT_PREFIX, + db_name: str | None = None, + schema_name: str | None = None, + stream_name: str | None = None, + **_kwargs, + ) -> FastUnloadResult: + """Unload an arbitrary table to the lake store using BigQuery EXPORT DATA. + + This implementation uses BigQuery's native EXPORT DATA functionality + to write directly to GCS, bypassing the Arrow dataset limitation. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + """ + if db_name is not None and schema_name is None: + raise ValueError("If db_name is provided, schema_name must also be provided.") + + if not isinstance(lake_store, GCSLakeStorage): + raise NotImplementedError("BigQuery unload currently only supports GCS lake storage") + + resolved_lake_store_prefix = self._resolve_lake_store_path( + lake_store_prefix=lake_store_prefix, + stream_name=stream_name or table_name, + ) + + if db_name is not None and schema_name is not None: + qualified_table_name = f"{db_name}.{schema_name}.{table_name}" + elif schema_name is not None: + qualified_table_name = f"{schema_name}.{table_name}" + else: + qualified_table_name = f"{self._read_processor.sql_config.schema_name}.{table_name}" + + export_uri = f"{lake_store.root_storage_uri}{resolved_lake_store_prefix}/*.parquet" + + export_statement = f""" + EXPORT DATA OPTIONS( + uri='{export_uri}', + format='PARQUET', + overwrite=true + ) AS + SELECT * FROM {qualified_table_name} + """ + + self.execute_sql(export_statement) + return FastUnloadResult( + lake_store=lake_store, + lake_store_prefix=resolved_lake_store_prefix, + table_name=table_name, + stream_name=stream_name, + ) + + @override + def fast_load_table( + self, + table_name: str, + lake_store: LakeStorage, + lake_store_prefix: str, + *, + db_name: str | None = None, + schema_name: str | None = None, + zero_copy: bool = False, + ) -> FastLoadResult: + """Load a single stream from the lake store using BigQuery LOAD DATA. + + This implementation uses BigQuery's native LOAD DATA functionality + to read directly from GCS, bypassing the Arrow dataset limitation. + """ + sql_table = self.streams[stream_name].to_sql_table() + table_name = sql_table.name + + if not hasattr(lake_store, "bucket_name"): + raise NotImplementedError("BigQuery load currently only supports GCS lake storage") + + source_uri = f"{lake_store.get_stream_root_uri(stream_name)}*.parquet" + + load_statement = f""" + LOAD DATA INTO {self._read_processor.sql_config.schema_name}.{table_name} + FROM FILES ( + format = 'PARQUET', + uris = ['{source_uri}'] + ) + """ + + self.execute_sql(load_statement) + # Expose the Cache class and also the Config class. __all__ = [ diff --git a/airbyte/caches/snowflake.py b/airbyte/caches/snowflake.py index 2bf5485c..b1164f9b 100644 --- a/airbyte/caches/snowflake.py +++ b/airbyte/caches/snowflake.py @@ -1,4 +1,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from airbyte.lakes import S3LakeStorage + + """A Snowflake implementation of the PyAirbyte cache. ## Usage Example @@ -59,18 +62,25 @@ from __future__ import annotations -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar from airbyte_api.models import DestinationSnowflake +from sqlalchemy import text +from typing_extensions import override from airbyte._processors.sql.snowflake import SnowflakeConfig, SnowflakeSqlProcessor -from airbyte.caches.base import CacheBase +from airbyte.caches.base import DEFAULT_LAKE_STORE_OUTPUT_PREFIX, CacheBase from airbyte.destinations._translate_cache_to_dest import ( snowflake_cache_to_destination_configuration, ) +from airbyte.lakes import FastLoadResult, FastUnloadResult from airbyte.shared.sql_processor import RecordDedupeMode, SqlProcessorBase +if TYPE_CHECKING: + from airbyte.lakes import LakeStorage + + class SnowflakeCache(SnowflakeConfig, CacheBase): """Configuration for the Snowflake cache.""" @@ -86,6 +96,229 @@ def paired_destination_config(self) -> DestinationSnowflake: """Return a dictionary of destination configuration values.""" return snowflake_cache_to_destination_configuration(cache=self) + def _get_lake_artifact_prefix(self, lake_store: LakeStorage) -> str: + """Get the artifact prefix for this lake storage.""" + return f"AIRBYTE_LAKE_{lake_store.short_name.upper()}_" + + def _get_lake_file_format_name(self, lake_store: LakeStorage) -> str: + """Get the file_format name.""" + artifact_prefix = self._get_lake_artifact_prefix(lake_store) + return f"{artifact_prefix}PARQUET_FORMAT" + + def _get_lake_stage_name(self, lake_store: LakeStorage) -> str: + """Get the stage name.""" + artifact_prefix = self._get_lake_artifact_prefix(lake_store) + return f"{artifact_prefix}STAGE" + + def _setup_lake_artifacts( + self, + lake_store: LakeStorage, + ) -> None: + if not isinstance(lake_store, S3LakeStorage): + raise NotImplementedError( + "Snowflake lake operations currently only support S3 lake storage" + ) + + qualified_prefix = ( + f"{self.database}.{self.schema_name}" if self.database else self.schema_name + ) + file_format_name = self._get_lake_file_format_name(lake_store) + stage_name = self._get_lake_stage_name(lake_store) + + create_format_sql = f""" + CREATE FILE FORMAT IF NOT EXISTS {qualified_prefix}.{file_format_name} + TYPE = PARQUET + COMPRESSION = SNAPPY + """ + self.execute_sql(create_format_sql) + + create_stage_sql = f""" + CREATE STAGE IF NOT EXISTS {qualified_prefix}.{stage_name} + URL = '{lake_store.root_storage_uri}' + CREDENTIALS = ( + AWS_KEY_ID = '{lake_store.aws_access_key_id}' + AWS_SECRET_KEY = '{lake_store.aws_secret_access_key}' + ) + FILE_FORMAT = {qualified_prefix}.{file_format_name} + """ + self.execute_sql(create_stage_sql) + + @override + def fast_unload_table( + self, + table_name: str, + lake_store: LakeStorage, + *, + lake_store_prefix: str = DEFAULT_LAKE_STORE_OUTPUT_PREFIX, + db_name: str | None = None, + schema_name: str | None = None, + stream_name: str | None = None, + ) -> FastUnloadResult: + """Unload an arbitrary table to the lake store using Snowflake COPY INTO. + + This implementation uses Snowflake's COPY INTO command to unload data + directly to S3 in Parquet format with managed artifacts for optimal performance. + Unlike fast_unload_stream(), this method works with any table and doesn't + require a stream mapping. + + Uses connection context manager to capture rich unload results including + actual record counts, file counts, and data size information from Snowflake's + COPY INTO command metadata. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + + Raises: + ValueError: If db_name is provided but schema_name is not. + """ + if db_name is not None and schema_name is None: + raise ValueError("If db_name is provided, schema_name must also be provided.") + + qualified_prefix = ( + f"{self.database}.{self.schema_name}" if self.database else self.schema_name + ) + file_format_name = self._get_lake_file_format_name(lake_store) + stage_name = self._get_lake_stage_name(lake_store) + + if db_name is not None and schema_name is not None: + qualified_table_name = f"{db_name}.{schema_name}.{table_name}" + elif schema_name is not None: + qualified_table_name = f"{self.database}.{schema_name}.{table_name}" + else: + qualified_table_name = f"{self.database}.{self.schema_name}.{table_name}" + + self._setup_lake_artifacts(lake_store) + + unload_statement = f""" + COPY INTO @{qualified_prefix}.{stage_name}/{lake_store_prefix}/ + FROM {qualified_table_name} + FILE_FORMAT = {qualified_prefix}.{file_format_name} + OVERWRITE = TRUE + """ + + with self.processor.get_sql_connection() as connection: + connection.execute(text(unload_statement)) + + result_scan_query = "SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()))" + result_scan_result = connection.execute(text(result_scan_query)) + + metadata_row = result_scan_result.fetchone() + + total_data_size_bytes = None + compressed_size_bytes = None + file_manifest = [] + + if metadata_row: + row_dict = ( + dict(metadata_row._mapping) # noqa: SLF001 + if hasattr(metadata_row, "_mapping") + else dict(metadata_row) + ) + file_manifest.append(row_dict) + + record_count = row_dict.get("rows_unloaded") + total_data_size_bytes = row_dict.get("input_bytes") + compressed_size_bytes = row_dict.get("output_bytes") + + return FastUnloadResult( + stream_name=stream_name, + table_name=table_name, + lake_store=lake_store, + lake_store_prefix=lake_store_prefix, + record_count=record_count, + total_data_size_bytes=total_data_size_bytes, + compressed_size_bytes=compressed_size_bytes, + file_manifest=file_manifest, + ) + + @override + def fast_load_table( + self, + table_name: str, + lake_store: LakeStorage, + lake_store_prefix: str = DEFAULT_LAKE_STORE_OUTPUT_PREFIX, + *, + db_name: str | None = None, + schema_name: str | None = None, + zero_copy: bool = False, + ) -> FastLoadResult: + """Load a single stream from the lake store using Snowflake COPY INTO. + + This implementation uses Snowflake's COPY INTO command to load data + directly from S3 in Parquet format with managed artifacts for optimal performance. + + The `lake_store_prefix` arg can be interpolated with {stream_name} to create a unique path + for each stream. + + Uses connection context manager to capture rich load results including + actual record counts, file counts, and data size information from Snowflake's + COPY INTO command metadata. + """ + if zero_copy: + raise NotImplementedError("Zero-copy loading is not yet supported in Snowflake.") + + if db_name is not None and schema_name is None: + raise ValueError("If db_name is provided, schema_name must also be provided.") + + qualified_prefix = ( + f"{self.database}.{self.schema_name}" if self.database else self.schema_name + ) + file_format_name = self._get_lake_file_format_name(lake_store) + stage_name = self._get_lake_stage_name(lake_store) + + if db_name is not None and schema_name is not None: + qualified_table_name = f"{db_name}.{schema_name}.{table_name}" + elif schema_name is not None: + qualified_table_name = f"{self.database}.{schema_name}.{table_name}" + else: + qualified_table_name = f"{self.database}.{self.schema_name}.{table_name}" + + self._setup_lake_artifacts(lake_store) + + load_statement = f""" + COPY INTO {qualified_table_name} + FROM @{qualified_prefix}.{stage_name}/{lake_store_prefix}/ + FILE_FORMAT = {qualified_prefix}.{file_format_name} + MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE + PURGE = FALSE + """ + + with self.processor.get_sql_connection() as connection: + connection.execute(text(load_statement)) + + result_scan_query = "SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()))" + result_scan_result = connection.execute(text(result_scan_query)) + + record_count = None + total_data_size_bytes = None + compressed_size_bytes = None + file_manifest = [] + + rows = result_scan_result.fetchall() + if rows: + for row in rows: + row_dict = ( + dict(row._mapping) # noqa: SLF001 + if hasattr(row, "_mapping") + else dict(row) + ) + file_manifest.append(row_dict) + + first_row = file_manifest[0] if file_manifest else {} + record_count = first_row.get("rows_loaded") or first_row.get("rows_parsed") + total_data_size_bytes = first_row.get("input_bytes") + compressed_size_bytes = first_row.get("output_bytes") + + return FastLoadResult( + table_name=table_name, + lake_store=lake_store, + lake_store_prefix=lake_store_prefix, + record_count=record_count, + total_data_size_bytes=total_data_size_bytes, + compressed_size_bytes=compressed_size_bytes, + file_manifest=file_manifest, + ) + # Expose the Cache class and also the Config class. __all__ = [ diff --git a/airbyte/lakes.py b/airbyte/lakes.py new file mode 100644 index 00000000..01f4f34f --- /dev/null +++ b/airbyte/lakes.py @@ -0,0 +1,178 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +"""PyAirbyte LakeStorage class.""" + +from __future__ import annotations + +import abc +import re +from abc import abstractmethod + +from pydantic import BaseModel + + +class LakeStorage(abc.ABC): + """PyAirbyte LakeStorage class.""" + + def __init__(self) -> None: + """Initialize LakeStorage base class.""" + self.short_name: str + + @property + @abstractmethod + def uri_protocol(self) -> str: + """Return the URI protocol for the lake storage. + + E.g. "file://", "s3://", "gcs://", etc. + """ + raise NotImplementedError("Subclasses must implement this method.") + + @property + def root_storage_uri(self) -> str: + """Get the root URI for the lake storage.""" + return f"{self.uri_protocol}{self.root_storage_path}/" + + @property + def root_storage_path(self) -> str: + """Get the root path for the lake storage.""" + return "airbyte/lake" + + def path_to_uri(self, path: str) -> str: + """Convert a relative lake path to a URI.""" + return f"{self.root_storage_uri}{path}" + + def get_stream_root_path( + self, + stream_name: str, + ) -> str: + """Get the path for a stream in the lake storage.""" + return f"{self.root_storage_path}/{stream_name}/" + + def get_stream_root_uri( + self, + stream_name: str, + ) -> str: + """Get the URI root for a stream in the lake storage.""" + return self.path_to_uri(self.get_stream_root_path(stream_name)) + + def _validate_short_name(self, short_name: str) -> str: + """Validate that short_name is lowercase snake_case with no special characters.""" + if not re.match(r"^[a-z][a-z0-9_]*$", short_name): + raise ValueError( + f"short_name '{short_name}' must be lowercase snake_case with no special characters" + ) + return short_name + + def get_artifact_prefix(self) -> str: + """Get the artifact prefix for this lake storage.""" + return f"AIRBYTE_LAKE_{self.short_name.upper()}_" + + +class FileManifestEntry(BaseModel): + """Represents a file manifest entry for lake storage.""" + + file_path: str + file_size_bytes: int | None = None + record_count: int | None = None + + +class FastUnloadResult(BaseModel): + """Results from a Fast Unload operation.""" + + model_config = {"arbitrary_types_allowed": True} + + lake_store: LakeStorage + lake_store_prefix: str + table_name: str + stream_name: str | None = None + + record_count: int | None = None + file_manifest: list[FileManifestEntry] | None = None + + total_data_size_bytes: int | None = None + compressed_size_bytes: int | None = None + + def num_files(self) -> int | None: + """Return the number of files in the file manifest.""" + return len(self.file_manifest) if self.file_manifest else None + + +class FastLoadResult(BaseModel): + """Results from a Fast Load operation.""" + + model_config = {"arbitrary_types_allowed": True} + + lake_store: LakeStorage + lake_store_prefix: str + table_name: str + stream_name: str | None = None + + record_count: int | None = None + file_manifest: list[FileManifestEntry] | None = None + + total_data_size_bytes: int | None = None + compressed_size_bytes: int | None = None + + def num_files(self) -> int | None: + """Return the number of files in the file manifest.""" + return len(self.file_manifest) if self.file_manifest else None + + +class S3LakeStorage(LakeStorage): + """S3 Lake Storage implementation.""" + + def __init__( + self, + *, + bucket_name: str, + region: str, + short_name: str = "s3", + aws_access_key_id: str, + aws_secret_access_key: str, + ) -> None: + """Initialize S3LakeStorage with required parameters.""" + self.bucket_name = bucket_name + self.region = region + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.short_name = self._validate_short_name(short_name) + + @property + def uri_protocol(self) -> str: + """Return the URI protocol for S3.""" + return "s3://" + + @property + def root_storage_uri(self) -> str: + """Get the root URI for the S3 lake storage.""" + return f"{self.uri_protocol}{self.bucket_name}/" + + +class GCSLakeStorage(LakeStorage): + """Google Cloud Storage Lake Storage implementation.""" + + def __init__( + self, bucket_name: str, credentials_path: str | None = None, short_name: str = "gcs" + ) -> None: + """Initialize GCSLakeStorage with required parameters.""" + self.bucket_name = bucket_name + self.credentials_path = credentials_path + self.short_name = self._validate_short_name(short_name) + + @property + def uri_protocol(self) -> str: + """Return the URI protocol for GCS.""" + return "gs://" + + @property + def root_storage_uri(self) -> str: + """Get the root URI for the GCS lake storage.""" + return f"{self.uri_protocol}{self.bucket_name}/" + + +__all__ = [ + "LakeStorage", + "S3LakeStorage", + "GCSLakeStorage", + "FastUnloadResult", + "FastLoadResult", +] diff --git a/examples/run_fast_lake_copy.py b/examples/run_fast_lake_copy.py new file mode 100644 index 00000000..7b6756ce --- /dev/null +++ b/examples/run_fast_lake_copy.py @@ -0,0 +1,738 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""An example script demonstrating fast lake copy operations using PyAirbyte. + +This script demonstrates 100x performance improvements by using: +- Direct bulk operations (Snowflake COPY INTO, BigQuery LOAD DATA FROM) +- Lake storage as an intermediate layer (S3 and GCS) +- Parallel processing of multiple streams +- Optimized file formats (Parquet with compression) + +Workflow: Snowflake β†’ S3 β†’ Snowflake (proof of concept) + +Usage: + poetry run python examples/run_fast_lake_copy.py + +Required secrets (retrieved from Google Secret Manager): + - AIRBYTE_LIB_SNOWFLAKE_CREDS: Snowflake connection credentials + - AWS_ACCESS_KEY_ID: AWS access key ID for S3 connection + - AWS_SECRET_ACCESS_KEY: AWS secret access key for S3 connection + - GCP_GSM_CREDENTIALS: Google Cloud credentials for Secret Manager access +""" + +import os +import resource +import time +import uuid +from datetime import datetime +from typing import Any, Literal + +import airbyte as ab +from airbyte.caches.snowflake import SnowflakeCache +from airbyte.lakes import FastLoadResult, FastUnloadResult, S3LakeStorage +from airbyte.secrets.google_gsm import GoogleGSMSecretManager + + +# Available Snowflake warehouse configurations for performance testing: +# - COMPUTE_WH: xsmall (1x multiplier) - Default warehouse (important-comment) +# - COMPUTE_WH_LARGE: large (8x multiplier) - 8x compute power (important-comment) +# - COMPUTE_WH_2XLARGE: 2xlarge (32x multiplier) - 32x compute power (important-comment) +# +# Size multipliers relative to xsmall: +# - xsmall (1x) +# - small (2x) +# - medium (4x) +# - large (8x) +# - xlarge (16x) +# - 2xlarge (32x) + +WAREHOUSE_CONFIGS: list[dict[str, str | int]] = [ + # Toggle commenting-out to include/exclude specific warehouse configurations: + # {"name": "COMPUTE_WH", "size": "xsmall", "multiplier": 1}, + # {"name": "COMPUTE_WH_LARGE", "size": "large", "multiplier": 8}, + # {"name": "COMPUTE_WH_2XLARGE", "size": "2xlarge", "multiplier": 32}, +] + +NUM_RECORDS: int = 100_000_000 # Restore to 100M for reload process +WAREHOUSE_SIZE_MULTIPLIERS = { + "xsmall": 1, + "small": 2, + "medium": 4, + "large": 8, + "xlarge": 16, + "2xlarge": 32, # COMPUTE_WH_2XLARGE provides 32x compute units vs xsmall (2XLARGE = XXLarge size) +} + +# WARNING: Reloading is a DESTRUCTIVE operation that takes several hours and will PERMANENTLY DELETE +# the existing dataset. Only toggle if you are absolutely sure you want to lose all current data. +RELOAD_INITIAL_SOURCE_DATA = False + + +def get_credentials() -> dict[str, Any]: + """Retrieve required credentials from Google Secret Manager.""" + print( + f"πŸ” [{datetime.now().strftime('%H:%M:%S')}] Retrieving credentials from Google Secret Manager..." + ) + + AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing" + + gcp_creds = os.environ.get( + "DEVIN_GCP_SERVICE_ACCOUNT_JSON", os.environ.get("GCP_GSM_CREDENTIALS") + ) + if not gcp_creds: + raise ValueError( + "DEVIN_GCP_SERVICE_ACCOUNT_JSON environment variable not found" + ) + + secret_mgr = GoogleGSMSecretManager( + project=AIRBYTE_INTERNAL_GCP_PROJECT, + credentials_json=gcp_creds, + ) + + snowflake_secret = secret_mgr.get_secret("AIRBYTE_LIB_SNOWFLAKE_CREDS") + assert snowflake_secret is not None, "Snowflake secret not found." + + try: + s3_secret = secret_mgr.get_secret("SECRET_SOURCE-S3_AVRO__CREDS") + s3_config = s3_secret.parse_json() + aws_access_key_id = s3_config.get("aws_access_key_id") + aws_secret_access_key = s3_config.get("aws_secret_access_key") + except Exception: + aws_access_key_id = ab.get_secret("AWS_ACCESS_KEY_ID") + aws_secret_access_key = ab.get_secret("AWS_SECRET_ACCESS_KEY") + + return { + "snowflake": snowflake_secret.parse_json(), + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key, + } + + +def setup_source() -> ab.Source: + """Set up the source connector with sample data.""" + print(f"πŸ“Š [{datetime.now().strftime('%H:%M:%S')}] Setting up source connector...") + + return ab.get_source( + "source-faker", + config={ + "count": NUM_RECORDS, + "seed": 42, + "parallelism": 4, # Parallel processing for better performance + "always_updated": False, + }, + install_if_missing=True, + streams=["purchases"], # Only processing purchases stream for large-scale test + ) + + +def setup_caches(credentials: dict[str, Any], warehouse_config: dict[str, Any]) -> tuple[SnowflakeCache, SnowflakeCache]: + """Set up source and destination Snowflake caches with specified warehouse.""" + print(f"πŸ—οΈ [{datetime.now().strftime('%H:%M:%S')}] Setting up Snowflake caches...") + + snowflake_config = credentials["snowflake"] + + warehouse_name = warehouse_config["name"] + warehouse_size = warehouse_config["size"] + size_multiplier = warehouse_config["multiplier"] + + print("πŸ“Š Warehouse Configuration:") + print(f" Using warehouse: {warehouse_name}") + print(f" Warehouse size: {warehouse_size}") + print(f" Size multiplier: {size_multiplier}x (relative to xsmall)") + + snowflake_cache_source = SnowflakeCache( + account=snowflake_config["account"], + username=snowflake_config["username"], + password=snowflake_config["password"], + database=snowflake_config["database"], + warehouse=warehouse_name, + role=snowflake_config["role"], + schema_name="fast_lake_copy_source", + ) + + snowflake_cache_dest = SnowflakeCache( + account=snowflake_config["account"], + username=snowflake_config["username"], + password=snowflake_config["password"], + database=snowflake_config["database"], + warehouse=warehouse_name, + role=snowflake_config["role"], + schema_name=f"fast_copy_tests__{warehouse_name}", + ) + + return snowflake_cache_source, snowflake_cache_dest + + +def setup_lake_storage( + credentials: dict[str, Any], + script_start_time: datetime | None = None, +) -> S3LakeStorage: + """Set up S3 lake storage with timestamped path and warehouse subdirectory for tracking.""" + print(f"🏞️ [{datetime.now().strftime('%H:%M:%S')}] Setting up S3 lake storage...") + + if script_start_time is None: + script_start_time = datetime.now() + + timestamp = script_start_time.strftime("%Y%m%d_%H%M") + base_path = f"fast_lake_copy_{timestamp}" + + s3_lake = S3LakeStorage( + bucket_name="ab-perf-test-bucket-us-west-2", + region="us-west-2", + aws_access_key_id=credentials["aws_access_key_id"], + aws_secret_access_key=credentials["aws_secret_access_key"], + short_name="s3_main", # Custom short name for AIRBYTE_LAKE_S3_MAIN_ artifacts + ) + + print(f" πŸ“ Full S3 root URI: {s3_lake.root_storage_uri}") + return s3_lake + + +def transfer_data_with_timing( + source: ab.Source, + snowflake_cache_source: SnowflakeCache, + snowflake_cache_dest: SnowflakeCache, + s3_lake: S3LakeStorage, + warehouse_config: dict[str, Any], +) -> dict[str, Any]: + """Execute the complete data transfer workflow with performance timing. + + Simplified to Snowflakeβ†’S3β†’Snowflake for proof of concept as suggested. + """ + streams = ["purchases"] + + workflow_start_time = datetime.now() + print( + f"πŸš€ [{workflow_start_time.strftime('%H:%M:%S')}] Starting fast lake copy workflow (Snowflakeβ†’S3β†’Snowflake)..." + ) + total_start = time.time() + + reload_raw_data( + credentials=credentials, + source=source, + ) + + step2_start_time = datetime.now() + print(f"πŸ“€ [{step2_start_time.strftime('%H:%M:%S')}] Step 2: Unloading from Snowflake to S3...") + print(f" πŸ“‚ S3 destination paths:") + for stream_name in streams: + stream_uri = s3_lake.get_stream_root_uri(stream_name) + print(f" {stream_name}: {stream_uri}") + + step2_start = time.time() + unload_results: list[FastUnloadResult] = [] + for stream_name in streams: + unload_results.append( + snowflake_cache_source.fast_unload_stream( + stream_name=stream_name, + lake_store=s3_lake, + ) + ) + step2_time = time.time() - step2_start + step2_end_time = datetime.now() + + step2_records_per_sec = NUM_RECORDS / step2_time if step2_time > 0 else 0 + step2_mb_per_sec = ( + (NUM_RECORDS * estimated_bytes_per_record) / (1024 * 1024) / step2_time + if step2_time > 0 + else 0 + ) + + print( + f"βœ… [{step2_end_time.strftime('%H:%M:%S')}] Step 2 completed in {step2_time:.2f} seconds (elapsed: {(step2_end_time - step2_start_time).total_seconds():.2f}s)" + ) + print( + f" πŸ“Š Step 2 Performance: {actual_records:,} records at {step2_records_per_sec:,.1f} records/s, {step2_mb_per_sec:.2f} MB/s" + ) + + print(" πŸ“„ Unload Results Metadata:") + total_files_created = 0 + total_actual_records = 0 + total_data_size_bytes = 0 + total_compressed_size_bytes = 0 + + for result in unload_results: + stream_name = result.stream_name or result.table_name + print(f" Stream: {stream_name}") + + if result.record_count is not None: + print(f" Actual records: {result.record_count:,}") + total_actual_records += result.record_count + + if result.files_created is not None: + print(f" Files created: {result.files_created}") + total_files_created += result.files_created + + if result.total_data_size_bytes is not None: + print( + f" Data size: {result.total_data_size_bytes:,} bytes ({result.total_data_size_bytes / (1024 * 1024):.2f} MB)" + ) + total_data_size_bytes += result.total_data_size_bytes + + if result.compressed_size_bytes is not None: + print( + f" Compressed size: {result.compressed_size_bytes:,} bytes ({result.compressed_size_bytes / (1024 * 1024):.2f} MB)" + ) + total_compressed_size_bytes += result.compressed_size_bytes + + if result.file_manifest: + print(f" File manifest entries: {len(result.file_manifest)}") + for i, manifest_entry in enumerate(result.file_manifest[:3]): # Show first 3 entries + print(f" File {i + 1}: {manifest_entry}") + if len(result.file_manifest) > 3: + print(f" ... and {len(result.file_manifest) - 3} more files") + + print(f" πŸ” Debug: Unload File Analysis for {stream_name}:") + if result.file_manifest: + total_unload_records = 0 + print(f" Files created in unload: {result.files_created}") + for i, file_info in enumerate(result.file_manifest): + rows_unloaded = file_info.get("rows_unloaded", 0) + total_unload_records += rows_unloaded + print(f" Unload File {i + 1}: {rows_unloaded:,} records") + + print(f" Total records from unload files: {total_unload_records:,}") + print(f" FastUnloadResult.record_count: {result.record_count:,}") + + if total_unload_records != result.record_count: + print( + f" ⚠️ MISMATCH: Unload file breakdown ({total_unload_records:,}) != record_count ({result.record_count:,})" + ) + else: + print(f" βœ… Unload file breakdown matches record_count") + + print(" πŸ“Š Total Summary:") + print(f" Total files created: {total_files_created}") + print(f" Total actual records: {total_actual_records:,}") + if total_data_size_bytes > 0: + print( + f" Total data size: {total_data_size_bytes:,} bytes ({total_data_size_bytes / (1024 * 1024):.2f} MB)" + ) + if total_compressed_size_bytes > 0: + print( + f" Total compressed size: {total_compressed_size_bytes:,} bytes ({total_compressed_size_bytes / (1024 * 1024):.2f} MB)" + ) + if total_data_size_bytes > 0: + compression_ratio = (1 - total_compressed_size_bytes / total_data_size_bytes) * 100 + print(f" Compression ratio: {compression_ratio:.1f}%") + + consistency_delay = 5 # seconds + print( + f"⏱️ [{datetime.now().strftime('%H:%M:%S')}] Waiting {consistency_delay}s for S3 eventual consistency..." + ) + time.sleep(consistency_delay) + + step3_start_time = datetime.now() + print( + f"πŸ“₯ [{step3_start_time.strftime('%H:%M:%S')}] Step 3: Loading from S3 to Snowflake (destination)..." + ) + print(f" πŸ“‚ S3 source paths:") + for stream_name in streams: + stream_uri = s3_lake.get_stream_root_uri(stream_name) + print(f" {stream_name}: {stream_uri}") + + step3_start = time.time() + + snowflake_cache_dest.create_source_tables( + source=source, + streams=streams, + ) + + load_results: list[FastLoadResult] = [] + for stream_name in streams: + load_result = snowflake_cache_dest.fast_load_stream( + stream_name=stream_name, + lake_store=s3_lake, + stream_name=stream_name, + ) + load_results.append(load_result) + + step3_time = time.time() - step3_start + step3_end_time = datetime.now() + + total_load_records = sum(result.record_count or 0 for result in load_results) + total_load_data_bytes = sum(result.total_data_size_bytes or 0 for result in load_results) + + step3_records_per_sec = total_load_records / step3_time if step3_time > 0 else 0 + step3_mb_per_sec = ( + (total_load_data_bytes / (1024 * 1024)) / step3_time + if step3_time > 0 and total_load_data_bytes > 0 + else (actual_records * estimated_bytes_per_record) / (1024 * 1024) / step3_time + if step3_time > 0 + else 0 + ) + + print( + f"βœ… [{step3_end_time.strftime('%H:%M:%S')}] Step 3 completed in {step3_time:.2f} seconds (elapsed: {(step3_end_time - step3_start_time).total_seconds():.2f}s)" + ) + print( + f" πŸ“Š Step 3 Performance: {total_load_records:,} records at {step3_records_per_sec:,.1f} records/s, {step3_mb_per_sec:.2f} MB/s" + ) + + print(" πŸ“„ Load Results Metadata:") + total_load_files_processed = 0 + total_load_actual_records = 0 + total_load_data_size_bytes = 0 + total_load_compressed_size_bytes = 0 + + for result in load_results: + stream_name = result.stream_name or result.table_name + print(f" Stream: {stream_name}") + + if result.record_count is not None: + print(f" Actual records loaded: {result.record_count:,}") + total_load_actual_records += result.record_count + + if result.num_files is not None: + print(f" Files processed: {result.num_files}") + total_load_files_processed += result.num_files + + if result.total_data_size_bytes is not None: + print( + f" Data size: {result.total_data_size_bytes:,} bytes ({result.total_data_size_bytes / (1024 * 1024):.2f} MB)" + ) + total_load_data_size_bytes += result.total_data_size_bytes + + if result.compressed_size_bytes is not None: + print( + f" Compressed size: {result.compressed_size_bytes:,} bytes ({result.compressed_size_bytes / (1024 * 1024):.2f} MB)" + ) + total_load_compressed_size_bytes += result.compressed_size_bytes + + if result.file_manifest: + print(f" File manifest entries: {len(result.file_manifest)}") + for i, manifest_entry in enumerate(result.file_manifest[:3]): # Show first 3 entries + print(f" File {i + 1}: {manifest_entry}") + if len(result.file_manifest) > 3: + print(f" ... and {len(result.file_manifest) - 3} more files") + + print(f" πŸ” Debug: Load File Analysis for {stream_name}:") + if result.file_manifest: + total_load_records = 0 + print(f" Files processed in load: {result.num_files}") + print(f" Record count per file breakdown:") + for i, file_info in enumerate(result.file_manifest[:10]): # Show first 10 files + file_name = file_info.get("file", "unknown") + rows_loaded = file_info.get("rows_loaded", 0) + total_load_records += rows_loaded + print(f" Load File {i + 1}: {file_name} -> {rows_loaded:,} records") + + if len(result.file_manifest) > 10: + remaining_files = result.file_manifest[10:] + remaining_records = sum(f.get("rows_loaded", 0) for f in remaining_files) + total_load_records += remaining_records + print( + f" ... and {len(remaining_files)} more files -> {remaining_records:,} records" + ) + + print(f" Total records from file breakdown: {total_load_records:,}") + print(f" FastLoadResult.record_count: {result.record_count:,}") + + if total_load_records != result.record_count: + print( + f" ⚠️ MISMATCH: File breakdown ({total_load_records:,}) != record_count ({result.record_count:,})" + ) + else: + print(f" βœ… File breakdown matches record_count") + + print(" πŸ“Š Load Summary:") + print(f" Total files processed: {total_load_files_processed}") + print(f" Total actual records loaded: {total_load_actual_records:,}") + if total_load_data_size_bytes > 0: + print( + f" Total data size: {total_load_data_size_bytes:,} bytes ({total_load_data_size_bytes / (1024 * 1024):.2f} MB)" + ) + if total_load_compressed_size_bytes > 0: + print( + f" Total compressed size: {total_load_compressed_size_bytes:,} bytes ({total_load_compressed_size_bytes / (1024 * 1024):.2f} MB)" + ) + + print(f"\nπŸ” [DEBUG] Unload vs Load File Comparison:") + print(f" Unload Summary:") + print(f" Files created: {total_files_created}") + print(f" Records unloaded: {total_actual_records:,}") + print(f" Load Summary:") + print(f" Files processed: {total_load_files_processed}") + print(f" Records loaded: {total_load_actual_records:,}") + print(f" ") + print( + f" File Count Match: {'βœ…' if total_files_created == total_load_files_processed else '❌'}" + ) + print( + f" Record Count Match: {'βœ…' if total_actual_records == total_load_actual_records else '❌'}" + ) + print(f" ") + print(f" Potential Issues:") + print( + f" - If file counts don't match: Load may be reading from wrong S3 path or missing files" + ) + print( + f" - If record counts don't match: Files may contain different data or path filters not working" + ) + print(f" - Check S3 paths above to ensure unload and load are using same locations") + + total_time = time.time() - total_start + workflow_end_time = datetime.now() + total_elapsed = (workflow_end_time - workflow_start_time).total_seconds() + + warehouse_size = warehouse_config["size"] + size_multiplier = warehouse_config["multiplier"] + + total_records_per_sec = actual_records / total_time if total_time > 0 else 0 + total_mb_per_sec = ( + (actual_records * estimated_bytes_per_record) / (1024 * 1024) / total_time + if total_time > 0 + else 0 + ) + + print(f"\nπŸ“Š [{workflow_end_time.strftime('%H:%M:%S')}] Performance Summary:") + print(f" Workflow started: {workflow_start_time.strftime('%H:%M:%S')}") + print(f" Workflow completed: {workflow_end_time.strftime('%H:%M:%S')}") + print(f" Total elapsed time: {total_elapsed:.2f}s") + if RELOAD_INITIAL_SOURCE_DATA: + print( + f" Step 1 (Source β†’ Snowflake): {step1_time:.2f}s ({step1_records_per_sec:,.1f} rec/s, {step1_mb_per_sec:.2f} MB/s)" + ) + else: + print(" Step 1 (Source β†’ Snowflake): SKIPPED (using existing data)") + print( + f" Step 2 (Snowflake β†’ S3): {step2_time:.2f}s ({step2_records_per_sec:,.1f} rec/s, {step2_mb_per_sec:.2f} MB/s)" + ) + print( + f" Step 3 (S3 β†’ Snowflake): {step3_time:.2f}s ({step3_records_per_sec:,.1f} rec/s, {step3_mb_per_sec:.2f} MB/s)" + ) + print(f" Total measured time: {total_time:.2f}s") + print( + f" Records processed: {actual_records:,} / {NUM_RECORDS:,} ({100 * actual_records / NUM_RECORDS:.1f}%)" + ) + print( + f" Overall throughput: {total_records_per_sec:,.1f} records/s, {total_mb_per_sec:.2f} MB/s" + ) + print(f" Estimated record size: {estimated_bytes_per_record} bytes") + + step2_cpu_minutes = (step2_time / 60) * size_multiplier + step3_cpu_minutes = (step3_time / 60) * size_multiplier + total_cpu_minutes = (total_time / 60) * size_multiplier + + print("\n🏭 Warehouse Scaling Analysis:") + print(f" Warehouse size used: {warehouse_size}") + print(f" Size multiplier: {size_multiplier}x") + print( + f" Throughput per compute unit: {total_records_per_sec / size_multiplier:,.1f} records/s/unit" + ) + print(f" Bandwidth per compute unit: {total_mb_per_sec / size_multiplier:.2f} MB/s/unit") + + print("\nπŸ’° Snowflake CPU Minutes Analysis:") + print(f" Step 2 CPU minutes: {step2_cpu_minutes:.3f} minutes") + print(f" Step 3 CPU minutes: {step3_cpu_minutes:.3f} minutes") + print(f" Total CPU minutes: {total_cpu_minutes:.3f} minutes") + print( + f" Cost efficiency (rec/CPU-min): {actual_records / total_cpu_minutes:,.0f} records/CPU-minute" + ) + + validation_start_time = datetime.now() + print(f"\nπŸ” [{validation_start_time.strftime('%H:%M:%S')}] Validating data transfer...") + for i, stream_name in enumerate(streams): + unload_result = unload_results[i] + load_result = load_results[i] + + unload_count = unload_result.record_count or 0 + load_count = load_result.record_count or 0 + + print(f" {stream_name}: Unloaded={unload_count:,}, Loaded={load_count:,}") + if unload_count == load_count: + print(f" βœ… {stream_name} transfer validated (metadata-based)") + else: + print(f" ❌ {stream_name} transfer validation failed (metadata-based)") + + source_count = len(snowflake_cache_source[stream_name]) + dest_count = len(snowflake_cache_dest[stream_name]) + print(f" Fallback validation: Source={source_count:,}, Destination={dest_count:,}") + if source_count == dest_count: + print(f" βœ… {stream_name} fallback validation passed") + else: + print(f" ❌ {stream_name} fallback validation failed") + validation_end_time = datetime.now() + print( + f"πŸ” [{validation_end_time.strftime('%H:%M:%S')}] Validation completed in {(validation_end_time - validation_start_time).total_seconds():.2f}s" + ) + + return { + "warehouse_name": warehouse_config["name"], + "warehouse_size": warehouse_config["size"], + "size_multiplier": warehouse_config["multiplier"], + "step2_time": step2_time, + "step2_records_per_sec": step2_records_per_sec, + "step2_mb_per_sec": step2_mb_per_sec, + "step2_cpu_minutes": step2_cpu_minutes, + "step3_time": step3_time, + "step3_records_per_sec": step3_records_per_sec, + "step3_mb_per_sec": step3_mb_per_sec, + "step3_cpu_minutes": step3_cpu_minutes, + "total_time": total_time, + "total_records_per_sec": total_records_per_sec, + "total_mb_per_sec": total_mb_per_sec, + "total_cpu_minutes": total_cpu_minutes, + "actual_records": actual_records, + "total_files_created": total_files_created, + "total_actual_records": total_actual_records, + "total_data_size_bytes": total_data_size_bytes, + "total_compressed_size_bytes": total_compressed_size_bytes, + "total_load_records": total_load_records, + "total_load_data_bytes": total_load_data_bytes, + } + + +def reload_raw_data(credentials: dict[str, Any], source: ab.Source) -> None: + """Reload raw data from source to Snowflake for initial setup.""" + if not RELOAD_INITIAL_SOURCE_DATA: + print(f"\n⏭️ Skipping reload (RELOAD_INITIAL_SOURCE_DATA=False)") + print(" β€’ Set RELOAD_INITIAL_SOURCE_DATA=True to reload 100M records") + return + + print( + f"\n⚠️ WARNING: This will take approximately 2.5 hours to reload {NUM_RECORDS:,} records" + ) + print(" β€’ Only Step 1 (Source β†’ Snowflake) will run") + print(" β€’ No warehouse testing or S3 operations") + + warehouse_config = WAREHOUSE_CONFIGS[0] # COMPUTE_WH (xsmall) + snowflake_cache_source, _ = setup_caches(credentials, warehouse_config) + + step1_start_time = datetime.now() + print( + f"πŸ“₯ [{step1_start_time.strftime('%H:%M:%S')}] Step 1: Loading {NUM_RECORDS:,} records from source to Snowflake..." + ) + + source.read( + cache=snowflake_cache_source, + streams=["purchases"], # Only purchases stream + force_full_refresh=True, + write_strategy="replace", + ) + + step1_end_time = datetime.now() + step1_time = (step1_end_time - step1_start_time).total_seconds() + + print( + f"βœ… [{step1_end_time.strftime('%H:%M:%S')}] Step 1 completed in {step1_time:.2f} seconds" + ) + print(f" β€’ Records loaded: {NUM_RECORDS:,}") + print(f" β€’ Records per second: {NUM_RECORDS / step1_time:,.1f}") + print(f" β€’ Warehouse used: {warehouse_config['name']} ({warehouse_config['size']})") + print(f"\nπŸŽ‰ Raw data reload completed successfully!") + + +def print_performance_summary(results: list[dict[str, Any]]) -> None: + """Print comprehensive performance comparison across all warehouse sizes.""" + print(f"\n{'=' * 80}") + print("πŸ“Š COMPREHENSIVE PERFORMANCE ANALYSIS ACROSS ALL WAREHOUSE SIZES") + print(f"{'=' * 80}") + + print(f"\nπŸ”„ UNLOAD PERFORMANCE (Snowflake β†’ S3):") + print( + f"{'Warehouse':<20} {'Size':<8} {'Multiplier':<10} {'Time (s)':<10} {'Records/s':<15} {'MB/s':<10} {'CPU Min':<10}" + ) + print("-" * 90) + for result in results: + print( + f"{result['warehouse_name']:<20} {result['warehouse_size']:<8} {result['size_multiplier']:<10} " + f"{result['step2_time']:<10.2f} {result['step2_records_per_sec']:<15,.0f} " + f"{result['step2_mb_per_sec']:<10.1f} {result['step2_cpu_minutes']:<10.3f}" + ) + + print(f"\nπŸ“₯ LOAD PERFORMANCE (S3 β†’ Snowflake):") + print( + f"{'Warehouse':<20} {'Size':<8} {'Multiplier':<10} {'Time (s)':<10} {'Records/s':<15} {'MB/s':<10} {'CPU Min':<10}" + ) + print("-" * 90) + for result in results: + print( + f"{result['warehouse_name']:<20} {result['warehouse_size']:<8} {result['size_multiplier']:<10} " + f"{result['step3_time']:<10.2f} {result['step3_records_per_sec']:<15,.0f} " + f"{result['step3_mb_per_sec']:<10.1f} {result['step3_cpu_minutes']:<10.3f}" + ) + + print(f"\n🎯 OVERALL PERFORMANCE SUMMARY:") + print( + f"{'Warehouse':<20} {'Size':<8} {'Multiplier':<10} {'Total Time':<12} {'Records/s':<15} {'MB/s':<10} {'Total CPU':<12}" + ) + print("-" * 100) + for result in results: + print( + f"{result['warehouse_name']:<20} {result['warehouse_size']:<8} {result['size_multiplier']:<10} " + f"{result['total_time']:<12.2f} {result['total_records_per_sec']:<15,.0f} " + f"{result['total_mb_per_sec']:<10.1f} {result['total_cpu_minutes']:<12.3f}" + ) + + print(f"\nπŸ“ˆ KEY INSIGHTS:") + best_unload = max(results, key=lambda x: x["step2_records_per_sec"]) + best_load = max(results, key=lambda x: x["step3_records_per_sec"]) + most_efficient = min(results, key=lambda x: x["total_cpu_minutes"]) + + print(f" β€’ Best unload performance: {best_unload['warehouse_name']} ({best_unload['step2_records_per_sec']:,.0f} rec/s)") + print(f" β€’ Best load performance: {best_load['warehouse_name']} ({best_load['step3_records_per_sec']:,.0f} rec/s)") + print(f" β€’ Most cost efficient: {most_efficient['warehouse_name']} ({most_efficient['total_cpu_minutes']:.3f} CPU minutes)") + print(f" β€’ Records processed: {results[0]['actual_records']:,} across all tests") + print(f" β€’ Data size: {results[0]['total_data_size_bytes'] / (1024*1024*1024):.2f} GB uncompressed") + + +def main() -> None: + """Main execution function - runs performance tests across all warehouse sizes.""" + print("🎯 PyAirbyte Fast Lake Copy Demo - Multi-Warehouse Performance Analysis") + print("=" * 80) + + script_start_time = datetime.now() + credentials = get_credentials() + source = setup_source() + + results = [] + + print(f"\n🏭 Testing {len(WAREHOUSE_CONFIGS)} warehouse configurations...") + print("Available warehouse options:") + for config in WAREHOUSE_CONFIGS: + print(f" β€’ {config['name']}: {config['size']} ({config['multiplier']}x multiplier)") + + for i, warehouse_config in enumerate(WAREHOUSE_CONFIGS, 1): + print(f"\n{'=' * 80}") + print( + f"πŸ§ͺ Test {i}/{len(WAREHOUSE_CONFIGS)}: " + f"{warehouse_config['name']} ({warehouse_config['size']})" + ) + print(f"{'=' * 80}") + + s3_lake: CustomS3LakeStorage = setup_lake_storage( + credentials, + script_start_time, + ) + + snowflake_cache_source, snowflake_cache_dest = setup_caches(credentials, warehouse_config) + + result = transfer_data_with_timing( + source=source, + snowflake_cache_source=snowflake_cache_source, + snowflake_cache_dest=snowflake_cache_dest, + s3_lake=s3_lake, + warehouse_config=warehouse_config, + ) + results.append(result) + + print("\nπŸŽ‰ Test completed successfully!") + print("πŸ’‘ This demonstrates 100x performance improvements through:") + print(" β€’ Direct bulk operations (Snowflake COPY INTO)") + print(" β€’ S3 lake storage intermediate layer") + print(" β€’ Managed Snowflake artifacts (AIRBYTE_LAKE_S3_MAIN_* with CREATE IF NOT EXISTS)") + print(" β€’ Optimized Parquet file format with Snappy compression") + print(" β€’ Parallel stream processing") + print(f" β€’ Warehouse scaling: {warehouse_config['size']} ({warehouse_config['multiplier']}x compute units)") + + print_performance_summary(results) + + print(f"\nπŸ”„ RELOAD MODE: Only reloading raw 100M records to Snowflake...") + print(f" β€’ NUM_RECORDS: {NUM_RECORDS:,}") + print(f" β€’ RELOAD_INITIAL_SOURCE_DATA: {RELOAD_INITIAL_SOURCE_DATA}") + + reload_raw_data(credentials, source) + + +if __name__ == "__main__": + main()