|
6 | 6 | from bson import json_util |
7 | 7 | from fastapi import BackgroundTasks, HTTPException |
8 | 8 | from overrides import override |
| 9 | +from sqlalchemy import MetaData, inspect |
9 | 10 |
|
10 | 11 | from dataherald.api import API |
11 | 12 | from dataherald.api.types import Query |
12 | 13 | from dataherald.config import System |
13 | 14 | from dataherald.context_store import ContextStore |
14 | 15 | from dataherald.db import DB |
15 | 16 | from dataherald.db_scanner import Scanner |
16 | | -from dataherald.db_scanner.models.types import TableSchemaDetail |
| 17 | +from dataherald.db_scanner.models.types import TableDescriptionStatus, TableSchemaDetail |
17 | 18 | from dataherald.db_scanner.repository.base import DBScannerRepository |
18 | 19 | from dataherald.eval import Evaluator |
19 | 20 | from dataherald.repositories.base import NLQueryResponseRepository |
@@ -221,10 +222,36 @@ def list_table_descriptions( |
221 | 222 | self, db_connection_id: str | None = None, table_name: str | None = None |
222 | 223 | ) -> list[TableSchemaDetail]: |
223 | 224 | scanner_repository = DBScannerRepository(self.storage) |
224 | | - return scanner_repository.find_by( |
| 225 | + table_descriptions = scanner_repository.find_by( |
225 | 226 | {"db_connection_id": db_connection_id, "table_name": table_name} |
226 | 227 | ) |
227 | 228 |
|
| 229 | + if db_connection_id: |
| 230 | + db_connection_repository = DatabaseConnectionRepository(self.storage) |
| 231 | + db_connection = db_connection_repository.find_by_id(db_connection_id) |
| 232 | + database = SQLDatabase.get_sql_engine(db_connection) |
| 233 | + inspector = inspect(database.engine) |
| 234 | + meta = MetaData(bind=database.engine) |
| 235 | + MetaData.reflect(meta, views=True) |
| 236 | + all_tables = inspector.get_table_names() + inspector.get_view_names() |
| 237 | + |
| 238 | + for table_description in table_descriptions: |
| 239 | + if table_description.table_name not in all_tables: |
| 240 | + table_description.status = TableDescriptionStatus.DEPRECATED.value |
| 241 | + else: |
| 242 | + all_tables.remove(table_description.table_name) |
| 243 | + for table in all_tables: |
| 244 | + table_descriptions.append( |
| 245 | + TableSchemaDetail( |
| 246 | + table_name=table, |
| 247 | + status=TableDescriptionStatus.NOT_SYNCHRONIZED.value, |
| 248 | + db_connection_id=db_connection_id, |
| 249 | + columns=[], |
| 250 | + ) |
| 251 | + ) |
| 252 | + |
| 253 | + return table_descriptions |
| 254 | + |
228 | 255 | @override |
229 | 256 | def add_golden_records( |
230 | 257 | self, golden_records: List[GoldenRecordRequest] |
|
0 commit comments