Skip to content

Commit 76f6d5f

Browse files
svdimchenkonicor88
andauthored
feat: Implement iceberg retry logic (#657)
Co-authored-by: nicor88 <[email protected]>
1 parent 97430f9 commit 76f6d5f

File tree

3 files changed

+191
-36
lines changed

3 files changed

+191
-36
lines changed

Diff for: README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ You can either:
119119
A dbt profile can be configured to run against AWS Athena using the following configuration:
120120

121121
| Option | Description | Required? | Example |
122-
| --------------------- | ---------------------------------------------------------------------------------------- | --------- | ------------------------------------------ |
122+
|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------|
123123
| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` |
124124
| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` |
125125
| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` |
@@ -134,8 +134,9 @@ A dbt profile can be configured to run against AWS Athena using the following co
134134
| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` |
135135
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
136136
| num_retries | Number of times to retry a failing query | Optional | `3` |
137-
| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` |
138137
| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` |
138+
| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `0` |
139+
| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` |
139140
| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` |
140141
| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` |
141142

Diff for: dbt/adapters/athena/connections.py

+53-34
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
from pyathena.result_set import AthenaResultSet
2626
from pyathena.util import RetryConfig
2727
from tenacity import (
28-
Retrying,
28+
retry,
2929
retry_if_exception,
3030
stop_after_attempt,
3131
wait_random_exponential,
3232
)
33+
from typing_extensions import Self
3334

3435
from dbt.adapters.athena.config import get_boto3_config
3536
from dbt.adapters.athena.constants import LOGGER
@@ -64,8 +65,9 @@ class AthenaCredentials(Credentials):
6465
_ALIASES = {"catalog": "database"}
6566
num_retries: int = 5
6667
num_boto3_retries: Optional[int] = None
68+
num_iceberg_retries: int = 3
6769
s3_data_dir: Optional[str] = None
68-
s3_data_naming: Optional[str] = "schema_table_unique"
70+
s3_data_naming: str = "schema_table_unique"
6971
spark_work_group: Optional[str] = None
7072
s3_tmp_table_dir: Optional[str] = None
7173
# Unfortunately we can not just use dict, must be Dict because we'll get the following error:
@@ -147,7 +149,7 @@ def __poll(self, query_id: str) -> AthenaQueryExecution:
147149
LOGGER.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...")
148150
time.sleep(self._poll_interval)
149151

150-
def execute( # type: ignore
152+
def execute(
151153
self,
152154
operation: str,
153155
parameters: Optional[Dict[str, Any]] = None,
@@ -157,35 +159,9 @@ def execute( # type: ignore
157159
cache_size: int = 0,
158160
cache_expiration_time: int = 0,
159161
catch_partitions_limit: bool = False,
160-
**kwargs,
161-
):
162-
def inner() -> AthenaCursor:
163-
query_id = self._execute(
164-
operation,
165-
parameters=parameters,
166-
work_group=work_group,
167-
s3_staging_dir=s3_staging_dir,
168-
cache_size=cache_size,
169-
cache_expiration_time=cache_expiration_time,
170-
)
171-
172-
LOGGER.debug(f"Athena query ID {query_id}")
173-
174-
query_execution = self._executor.submit(self._collect_result_set, query_id).result()
175-
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
176-
self.result_set = self._result_set_class(
177-
self._connection,
178-
self._converter,
179-
query_execution,
180-
self.arraysize,
181-
self._retry_config,
182-
)
183-
184-
else:
185-
raise OperationalError(query_execution.state_change_reason)
186-
return self
187-
188-
retry = Retrying(
162+
**kwargs: Dict[str, Any],
163+
) -> Self:
164+
@retry(
189165
# No need to retry if TOO_MANY_OPEN_PARTITIONS occurs.
190166
# Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry,
191167
# because not all files are removed immediately after first try to create table
@@ -200,7 +176,47 @@ def inner() -> AthenaCursor:
200176
),
201177
reraise=True,
202178
)
203-
return retry(inner)
179+
def inner() -> AthenaCursor:
180+
num_iceberg_retries = self.connection.cursor_kwargs.get("num_iceberg_retries") + 1
181+
182+
@retry(
183+
# Nested retry is needed to handle ICEBERG_COMMIT_ERROR for parallel inserts
184+
retry=retry_if_exception(lambda e: "ICEBERG_COMMIT_ERROR" in str(e)),
185+
stop=stop_after_attempt(num_iceberg_retries),
186+
wait=wait_random_exponential(
187+
multiplier=num_iceberg_retries,
188+
max=self._retry_config.max_delay,
189+
exp_base=self._retry_config.exponential_base,
190+
),
191+
reraise=True,
192+
)
193+
def execute_with_iceberg_retries() -> AthenaCursor:
194+
query_id = self._execute(
195+
operation,
196+
parameters=parameters,
197+
work_group=work_group,
198+
s3_staging_dir=s3_staging_dir,
199+
cache_size=cache_size,
200+
cache_expiration_time=cache_expiration_time,
201+
)
202+
203+
LOGGER.debug(f"Athena query ID {query_id}")
204+
205+
query_execution = self._executor.submit(self._collect_result_set, query_id).result()
206+
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
207+
self.result_set = self._result_set_class(
208+
self._connection,
209+
self._converter,
210+
query_execution,
211+
self.arraysize,
212+
self._retry_config,
213+
)
214+
return self
215+
raise OperationalError(query_execution.state_change_reason)
216+
217+
return execute_with_iceberg_retries() # type: ignore
218+
219+
return inner() # type: ignore
204220

205221

206222
class AthenaConnectionManager(SQLConnectionManager):
@@ -243,7 +259,10 @@ def open(cls, connection: Connection) -> Connection:
243259
schema_name=creds.schema,
244260
work_group=creds.work_group,
245261
cursor_class=AthenaCursor,
246-
cursor_kwargs={"debug_query_state": creds.debug_query_state},
262+
cursor_kwargs={
263+
"debug_query_state": creds.debug_query_state,
264+
"num_iceberg_retries": creds.num_iceberg_retries,
265+
},
247266
formatter=AthenaParameterFormatter(),
248267
poll_interval=creds.poll_interval,
249268
session=get_boto3_session(connection),

Diff for: tests/functional/adapter/test_retries_iceberg.py

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Test parallel insert into iceberg table."""
2+
import copy
3+
import os
4+
5+
import pytest
6+
7+
from dbt.artifacts.schemas.results import RunStatus
8+
from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture
9+
10+
PARALLELISM = 10
11+
12+
base_dbt_profile = {
13+
"type": "athena",
14+
"s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"),
15+
"s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"),
16+
"schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"),
17+
"database": os.getenv("DBT_TEST_ATHENA_DATABASE"),
18+
"region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"),
19+
"threads": PARALLELISM,
20+
"poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")),
21+
"num_retries": 0,
22+
"work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"),
23+
"aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None,
24+
}
25+
26+
models__target = """
27+
{{
28+
config(
29+
table_type='iceberg',
30+
materialized='table'
31+
)
32+
}}
33+
34+
select * from (
35+
values
36+
(1, -1)
37+
) as t (id, status)
38+
limit 0
39+
40+
"""
41+
42+
models__source = {
43+
f"model_{i}.sql": f"""
44+
{{{{
45+
config(
46+
table_type='iceberg',
47+
materialized='table',
48+
tags=['src'],
49+
pre_hook='insert into target values ({i}, {i})'
50+
)
51+
}}}}
52+
53+
select 1 as col
54+
"""
55+
for i in range(PARALLELISM)
56+
}
57+
58+
seeds__expected_target_init = "id,status"
59+
seeds__expected_target_post = "id,status\n" + "\n".join([f"{i},{i}" for i in range(PARALLELISM)])
60+
61+
62+
class TestIcebergRetriesDisabled:
63+
@pytest.fixture(scope="class")
64+
def dbt_profile_target(self):
65+
profile = copy.deepcopy(base_dbt_profile)
66+
profile["num_iceberg_retries"] = 0
67+
return profile
68+
69+
@pytest.fixture(scope="class")
70+
def models(self):
71+
return {**{"target.sql": models__target}, **models__source}
72+
73+
@pytest.fixture(scope="class")
74+
def seeds(self):
75+
return {
76+
"expected_target_init.csv": seeds__expected_target_init,
77+
"expected_target_post.csv": seeds__expected_target_post,
78+
}
79+
80+
def test__retries_iceberg(self, project):
81+
"""Seed should match the model after run"""
82+
83+
expected__init_seed_name = "expected_target_init"
84+
run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"])
85+
86+
relation_name = "target"
87+
model_run = run_dbt(["run", "--select", relation_name])
88+
model_run_result = model_run.results[0]
89+
assert model_run_result.status == RunStatus.Success
90+
check_relations_equal(project.adapter, [relation_name, expected__init_seed_name])
91+
92+
expected__post_seed_name = "expected_target_post"
93+
run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"])
94+
95+
run, log = run_dbt_and_capture(["run", "--select", "tag:src"], expect_pass=False)
96+
assert any(model_run_result.status == RunStatus.Error for model_run_result in run.results)
97+
assert "ICEBERG_COMMIT_ERROR" in log
98+
99+
100+
class TestIcebergRetriesEnabled:
101+
@pytest.fixture(scope="class")
102+
def dbt_profile_target(self):
103+
profile = copy.deepcopy(base_dbt_profile)
104+
profile["num_iceberg_retries"] = 1
105+
return profile
106+
107+
@pytest.fixture(scope="class")
108+
def models(self):
109+
return {**{"target.sql": models__target}, **models__source}
110+
111+
@pytest.fixture(scope="class")
112+
def seeds(self):
113+
return {
114+
"expected_target_init.csv": seeds__expected_target_init,
115+
"expected_target_post.csv": seeds__expected_target_post,
116+
}
117+
118+
def test__retries_iceberg(self, project):
119+
"""Seed should match the model after run"""
120+
121+
expected__init_seed_name = "expected_target_init"
122+
run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"])
123+
124+
relation_name = "target"
125+
model_run = run_dbt(["run", "--select", relation_name])
126+
model_run_result = model_run.results[0]
127+
assert model_run_result.status == RunStatus.Success
128+
check_relations_equal(project.adapter, [relation_name, expected__init_seed_name])
129+
130+
expected__post_seed_name = "expected_target_post"
131+
run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"])
132+
133+
run = run_dbt(["run", "--select", "tag:src"])
134+
assert all([model_run_result.status == RunStatus.Success for model_run_result in run.results])
135+
check_relations_equal(project.adapter, [relation_name, expected__post_seed_name])

0 commit comments

Comments
 (0)