Skip to content

Commit e2ecf3b

Browse files
Return DOUBLE type for double columns in SQLAlchemy 2.0+
Add get_double_type() helper function that returns types.DOUBLE for SQLAlchemy 2.0+ and falls back to types.FLOAT for earlier versions. This fixes the issue where inspector.get_columns() returns FLOAT() instead of DOUBLE() for columns with 'double' data type. Changes: - Add get_double_type() helper in pyathena/sqlalchemy/types.py - Update ischema_names in pyathena/sqlalchemy/base.py - Update tests for version-compatible assertions - Add .env to .gitignore and document usage in CLAUDE.md - Fix import sorting in benchmark file Closes #647 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 41ff63e commit e2ecf3b

File tree

6 files changed

+76
-40
lines changed

6 files changed

+76
-40
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ docs/_build
2424
.mcp.json
2525
.serena
2626
_version.py
27+
.env

CLAUDE.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,18 @@ def method_name(self, param1: str, param2: Optional[int] = None) -> List[str]:
134134
To run tests locally, you need to set the following environment variables:
135135

136136
```bash
137-
export AWS_DEFAULT_REGION=us-west-2
138-
export AWS_ATHENA_S3_STAGING_DIR=s3://your-staging-bucket/path/
139-
export AWS_ATHENA_WORKGROUP=primary
140-
export AWS_ATHENA_SPARK_WORKGROUP=spark-primary
137+
export AWS_DEFAULT_REGION=<your-region>
138+
export AWS_ATHENA_S3_STAGING_DIR=s3://<your-bucket>/<path>/
139+
export AWS_ATHENA_WORKGROUP=<your-workgroup>
140+
export AWS_ATHENA_SPARK_WORKGROUP=<your-spark-workgroup>
141+
```
142+
143+
**Using .env file (Recommended)**:
144+
Create a `.env` file in the project root (already in `.gitignore`) with your AWS settings, then load it before running tests:
145+
146+
```bash
147+
# Load .env and run tests
148+
export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v
141149
```
142150

143151
**CRITICAL: Pre-test Requirements**
@@ -147,7 +155,7 @@ make chk
147155

148156
# Only after lint passes, install dependencies and run tests
149157
uv sync
150-
uv run pytest tests/pyathena/test_file.py -v
158+
export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v
151159
```
152160

153161
#### Writing Tests

pyathena/sqlalchemy/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@
3737
AthenaTypeCompiler,
3838
)
3939
from pyathena.sqlalchemy.preparer import AthenaDMLIdentifierPreparer
40-
from pyathena.sqlalchemy.types import TINYINT, AthenaDate, AthenaStruct, AthenaTimestamp
40+
from pyathena.sqlalchemy.types import (
41+
TINYINT,
42+
AthenaDate,
43+
AthenaStruct,
44+
AthenaTimestamp,
45+
get_double_type,
46+
)
4147
from pyathena.sqlalchemy.util import _HashableDict
4248
from pyathena.util import strtobool
4349

@@ -61,8 +67,7 @@
6167
ischema_names: Dict[str, Type[Any]] = {
6268
"boolean": types.BOOLEAN,
6369
"float": types.FLOAT,
64-
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
65-
"double": types.FLOAT,
70+
"double": get_double_type(),
6671
"real": types.FLOAT,
6772
"tinyint": TINYINT,
6873
"smallint": types.SMALLINT,

pyathena/sqlalchemy/types.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from __future__ import annotations
33

44
from datetime import date, datetime
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
66

7+
from sqlalchemy import types
78
from sqlalchemy.sql import sqltypes
89
from sqlalchemy.sql.type_api import TypeEngine
910

@@ -12,6 +13,21 @@
1213
from sqlalchemy.sql.type_api import _LiteralProcessorType
1314

1415

16+
def get_double_type() -> Type[Any]:
17+
"""Get the appropriate type for DOUBLE based on SQLAlchemy version.
18+
19+
SQLAlchemy 2.0+ provides a native DOUBLE type, while earlier versions
20+
only have FLOAT. This function returns the appropriate type based on
21+
what's available.
22+
23+
Returns:
24+
types.DOUBLE for SQLAlchemy 2.0+, types.FLOAT for earlier versions.
25+
"""
26+
if hasattr(types, "DOUBLE"):
27+
return types.DOUBLE
28+
return types.FLOAT
29+
30+
1531
class AthenaTimestamp(TypeEngine[datetime]):
1632
"""SQLAlchemy type for Athena TIMESTAMP values.
1733

tests/pyathena/sqlalchemy/test_base.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
from sqlalchemy.sql.schema import Column, MetaData, Table
1919
from sqlalchemy.sql.selectable import TextualSelect
2020

21-
from pyathena.sqlalchemy.types import TINYINT, AthenaArray, AthenaMap, AthenaStruct, Tinyint
21+
from pyathena.sqlalchemy.types import (
22+
TINYINT,
23+
AthenaArray,
24+
AthenaMap,
25+
AthenaStruct,
26+
Tinyint,
27+
get_double_type,
28+
)
2229
from tests.pyathena.conftest import ENV
2330

2431

@@ -507,8 +514,7 @@ def test_reflect_select(self, engine):
507514
assert isinstance(one_row_complex.c.col_int.type, types.INTEGER)
508515
assert isinstance(one_row_complex.c.col_bigint.type, types.BIGINT)
509516
assert isinstance(one_row_complex.c.col_float.type, types.FLOAT)
510-
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
511-
assert isinstance(one_row_complex.c.col_double.type, types.FLOAT)
517+
assert isinstance(one_row_complex.c.col_double.type, get_double_type())
512518
assert isinstance(one_row_complex.c.col_string.type, types.String)
513519
assert isinstance(one_row_complex.c.col_varchar.type, types.VARCHAR)
514520
assert one_row_complex.c.col_varchar.type.length == 10
@@ -558,8 +564,7 @@ def test_get_column_type(self, engine):
558564
assert isinstance(dialect._get_column_type("int"), types.INTEGER)
559565
assert isinstance(dialect._get_column_type("bigint"), types.BIGINT)
560566
assert isinstance(dialect._get_column_type("float"), types.FLOAT)
561-
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
562-
assert isinstance(dialect._get_column_type("double"), types.FLOAT)
567+
assert isinstance(dialect._get_column_type("double"), get_double_type())
563568
assert isinstance(dialect._get_column_type("real"), types.FLOAT)
564569
assert isinstance(dialect._get_column_type("string"), types.String)
565570
assert isinstance(dialect._get_column_type("varchar"), types.VARCHAR)
@@ -2145,30 +2150,12 @@ def test_numeric_type_variants(self, engine):
21452150
assert type(actual.c.col_integer2.type) in [types.INT, types.INTEGER, types.Integer]
21462151
assert type(actual.c.col_bigint.type) in [types.BIGINT, types.BigInteger]
21472152
assert type(actual.c.col_biginteger.type) in [types.BIGINT, types.BigInteger]
2148-
assert type(actual.c.col_double1.type) in [types.FLOAT, types.Float]
2149-
assert type(actual.c.col_double2.type) in [types.FLOAT, types.Float]
2150-
assert type(actual.c.col_double_precision.type) in [types.FLOAT, types.Float]
2151-
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
2152-
# assert type(actual.c.col_double_precision.type) in [
2153-
# types.DOUBLE,
2154-
# types.Double,
2155-
# types.DOUBLE_PRECISION,
2156-
# ]
2157-
# assert type(actual.c.col_double1.type) in [
2158-
# types.DOUBLE,
2159-
# types.Double,
2160-
# types.DOUBLE_PRECISION,
2161-
# ]
2162-
# assert type(actual.c.col_double2.type) in [
2163-
# types.DOUBLE,
2164-
# types.Double,
2165-
# types.DOUBLE_PRECISION,
2166-
# ]
2167-
# assert type(actual.c.col_double_precision.type) in [
2168-
# types.DOUBLE,
2169-
# types.Double,
2170-
# types.DOUBLE_PRECISION,
2171-
# ]
2153+
expected_double_types = [types.FLOAT, types.Float]
2154+
if hasattr(types, "DOUBLE"):
2155+
expected_double_types.extend([types.DOUBLE, types.Double, types.DOUBLE_PRECISION])
2156+
assert type(actual.c.col_double1.type) in expected_double_types
2157+
assert type(actual.c.col_double2.type) in expected_double_types
2158+
assert type(actual.c.col_double_precision.type) in expected_double_types
21722159
assert type(actual.c.col_float1.type) in [types.FLOAT, types.Float]
21732160
assert type(actual.c.col_float2.type) in [types.FLOAT, types.Float]
21742161
assert type(actual.c.col_decimal.type) in [types.DECIMAL]

tests/pyathena/sqlalchemy/test_types.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
# -*- coding: utf-8 -*-
22
import pytest
3-
from sqlalchemy import Integer, String
3+
from sqlalchemy import Integer, String, types
44
from sqlalchemy.sql import sqltypes
55

6-
from pyathena.sqlalchemy.types import ARRAY, MAP, STRUCT, AthenaArray, AthenaMap, AthenaStruct
6+
from pyathena.sqlalchemy.types import (
7+
ARRAY,
8+
MAP,
9+
STRUCT,
10+
AthenaArray,
11+
AthenaMap,
12+
AthenaStruct,
13+
get_double_type,
14+
)
715

816

917
class TestAthenaStruct:
@@ -145,3 +153,14 @@ def test_array_with_map_type(self):
145153
assert isinstance(array_type.item_type, AthenaMap)
146154
assert isinstance(array_type.item_type.key_type, sqltypes.String)
147155
assert isinstance(array_type.item_type.value_type, sqltypes.Integer)
156+
157+
158+
def test_get_double_type():
159+
from pyathena.sqlalchemy.base import ischema_names
160+
161+
result = get_double_type()
162+
if hasattr(types, "DOUBLE"):
163+
assert result is types.DOUBLE
164+
else:
165+
assert result is types.FLOAT
166+
assert ischema_names["double"] is result

0 commit comments

Comments
 (0)