Skip to content

Commit f0968df

Browse files
committed
fix: Compatibility with older airflow versions in dag tests
1 parent cf641c3 commit f0968df

1 file changed

Lines changed: 20 additions & 13 deletions

File tree

tests/dags/test_dbt_dags.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
else:
3535
from airflow.providers.common.compat.sdk import DAG
3636

37-
if AIRFLOW_V_3_0_PLUS:
37+
from airflow.serialization.serialized_objects import SerializedDAG
38+
39+
try:
3840
from airflow.serialization.serialized_objects import DagSerialization
39-
else:
40-
from airflow.serialization.serialized_objects import SerializedDAG
41+
except ImportError:
42+
DagSerialization = SerializedDAG
4143

4244
DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
4345
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)
@@ -62,20 +64,12 @@ def sync_dag_to_db(
6264

6365
def _write_dag(dag: DAG) -> SerializedDAG:
6466
if not SerializedDagModel.has_dag(dag.dag_id):
65-
data = (
66-
DagSerialization.to_dict(dag)
67-
if AIRFLOW_V_3_0_PLUS
68-
else SerializedDAG.to_dict(dag)
69-
)
67+
data = DagSerialization.to_dict(dag)
7068
SerializedDagModel.write_dag(
7169
LazyDeserializedDAG(data=data), bundle_name, session=session
7270
)
7371
session.flush()
74-
return (
75-
DagSerialization.from_dict(data)
76-
if AIRFLOW_V_3_0_PLUS
77-
else SerializedDAG.from_dict(data)
78-
)
72+
return DagSerialization.from_dict(data)
7973

8074
SerializedDAG.bulk_write_to_db(bundle_name, None, [dag], session=session)
8175
_ = _write_dag(dag)
@@ -123,6 +117,13 @@ def _run_task_instance(ti):
123117
if AIRFLOW_V_3_1_PLUS:
124118
return
125119

120+
if AIRFLOW_V_3_0:
121+
# Airflow 3.0.6's TaskInstance runner fails in-process with a 422.
122+
if isinstance(ti.task, DbtBaseOperator):
123+
ti.task.execute({"ti": ti})
124+
ti.state = TaskInstanceState.SUCCESS
125+
return
126+
126127
ti.run(ignore_ti_state=True)
127128

128129

@@ -365,6 +366,12 @@ def test_dbt_operators_in_taskflow_dag(
365366
else:
366367
dag = taskflow_dag
367368

369+
if AIRFLOW_V_3_0:
370+
for task in dag.tasks:
371+
if isinstance(task, DbtBaseOperator):
372+
task.profiles_dir = str(profiles_file.parent)
373+
task.project_dir = str(dbt_project_file.parent)
374+
368375
dagrun = _create_dagrun(
369376
dag,
370377
state=DagRunState.RUNNING,

0 commit comments

Comments
 (0)