|
20 | 20 | # software solely pursuant to the terms of the relevant commercial agreement.
|
21 | 21 |
|
22 | 22 | from datetime import datetime
|
23 |
| -from unittest import TestCase |
| 23 | +from unittest import TestCase, skipIf |
24 | 24 | from unittest.mock import MagicMock, patch
|
25 | 25 |
|
26 | 26 | import sqlalchemy as sa
|
27 | 27 |
|
28 | 28 | from crate.client.cursor import Cursor
|
| 29 | +from crate.client.sqlalchemy import SA_VERSION |
| 30 | +from crate.client.sqlalchemy.sa_version import SA_1_4, SA_2_0 |
29 | 31 | from crate.client.sqlalchemy.types import Object
|
30 | 32 | from sqlalchemy import inspect
|
31 | 33 | from sqlalchemy.orm import Session
|
32 | 34 | try:
|
33 | 35 | from sqlalchemy.orm import declarative_base
|
34 | 36 | except ImportError:
|
35 | 37 | from sqlalchemy.ext.declarative import declarative_base
|
36 |
| -from sqlalchemy.testing import eq_, in_ |
| 38 | +from sqlalchemy.testing import eq_, in_, is_true |
37 | 39 |
|
38 | 40 | FakeCursor = MagicMock(name='FakeCursor', spec=Cursor)
|
39 | 41 |
|
@@ -70,6 +72,13 @@ class Character(self.base):
|
70 | 72 |
|
71 | 73 | self.session = Session(bind=self.engine)
|
72 | 74 |
|
| 75 | + def init_mock(self, return_value=None): |
| 76 | + self.fake_cursor.rowcount = 1 |
| 77 | + self.fake_cursor.description = ( |
| 78 | + ('foo', None, None, None, None, None, None), |
| 79 | + ) |
| 80 | + self.fake_cursor.fetchall = MagicMock(return_value=return_value) |
| 81 | + |
73 | 82 | def test_primary_keys_2_3_0(self):
|
74 | 83 | insp = inspect(self.session.bind)
|
75 | 84 | self.engine.dialect.server_version_info = (2, 3, 0)
|
@@ -126,3 +135,22 @@ def test_get_view_names(self):
|
126 | 135 | ['v1', 'v2'])
|
127 | 136 | eq_(self.executed_statement, "SELECT table_name FROM information_schema.views "
|
128 | 137 | "ORDER BY table_name ASC, table_schema ASC")
|
| 138 | + |
| 139 | + @skipIf(SA_VERSION < SA_1_4, "Inspector.has_table only available on SQLAlchemy>=1.4") |
| 140 | + def test_has_table(self): |
| 141 | + self.init_mock(return_value=[["foo"], ["bar"]]) |
| 142 | + insp = inspect(self.session.bind) |
| 143 | + is_true(insp.has_table("bar")) |
| 144 | + eq_(self.executed_statement, |
| 145 | + "SELECT table_name FROM information_schema.tables " |
| 146 | + "WHERE table_schema = ? AND table_type = 'BASE TABLE' " |
| 147 | + "ORDER BY table_name ASC, table_schema ASC") |
| 148 | + |
| 149 | + @skipIf(SA_VERSION < SA_2_0, "Inspector.has_schema only available on SQLAlchemy>=2.0") |
| 150 | + def test_has_schema(self): |
| 151 | + self.init_mock( |
| 152 | + return_value=[["blob"], ["doc"], ["information_schema"], ["pg_catalog"], ["sys"]]) |
| 153 | + insp = inspect(self.session.bind) |
| 154 | + is_true(insp.has_schema("doc")) |
| 155 | + eq_(self.executed_statement, |
| 156 | + "select schema_name from information_schema.schemata order by schema_name asc") |
0 commit comments