Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions awswrangler/athena/_write_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
_logger: logging.Logger = logging.getLogger(__name__)


def _escape_athena_string_literal(value: Any) -> str:
# Used for caller-supplied values spliced inside SQL string literals, e.g.
# COMMENT '<value>' or TBLPROPERTIES ('<key>'='<value>').
#
# The wrapping single quotes around the splice site are *delimiters* — they tell
# Athena "this is a string literal" and are part of the DDL grammar, not escaping.
# A naive splice f"'{value}'" lets a value containing ' close the literal and
# append arbitrary DDL (e.g. LOCATION '...').
#
# Athena/Trino string literals have exactly one escape mechanism: a single quote
# inside the literal must be doubled. Pre-doubling caller quotes here means the
# whole value parses as one literal regardless of what it contains.
return str(value).replace("'", "''")


def _create_iceberg_table(
df: pd.DataFrame,
database: str,
Expand All @@ -50,13 +65,19 @@ def _create_iceberg_table(
[
f"{k} {v}"
if (columns_comments is None or columns_comments.get(k) is None)
else f"{k} {v} COMMENT '{columns_comments[k]}'"
else f"{k} {v} COMMENT '{_escape_athena_string_literal(columns_comments[k])}'"
for k, v in columns_types.items()
]
)
partition_cols_str: str = f"PARTITIONED BY ({', '.join([col for col in partition_cols])})" if partition_cols else ""
table_properties_str: str = (
", " + ", ".join([f"'{key}'='{value}'" for key, value in additional_table_properties.items()])
", "
+ ", ".join(
[
f"'{_escape_athena_string_literal(key)}'='{_escape_athena_string_literal(value)}'"
for key, value in additional_table_properties.items()
]
)
if additional_table_properties
else ""
)
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_moto.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,45 @@ def mock_make_api_call(self, operation_name, kwarg):
allow_full_scan=True,
)
assert describe_table_calls == 1


def test_create_iceberg_table_escapes_single_quotes_in_columns_comments() -> None:
# Single quotes in caller-supplied columns_comments / additional_table_properties
# values must be doubled so they cannot terminate the surrounding 'literal' and
# change the structure of the generated DDL.
from awswrangler.athena import _write_iceberg

captured: list[str] = []

def fake_start(*, sql: str, **_) -> str:
captured.append(sql)
return "qid"

df = pd.DataFrame({"id": pd.Series(dtype="int64"), "user_name": pd.Series(dtype="string")})
wg_config = mock.MagicMock()
wg_config.enforce_workgroup_location = False

with mock.patch.object(_write_iceberg, "_start_query_execution", side_effect=fake_start), mock.patch.object(
_write_iceberg, "wait_query"
):
_write_iceberg._create_iceberg_table(
df=df,
database="db",
table="t",
path="s3://intended/output/",
wg_config=wg_config,
partition_cols=None,
additional_table_properties={"prop": "val') LOCATION 's3://other/' --"},
index=False,
boto3_session=mock.MagicMock(),
columns_comments={"user_name": "') LOCATION 's3://other/' TBLPROPERTIES ('x'='y"},
)

sql = captured[0]
# Quotes were doubled in both splices, so unescaped caller content stays inside the
# COMMENT / TBLPROPERTIES string literals and does not open a new DDL clause.
assert "COMMENT ''') LOCATION ''s3://other/'' TBLPROPERTIES (''x''=''y'" in sql
assert "'prop'='val'') LOCATION ''s3://other/'' --'" in sql
# The intended LOCATION (un-doubled quotes) is the only top-level clause.
assert "LOCATION 's3://intended/output/'" in sql
assert "LOCATION 's3://other/'" not in sql
Loading