|
18 | 18 | from sqlalchemy.sql.schema import Column, MetaData, Table |
19 | 19 | from sqlalchemy.sql.selectable import TextualSelect |
20 | 20 |
|
21 | | -from pyathena.sqlalchemy.types import TINYINT, AthenaArray, AthenaMap, AthenaStruct, Tinyint |
| 21 | +from pyathena.sqlalchemy.types import ( |
| 22 | + TINYINT, |
| 23 | + AthenaArray, |
| 24 | + AthenaMap, |
| 25 | + AthenaStruct, |
| 26 | + Tinyint, |
| 27 | + get_double_type, |
| 28 | +) |
22 | 29 | from tests.pyathena.conftest import ENV |
23 | 30 |
|
24 | 31 |
|
@@ -507,8 +514,7 @@ def test_reflect_select(self, engine): |
507 | 514 | assert isinstance(one_row_complex.c.col_int.type, types.INTEGER) |
508 | 515 | assert isinstance(one_row_complex.c.col_bigint.type, types.BIGINT) |
509 | 516 | assert isinstance(one_row_complex.c.col_float.type, types.FLOAT) |
510 | | - # TODO: types.DOUBLE is not defined in SQLAlchemy 1.4. |
511 | | - assert isinstance(one_row_complex.c.col_double.type, types.FLOAT) |
| 517 | + assert isinstance(one_row_complex.c.col_double.type, get_double_type()) |
512 | 518 | assert isinstance(one_row_complex.c.col_string.type, types.String) |
513 | 519 | assert isinstance(one_row_complex.c.col_varchar.type, types.VARCHAR) |
514 | 520 | assert one_row_complex.c.col_varchar.type.length == 10 |
@@ -558,8 +564,7 @@ def test_get_column_type(self, engine): |
558 | 564 | assert isinstance(dialect._get_column_type("int"), types.INTEGER) |
559 | 565 | assert isinstance(dialect._get_column_type("bigint"), types.BIGINT) |
560 | 566 | assert isinstance(dialect._get_column_type("float"), types.FLOAT) |
561 | | - # TODO: types.DOUBLE is not defined in SQLAlchemy 1.4. |
562 | | - assert isinstance(dialect._get_column_type("double"), types.FLOAT) |
| 567 | + assert isinstance(dialect._get_column_type("double"), get_double_type()) |
563 | 568 | assert isinstance(dialect._get_column_type("real"), types.FLOAT) |
564 | 569 | assert isinstance(dialect._get_column_type("string"), types.String) |
565 | 570 | assert isinstance(dialect._get_column_type("varchar"), types.VARCHAR) |
@@ -2145,30 +2150,12 @@ def test_numeric_type_variants(self, engine): |
2145 | 2150 | assert type(actual.c.col_integer2.type) in [types.INT, types.INTEGER, types.Integer] |
2146 | 2151 | assert type(actual.c.col_bigint.type) in [types.BIGINT, types.BigInteger] |
2147 | 2152 | assert type(actual.c.col_biginteger.type) in [types.BIGINT, types.BigInteger] |
2148 | | - assert type(actual.c.col_double1.type) in [types.FLOAT, types.Float] |
2149 | | - assert type(actual.c.col_double2.type) in [types.FLOAT, types.Float] |
2150 | | - assert type(actual.c.col_double_precision.type) in [types.FLOAT, types.Float] |
2151 | | - # TODO: types.DOUBLE is not defined in SQLAlchemy 1.4. |
2152 | | - # assert type(actual.c.col_double_precision.type) in [ |
2153 | | - # types.DOUBLE, |
2154 | | - # types.Double, |
2155 | | - # types.DOUBLE_PRECISION, |
2156 | | - # ] |
2157 | | - # assert type(actual.c.col_double1.type) in [ |
2158 | | - # types.DOUBLE, |
2159 | | - # types.Double, |
2160 | | - # types.DOUBLE_PRECISION, |
2161 | | - # ] |
2162 | | - # assert type(actual.c.col_double2.type) in [ |
2163 | | - # types.DOUBLE, |
2164 | | - # types.Double, |
2165 | | - # types.DOUBLE_PRECISION, |
2166 | | - # ] |
2167 | | - # assert type(actual.c.col_double_precision.type) in [ |
2168 | | - # types.DOUBLE, |
2169 | | - # types.Double, |
2170 | | - # types.DOUBLE_PRECISION, |
2171 | | - # ] |
| 2153 | + expected_double_types = [types.FLOAT, types.Float] |
| 2154 | + if hasattr(types, "DOUBLE"): |
| 2155 | + expected_double_types.extend([types.DOUBLE, types.Double, types.DOUBLE_PRECISION]) |
| 2156 | + assert type(actual.c.col_double1.type) in expected_double_types |
| 2157 | + assert type(actual.c.col_double2.type) in expected_double_types |
| 2158 | + assert type(actual.c.col_double_precision.type) in expected_double_types |
2172 | 2159 | assert type(actual.c.col_float1.type) in [types.FLOAT, types.Float] |
2173 | 2160 | assert type(actual.c.col_float2.type) in [types.FLOAT, types.Float] |
2174 | 2161 | assert type(actual.c.col_decimal.type) in [types.DECIMAL] |
|
0 commit comments