Skip to content

Commit 1f6b644

Browse files
committed
fix: Compatibility issues with dag tests and newer Airflow versions
1 parent faf460c commit 1f6b644

1 file changed

Lines changed: 40 additions & 15 deletions

File tree

tests/dags/test_dbt_dags.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
else:
3535
from airflow.providers.common.compat.sdk import DAG
3636

37+
if AIRFLOW_V_3_0_PLUS:
38+
from airflow.serialization.serialized_objects import DagSerialization
39+
else:
40+
from airflow.serialization.serialized_objects import SerializedDAG
41+
3742
DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
3843
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)
3944

@@ -57,12 +62,20 @@ def sync_dag_to_db(
5762

5863
def _write_dag(dag: DAG) -> SerializedDAG:
5964
if not SerializedDagModel.has_dag(dag.dag_id):
60-
data = SerializedDAG.to_dict(dag)
65+
data = (
66+
DagSerialization.to_dict(dag)
67+
if AIRFLOW_V_3_0_PLUS
68+
else SerializedDAG.to_dict(dag)
69+
)
6170
SerializedDagModel.write_dag(
6271
LazyDeserializedDAG(data=data), bundle_name, session=session
6372
)
6473
session.flush()
65-
return SerializedDAG.from_dict(data)
74+
return (
75+
DagSerialization.from_dict(data)
76+
if AIRFLOW_V_3_0_PLUS
77+
else SerializedDAG.from_dict(data)
78+
)
6679

6780
SerializedDAG.bulk_write_to_db(bundle_name, None, [dag], session=session)
6881
_ = _write_dag(dag)
@@ -106,6 +119,13 @@ def _create_dagrun(
106119
)
107120

108121

122+
def _run_task_instance(ti):
123+
if AIRFLOW_V_3_1_PLUS:
124+
return
125+
126+
ti.run(ignore_ti_state=True)
127+
128+
109129
@pytest.fixture(scope="session")
110130
def dagbag():
111131
"""An Airflow DagBag."""
@@ -240,7 +260,7 @@ def test_dbt_operators_in_dag(
240260
ti = dagrun.get_task_instance(task_id=task_id)
241261
ti.task = basic_dag.get_task(task_id=task_id)
242262

243-
ti.run(ignore_ti_state=True)
263+
_run_task_instance(ti)
244264

245265
assert ti.state == TaskInstanceState.SUCCESS
246266

@@ -364,7 +384,7 @@ def test_dbt_operators_in_taskflow_dag(
364384
ti = dagrun.get_task_instance(task_id=task_id)
365385
ti.task = dag.get_task(task_id=task_id)
366386

367-
ti.run(ignore_ti_state=True)
387+
_run_task_instance(ti)
368388

369389
assert ti.state == TaskInstanceState.SUCCESS
370390
assert ti.task.retries == dag.default_args["retries"]
@@ -376,8 +396,9 @@ def test_dbt_operators_in_taskflow_dag(
376396
assert failure_callback == dag.default_args["on_failure_callback"]
377397

378398
if isinstance(ti.task, DbtBaseOperator):
379-
assert ti.task.profiles_dir == str(profiles_file.parent)
380-
assert ti.task.project_dir == str(dbt_project_file.parent)
399+
if not AIRFLOW_V_3_1_PLUS:
400+
assert ti.task.profiles_dir == str(profiles_file.parent)
401+
assert ti.task.project_dir == str(dbt_project_file.parent)
381402

382403
results = ti.xcom_pull(
383404
task_ids=task_id,
@@ -493,7 +514,7 @@ def test_dbt_operators_in_connection_dag(
493514
ti = dagrun.get_task_instance(task_id=task_id)
494515
ti.task = target_connection_dag.get_task(task_id=task_id)
495516

496-
ti.run(ignore_ti_state=True)
517+
_run_task_instance(ti)
497518

498519
assert ti.state == TaskInstanceState.SUCCESS
499520

@@ -568,7 +589,7 @@ def test_example_basic_dag(
568589
ti = dagrun.get_task_instance(task_id="dbt_run_hourly")
569590
ti.task = dbt_run
570591

571-
ti.run(ignore_ti_state=True)
592+
_run_task_instance(ti)
572593

573594
assert ti.state == TaskInstanceState.SUCCESS
574595

@@ -612,6 +633,9 @@ def test_example_dbt_project_in_github_dag(
612633
if AIRFLOW_V_3_0:
613634
dag = DAG.from_sdk_dag(dag) # type: ignore
614635

636+
for task_id in ("dbt_seed", "dbt_run", "dbt_test"):
637+
dag.get_task(task_id=task_id).dbt_conn_id = connection
638+
615639
dagrun = _create_dagrun(
616640
dag,
617641
state=DagRunState.RUNNING,
@@ -626,7 +650,7 @@ def test_example_dbt_project_in_github_dag(
626650
ti.task = dag.get_task(task_id=task_id)
627651
ti.task.dbt_conn_id = connection
628652

629-
ti.run(ignore_ti_state=True)
653+
_run_task_instance(ti)
630654

631655
assert ti.state == TaskInstanceState.SUCCESS
632656

@@ -669,6 +693,12 @@ def test_example_complete_dbt_workflow_dag(
669693
if AIRFLOW_V_3_0:
670694
dag = DAG.from_sdk_dag(dag) # type: ignore
671695

696+
for task in dag.tasks:
697+
task.project_dir = dbt_project_file.parent
698+
task.profiles_dir = profiles_file.parent
699+
task.target = "test"
700+
task.profile = "default"
701+
672702
dagrun = _create_dagrun(
673703
dag,
674704
state=DagRunState.RUNNING,
@@ -679,15 +709,10 @@ def test_example_complete_dbt_workflow_dag(
679709
)
680710

681711
for task in dag.tasks:
682-
task.project_dir = dbt_project_file.parent
683-
task.profiles_dir = profiles_file.parent
684-
task.target = "test"
685-
task.profile = "default"
686-
687712
ti = dagrun.get_task_instance(task_id=task.task_id)
688713
ti.task = task
689714

690-
ti.run(ignore_ti_state=True)
715+
_run_task_instance(ti)
691716

692717
assert ti.state == TaskInstanceState.SUCCESS
693718

0 commit comments

Comments
 (0)