Skip to content

Commit 909812c

Browse files
authored
Merge branch 'main' into release-3.0.0
2 parents 5a1f275 + 35a09a0 commit 909812c

File tree

5 files changed

+175
-1
lines changed

5 files changed

+175
-1
lines changed

.github/workflows/git-hygiene.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
repo-token: ${{ secrets.GITHUB_TOKEN }}
1919
days-before-stale: 60
2020
days-before-close: 7
21-
exempt-issue-labels: 'needs-triage,help wanted'
21+
exempt-issue-labels: 'needs-triage,help wanted,backlog'
2222
exempt-pr-labels: 'needs-triage'
2323
stale-issue-label: 'closing-soon'
2424
operations-per-run: 100

awswrangler/athena/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
create_athena_bucket,
66
create_ctas_table,
77
describe_table,
8+
generate_create_query,
89
get_named_query_statement,
910
get_query_columns_types,
1011
get_query_execution,
@@ -26,6 +27,7 @@
2627
"get_query_results",
2728
"get_named_query_statement",
2829
"get_work_group",
30+
"generate_create_query",
2931
"repair_table",
3032
"create_ctas_table",
3133
"show_create_table",

awswrangler/athena/_utils.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Utilities Module for Amazon Athena."""
2+
import base64
23
import csv
4+
import json
35
import logging
46
import pprint
57
import time
@@ -14,6 +16,7 @@
1416

1517
from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts
1618
from awswrangler._config import apply_configs
19+
from awswrangler.catalog._utils import _catalog_id, _transaction_id
1720

1821
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results, _LocalMetadataCacheManager
1922

@@ -925,6 +928,102 @@ def show_create_table(
925928
return cast(str, raw_result.createtab_stmt.str.strip().str.cat(sep=" "))
926929

927930

931+
@apply_configs
932+
def generate_create_query(
933+
table: str,
934+
database: Optional[str] = None,
935+
transaction_id: Optional[str] = None,
936+
query_as_of_time: Optional[str] = None,
937+
catalog_id: Optional[str] = None,
938+
boto3_session: Optional[boto3.Session] = None,
939+
) -> str:
940+
"""Generate the query that created a table(EXTERNAL_TABLE) or a view(VIRTUAL_TABLE).
941+
942+
Analyzes an existing table named table_name to generate the query that created it.
943+
944+
Parameters
945+
----------
946+
table : str
947+
Table name.
948+
database : str
949+
Database name.
950+
transaction_id: str, optional
951+
The ID of the transaction.
952+
query_as_of_time: str, optional
953+
The time as of when to read the table contents. Must be a valid Unix epoch timestamp.
954+
Cannot be specified alongside transaction_id.
955+
catalog_id : str, optional
956+
The ID of the Data Catalog from which to retrieve Databases.
957+
If none is provided, the AWS account ID is used by default.
958+
boto3_session : boto3.Session(), optional
959+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
960+
961+
Returns
962+
-------
963+
str
964+
The query that created the table or view.
965+
966+
Examples
967+
--------
968+
>>> import awswrangler as wr
969+
>>> view_create_query: str = wr.athena.generate_create_query(table='my_view', database='default')
970+
971+
"""
972+
973+
def parse_columns(columns_description: List[Dict[str, str]]) -> str:
974+
columns_str: List[str] = []
975+
for column in columns_description:
976+
column_str = f" `{column['Name']}` {column['Type']}"
977+
if "Comment" in column:
978+
column_str += f" COMMENT '{column['Comment']}'"
979+
columns_str.append(column_str)
980+
return ", \n".join(columns_str)
981+
982+
def parse_properties(parameters: Dict[str, str]) -> str:
983+
properties_str: List[str] = []
984+
for key, value in parameters.items():
985+
if key == "EXTERNAL":
986+
continue
987+
property_key_value = f" '{key}'='{value}'"
988+
properties_str.append(property_key_value)
989+
return ", \n".join(properties_str)
990+
991+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
992+
table_detail: Dict[str, Any] = client_glue.get_table(
993+
**_catalog_id(
994+
catalog_id=catalog_id,
995+
**_transaction_id(
996+
transaction_id=transaction_id, query_as_of_time=query_as_of_time, DatabaseName=database, Name=table
997+
),
998+
)
999+
)["Table"]
1000+
if table_detail["TableType"] == "VIRTUAL_VIEW":
1001+
glue_base64_query: str = table_detail["ViewOriginalText"].replace("/* Presto View: ", "").replace(" */", "")
1002+
glue_query: str = json.loads(base64.b64decode(glue_base64_query))["originalSql"]
1003+
return f"""CREATE OR REPLACE VIEW "{table}" AS \n{glue_query}"""
1004+
if table_detail["TableType"] == "EXTERNAL_TABLE":
1005+
columns: str = parse_columns(columns_description=table_detail["StorageDescriptor"]["Columns"])
1006+
query_parts: List[str] = [f"""CREATE EXTERNAL TABLE `{table}`(\n{columns})"""]
1007+
partitioned_columns: str = parse_columns(columns_description=table_detail["PartitionKeys"])
1008+
if partitioned_columns:
1009+
query_parts.append(f"""PARTITIONED BY ( \n{partitioned_columns})""")
1010+
tblproperties: str = parse_properties(parameters=table_detail["Parameters"])
1011+
1012+
query_parts += [
1013+
"""ROW FORMAT SERDE """,
1014+
f""" '{table_detail['StorageDescriptor']['SerdeInfo']['SerializationLibrary']}' """,
1015+
"""STORED AS INPUTFORMAT """,
1016+
f""" '{table_detail['StorageDescriptor']['InputFormat']}' """,
1017+
"""OUTPUTFORMAT """,
1018+
f""" '{table_detail['StorageDescriptor']['OutputFormat']}'""",
1019+
"""LOCATION""",
1020+
f""" '{table_detail['StorageDescriptor']['Location']}'""",
1021+
f"""TBLPROPERTIES (\n{tblproperties})""",
1022+
]
1023+
return "\n".join(query_parts)
1024+
raise NotImplementedError()
1025+
1026+
9281027
def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
9291028
"""Return information about the workgroup with the specified name.
9301029

docs/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ Amazon Athena
116116

117117
create_athena_bucket
118118
create_ctas_table
119+
generate_create_query
119120
get_query_columns_types
120121
get_query_execution
122+
get_query_results
121123
get_named_query_statement
122124
get_work_group
123125
read_sql_query

tests/test_athena.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,3 +1150,74 @@ def test_get_query_results(path, glue_table, glue_database):
11501150
query_id_regular = df_regular.query_metadata["QueryExecutionId"]
11511151
df_get_query_results_df_regular = wr.athena.get_query_results(query_execution_id=query_id_regular)
11521152
pd.testing.assert_frame_equal(df_get_query_results_df_regular, df_regular)
1153+
1154+
1155+
def test_athena_generate_create_query(path, glue_database, glue_table):
1156+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
1157+
wr.catalog.create_parquet_table(database=glue_database, table=glue_table, path=path, columns_types={"c0": "int"})
1158+
query: str = wr.athena.generate_create_query(database=glue_database, table=glue_table)
1159+
create_query_no_partition: str = "\n".join(
1160+
[
1161+
f"CREATE EXTERNAL TABLE `{glue_table}`(",
1162+
" `c0` int)",
1163+
"ROW FORMAT SERDE ",
1164+
" 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' ",
1165+
"STORED AS INPUTFORMAT ",
1166+
" 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' ",
1167+
"OUTPUTFORMAT ",
1168+
" 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'",
1169+
"LOCATION",
1170+
f" '{path}'",
1171+
"TBLPROPERTIES (",
1172+
" 'classification'='parquet', ",
1173+
" 'compressionType'='none', ",
1174+
" 'projection.enabled'='false', ",
1175+
" 'typeOfData'='file')",
1176+
]
1177+
)
1178+
assert query == create_query_no_partition
1179+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
1180+
wr.catalog.create_parquet_table(
1181+
database=glue_database,
1182+
table=glue_table,
1183+
path=path,
1184+
columns_types={"c0": "int"},
1185+
partitions_types={"col2": "date"},
1186+
)
1187+
query: str = wr.athena.generate_create_query(database=glue_database, table=glue_table)
1188+
create_query_partition: str = "\n".join(
1189+
[
1190+
f"CREATE EXTERNAL TABLE `{glue_table}`(",
1191+
" `c0` int)",
1192+
"PARTITIONED BY ( ",
1193+
" `col2` date)",
1194+
"ROW FORMAT SERDE ",
1195+
" 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' ",
1196+
"STORED AS INPUTFORMAT ",
1197+
" 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' ",
1198+
"OUTPUTFORMAT ",
1199+
" 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'",
1200+
"LOCATION",
1201+
f" '{path}'",
1202+
"TBLPROPERTIES (",
1203+
" 'classification'='parquet', ",
1204+
" 'compressionType'='none', ",
1205+
" 'projection.enabled'='false', ",
1206+
" 'typeOfData'='file')",
1207+
]
1208+
)
1209+
1210+
assert query == create_query_partition
1211+
1212+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
1213+
query: str = "\n".join(
1214+
[
1215+
f"""CREATE OR REPLACE VIEW "{glue_table}" AS """,
1216+
(
1217+
"SELECT CAST(ROW (1, ROW (2, ROW (3, '4'))) AS "
1218+
"row(field0 bigint,field1 row(field2 bigint,field3 row(field4 bigint,field5 varchar)))) col0\n\n"
1219+
),
1220+
]
1221+
)
1222+
wr.athena.start_query_execution(sql=query, database=glue_database, wait=True)
1223+
assert query == wr.athena.generate_create_query(database=glue_database, table=glue_table)

0 commit comments

Comments
 (0)