Skip to content

Commit 0a46ee7

Browse files
feat(data-types): Support PyArrow ExtensionTypes in pyarrow2athena (#3351)
* Update _data_types.py * Update test_s3_parquet.py --------- Co-authored-by: Anton Kukushkin <kukushkin.anton@gmail.com>
1 parent a730a88 commit 0a46ee7

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

awswrangler/_data_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def pyarrow2athena( # noqa: PLR0911,PLR0912
5959
)
6060
if pa.types.is_map(dtype):
6161
return f"map<{pyarrow2athena(dtype=dtype.key_type, ignore_null=ignore_null)},{pyarrow2athena(dtype=dtype.item_type, ignore_null=ignore_null)}>"
62+
if isinstance(dtype, getattr(pa, "BaseExtensionType", pa.ExtensionType)):
63+
return pyarrow2athena(dtype=dtype.storage_type, ignore_null=ignore_null)
6264
if dtype == pa.null():
6365
if ignore_null:
6466
return ""

tests/unit/test_s3_parquet.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,42 @@ def test_read_parquet_metadata_large_dtype(path):
8484
assert columns_types.get("c1") == "string"
8585

8686

87+
def test_pyarrow2athena_uuid_extension_type():
88+
from awswrangler._data_types import pyarrow2athena
89+
90+
assert pyarrow2athena(pa.uuid()) == "binary"
91+
92+
93+
def test_pyarrow2athena_custom_extension_type():
94+
from awswrangler._data_types import pyarrow2athena
95+
96+
class _TestExtType(pa.ExtensionType):
97+
def __init__(self):
98+
super().__init__(pa.int64(), "test.custom_ext")
99+
100+
def __arrow_ext_serialize__(self):
101+
return b""
102+
103+
@classmethod
104+
def __arrow_ext_deserialize__(cls, storage_type, serialized):
105+
return cls()
106+
107+
assert pyarrow2athena(_TestExtType()) == "bigint"
108+
109+
110+
def test_athena_types_from_pyarrow_schema_with_extension():
111+
from awswrangler._data_types import athena_types_from_pyarrow_schema
112+
113+
schema = pa.schema(
114+
[
115+
pa.field("id", pa.uuid()),
116+
pa.field("value", pa.int64()),
117+
]
118+
)
119+
result = athena_types_from_pyarrow_schema(schema)
120+
assert result == {"id": "binary", "value": "bigint"}
121+
122+
87123
@pytest.mark.parametrize(
88124
"partition_cols",
89125
[

0 commit comments

Comments
 (0)