Skip to content

Commit 97430f9

Browse files
pierrebzlnicor88
andauthored
feat: Allow custom schema def for tmp tables generated by incremental (#659)
Co-authored-by: nicor88 <[email protected]>
1 parent 8faa921 commit 97430f9

File tree

7 files changed

+174
-39
lines changed

7 files changed

+174
-39
lines changed

Diff for: README.md

+4
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ athena:
213213
- For incremental models using insert overwrite strategy on hive table
214214
- Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid
215215
- Useful if you are looking to run multiple dbt build inserting in the same table in parallel
216+
- `temp_schema` (`default=none`)
217+
- For incremental models, it allows to define a schema to hold temporary create statements
218+
used in incremental model runs
219+
- Schema will be created in the model target database if does not exist
216220
- `lf_tags_config` (`default=none`)
217221
- [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns
218222
- `enabled` (`default=False`) whether LF tags management is enabled for a model

Diff for: dbt/adapters/athena/impl.py

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class AthenaConfig(AdapterConfig):
9999
partitions_limit: Maximum numbers of partitions when batching.
100100
force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode.
101101
unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp.
102+
temp_schema: Define in which schema to create temporary tables used in incremental runs.
102103
"""
103104

104105
work_group: Optional[str] = None
@@ -120,6 +121,7 @@ class AthenaConfig(AdapterConfig):
120121
partitions_limit: Optional[int] = None
121122
force_batch: bool = False
122123
unique_tmp_table_suffix: bool = False
124+
temp_schema: Optional[str] = None
123125

124126

125127
class AthenaAdapter(SQLAdapter):

Diff for: dbt/include/athena/macros/adapters/relation.sql

+15
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@
3636
{%- endcall %}
3737
{%- endmacro %}
3838

39+
{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %}
40+
{%- set temp_identifier = base_relation.identifier ~ suffix -%}
41+
{%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%}
42+
43+
{%- if temp_schema is not none -%}
44+
{%- set temp_relation = temp_relation.incorporate(path={
45+
"identifier": temp_identifier,
46+
"schema": temp_schema
47+
}) -%}
48+
{%- do create_schema(temp_relation) -%}
49+
{% endif %}
50+
51+
{{ return(temp_relation) }}
52+
{% endmacro %}
53+
3954
{% macro athena__rename_relation(from_relation, to_relation) %}
4055
{% call statement('rename_relation') -%}
4156
alter table {{ from_relation.render_hive() }} rename to `{{ to_relation.schema }}`.`{{ to_relation.identifier }}`

Diff for: dbt/include/athena/macros/materializations/models/incremental/incremental.sql

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
{% set partitioned_by = config.get('partitioned_by') %}
1111
{% set force_batch = config.get('force_batch', False) | as_bool -%}
1212
{% set unique_tmp_table_suffix = config.get('unique_tmp_table_suffix', False) | as_bool -%}
13+
{% set temp_schema = config.get('temp_schema') %}
1314
{% set target_relation = this.incorporate(type='table') %}
1415
{% set existing_relation = load_relation(this) %}
1516
-- If using insert_overwrite on Hive table, allow to set a unique tmp table suffix
@@ -22,7 +23,7 @@
2223
{% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix,
2324
schema=schema,
2425
database=database) %}
25-
{% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix) %}
26+
{% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix, temp_schema=temp_schema) %}
2627

2728
-- If no partitions are used with insert_overwrite, we fall back to append mode.
2829
{% if partitioned_by is none and strategy == 'insert_overwrite' %}
+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import pytest
2+
import yaml
3+
from tests.functional.adapter.utils.parse_dbt_run_output import (
4+
extract_create_statement_table_names,
5+
extract_running_create_statements,
6+
)
7+
8+
from dbt.contracts.results import RunStatus
9+
from dbt.tests.util import run_dbt
10+
11+
models__schema_tmp_sql = """
12+
{{ config(
13+
materialized='incremental',
14+
incremental_strategy='insert_overwrite',
15+
partitioned_by=['date_column'],
16+
temp_schema=var('temp_schema_name')
17+
)
18+
}}
19+
select
20+
random() as rnd,
21+
cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column
22+
"""
23+
24+
25+
class TestIncrementalTmpSchema:
26+
@pytest.fixture(scope="class")
27+
def models(self):
28+
return {"schema_tmp.sql": models__schema_tmp_sql}
29+
30+
def test__schema_tmp(self, project, capsys):
31+
relation_name = "schema_tmp"
32+
temp_schema_name = f"{project.test_schema}_tmp"
33+
drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade"
34+
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
35+
36+
vars_dict = {
37+
"temp_schema_name": temp_schema_name,
38+
"logical_date": "2024-01-01",
39+
}
40+
41+
first_model_run = run_dbt(
42+
[
43+
"run",
44+
"--select",
45+
relation_name,
46+
"--vars",
47+
yaml.safe_dump(vars_dict),
48+
"--log-level",
49+
"debug",
50+
"--log-format",
51+
"json",
52+
]
53+
)
54+
55+
first_model_run_result = first_model_run.results[0]
56+
57+
assert first_model_run_result.status == RunStatus.Success
58+
59+
records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
60+
61+
assert records_count_first_run == 1
62+
63+
out, _ = capsys.readouterr()
64+
athena_running_create_statements = extract_running_create_statements(out, relation_name)
65+
66+
assert len(athena_running_create_statements) == 1
67+
68+
incremental_model_run_result_table_name = extract_create_statement_table_names(
69+
athena_running_create_statements[0]
70+
)[0]
71+
72+
assert temp_schema_name not in incremental_model_run_result_table_name
73+
74+
vars_dict["logical_date"] = "2024-01-02"
75+
incremental_model_run = run_dbt(
76+
[
77+
"run",
78+
"--select",
79+
relation_name,
80+
"--vars",
81+
yaml.safe_dump(vars_dict),
82+
"--log-level",
83+
"debug",
84+
"--log-format",
85+
"json",
86+
]
87+
)
88+
89+
incremental_model_run_result = incremental_model_run.results[0]
90+
91+
assert incremental_model_run_result.status == RunStatus.Success
92+
93+
records_count_incremental_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
94+
95+
assert records_count_incremental_run == 2
96+
97+
out, _ = capsys.readouterr()
98+
athena_running_create_statements = extract_running_create_statements(out, relation_name)
99+
100+
assert len(athena_running_create_statements) == 1
101+
102+
incremental_model_run_result_table_name = extract_create_statement_table_names(
103+
athena_running_create_statements[0]
104+
)[0]
105+
106+
assert temp_schema_name == incremental_model_run_result_table_name.split(".")[1].strip('"')
107+
108+
project.run_sql(drop_temp_schema)

Diff for: tests/functional/adapter/test_unique_tmp_table_suffix.py

+7-38
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import json
21
import re
3-
from typing import List
42

53
import pytest
4+
from tests.functional.adapter.utils.parse_dbt_run_output import (
5+
extract_create_statement_table_names,
6+
extract_running_create_statements,
7+
)
68

79
from dbt.contracts.results import RunStatus
810
from dbt.tests.util import run_dbt
@@ -21,39 +23,6 @@
2123
"""
2224

2325

24-
def extract_running_create_statements(dbt_run_capsys_output: str) -> List[str]:
25-
sql_create_statements = []
26-
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
27-
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
28-
base_msg_data = None
29-
# Best effort solution to avoid invalid records and blank lines
30-
try:
31-
base_msg_data = json.loads(events_msg).get("data")
32-
except json.JSONDecodeError:
33-
pass
34-
"""First run will not produce data.sql object in the execution logs, only data.base_msg
35-
containing the "Running Athena query:" initial create statement.
36-
Subsequent incremental runs will only contain the insert from the tmp table into the model
37-
table destination.
38-
Since we want to compare both run create statements, we need to handle both cases"""
39-
if base_msg_data:
40-
base_msg = base_msg_data.get("base_msg")
41-
if "Running Athena query:" in str(base_msg):
42-
if "create table" in base_msg:
43-
sql_create_statements.append(base_msg)
44-
45-
if base_msg_data.get("conn_name") == "model.test.unique_tmp_table_suffix" and "sql" in base_msg_data:
46-
if "create table" in base_msg_data.get("sql"):
47-
sql_create_statements.append(base_msg_data.get("sql"))
48-
49-
return sql_create_statements
50-
51-
52-
def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
53-
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
54-
return [table_name.rstrip() for table_name in table_names]
55-
56-
5726
class TestUniqueTmpTableSuffix:
5827
@pytest.fixture(scope="class")
5928
def models(self):
@@ -86,7 +55,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
8655
assert first_model_run_result.status == RunStatus.Success
8756

8857
out, _ = capsys.readouterr()
89-
athena_running_create_statements = extract_running_create_statements(out)
58+
athena_running_create_statements = extract_running_create_statements(out, relation_name)
9059

9160
assert len(athena_running_create_statements) == 1
9261

@@ -118,7 +87,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
11887
assert incremental_model_run_result.status == RunStatus.Success
11988

12089
out, _ = capsys.readouterr()
121-
athena_running_create_statements = extract_running_create_statements(out)
90+
athena_running_create_statements = extract_running_create_statements(out, relation_name)
12291

12392
assert len(athena_running_create_statements) == 1
12493

@@ -150,7 +119,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
150119
assert incremental_model_run_result.status == RunStatus.Success
151120

152121
out, _ = capsys.readouterr()
153-
athena_running_create_statements = extract_running_create_statements(out)
122+
athena_running_create_statements = extract_running_create_statements(out, relation_name)
154123

155124
incremental_model_run_result_table_name_2 = extract_create_statement_table_names(
156125
athena_running_create_statements[0]
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import json
2+
import re
3+
from typing import List
4+
5+
6+
def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]:
7+
sql_create_statements = []
8+
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
9+
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
10+
base_msg_data = None
11+
# Best effort solution to avoid invalid records and blank lines
12+
try:
13+
base_msg_data = json.loads(events_msg).get("data")
14+
except json.JSONDecodeError:
15+
pass
16+
"""First run will not produce data.sql object in the execution logs, only data.base_msg
17+
containing the "Running Athena query:" initial create statement.
18+
Subsequent incremental runs will only contain the insert from the tmp table into the model
19+
table destination.
20+
Since we want to compare both run create statements, we need to handle both cases"""
21+
if base_msg_data:
22+
base_msg = base_msg_data.get("base_msg")
23+
if "Running Athena query:" in str(base_msg):
24+
if "create table" in base_msg:
25+
sql_create_statements.append(base_msg)
26+
27+
if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data:
28+
if "create table" in base_msg_data.get("sql"):
29+
sql_create_statements.append(base_msg_data.get("sql"))
30+
31+
return sql_create_statements
32+
33+
34+
def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
35+
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
36+
return [table_name.rstrip() for table_name in table_names]

0 commit comments

Comments
 (0)