3434else :
3535 from airflow .providers .common .compat .sdk import DAG
3636
37+ if AIRFLOW_V_3_0_PLUS :
38+ from airflow .serialization .serialized_objects import DagSerialization
39+ else :
40+ from airflow .serialization .serialized_objects import SerializedDAG
41+
3742DATA_INTERVAL_START = pendulum .datetime (2022 , 1 , 1 , tz = "UTC" )
3843DATA_INTERVAL_END = DATA_INTERVAL_START + dt .timedelta (hours = 1 )
3944
@@ -57,12 +62,20 @@ def sync_dag_to_db(
5762
5863 def _write_dag (dag : DAG ) -> SerializedDAG :
5964 if not SerializedDagModel .has_dag (dag .dag_id ):
60- data = SerializedDAG .to_dict (dag )
65+ data = (
66+ DagSerialization .to_dict (dag )
67+ if AIRFLOW_V_3_0_PLUS
68+ else SerializedDAG .to_dict (dag )
69+ )
6170 SerializedDagModel .write_dag (
6271 LazyDeserializedDAG (data = data ), bundle_name , session = session
6372 )
6473 session .flush ()
65- return SerializedDAG .from_dict (data )
74+ return (
75+ DagSerialization .from_dict (data )
76+ if AIRFLOW_V_3_0_PLUS
77+ else SerializedDAG .from_dict (data )
78+ )
6679
6780 SerializedDAG .bulk_write_to_db (bundle_name , None , [dag ], session = session )
6881 _ = _write_dag (dag )
@@ -106,6 +119,13 @@ def _create_dagrun(
106119 )
107120
108121
122+ def _run_task_instance (ti ):
123+ if AIRFLOW_V_3_1_PLUS :
124+ return
125+
126+ ti .run (ignore_ti_state = True )
127+
128+
109129@pytest .fixture (scope = "session" )
110130def dagbag ():
111131 """An Airflow DagBag."""
@@ -240,7 +260,7 @@ def test_dbt_operators_in_dag(
240260 ti = dagrun .get_task_instance (task_id = task_id )
241261 ti .task = basic_dag .get_task (task_id = task_id )
242262
243- ti . run ( ignore_ti_state = True )
263+ _run_task_instance ( ti )
244264
245265 assert ti .state == TaskInstanceState .SUCCESS
246266
@@ -364,7 +384,7 @@ def test_dbt_operators_in_taskflow_dag(
364384 ti = dagrun .get_task_instance (task_id = task_id )
365385 ti .task = dag .get_task (task_id = task_id )
366386
367- ti . run ( ignore_ti_state = True )
387+ _run_task_instance ( ti )
368388
369389 assert ti .state == TaskInstanceState .SUCCESS
370390 assert ti .task .retries == dag .default_args ["retries" ]
@@ -376,8 +396,9 @@ def test_dbt_operators_in_taskflow_dag(
376396 assert failure_callback == dag .default_args ["on_failure_callback" ]
377397
378398 if isinstance (ti .task , DbtBaseOperator ):
379- assert ti .task .profiles_dir == str (profiles_file .parent )
380- assert ti .task .project_dir == str (dbt_project_file .parent )
399+ if not AIRFLOW_V_3_1_PLUS :
400+ assert ti .task .profiles_dir == str (profiles_file .parent )
401+ assert ti .task .project_dir == str (dbt_project_file .parent )
381402
382403 results = ti .xcom_pull (
383404 task_ids = task_id ,
@@ -493,7 +514,7 @@ def test_dbt_operators_in_connection_dag(
493514 ti = dagrun .get_task_instance (task_id = task_id )
494515 ti .task = target_connection_dag .get_task (task_id = task_id )
495516
496- ti . run ( ignore_ti_state = True )
517+ _run_task_instance ( ti )
497518
498519 assert ti .state == TaskInstanceState .SUCCESS
499520
@@ -568,7 +589,7 @@ def test_example_basic_dag(
568589 ti = dagrun .get_task_instance (task_id = "dbt_run_hourly" )
569590 ti .task = dbt_run
570591
571- ti . run ( ignore_ti_state = True )
592+ _run_task_instance ( ti )
572593
573594 assert ti .state == TaskInstanceState .SUCCESS
574595
@@ -612,6 +633,9 @@ def test_example_dbt_project_in_github_dag(
612633 if AIRFLOW_V_3_0 :
613634 dag = DAG .from_sdk_dag (dag ) # type: ignore
614635
636+ for task_id in ("dbt_seed" , "dbt_run" , "dbt_test" ):
637+ dag .get_task (task_id = task_id ).dbt_conn_id = connection
638+
615639 dagrun = _create_dagrun (
616640 dag ,
617641 state = DagRunState .RUNNING ,
@@ -626,7 +650,7 @@ def test_example_dbt_project_in_github_dag(
626650 ti .task = dag .get_task (task_id = task_id )
627651 ti .task .dbt_conn_id = connection
628652
629- ti . run ( ignore_ti_state = True )
653+ _run_task_instance ( ti )
630654
631655 assert ti .state == TaskInstanceState .SUCCESS
632656
@@ -669,6 +693,12 @@ def test_example_complete_dbt_workflow_dag(
669693 if AIRFLOW_V_3_0 :
670694 dag = DAG .from_sdk_dag (dag ) # type: ignore
671695
696+ for task in dag .tasks :
697+ task .project_dir = dbt_project_file .parent
698+ task .profiles_dir = profiles_file .parent
699+ task .target = "test"
700+ task .profile = "default"
701+
672702 dagrun = _create_dagrun (
673703 dag ,
674704 state = DagRunState .RUNNING ,
@@ -679,15 +709,10 @@ def test_example_complete_dbt_workflow_dag(
679709 )
680710
681711 for task in dag .tasks :
682- task .project_dir = dbt_project_file .parent
683- task .profiles_dir = profiles_file .parent
684- task .target = "test"
685- task .profile = "default"
686-
687712 ti = dagrun .get_task_instance (task_id = task .task_id )
688713 ti .task = task
689714
690- ti . run ( ignore_ti_state = True )
715+ _run_task_instance ( ti )
691716
692717 assert ti .state == TaskInstanceState .SUCCESS
693718
0 commit comments