diff --git a/airbyte/_processors/sql/bigquery.py b/airbyte/_processors/sql/bigquery.py index f8651c86..d4d31a5e 100644 --- a/airbyte/_processors/sql/bigquery.py +++ b/airbyte/_processors/sql/bigquery.py @@ -5,7 +5,7 @@ import warnings from pathlib import Path -from typing import TYPE_CHECKING, cast, final +from typing import TYPE_CHECKING, final import google.oauth2 import sqlalchemy @@ -103,7 +103,7 @@ class BigQueryTypeConverter(SQLTypeConverter): @classmethod def get_string_type(cls) -> sqlalchemy.types.TypeEngine: """Return the string type for BigQuery.""" - return cast("sqlalchemy.types.TypeEngine", "String") # BigQuery uses STRING for all strings + return sqlalchemy.types.String() @overrides def to_sql_type( @@ -210,7 +210,7 @@ def _ensure_schema_exists( project = self.sql_config.project_name schema = self.sql_config.schema_name location = self.sql_config.dataset_location - sql = f"CREATE SCHEMA IF NOT EXISTS `{project}.{schema}` " f'OPTIONS(location="{location}")' + sql = f'CREATE SCHEMA IF NOT EXISTS `{project}.{schema}` OPTIONS(location="{location}")' try: self._execute_sql(sql) except Exception as ex: @@ -294,8 +294,7 @@ def _swap_temp_table_with_final_table( deletion_name = f"{final_table_name}_deleteme" commands = "\n".join( [ - f"ALTER TABLE {self._fully_qualified(final_table_name)} " - f"RENAME TO {deletion_name};", + f"ALTER TABLE {self._fully_qualified(final_table_name)} RENAME TO {deletion_name};", f"ALTER TABLE {self._fully_qualified(temp_table_name)} " f"RENAME TO {final_table_name};", f"DROP TABLE {self._fully_qualified(deletion_name)};", diff --git a/airbyte/shared/sql_processor.py b/airbyte/shared/sql_processor.py index a602078b..ab0bc157 100644 --- a/airbyte/shared/sql_processor.py +++ b/airbyte/shared/sql_processor.py @@ -585,8 +585,9 @@ def _create_table_for_loading( ) -> str: """Create a new table for loading data.""" temp_table_name = self._get_temp_table_name(stream_name, batch_id) + engine = self.get_sql_engine() column_definition_str = ",\n ".join( - f"{self._quote_identifier(column_name)} {sql_type}" + f"{self._quote_identifier(column_name)} {sql_type.compile(engine.dialect)}" for column_name, sql_type in self._get_sql_column_definitions(stream_name).items() ) self._create_table(temp_table_name, column_definition_str) @@ -634,9 +635,10 @@ def _ensure_final_table_exists( """ table_name = self.get_sql_table_name(stream_name) did_exist = self._table_exists(table_name) + engine = self.get_sql_engine() if not did_exist and create_if_missing: column_definition_str = ",\n ".join( - f"{self._quote_identifier(column_name)} {sql_type}" + f"{self._quote_identifier(column_name)} {sql_type.compile(engine.dialect)}" for column_name, sql_type in self._get_sql_column_definitions( stream_name, ).items() @@ -698,7 +700,7 @@ def _get_sql_column_definitions( ) columns[AB_RAW_ID_COLUMN] = self.type_converter_class.get_string_type() - columns[AB_EXTRACTED_AT_COLUMN] = sqlalchemy.TIMESTAMP() + columns[AB_EXTRACTED_AT_COLUMN] = sqlalchemy.TIMESTAMP(timezone=True) columns[AB_META_COLUMN] = self.type_converter_class.get_json_type() return columns diff --git a/airbyte/types.py b/airbyte/types.py index 4d78353f..859b9b56 100644 --- a/airbyte/types.py +++ b/airbyte/types.py @@ -18,8 +18,8 @@ "number": sqlalchemy.types.DECIMAL(38, 9), "boolean": sqlalchemy.types.BOOLEAN, "date": sqlalchemy.types.DATE, - "timestamp_with_timezone": sqlalchemy.types.TIMESTAMP, - "timestamp_without_timezone": sqlalchemy.types.TIMESTAMP, + "timestamp_with_timezone": sqlalchemy.types.TIMESTAMP(timezone=True), + "timestamp_without_timezone": sqlalchemy.types.TIMESTAMP(timezone=False), "time_with_timezone": sqlalchemy.types.TIME, "time_without_timezone": sqlalchemy.types.TIME, # Technically 'object' and 'array' as JSON Schema types, not airbyte types. @@ -151,7 +151,7 @@ def to_sql_type( # noqa: PLR0911 # Too many return statements return sqlalchemy.types.DATE() if json_schema_type == "string" and json_schema_format == "date-time": - return sqlalchemy.types.TIMESTAMP() + return sqlalchemy.types.TIMESTAMP(timezone=True) if json_schema_type == "array": return sqlalchemy.types.JSON() diff --git a/tests/integration_tests/test_all_cache_types.py b/tests/integration_tests/test_all_cache_types.py index 944db3bf..7a01f5ed 100644 --- a/tests/integration_tests/test_all_cache_types.py +++ b/tests/integration_tests/test_all_cache_types.py @@ -12,16 +12,18 @@ import os import sys from pathlib import Path +from unittest.mock import patch import airbyte as ab import pytest from airbyte import get_source +from airbyte._processors.sql.duckdb import DuckDBConfig, DuckDBSqlProcessor from airbyte._util.venv_util import get_bin_dir from airbyte.results import ReadResult +from airbyte.shared.catalog_providers import CatalogProvider from sqlalchemy import text from viztracer import VizTracer - # Product count is always the same, regardless of faker scale. NUM_PRODUCTS = 100 @@ -284,3 +286,46 @@ def test_auto_add_columns( result = source_faker_seed_a.read(cache=new_generic_cache, write_strategy="auto") assert "_airbyte_raw_id" in result["users"].to_sql_table().columns + + +@pytest.mark.slow +def test_cache_columns_for_datetime_types_are_timezone_aware(): + """Ensures sql types are correctly converted to the correct sql timezone aware column types""" + expected_sql = """ + CREATE TABLE airbyte."products" ( + "id" BIGINT, + "make" VARCHAR, + "model" VARCHAR, + "year" BIGINT, + "price" DECIMAL(38, 9), + "created_at" TIMESTAMP WITH TIME ZONE, + "updated_at" TIMESTAMP WITH TIME ZONE, + "_airbyte_raw_id" VARCHAR, + "_airbyte_extracted_at" TIMESTAMP WITH TIME ZONE, + "_airbyte_meta" JSON + ) + \n """ + source = get_source( + name="source-faker", + config={}, + ) + + config = DuckDBConfig( + schema_name="airbyte", + db_path=":memory:", + ) + + processor = DuckDBSqlProcessor( + catalog_provider=CatalogProvider(source.configured_catalog), + temp_dir=Path(), + temp_file_cleanup=True, + sql_config=config, + ) + + with ( + patch.object(processor, "_execute_sql") as _execute_sql_mock, + ): + processor._ensure_final_table_exists( + stream_name="products", + ) + _execute_sql_mock.assert_called_with(expected_sql) diff --git a/tests/unit_tests/test_type_translation.py b/tests/unit_tests/test_type_translation.py index 84b0368a..3835a59a 100644 --- a/tests/unit_tests/test_type_translation.py +++ b/tests/unit_tests/test_type_translation.py @@ -25,7 +25,7 @@ "format": "date-time", "airbyte_type": "timestamp_without_timezone", }, - types.TIMESTAMP, + types.TIMESTAMP(timezone=False), ), ( { @@ -33,7 +33,7 @@ "format": "date-time", "airbyte_type": "timestamp_with_timezone", }, - types.TIMESTAMP, + types.TIMESTAMP(timezone=True), ), ( { @@ -68,7 +68,11 @@ def test_to_sql_type(json_schema_property_def, expected_sql_type): converter = SQLTypeConverter() sql_type = converter.to_sql_type(json_schema_property_def) - assert isinstance(sql_type, expected_sql_type) + if isinstance(expected_sql_type, types.TIMESTAMP): + assert isinstance(sql_type, types.TIMESTAMP) + assert sql_type.timezone == expected_sql_type.timezone + else: + assert isinstance(sql_type, expected_sql_type) @pytest.mark.parametrize(