Skip to content

Commit 44565b1

Browse files
authored
Merge branch 'main' into fix/reduce-lambda-layer-size-icu
2 parents 17b7cf6 + 2234f5e commit 44565b1

4 files changed

Lines changed: 267 additions & 198 deletions

File tree

awswrangler/_data_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def pyarrow2athena( # noqa: PLR0911,PLR0912
5959
)
6060
if pa.types.is_map(dtype):
6161
return f"map<{pyarrow2athena(dtype=dtype.key_type, ignore_null=ignore_null)},{pyarrow2athena(dtype=dtype.item_type, ignore_null=ignore_null)}>"
62+
if isinstance(dtype, getattr(pa, "BaseExtensionType", pa.ExtensionType)):
63+
return pyarrow2athena(dtype=dtype.storage_type, ignore_null=ignore_null)
6264
if dtype == pa.null():
6365
if ignore_null:
6466
return ""

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ redshift = ["redshift-connector>=2.0.0,<3"]
3333
mysql = ["pymysql>=1.0.0,<2"]
3434
postgres = ["pg8000>=1.29.0,<2"]
3535
sqlserver = ["pyodbc>=4,<6"]
36-
oracle = ["oracledb>=1,<4"]
36+
oracle = ["oracledb>=1,<5"]
3737
gremlin = [
3838
"gremlinpython>=3.7.1,<4",
3939
"requests>=2.33.0,<3",
@@ -52,7 +52,7 @@ opensearch = [
5252
]
5353
openpyxl = ["openpyxl>=3.0.0,<4"]
5454
progressbar = ["progressbar2>=4.0.0,<5"]
55-
deltalake = ["deltalake>=0.18.0,<1.6.0"]
55+
deltalake = ["deltalake>=0.18.0,<1.7.0"]
5656
pyiceberg = ["pyiceberg[pyarrow]>=0.7.0,<1"]
5757
geopandas = ["geopandas>=1.0.0,<2", "pyproj>=3.6,<3.7.2"]
5858
modin = ["modin>=0.31,<0.38"]

tests/unit/test_s3_parquet.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,42 @@ def test_read_parquet_metadata_large_dtype(path):
8484
assert columns_types.get("c1") == "string"
8585

8686

87+
def test_pyarrow2athena_uuid_extension_type():
88+
from awswrangler._data_types import pyarrow2athena
89+
90+
assert pyarrow2athena(pa.uuid()) == "binary"
91+
92+
93+
def test_pyarrow2athena_custom_extension_type():
94+
from awswrangler._data_types import pyarrow2athena
95+
96+
class _TestExtType(pa.ExtensionType):
97+
def __init__(self):
98+
super().__init__(pa.int64(), "test.custom_ext")
99+
100+
def __arrow_ext_serialize__(self):
101+
return b""
102+
103+
@classmethod
104+
def __arrow_ext_deserialize__(cls, storage_type, serialized):
105+
return cls()
106+
107+
assert pyarrow2athena(_TestExtType()) == "bigint"
108+
109+
110+
def test_athena_types_from_pyarrow_schema_with_extension():
111+
from awswrangler._data_types import athena_types_from_pyarrow_schema
112+
113+
schema = pa.schema(
114+
[
115+
pa.field("id", pa.uuid()),
116+
pa.field("value", pa.int64()),
117+
]
118+
)
119+
result = athena_types_from_pyarrow_schema(schema)
120+
assert result == {"id": "binary", "value": "bigint"}
121+
122+
87123
@pytest.mark.parametrize(
88124
"partition_cols",
89125
[

0 commit comments

Comments
 (0)