|
1 | 1 | """Utilities Module for Amazon Athena.""" |
| 2 | +import base64 |
2 | 3 | import csv |
| 4 | +import json |
3 | 5 | import logging |
4 | 6 | import pprint |
5 | 7 | import time |
|
14 | 16 |
|
15 | 17 | from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts |
16 | 18 | from awswrangler._config import apply_configs |
| 19 | +from awswrangler.catalog._utils import _catalog_id, _transaction_id |
17 | 20 |
|
18 | 21 | from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results, _LocalMetadataCacheManager |
19 | 22 |
|
@@ -925,6 +928,102 @@ def show_create_table( |
925 | 928 | return cast(str, raw_result.createtab_stmt.str.strip().str.cat(sep=" ")) |
926 | 929 |
|
927 | 930 |
|
| 931 | +@apply_configs |
| 932 | +def generate_create_query( |
| 933 | + table: str, |
| 934 | + database: Optional[str] = None, |
| 935 | + transaction_id: Optional[str] = None, |
| 936 | + query_as_of_time: Optional[str] = None, |
| 937 | + catalog_id: Optional[str] = None, |
| 938 | + boto3_session: Optional[boto3.Session] = None, |
| 939 | +) -> str: |
| 940 | + """Generate the query that created a table(EXTERNAL_TABLE) or a view(VIRTUAL_TABLE). |
| 941 | +
|
| 942 | + Analyzes an existing table named table_name to generate the query that created it. |
| 943 | +
|
| 944 | + Parameters |
| 945 | + ---------- |
| 946 | + table : str |
| 947 | + Table name. |
| 948 | + database : str |
| 949 | + Database name. |
| 950 | + transaction_id: str, optional |
| 951 | + The ID of the transaction. |
| 952 | + query_as_of_time: str, optional |
| 953 | + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. |
| 954 | + Cannot be specified alongside transaction_id. |
| 955 | + catalog_id : str, optional |
| 956 | + The ID of the Data Catalog from which to retrieve Databases. |
| 957 | + If none is provided, the AWS account ID is used by default. |
| 958 | + boto3_session : boto3.Session(), optional |
| 959 | + Boto3 Session. The default boto3 session will be used if boto3_session receive None. |
| 960 | +
|
| 961 | + Returns |
| 962 | + ------- |
| 963 | + str |
| 964 | + The query that created the table or view. |
| 965 | +
|
| 966 | + Examples |
| 967 | + -------- |
| 968 | + >>> import awswrangler as wr |
| 969 | + >>> view_create_query: str = wr.athena.generate_create_query(table='my_view', database='default') |
| 970 | +
|
| 971 | + """ |
| 972 | + |
| 973 | + def parse_columns(columns_description: List[Dict[str, str]]) -> str: |
| 974 | + columns_str: List[str] = [] |
| 975 | + for column in columns_description: |
| 976 | + column_str = f" `{column['Name']}` {column['Type']}" |
| 977 | + if "Comment" in column: |
| 978 | + column_str += f" COMMENT '{column['Comment']}'" |
| 979 | + columns_str.append(column_str) |
| 980 | + return ", \n".join(columns_str) |
| 981 | + |
| 982 | + def parse_properties(parameters: Dict[str, str]) -> str: |
| 983 | + properties_str: List[str] = [] |
| 984 | + for key, value in parameters.items(): |
| 985 | + if key == "EXTERNAL": |
| 986 | + continue |
| 987 | + property_key_value = f" '{key}'='{value}'" |
| 988 | + properties_str.append(property_key_value) |
| 989 | + return ", \n".join(properties_str) |
| 990 | + |
| 991 | + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) |
| 992 | + table_detail: Dict[str, Any] = client_glue.get_table( |
| 993 | + **_catalog_id( |
| 994 | + catalog_id=catalog_id, |
| 995 | + **_transaction_id( |
| 996 | + transaction_id=transaction_id, query_as_of_time=query_as_of_time, DatabaseName=database, Name=table |
| 997 | + ), |
| 998 | + ) |
| 999 | + )["Table"] |
| 1000 | + if table_detail["TableType"] == "VIRTUAL_VIEW": |
| 1001 | + glue_base64_query: str = table_detail["ViewOriginalText"].replace("/* Presto View: ", "").replace(" */", "") |
| 1002 | + glue_query: str = json.loads(base64.b64decode(glue_base64_query))["originalSql"] |
| 1003 | + return f"""CREATE OR REPLACE VIEW "{table}" AS \n{glue_query}""" |
| 1004 | + if table_detail["TableType"] == "EXTERNAL_TABLE": |
| 1005 | + columns: str = parse_columns(columns_description=table_detail["StorageDescriptor"]["Columns"]) |
| 1006 | + query_parts: List[str] = [f"""CREATE EXTERNAL TABLE `{table}`(\n{columns})"""] |
| 1007 | + partitioned_columns: str = parse_columns(columns_description=table_detail["PartitionKeys"]) |
| 1008 | + if partitioned_columns: |
| 1009 | + query_parts.append(f"""PARTITIONED BY ( \n{partitioned_columns})""") |
| 1010 | + tblproperties: str = parse_properties(parameters=table_detail["Parameters"]) |
| 1011 | + |
| 1012 | + query_parts += [ |
| 1013 | + """ROW FORMAT SERDE """, |
| 1014 | + f""" '{table_detail['StorageDescriptor']['SerdeInfo']['SerializationLibrary']}' """, |
| 1015 | + """STORED AS INPUTFORMAT """, |
| 1016 | + f""" '{table_detail['StorageDescriptor']['InputFormat']}' """, |
| 1017 | + """OUTPUTFORMAT """, |
| 1018 | + f""" '{table_detail['StorageDescriptor']['OutputFormat']}'""", |
| 1019 | + """LOCATION""", |
| 1020 | + f""" '{table_detail['StorageDescriptor']['Location']}'""", |
| 1021 | + f"""TBLPROPERTIES (\n{tblproperties})""", |
| 1022 | + ] |
| 1023 | + return "\n".join(query_parts) |
| 1024 | + raise NotImplementedError() |
| 1025 | + |
| 1026 | + |
928 | 1027 | def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]: |
929 | 1028 | """Return information about the workgroup with the specified name. |
930 | 1029 |
|
|
0 commit comments