1010# See the License for the specific language governing permissions and
1111# limitations under the License.
1212
13- import logging
1413from uuid import uuid4
1514
1615import airflow .models
1716import time
18- from airflow import LoggingMixin
17+
1918from airflow .contrib .operators .bigquery_operator import BigQueryOperator
2019from airflow .operators .postgres_operator import PostgresOperator
2120from airflow .utils .state import State
3029 DagUtils
3130)
3231
32+ # Handling of import of different airflow versions
33+ from airflow .version import version as AIRFLOW_VERSION
34+ from pkg_resources import parse_version
35+ if parse_version (AIRFLOW_VERSION ) >= parse_version ("1.10.11" ):
36+ from airflow import LoggingMixin
37+ else :
38+ # Corrects path of import for Airflow versions below 1.10.11
39+ from airflow .utils .log .logging_mixin import LoggingMixin
40+
3341from marquez_airflow .marquez import Marquez
3442
35- log = logging .getLogger (__name__ )
3643
44+ _MARQUEZ = Marquez ()
45+
46+ # TODO: Manually define operator->extractor mappings for now,
47+ # but we'll want to encapsulate this logic in an 'Extractors' class
48+ # with more convenient methods (ex: 'Extractors.extractor_for_task()')
49+ _EXTRACTORS = {
50+ PostgresOperator : PostgresExtractor ,
51+ BigQueryOperator : BigQueryExtractor
52+ # Append new extractors here
53+ }
3754
38- class DAG (airflow .models .DAG , LoggingMixin ):
39- _job_id_mapping = None
40- _marquez = None
4155
56+ class DAG (airflow .models .DAG , LoggingMixin ):
4257 def __init__ (self , * args , ** kwargs ):
4358 super ().__init__ (* args , ** kwargs )
4459
45- self ._job_id_mapping = JobIdMapping ()
46- self ._marquez = Marquez ()
47- # TODO: Manually define operator->extractor mappings for now,
48- # but we'll want to encapsulate this logic in an 'Extractors' class
49- # with more convenient methods (ex: 'Extractors.extractor_for_task()')
50- self ._extractors = {
51- PostgresOperator : PostgresExtractor ,
52- BigQueryOperator : BigQueryExtractor
53- # Append new extractors here
54- }
55- self .log .debug (
56- f"DAG successfully created with extractors: { self ._extractors } "
57- )
58-
5960 def create_dagrun (self , * args , ** kwargs ):
6061 # run Airflow's create_dagrun() first
6162 dagrun = super (DAG , self ).create_dagrun (* args , ** kwargs )
6263
6364 create_dag_start_ms = self ._now_ms ()
6465 try :
65- self . _marquez .create_namespace ()
66+ _MARQUEZ .create_namespace ()
6667 self ._register_dagrun (
6768 dagrun ,
6869 DagUtils .get_execution_date (** kwargs ),
@@ -83,10 +84,10 @@ def _register_dagrun(self, dagrun, execution_date, run_args):
8384 t = self ._now_ms ()
8485 try :
8586 steps = self ._extract_metadata (dagrun , task )
86- [self . _marquez .create_job (
87+ [_MARQUEZ .create_job (
8788 step , self ._get_location (task ), self .description )
8889 for step in steps ]
89- marquez_jobrun_ids = [self . _marquez .create_run (
90+ marquez_jobrun_ids = [_MARQUEZ .create_run (
9091 self .new_run_id (),
9192 step ,
9293 run_args ,
@@ -95,7 +96,7 @@ def _register_dagrun(self, dagrun, execution_date, run_args):
9596 execution_date ,
9697 self .following_schedule (execution_date ))
9798 ) for step in steps ]
98- self . _job_id_mapping .set (
99+ JobIdMapping .set (
99100 self ._marquez_job_name (self .dag_id , task .task_id ),
100101 dagrun .run_id ,
101102 marquez_jobrun_ids )
@@ -110,7 +111,7 @@ def handle_callback(self, *args, **kwargs):
110111 try :
111112 dagrun = args [0 ]
112113 self .log .debug (f"handle_callback() dagrun : { dagrun } " )
113- self . _marquez .create_namespace ()
114+ _MARQUEZ .create_namespace ()
114115 self ._report_task_instances (
115116 dagrun ,
116117 DagUtils .get_run_args (** kwargs ),
@@ -137,17 +138,17 @@ def _report_task_instances(self, dagrun, run_args, session):
137138
138139 def _report_task_instance (self , ti , dagrun , run_args , session ):
139140 task = self .get_task (ti .task_id )
140- run_ids = self . _job_id_mapping .pop (
141+ run_ids = JobIdMapping .pop (
141142 self ._marquez_job_name_from_ti (ti ), dagrun .run_id , session )
142143 steps = self ._extract_metadata (dagrun , task , ti )
143144
144145 # Note: run_ids could be missing if it was removed from airflow
145146 # or the job could not be registered.
146147 if not run_ids :
147- [self . _marquez .create_job (
148+ [_MARQUEZ .create_job (
148149 step , self ._get_location (task ), self .description )
149- for step in steps ]
150- run_ids = [self . _marquez .create_run (
150+ for step in steps ]
151+ run_ids = [_MARQUEZ .create_run (
151152 self .new_run_id (),
152153 step ,
153154 run_args ,
@@ -159,30 +160,30 @@ def _report_task_instance(self, ti, dagrun, run_args, session):
159160
160161 for step in steps :
161162 for run_id in run_ids :
162- self . _marquez .create_job (
163+ _MARQUEZ .create_job (
163164 step , self ._get_location (task ), self .description ,
164165 ti .state , run_id )
165- self . _marquez .start_run (
166+ _MARQUEZ .start_run (
166167 run_id ,
167168 DagUtils .to_iso_8601 (ti .start_date ))
168169
169170 self .log .debug (f'Setting task state: { ti .state } '
170171 f' for { ti .task_id } ' )
171172 if ti .state in {State .SUCCESS , State .SKIPPED }:
172- self . _marquez .complete_run (
173+ _MARQUEZ .complete_run (
173174 run_id ,
174175 DagUtils .to_iso_8601 (ti .end_date ))
175176 else :
176- self . _marquez .fail_run (
177+ _MARQUEZ .fail_run (
177178 run_id ,
178179 DagUtils .to_iso_8601 (ti .end_date ))
179180
180181 def _extract_metadata (self , dagrun , task , ti = None ):
181182 extractor = self ._get_extractor (task )
182183 task_info = f'task_type={ task .__class__ .__name__ } ' \
183- f'airflow_dag_id={ self .dag_id } ' \
184- f'task_id={ task .task_id } ' \
185- f'airflow_run_id={ dagrun .run_id } '
184+ f'airflow_dag_id={ self .dag_id } ' \
185+ f'task_id={ task .task_id } ' \
186+ f'airflow_run_id={ dagrun .run_id } '
186187 if extractor :
187188 try :
188189 self .log .debug (
@@ -216,13 +217,13 @@ def _extract(self, extractor, task, ti):
216217 return extractor (task ).extract ()
217218
218219 def _get_extractor (self , task ):
219- extractor = self . _extractors .get (task .__class__ )
220- log .debug (f'extractor for { task .__class__ } is { extractor } ' )
220+ extractor = _EXTRACTORS .get (task .__class__ )
221+ self . log .debug (f'extractor for { task .__class__ } is { extractor } ' )
221222 return extractor
222223
223224 def _timed_log_message (self , start_time ):
224225 return f'airflow_dag_id={ self .dag_id } ' \
225- f'duration_ms={ (self ._now_ms () - start_time )} '
226+ f'duration_ms={ (self ._now_ms () - start_time )} '
226227
227228 def new_run_id (self ) -> str :
228229 return str (uuid4 ())
@@ -239,8 +240,6 @@ def _get_location(task):
239240 else :
240241 return get_location (task .dag .fileloc )
241242 except Exception :
242- log .warning (f"Failed to get location for task '{ task .task_id } '." ,
243- exc_info = True )
244243 return None
245244
246245 @staticmethod
0 commit comments