-
Notifications
You must be signed in to change notification settings - Fork 723
/
Copy pathbase_node.py
282 lines (233 loc) · 10.1 KB
/
base_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for TFX nodes."""
import abc
import copy
from typing import Any, Dict, Optional, Type
from tfx.dsl.components.base import base_driver
from tfx.dsl.components.base import base_executor
from tfx.dsl.components.base import executor_spec as executor_spec_module
from tfx.dsl.context_managers import dsl_context_registry
from tfx.dsl.experimental.node_execution_options import utils
from tfx.utils import deprecation_utils
from tfx.utils import doc_controls
from tfx.utils import json_utils
from tfx.utils import name_utils
import typing_extensions
def _abstract_property() -> Any:
"""Returns an abstract property for use in an ABC abstract class."""
return abc.abstractmethod(lambda: None)
class BaseNode(json_utils.Jsonable, abc.ABC):
"""Base class for a node in TFX pipeline."""
def __new__(cls, *args, **kwargs):
# Record invocation details for tracing. No backwards-compatibility
# guarantees; for TFX-internal use only.
result = super(BaseNode, cls).__new__(cls)
result._CONSTRUCT_CLS = cls
result._CONSTRUCT_ARGS = args
result._CONSTRUCT_KWARGS = kwargs
return result
def __init__(
self,
executor_spec: Optional[executor_spec_module.ExecutorSpec] = None,
driver_class: Optional[Type[base_driver.BaseDriver]] = None,
):
"""Initialize a node.
Args:
executor_spec: Optional instance of executor_spec.ExecutorSpec which
describes how to execute this node (optional, defaults to an empty
executor indicates no-op.
driver_class: Optional subclass of base_driver.BaseDriver as a custom
driver for this node (optional, defaults to base_driver.BaseDriver).
Nodes usually use the default driver class, but may override it.
"""
if executor_spec is None:
executor_spec = executor_spec_module.ExecutorClassSpec(
base_executor.EmptyExecutor)
if driver_class is None:
driver_class = base_driver.BaseDriver
self.executor_spec = executor_spec
self.driver_class = driver_class
self._upstream_nodes = set()
self._downstream_nodes = set()
self._id = None
self._node_execution_options: Optional[utils.NodeExecutionOptions] = None
dsl_context_registry.get().put_node(self)
@doc_controls.do_not_doc_in_subclasses
def to_json_dict(self) -> Dict[str, Any]:
"""Convert from an object to a JSON serializable dictionary."""
return dict((k, v)
for k, v in self.__dict__.items()
if k not in ['_upstream_nodes', '_downstream_nodes'])
@classmethod
@doc_controls.do_not_doc_in_subclasses
def get_class_type(cls) -> str:
nondeprecated_class = deprecation_utils.get_first_nondeprecated_class(cls)
# TODO(b/221166027): Turn strict_check=True once failing tests are fixed.
return name_utils.get_full_name(nondeprecated_class, strict_check=False)
@property
@doc_controls.do_not_doc_in_subclasses
def type(self) -> str:
return self.__class__.get_class_type()
@property
@deprecation_utils.deprecated(None,
'component_type is deprecated, use type instead'
)
@doc_controls.do_not_doc_in_subclasses
def component_type(self) -> str:
return self.type
@property
@doc_controls.do_not_doc_in_subclasses
def id(self) -> str:
"""Node id, unique across all TFX nodes in a pipeline.
If `id` is set by the user, return it directly.
Otherwise, return <node_class_name>.
Returns:
node id.
"""
if self._id:
return self._id
node_class = deprecation_utils.get_first_nondeprecated_class(self.__class__)
return node_class.__name__
@property
@deprecation_utils.deprecated(None,
'component_id is deprecated, use id instead')
@doc_controls.do_not_doc_in_subclasses
def component_id(self) -> str:
return self.id
@id.setter
@doc_controls.do_not_doc_in_subclasses
def id(self, id: str) -> None: # noqa: A002
self._id = id
# TODO(kmonte): Update this to Self once we're on 3.11 everywhere
@doc_controls.do_not_doc_in_subclasses
def with_id(self, id: str) -> typing_extensions.Self: # noqa: A002
self._id = id
return self
@property
@abc.abstractmethod
def inputs(self) -> Dict[str, Any]:
pass
@property
@abc.abstractmethod
def outputs(self) -> Dict[str, Any]:
pass
@property
@abc.abstractmethod
def exec_properties(self) -> Dict[str, Any]:
pass
@property
@doc_controls.do_not_doc_in_subclasses
def upstream_nodes(self):
return self._upstream_nodes
@property
@doc_controls.do_not_doc_in_subclasses
def node_execution_options(self) -> Optional[utils.NodeExecutionOptions]:
return self._node_execution_options
@node_execution_options.setter
@doc_controls.do_not_doc_in_subclasses
def node_execution_options(
self,
node_execution_options: utils.NodeExecutionOptions
):
self._node_execution_options = copy.deepcopy(node_execution_options)
# TODO(kmonte): Update this to Self once we're on 3.11 everywhere
def with_node_execution_options(
self, node_execution_options: utils.NodeExecutionOptions
) -> typing_extensions.Self:
self.node_execution_options = node_execution_options
return self
@doc_controls.do_not_doc_in_subclasses
def add_upstream_node(self, upstream_node):
"""Experimental: Add another component that must run before this one.
This method enables task-based dependencies by enforcing execution order for
synchronous pipelines on supported platforms. Currently, the supported
platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work
with asynchronous pipelines, sub-pipelines and pipelines with conditional
nodes. We also recommend relying on data for capturing dependencies where
possible to ensure data lineage is fully captured within MLMD.
It is symmetric with `add_downstream_node`.
Args:
upstream_node: a component that must run before this node.
"""
self._upstream_nodes.add(upstream_node)
if self not in upstream_node.downstream_nodes:
upstream_node.add_downstream_node(self)
@doc_controls.do_not_doc_in_subclasses
def add_upstream_nodes(self, upstream_nodes):
"""Experimental: Add components that must run before this one.
This method enables task-based dependencies by enforcing execution order for
synchronous pipelines on supported platforms. Currently, the supported
platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work
with asynchronous pipelines, sub-pipelines and pipelines with conditional
nodes. We also recommend relying on data for capturing dependencies where
possible to ensure data lineage is fully captured within MLMD.
Args:
upstream_nodes: a list of components that must run before this node.
"""
self._upstream_nodes.update(upstream_nodes)
for upstream_node in upstream_nodes:
if self not in upstream_node.downstream_nodes:
upstream_node.add_downstream_node(self)
@doc_controls.do_not_doc_in_subclasses
def remove_upstream_node(self, upstream_node):
self._upstream_nodes.remove(upstream_node)
if self in upstream_node.downstream_nodes:
upstream_node.remove_downstream_node(self)
@property
@doc_controls.do_not_doc_in_subclasses
def downstream_nodes(self):
return self._downstream_nodes
@doc_controls.do_not_doc_in_subclasses
def add_downstream_node(self, downstream_node):
"""Experimental: Add another component that must run after this one.
This method enables task-based dependencies by enforcing execution order for
synchronous pipelines on supported platforms. Currently, the supported
platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work
with asynchronous pipelines, sub-pipelines and pipelines with conditional
nodes. We also recommend relying on data for capturing dependencies where
possible to ensure data lineage is fully captured within MLMD.
It is symmetric with `add_upstream_node`.
Args:
downstream_node: a component that must run after this node.
"""
self._downstream_nodes.add(downstream_node)
if self not in downstream_node.upstream_nodes:
downstream_node.add_upstream_node(self)
@doc_controls.do_not_doc_in_subclasses
def add_downstream_nodes(self, downstream_nodes):
"""Experimental: Add another component that must run after this one.
This method enables task-based dependencies by enforcing execution order for
synchronous pipelines on supported platforms. Currently, the supported
platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work
with asynchronous pipelines, sub-pipelines and pipelines with conditional
nodes. We also recommend relying on data for capturing dependencies where
possible to ensure data lineage is fully captured within MLMD.
It is symmetric with `add_upstream_nodes`.
Args:
downstream_nodes: a list of components that must run after this node.
"""
self._downstream_nodes.update(downstream_nodes)
for downstream_node in downstream_nodes:
if self not in downstream_node.upstream_nodes:
downstream_node.add_upstream_node(self)
@doc_controls.do_not_doc_in_subclasses
def remove_downstream_node(self, downstream_node):
self._downstream_nodes.remove(downstream_node)
if self in downstream_node.upstream_nodes:
downstream_node.remove_upstream_node(self)