Skip to content

Commit 753c2ac

Browse files
feat(caches): add create_source_tables method to CacheBase class (do not merge) (#631)
Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Aaron <AJ> Steers <[email protected]>
1 parent 0d4029a commit 753c2ac

File tree

2 files changed

+106
-1
lines changed

2 files changed

+106
-1
lines changed

airbyte/caches/base.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from pathlib import Path
7-
from typing import IO, TYPE_CHECKING, Any, ClassVar, final
7+
from typing import IO, TYPE_CHECKING, Any, ClassVar, Literal, final
88

99
import pandas as pd
1010
import pyarrow as pa
@@ -34,6 +34,7 @@
3434
from airbyte.shared.sql_processor import SqlProcessorBase
3535
from airbyte.shared.state_providers import StateProviderBase
3636
from airbyte.shared.state_writers import StateWriterBase
37+
from airbyte.sources.base import Source
3738
from airbyte.strategies import WriteStrategy
3839

3940

@@ -293,6 +294,35 @@ def register_source(
293294
incoming_stream_names=stream_names,
294295
)
295296

297+
def create_source_tables(
298+
self,
299+
source: Source,
300+
streams: Literal["*"] | list[str] | None = None,
301+
) -> None:
302+
"""Create tables in the cache for the provided source if they do not exist already.
303+
304+
Tables are created based upon the Source's catalog.
305+
306+
Args:
307+
source: The source to create tables for.
308+
streams: Stream names to create tables for. If None, use the Source's selected_streams
309+
or "*" if neither is set. If "*", all available streams will be used.
310+
"""
311+
if streams is None:
312+
streams = source.get_selected_streams() or "*"
313+
314+
catalog_provider = CatalogProvider(source.get_configured_catalog(streams=streams))
315+
316+
# Ensure schema exists
317+
self.processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member
318+
319+
# Create tables for each stream if they don't exist
320+
for stream_name in catalog_provider.stream_names:
321+
self.processor._ensure_final_table_exists( # noqa: SLF001
322+
stream_name=stream_name,
323+
create_if_missing=True,
324+
)
325+
296326
def __getitem__(self, stream: str) -> CachedDataset:
297327
"""Return a dataset by stream name."""
298328
return self.streams[stream]

tests/unit_tests/test_caches.py

+75
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pathlib import Path
55

6+
67
from airbyte.caches.base import CacheBase
78
from airbyte.caches.duckdb import DuckDBCache
89

@@ -60,3 +61,77 @@ def test_duck_db_cache_config_get_database_name_with_default_schema_name():
6061

6162
def test_duck_db_cache_config_inheritance_from_sql_cache_config_base():
6263
assert issubclass(DuckDBCache, CacheBase)
64+
65+
66+
def test_create_source_tables(mocker):
67+
"""Test that the create_source_tables method correctly creates tables based on the source's catalog."""
68+
# Import here to avoid circular imports
69+
from airbyte_protocol.models import (
70+
ConfiguredAirbyteCatalog,
71+
ConfiguredAirbyteStream,
72+
)
73+
74+
# Create a proper ConfiguredAirbyteCatalog for mocking
75+
stream1 = ConfiguredAirbyteStream(
76+
stream={
77+
"name": "stream1",
78+
"json_schema": {},
79+
"supported_sync_modes": ["full_refresh"],
80+
},
81+
sync_mode="full_refresh",
82+
destination_sync_mode="overwrite",
83+
)
84+
stream2 = ConfiguredAirbyteStream(
85+
stream={
86+
"name": "stream2",
87+
"json_schema": {},
88+
"supported_sync_modes": ["full_refresh"],
89+
},
90+
sync_mode="full_refresh",
91+
destination_sync_mode="overwrite",
92+
)
93+
catalog = ConfiguredAirbyteCatalog(streams=[stream1, stream2])
94+
95+
# Mock the catalog provider
96+
mock_catalog_provider = mocker.Mock()
97+
mock_catalog_provider.stream_names = ["stream1", "stream2"]
98+
mocker.patch(
99+
"airbyte.shared.catalog_providers.CatalogProvider",
100+
return_value=mock_catalog_provider,
101+
)
102+
103+
# Mock a source with configured catalog and selected streams
104+
mock_source = mocker.Mock()
105+
mock_source.get_configured_catalog.return_value = catalog
106+
mock_source.get_selected_streams.return_value = ["stream1"]
107+
108+
# Create a DuckDBCache instance with mocked processor
109+
cache = DuckDBCache(db_path=UNIT_TEST_DB_PATH)
110+
111+
# Mock the processor property
112+
mock_processor = mocker.Mock()
113+
mocker.patch.object(
114+
DuckDBCache, "processor", mocker.PropertyMock(return_value=mock_processor)
115+
)
116+
117+
# Test with default (None) stream parameter - should use source's selected streams
118+
cache.create_source_tables(mock_source)
119+
120+
# Verify the correct methods were called
121+
mock_source.get_selected_streams.assert_called_once()
122+
mock_source.get_configured_catalog.assert_called_once_with(streams=["stream1"])
123+
mock_processor._ensure_schema_exists.assert_called_once()
124+
assert mock_processor._ensure_final_table_exists.call_count == 2
125+
126+
# Reset mocks
127+
mock_source.reset_mock()
128+
mock_processor.reset_mock()
129+
130+
# Test with explicit stream list
131+
cache.create_source_tables(mock_source, streams=["stream2"])
132+
133+
# Verify the correct methods were called
134+
mock_source.get_selected_streams.assert_not_called()
135+
mock_source.get_configured_catalog.assert_called_once_with(streams=["stream2"])
136+
mock_processor._ensure_schema_exists.assert_called_once()
137+
assert mock_processor._ensure_final_table_exists.call_count == 2

0 commit comments

Comments
 (0)