Skip to content

[dagster-airlift][migration 1/n] explicit methods for external jobs, use them in airlift #29706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class AssetLayer(NamedTuple):
check_names_by_asset_key_by_node_handle: Mapping[
NodeHandle, Mapping[AssetKey, AbstractSet[str]]
]
asset_keys: AbstractSet[AssetKey]
outer_node_names_by_asset_key: Mapping[AssetKey, str] = {}

@staticmethod
Expand Down Expand Up @@ -131,6 +132,23 @@ def from_graph_and_assets_node_mapping(
node_output_handles_by_asset_check_key=node_output_handles_by_asset_check_key,
check_names_by_asset_key_by_node_handle=check_names_by_asset_key_by_node_handle,
outer_node_names_by_asset_key=outer_node_names_by_asset_key,
asset_keys=set(outer_node_names_by_asset_key.keys()),
)

@staticmethod
def for_external_job(asset_keys: Iterable[AssetKey]) -> "AssetLayer":
from dagster._core.definitions.asset_graph import AssetGraph

return AssetLayer(
asset_graph=AssetGraph.from_assets([]),
assets_defs_by_node_handle={},
asset_keys_by_node_input_handle={},
asset_keys_by_node_output_handle={},
check_key_by_node_output_handle={},
node_output_handles_by_asset_check_key={},
check_names_by_asset_key_by_node_handle={},
outer_node_names_by_asset_key={},
asset_keys=asset_keys,
)

@property
Expand Down
23 changes: 23 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/job_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,28 @@ def dagster_internal_init(
_was_explicitly_provided_resources=_was_explicitly_provided_resources,
)

@staticmethod
def for_external_job(
asset_keys: Iterable[AssetKey],
name: str,
metadata: Optional[Mapping[str, Any]] = None,
tags: Optional[Mapping[str, Any]] = None,
) -> "JobDefinition":
from dagster._core.definitions import op

@op(name=f"{name}_op_inner")
def _op():
pass

return JobDefinition(
graph_def=GraphDefinition(name=name, node_defs=[_op]),
resource_defs={},
executor_def=None,
asset_layer=AssetLayer.for_external_job(asset_keys),
metadata=metadata,
tags=tags,
)

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -1321,6 +1343,7 @@ def _infer_asset_layer_from_source_asset_deps(job_graph_def: GraphDefinition) ->
check_names_by_asset_key_by_node_handle={},
check_key_by_node_output_handle={},
outer_node_names_by_asset_key={},
asset_keys={},
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1642,17 +1642,17 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
# key. This is the node that will be used to populate the AssetNodeSnap. We need to identify
# a primary node because the same asset can be materialized as part of multiple jobs.
primary_node_pairs_by_asset_key: dict[AssetKey, tuple[NodeOutputHandle, JobDefinition]] = {}
job_defs_by_asset_key: dict[AssetKey, list[JobDefinition]] = {}
job_defs_by_asset_key: dict[AssetKey, list[JobDefinition]] = defaultdict(list)
for job_def in repo.get_all_jobs():
asset_layer = job_def.asset_layer
for asset_key in asset_layer.asset_keys:
job_defs_by_asset_key[asset_key].append(job_def)
asset_keys_by_node_output = asset_layer.asset_keys_by_node_output_handle
for node_output_handle, asset_key in asset_keys_by_node_output.items():
if asset_key not in asset_layer.asset_keys_for_node(node_output_handle.node_handle):
continue
if asset_key not in primary_node_pairs_by_asset_key:
primary_node_pairs_by_asset_key[asset_key] = (node_output_handle, job_def)
job_defs_by_asset_key.setdefault(asset_key, []).append(job_def)

asset_node_snaps: list[AssetNodeSnap] = []
asset_graph = repo.asset_graph
for key in sorted(asset_graph.get_all_asset_keys()):
Expand Down Expand Up @@ -1680,7 +1680,6 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
pools = {op_def.pool for op_def in op_defs if op_def.pool}
op_names = sorted([str(handle) for handle in node_handles])
op_name = graph_name or next(iter(op_names), None) or node_def.name
job_names = sorted([jd.name for jd in job_defs_by_asset_key[key]])
compute_kind = node_def.tags.get(COMPUTE_KIND_TAG)
node_definition_name = node_def.name

Expand All @@ -1698,7 +1697,6 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
pools = set()
op_names = []
op_name = None
job_names = []
compute_kind = None
node_definition_name = None
output_name = None
Expand Down Expand Up @@ -1741,7 +1739,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
node_definition_name=node_definition_name,
graph_name=graph_name,
description=asset_node.description,
job_names=job_names,
job_names=sorted([jd.name for jd in job_defs_by_asset_key[key]]),
partitions=(
PartitionsSnap.from_def(asset_node.partitions_def)
if asset_node.partitions_def
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from dagster._core.definitions.decorators.asset_check_decorator import asset_check
from dagster._core.definitions.executor_definition import multi_or_in_process_executor
from dagster._core.definitions.metadata.metadata_value import TextMetadataValue
from dagster._core.definitions.partition import PartitionedConfig, StaticPartitionsDefinition
from dagster._core.errors import DagsterInvalidSubsetError
from dagster._loggers import default_loggers
Expand Down Expand Up @@ -1484,3 +1485,26 @@ def repo():
AutomationConditionSensorDefinition("a", target=[asset1]),
AutomationConditionSensorDefinition("b", target=[asset1, asset2]),
]


def test_external_job_assets() -> None:
@asset
def my_asset():
pass

my_job = JobDefinition.for_external_job(
asset_keys=[my_asset.key],
name="my_job",
metadata={"foo": "bar"},
tags={"baz": "qux"},
)

assert set(my_job.asset_layer.asset_keys) == {my_asset.key}
assert my_job.metadata == {"foo": TextMetadataValue("bar")}
assert my_job.tags == {"baz": "qux"}

@repository
def repo():
return [my_job, my_asset]

assert repo.assets_defs_by_key[my_asset.key] == my_asset
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Union

from dagster import JobDefinition
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.decorators.job_decorator import job
from dagster._core.definitions.unresolved_asset_job_definition import (
UnresolvedAssetJobDefinition,
define_asset_job,
)

from dagster_airlift.core.dag_asset import dag_asset_metadata
from dagster_airlift.core.serialization.serialized_data import (
Expand All @@ -20,37 +14,18 @@

def construct_dag_jobs(
serialized_data: SerializedAirflowDefinitionsData,
mapped_assets: Mapping[str, Sequence[Union[AssetSpec, AssetsDefinition]]],
) -> Sequence[Union[UnresolvedAssetJobDefinition, JobDefinition]]:
mapped_specs: Mapping[str, Sequence[AssetSpec]],
) -> Sequence[JobDefinition]:
"""Constructs a job for each DAG in the serialized data. The job will be used to power runs."""
jobs = []
for dag_id, dag_data in serialized_data.dag_datas.items():
assets_produced_by_dag = mapped_assets.get(dag_id)
if assets_produced_by_dag:
jobs.append(dag_asset_job(dag_data, assets_produced_by_dag))
else:
jobs.append(dag_non_asset_job(dag_data))
return jobs


def dag_asset_job(
dag_data: SerializedDagData, assets: Sequence[Union[AssetsDefinition, AssetSpec]]
) -> UnresolvedAssetJobDefinition:
specs: list[AssetSpec] = []
for asset in assets:
if not isinstance(asset, AssetSpec):
raise Exception(
"Fully resolved assets definition passed to dag job creation not yet supported."
)
specs.append(asset)
# Eventually we'll have to handle fully resolved AssetsDefinition objects here but it's a whole
# can of worms. For now, we enforce that only assetSpec objects are passed in.
return define_asset_job(
name=job_name(dag_data.dag_id),
metadata=dag_asset_metadata(dag_data.dag_info),
tags=airflow_job_tags(dag_data.dag_id),
selection=[asset.key for asset in specs],
)
return [
JobDefinition.for_external_job(
asset_keys=[spec.key for spec in mapped_specs[dag_id]],
name=job_name(dag_id),
metadata=dag_asset_metadata(dag_data.dag_info),
tags=airflow_job_tags(dag_id),
)
for dag_id, dag_data in serialized_data.dag_datas.items()
]


def job_name(dag_id: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dagster._annotations import beta
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_spec import map_asset_specs
from dagster._core.definitions.decorators.asset_decorator import multi_asset
from dagster._core.definitions.definitions_load_context import StateBackedDefinitionsLoader
from dagster._core.definitions.external_asset import external_asset_from_spec
from dagster._core.definitions.sensor_definition import DefaultSensorStatus
Expand Down Expand Up @@ -368,35 +367,20 @@ def construct_dataset_specs(
)


def _get_dag_to_asset_mapping(
mapped_assets: Sequence[AssetSpec],
def _get_dag_to_spec_mapping(
mapped_assets: Sequence[Union[AssetSpec, AssetsDefinition]],
) -> Mapping[str, Sequence[Union[AssetSpec, AssetsDefinition]]]:
res = defaultdict(list)
for asset in mapped_assets:
if is_task_mapped_asset_spec(asset):
for task_handle in task_handles_for_spec(asset):
res[task_handle.dag_id].append(asset)
elif is_dag_mapped_asset_spec(asset):
for dag_handle in dag_handles_for_spec(asset):
res[dag_handle.dag_id].append(asset)
for spec in spec_iterator(mapped_assets):
if is_task_mapped_asset_spec(spec):
for task_handle in task_handles_for_spec(spec):
res[task_handle.dag_id].append(spec)
elif is_dag_mapped_asset_spec(spec):
for dag_handle in dag_handles_for_spec(spec):
res[dag_handle.dag_id].append(spec)
return res


def _global_assets_def(
specs: Sequence[AssetSpec],
instance_name: str,
) -> AssetsDefinition:
@multi_asset(
specs=specs,
name=f"{instance_name}_global_assets_def",
can_subset=True,
)
def _global_assets():
pass

return _global_assets


def build_job_based_airflow_defs(
*,
airflow_instance: AirflowInstance,
Expand All @@ -411,29 +395,20 @@ def build_job_based_airflow_defs(
source_code_retrieval_enabled=True,
retrieval_filter=AirflowFilter(),
).get_or_fetch_state()
assets_with_airflow_data = cast(
"Sequence[AssetSpec]",
_apply_airflow_data_to_specs(
[
*mapped_assets,
*construct_dataset_specs(serialized_airflow_data),
],
serialized_airflow_data,
),
assets_with_airflow_data = _apply_airflow_data_to_specs(
[
*mapped_assets,
*construct_dataset_specs(serialized_airflow_data),
],
serialized_airflow_data,
)
dag_to_assets_mapping = _get_dag_to_asset_mapping(assets_with_airflow_data)
dag_to_spec_mapping = _get_dag_to_spec_mapping(assets_with_airflow_data)
jobs = construct_dag_jobs(
serialized_data=serialized_airflow_data,
mapped_assets=dag_to_assets_mapping,
mapped_specs=dag_to_spec_mapping,
)

full_assets_def = _global_assets_def(
specs=[spec for assets in dag_to_assets_mapping.values() for spec in spec_iterator(assets)],
instance_name=airflow_instance.name,
)

return Definitions.merge(
replace_assets_in_defs(defs=mapped_defs, assets=[full_assets_def]),
replace_assets_in_defs(defs=mapped_defs, assets=assets_with_airflow_data),
Definitions(jobs=jobs),
build_airflow_monitoring_defs(airflow_instance=airflow_instance),
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from dagster._core.definitions.definitions_class import get_job_from_defs
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.reconstruct import initialize_repository_def_from_pointer
from dagster._core.definitions.unresolved_asset_job_definition import UnresolvedAssetJobDefinition
from dagster._utils.test.definitions import (
definitions,
scoped_reconstruction_metadata,
Expand Down Expand Up @@ -840,14 +839,19 @@ def test_load_job_defs() -> None:
key="a", metadata=metadata_for_task_mapping(task_id="producing_task", dag_id="producer1")
)

# Add an additional materializable asset to the same task
@asset(metadata=metadata_for_task_mapping(task_id="producing_task", dag_id="producer1"))
def b():
pass

defs = build_job_based_airflow_defs(
airflow_instance=af_instance,
mapped_defs=Definitions(assets=[spec]),
mapped_defs=Definitions(assets=[spec, b]),
)
Definitions.validate_loadable(defs)
assert isinstance(get_job_from_defs("producer1", defs), UnresolvedAssetJobDefinition)
assert isinstance(get_job_from_defs("producer2", defs), UnresolvedAssetJobDefinition)
assert isinstance(get_job_from_defs("consumer1", defs), UnresolvedAssetJobDefinition)
assert isinstance(get_job_from_defs("producer1", defs), JobDefinition)
assert isinstance(get_job_from_defs("producer2", defs), JobDefinition)
assert isinstance(get_job_from_defs("consumer1", defs), JobDefinition)
assert isinstance(get_job_from_defs("consumer2", defs), JobDefinition)

airflow_defs_data = AirflowDefinitionsData(
Expand All @@ -864,7 +868,7 @@ def test_load_job_defs() -> None:
DagHandle(dag_id="consumer2"): repo.get_job("consumer2"),
}
assert airflow_defs_data.assets_per_job == {
"producer1": {AssetKey("example1"), AssetKey("a")},
"producer1": {AssetKey("example1"), AssetKey("a"), AssetKey("b")},
"producer2": {AssetKey("example1")},
"consumer1": {AssetKey("example2")},
"consumer2": set(),
Expand Down