Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ docs/_build
.mcp.json
.serena
_version.py
.env
18 changes: 13 additions & 5 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,18 @@ def method_name(self, param1: str, param2: Optional[int] = None) -> List[str]:
To run tests locally, you need to set the following environment variables:

```bash
export AWS_DEFAULT_REGION=us-west-2
export AWS_ATHENA_S3_STAGING_DIR=s3://your-staging-bucket/path/
export AWS_ATHENA_WORKGROUP=primary
export AWS_ATHENA_SPARK_WORKGROUP=spark-primary
export AWS_DEFAULT_REGION=<your-region>
export AWS_ATHENA_S3_STAGING_DIR=s3://<your-bucket>/<path>/
export AWS_ATHENA_WORKGROUP=<your-workgroup>
export AWS_ATHENA_SPARK_WORKGROUP=<your-spark-workgroup>
```

**Using .env file (Recommended)**:
Create a `.env` file in the project root (already in `.gitignore`) with your AWS settings, then load it before running tests:

```bash
# Load .env and run tests
export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v
```

**CRITICAL: Pre-test Requirements**
Expand All @@ -147,7 +155,7 @@ make chk

# Only after lint passes, install dependencies and run tests
uv sync
uv run pytest tests/pyathena/test_file.py -v
export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v
```

#### Writing Tests
Expand Down
14 changes: 7 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
RUFF_VERSION := 0.9.1
TOX_VERSION := 4.23.2
RUFF_VERSION := 0.14.14
TOX_VERSION := 4.34.1

.PHONY: fmt
fmt:
# TODO: https://github.com/astral-sh/uv/issues/5903
uvx ruff check --select I --fix .
uvx ruff format .
uvx ruff@$(RUFF_VERSION) check --select I --fix .
uvx ruff@$(RUFF_VERSION) format .

.PHONY: chk
chk:
uvx ruff check .
uvx ruff format --check .
uvx ruff@$(RUFF_VERSION) check .
uvx ruff@$(RUFF_VERSION) format --check .
uv run mypy .

.PHONY: test
Expand All @@ -23,7 +23,7 @@ test-sqla:

.PHONY: tox
tox:
uvx tox run
uvx tox@$(TOX_VERSION) -c pyproject.toml run

.PHONY: docs
docs:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/20180915/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import sys
import time

from pyathena.pandas_cursor import PandasCursor
from pyathenajdbc import connect as jdbc_connect

from pyathena import connect
from pyathena.pandas_cursor import PandasCursor

LOGGER = logging.getLogger(__name__)
LOGGER.addHandler(logging.StreamHandler(sys.stdout))
Expand Down
11 changes: 8 additions & 3 deletions pyathena/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
AthenaTypeCompiler,
)
from pyathena.sqlalchemy.preparer import AthenaDMLIdentifierPreparer
from pyathena.sqlalchemy.types import TINYINT, AthenaDate, AthenaStruct, AthenaTimestamp
from pyathena.sqlalchemy.types import (
TINYINT,
AthenaDate,
AthenaStruct,
AthenaTimestamp,
get_double_type,
)
from pyathena.sqlalchemy.util import _HashableDict
from pyathena.util import strtobool

Expand All @@ -61,8 +67,7 @@
ischema_names: Dict[str, Type[Any]] = {
"boolean": types.BOOLEAN,
"float": types.FLOAT,
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
"double": types.FLOAT,
"double": get_double_type(),
"real": types.FLOAT,
"tinyint": TINYINT,
"smallint": types.SMALLINT,
Expand Down
18 changes: 17 additions & 1 deletion pyathena/sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

from sqlalchemy import types
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine

Expand All @@ -12,6 +13,21 @@
from sqlalchemy.sql.type_api import _LiteralProcessorType


def get_double_type() -> Type[Any]:
"""Get the appropriate type for DOUBLE based on SQLAlchemy version.

SQLAlchemy 2.0+ provides a native DOUBLE type, while earlier versions
only have FLOAT. This function returns the appropriate type based on
what's available.

Returns:
types.DOUBLE for SQLAlchemy 2.0+, types.FLOAT for earlier versions.
"""
if hasattr(types, "DOUBLE"):
return types.DOUBLE
return types.FLOAT


class AthenaTimestamp(TypeEngine[datetime]):
"""SQLAlchemy type for Athena TIMESTAMP values.

Expand Down
45 changes: 16 additions & 29 deletions tests/pyathena/sqlalchemy/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
from sqlalchemy.sql.schema import Column, MetaData, Table
from sqlalchemy.sql.selectable import TextualSelect

from pyathena.sqlalchemy.types import TINYINT, AthenaArray, AthenaMap, AthenaStruct, Tinyint
from pyathena.sqlalchemy.types import (
TINYINT,
AthenaArray,
AthenaMap,
AthenaStruct,
Tinyint,
get_double_type,
)
from tests.pyathena.conftest import ENV


Expand Down Expand Up @@ -507,8 +514,7 @@ def test_reflect_select(self, engine):
assert isinstance(one_row_complex.c.col_int.type, types.INTEGER)
assert isinstance(one_row_complex.c.col_bigint.type, types.BIGINT)
assert isinstance(one_row_complex.c.col_float.type, types.FLOAT)
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
assert isinstance(one_row_complex.c.col_double.type, types.FLOAT)
assert isinstance(one_row_complex.c.col_double.type, get_double_type())
assert isinstance(one_row_complex.c.col_string.type, types.String)
assert isinstance(one_row_complex.c.col_varchar.type, types.VARCHAR)
assert one_row_complex.c.col_varchar.type.length == 10
Expand Down Expand Up @@ -558,8 +564,7 @@ def test_get_column_type(self, engine):
assert isinstance(dialect._get_column_type("int"), types.INTEGER)
assert isinstance(dialect._get_column_type("bigint"), types.BIGINT)
assert isinstance(dialect._get_column_type("float"), types.FLOAT)
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
assert isinstance(dialect._get_column_type("double"), types.FLOAT)
assert isinstance(dialect._get_column_type("double"), get_double_type())
assert isinstance(dialect._get_column_type("real"), types.FLOAT)
assert isinstance(dialect._get_column_type("string"), types.String)
assert isinstance(dialect._get_column_type("varchar"), types.VARCHAR)
Expand Down Expand Up @@ -2145,30 +2150,12 @@ def test_numeric_type_variants(self, engine):
assert type(actual.c.col_integer2.type) in [types.INT, types.INTEGER, types.Integer]
assert type(actual.c.col_bigint.type) in [types.BIGINT, types.BigInteger]
assert type(actual.c.col_biginteger.type) in [types.BIGINT, types.BigInteger]
assert type(actual.c.col_double1.type) in [types.FLOAT, types.Float]
assert type(actual.c.col_double2.type) in [types.FLOAT, types.Float]
assert type(actual.c.col_double_precision.type) in [types.FLOAT, types.Float]
# TODO: types.DOUBLE is not defined in SQLAlchemy 1.4.
# assert type(actual.c.col_double_precision.type) in [
# types.DOUBLE,
# types.Double,
# types.DOUBLE_PRECISION,
# ]
# assert type(actual.c.col_double1.type) in [
# types.DOUBLE,
# types.Double,
# types.DOUBLE_PRECISION,
# ]
# assert type(actual.c.col_double2.type) in [
# types.DOUBLE,
# types.Double,
# types.DOUBLE_PRECISION,
# ]
# assert type(actual.c.col_double_precision.type) in [
# types.DOUBLE,
# types.Double,
# types.DOUBLE_PRECISION,
# ]
expected_double_types = [types.FLOAT, types.Float]
if hasattr(types, "DOUBLE"):
expected_double_types.extend([types.DOUBLE, types.Double, types.DOUBLE_PRECISION])
assert type(actual.c.col_double1.type) in expected_double_types
assert type(actual.c.col_double2.type) in expected_double_types
assert type(actual.c.col_double_precision.type) in expected_double_types
assert type(actual.c.col_float1.type) in [types.FLOAT, types.Float]
assert type(actual.c.col_float2.type) in [types.FLOAT, types.Float]
assert type(actual.c.col_decimal.type) in [types.DECIMAL]
Expand Down
23 changes: 21 additions & 2 deletions tests/pyathena/sqlalchemy/test_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# -*- coding: utf-8 -*-
import pytest
from sqlalchemy import Integer, String
from sqlalchemy import Integer, String, types
from sqlalchemy.sql import sqltypes

from pyathena.sqlalchemy.types import ARRAY, MAP, STRUCT, AthenaArray, AthenaMap, AthenaStruct
from pyathena.sqlalchemy.types import (
ARRAY,
MAP,
STRUCT,
AthenaArray,
AthenaMap,
AthenaStruct,
get_double_type,
)


class TestAthenaStruct:
Expand Down Expand Up @@ -145,3 +153,14 @@ def test_array_with_map_type(self):
assert isinstance(array_type.item_type, AthenaMap)
assert isinstance(array_type.item_type.key_type, sqltypes.String)
assert isinstance(array_type.item_type.value_type, sqltypes.Integer)


def test_get_double_type():
from pyathena.sqlalchemy.base import ischema_names

result = get_double_type()
if hasattr(types, "DOUBLE"):
assert result is types.DOUBLE
else:
assert result is types.FLOAT
assert ischema_names["double"] is result