Skip to content

Commit 6003a54

Browse files
authored
[dagster-airlift] filter (#28606)
## Summary & Motivation Write up a simple filter abstraction which limits the number of dags that we retrieve up front. This would be for perf reasons. ## How I Tested These Changes ## Changelog - Added `AirflowFilter` API for use with `dagster-airlift`, allows you to filter down the set of dags retrieved up front for perf improvements.
1 parent 71f05ad commit 6003a54

File tree

14 files changed

+251
-17
lines changed

14 files changed

+251
-17
lines changed

Diff for: docs/docs/guides/migrate/airflow-to-dagster/federation/observe.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ For a full list of `dagster-airlift` classes and methods, see the [API docs](htt
2222

2323
## Observe the `warehouse` Airflow instance
2424

25-
Next, declare a reference to the `warehouse` Airflow instance, which is running at `http://localhost:8081`:
25+
Next, in your `airlift_federation_tutorial/dagster_defs/definitions.py` file, declare a reference to the `warehouse` Airflow instance, which is running at `http://localhost:8081`:
2626

2727
<CodeExample
2828
path="airlift-federation-tutorial/snippets/observe.py"

Diff for: examples/airlift-federation-tutorial/Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ help:
2727
airflow_install:
2828
pip install uv && \
2929
uv pip install dagster-airlift[tutorial] && \
30-
uv pip install -e $(MAKEFILE_DIR)
30+
uv pip install -e $(MAKEFILE_DIR) && \
31+
uv pip install duckdb pandas
3132

3233
airflow_setup: wipe
3334
mkdir -p $$WAREHOUSE_AIRFLOW_HOME

Diff for: examples/airlift-federation-tutorial/scripts/airflow_setup.sh

+102
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ if [[ "$DAGS_FOLDER" != /* ]] || [[ "$AIRFLOW_HOME_DIR" != /* ]]; then
1717
exit 1
1818
fi
1919

20+
# Generate a unique secret key
21+
SECRET_KEY=$(openssl rand -hex 30)
22+
2023
# Create the airflow.cfg file in the specified AIRFLOW_HOME_DIR
2124
cat <<EOL > $AIRFLOW_HOME_DIR/airflow.cfg
2225
[core]
@@ -28,7 +31,106 @@ auth_backend = airflow.api.auth.backend.basic_auth
2831
[webserver]
2932
expose_config = True
3033
web_server_port = $PORT
34+
secret_key = $SECRET_KEY
35+
36+
EOL
37+
38+
# Create the webserver_config.py file
39+
cat <<EOL > $AIRFLOW_HOME_DIR/webserver_config.py
40+
# -*- coding: utf-8 -*-
41+
#
42+
# Licensed to the Apache Software Foundation (ASF) under one
43+
# or more contributor license agreements. See the NOTICE file
44+
# distributed with this work for additional information
45+
# regarding copyright ownership. The ASF licenses this file
46+
# to you under the Apache License, Version 2.0 (the
47+
# "License"); you may not use this file except in compliance
48+
# with the License. You may obtain a copy of the License at
49+
#
50+
# http://www.apache.org/licenses/LICENSE-2.0
51+
#
52+
# Unless required by applicable law or agreed to in writing,
53+
# software distributed under the License is distributed on an
54+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
55+
# KIND, either express or implied. See the License for the
56+
# specific language governing permissions and limitations
57+
# under the License.
58+
59+
import os
60+
from flask_appbuilder.security.manager import AUTH_DB
61+
62+
# from airflow.www.fab_security.manager import AUTH_LDAP
63+
64+
basedir = os.path.abspath(os.path.dirname(__file__))
65+
66+
# The SQLAlchemy connection string.
67+
SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(basedir, 'webserver.db')
68+
69+
# Flask-WTF flag for CSRF
70+
WTF_CSRF_ENABLED = True
71+
72+
# Use a unique cookie name for this instance based on port
73+
SESSION_COOKIE_NAME = "airflow_session_${PORT}"
74+
75+
# ----------------------------------------------------
76+
# AUTHENTICATION CONFIG
77+
# ----------------------------------------------------
78+
# For details on how to set up each of the following authentication, see
79+
# http://flask-appbuilder.readthedocs.io/en/latest/security.html
80+
81+
# The authentication type
82+
# AUTH_OID : Is for OpenID
83+
# AUTH_DB : Is for database
84+
# AUTH_LDAP : Is for LDAP
85+
# AUTH_REMOTE_USER : Is for using REMOTE_USER from web server
86+
# AUTH_OAUTH : Is for OAuth
87+
AUTH_TYPE = AUTH_DB
88+
89+
# When using LDAP Auth, setup the ldap server
90+
# LDAP_SERVER = "ldap://ldapserver.new"
91+
# LDAP_PORT = 389
92+
# LDAP_USE_TLS = False
93+
# LDAP_SEARCH_SCOPE = "LEVEL"
94+
# LDAP_BIND_USER = "uid=admin,ou=users,dc=example,dc=com"
95+
# LDAP_BIND_PASSWORD = "admin_password"
96+
# LDAP_BASEDN = "dc=example,dc=com"
97+
# LDAP_USER_DN = "ou=users"
98+
# LDAP_USER_FILTER = "(uid=%s)"
99+
# LDAP_GROUP_DN = "ou=groups"
100+
# LDAP_GROUP_FILTER = "(member=%s)"
101+
# LDAP_USER_NAME_FORMAT = "uid=%s,ou=users,dc=example,dc=com"
102+
# LDAP_GROUP_NAME_FORMAT = "cn=%s,ou=groups,dc=example,dc=com"
103+
104+
# Uncomment to setup Full admin role name
105+
# AUTH_ROLE_ADMIN = 'Admin'
106+
107+
# Uncomment to setup Public role name, no authentication needed
108+
# AUTH_ROLE_PUBLIC = 'Public'
109+
110+
# Will allow user self registration
111+
# AUTH_USER_REGISTRATION = True
112+
113+
# The default user self registration role
114+
# AUTH_USER_REGISTRATION_ROLE = "Public"
31115
116+
# When using OAuth Auth, uncomment to setup provider(s) info
117+
# Google OAuth example:
118+
# OAUTH_PROVIDERS = [{
119+
# 'name':'google',
120+
# 'token_key':'access_token',
121+
# 'icon':'fa-google',
122+
# 'remote_app': {
123+
# 'api_base_url':'https://www.googleapis.com/oauth2/v2/',
124+
# 'client_kwargs':{
125+
# 'scope': 'email profile'
126+
# },
127+
# 'access_token_url':'https://accounts.google.com/o/oauth2/token',
128+
# 'authorize_url':'https://accounts.google.com/o/oauth2/auth',
129+
# 'request_token_url': None,
130+
# 'client_id': GOOGLE_KEY,
131+
# 'client_secret': GOOGLE_SECRET_KEY,
132+
# }
133+
# }]
32134
EOL
33135

34136
# call airflow command to create the default user

Diff for: python_modules/libraries/dagster-airlift/dagster_airlift/core/airflow_instance.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from dagster._record import record
1313
from dagster._time import get_current_datetime
1414

15+
from dagster_airlift.core.filter import AirflowFilter
1516
from dagster_airlift.core.serialization.serialized_data import DagInfo, TaskInfo
1617

1718
TERMINAL_STATES = {"success", "failed", "canceled"}
@@ -79,13 +80,21 @@ def normalized_name(self) -> str:
7980
def get_api_url(self) -> str:
8081
return f"{self.auth_backend.get_webserver_url()}/api/v1"
8182

82-
def list_dags(self) -> list["DagInfo"]:
83+
def list_dags(self, retrieval_filter: Optional[AirflowFilter] = None) -> list["DagInfo"]:
84+
retrieval_filter = retrieval_filter or AirflowFilter()
8385
dag_responses = []
8486
webserver_url = self.auth_backend.get_webserver_url()
87+
8588
while True:
89+
params = retrieval_filter.augment_request_params(
90+
{
91+
"limit": self.dag_list_limit,
92+
"offset": len(dag_responses),
93+
}
94+
)
8695
response = self.auth_backend.get_session().get(
8796
f"{self.get_api_url()}/dags",
88-
params={"limit": self.dag_list_limit, "offset": len(dag_responses)},
97+
params=params,
8998
)
9099
if response.status_code == 200:
91100
dags = response.json()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from collections.abc import Sequence
2+
from typing import Optional
3+
4+
from dagster._record import record
5+
6+
7+
@record
8+
class AirflowFilter:
9+
"""Filters the set of Airflow objects to fetch.
10+
11+
Args:
12+
dag_id_ilike (Optional[str]): A pattern used to match the set of dag_ids to retrieve. Uses the sql ILIKE operator Airflow-side.
13+
airflow_tags (Optional[Sequence[str]]): Filters down to the set of Airflow DAGs whcih contain the particular tags provided.
14+
"""
15+
16+
dag_id_ilike: Optional[str] = None
17+
airflow_tags: Optional[Sequence[str]] = None
18+
19+
def augment_request_params(self, request_params: dict) -> dict:
20+
new_request_params = request_params.copy()
21+
if self.dag_id_ilike is not None:
22+
new_request_params["dag_id_pattern"] = self.dag_id_ilike
23+
if self.airflow_tags is not None:
24+
new_request_params["tags"] = self.airflow_tags
25+
return new_request_params

Diff for: python_modules/libraries/dagster-airlift/dagster_airlift/core/load_defs.py

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from dagster_airlift.core.airflow_defs_data import MappedAsset
1717
from dagster_airlift.core.airflow_instance import AirflowInstance
18+
from dagster_airlift.core.filter import AirflowFilter
1819
from dagster_airlift.core.sensor.event_translation import (
1920
DagsterEventTransformerFn,
2021
default_event_transformer,
@@ -38,6 +39,7 @@
3839
@dataclass
3940
class AirflowInstanceDefsLoader(StateBackedDefinitionsLoader[SerializedAirflowDefinitionsData]):
4041
airflow_instance: AirflowInstance
42+
retrieval_filter: AirflowFilter
4143
mapped_assets: Sequence[MappedAsset]
4244
source_code_retrieval_enabled: Optional[bool]
4345
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS
@@ -54,6 +56,7 @@ def fetch_state(self) -> SerializedAirflowDefinitionsData:
5456
dag_selector_fn=self.dag_selector_fn,
5557
automapping_enabled=False,
5658
source_code_retrieval_enabled=self.source_code_retrieval_enabled,
59+
retrieval_filter=self.retrieval_filter,
5760
)
5861

5962
def defs_from_state( # pyright: ignore[reportIncompatibleMethodOverride]
@@ -77,6 +80,7 @@ def build_defs_from_airflow_instance(
7780
dag_selector_fn: Optional[Callable[[DagInfo], bool]] = None,
7881
source_code_retrieval_enabled: Optional[bool] = None,
7982
default_sensor_status: Optional[DefaultSensorStatus] = None,
83+
retrieval_filter: Optional[AirflowFilter] = None,
8084
) -> Definitions:
8185
"""Builds a :py:class:`dagster.Definitions` object from an Airflow instance.
8286
@@ -221,6 +225,7 @@ def only_include_dag(dag_info: DagInfo) -> bool:
221225
mapped_assets=mapped_assets,
222226
dag_selector_fn=dag_selector_fn,
223227
source_code_retrieval_enabled=source_code_retrieval_enabled,
228+
retrieval_filter=retrieval_filter or AirflowFilter(),
224229
).get_or_fetch_state()
225230
mapped_and_constructed_assets = [
226231
*_apply_airflow_data_to_specs(mapped_assets, serialized_airflow_data),
@@ -292,12 +297,14 @@ def enrich_airflow_mapped_assets(
292297
mapped_assets: Sequence[MappedAsset],
293298
airflow_instance: AirflowInstance,
294299
source_code_retrieval_enabled: Optional[bool],
300+
retrieval_filter: Optional[AirflowFilter] = None,
295301
) -> Sequence[AssetsDefinition]:
296302
"""Enrich Airflow-mapped assets with metadata from the provided :py:class:`AirflowInstance`."""
297303
serialized_data = AirflowInstanceDefsLoader(
298304
airflow_instance=airflow_instance,
299305
mapped_assets=mapped_assets,
300306
source_code_retrieval_enabled=source_code_retrieval_enabled,
307+
retrieval_filter=retrieval_filter or AirflowFilter(),
301308
).get_or_fetch_state()
302309
return list(_apply_airflow_data_to_specs(mapped_assets, serialized_data))
303310

@@ -308,12 +315,14 @@ def load_airflow_dag_asset_specs(
308315
mapped_assets: Optional[Sequence[MappedAsset]] = None,
309316
dag_selector_fn: Optional[Callable[[DagInfo], bool]] = None,
310317
source_code_retrieval_enabled: Optional[bool] = None,
318+
retrieval_filter: Optional[AirflowFilter] = None,
311319
) -> Sequence[AssetSpec]:
312320
"""Load asset specs for Airflow DAGs from the provided :py:class:`AirflowInstance`, and link upstreams from mapped assets."""
313321
serialized_data = AirflowInstanceDefsLoader(
314322
airflow_instance=airflow_instance,
315323
mapped_assets=mapped_assets or [],
316324
dag_selector_fn=dag_selector_fn,
317325
source_code_retrieval_enabled=source_code_retrieval_enabled,
326+
retrieval_filter=retrieval_filter or AirflowFilter(),
318327
).get_or_fetch_state()
319328
return list(spec_iterator(construct_dag_assets_defs(serialized_data)))

Diff for: python_modules/libraries/dagster-airlift/dagster_airlift/core/serialization/compute.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from dagster_airlift.core.airflow_instance import AirflowInstance, DagInfo
1010
from dagster_airlift.core.dag_asset import get_leaf_assets_for_dag
11+
from dagster_airlift.core.filter import AirflowFilter
1112
from dagster_airlift.core.serialization.serialized_data import (
1213
DagHandle,
1314
KeyScopedDagHandles,
@@ -144,10 +145,11 @@ def fetch_all_airflow_data(
144145
mapping_info: AirliftMetadataMappingInfo,
145146
dag_selector_fn: Optional[DagSelectorFn],
146147
automapping_enabled: bool,
148+
retrieval_filter: AirflowFilter,
147149
) -> FetchedAirflowData:
148150
dag_infos = {
149151
dag.dag_id: dag
150-
for dag in airflow_instance.list_dags()
152+
for dag in airflow_instance.list_dags(retrieval_filter=retrieval_filter)
151153
if dag_selector_fn is None or dag_selector_fn(dag)
152154
}
153155
# To limit the number of API calls, only fetch task infos for the dags that we absolutely have to.
@@ -185,10 +187,15 @@ def compute_serialized_data(
185187
dag_selector_fn: Optional[DagSelectorFn],
186188
automapping_enabled: bool,
187189
source_code_retrieval_enabled: Optional[bool],
190+
retrieval_filter: AirflowFilter,
188191
) -> "SerializedAirflowDefinitionsData":
189192
mapping_info = build_airlift_metadata_mapping_info(mapped_assets)
190193
fetched_airflow_data = fetch_all_airflow_data(
191-
airflow_instance, mapping_info, dag_selector_fn, automapping_enabled=automapping_enabled
194+
airflow_instance,
195+
mapping_info,
196+
dag_selector_fn,
197+
automapping_enabled=automapping_enabled,
198+
retrieval_filter=retrieval_filter,
192199
)
193200
source_code_retrieval_enabled = infer_code_retrieval_enabled(
194201
source_code_retrieval_enabled, fetched_airflow_data

Diff for: python_modules/libraries/dagster-airlift/dagster_airlift/test/airflow_test_instance.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from collections.abc import Sequence
2+
from collections.abc import Mapping, Sequence
33
from datetime import datetime, timedelta
44
from typing import Any, Optional
55

@@ -8,6 +8,7 @@
88
from dagster_airlift.core import AirflowInstance
99
from dagster_airlift.core.airflow_instance import DagInfo, DagRun, TaskInfo, TaskInstance
1010
from dagster_airlift.core.basic_auth import AirflowAuthBackend
11+
from dagster_airlift.core.filter import AirflowFilter
1112

1213

1314
class DummyAuthBackend(AirflowAuthBackend):
@@ -18,6 +19,9 @@ def get_webserver_url(self) -> str:
1819
return "http://dummy.domain"
1920

2021

22+
DEFAULT_FAKE_INSTANCE_NAME = "test_instance"
23+
24+
2125
class AirflowInstanceFake(AirflowInstance):
2226
"""Loads a set of provided DagInfo, TaskInfo, and TaskInstance objects for testing."""
2327

@@ -50,11 +54,29 @@ def __init__(
5054
self._max_runs_per_batch = max_runs_per_batch
5155
super().__init__(
5256
auth_backend=DummyAuthBackend(),
53-
name="test_instance" if instance_name is None else instance_name,
57+
name=DEFAULT_FAKE_INSTANCE_NAME if instance_name is None else instance_name,
5458
)
5559

56-
def list_dags(self) -> list[DagInfo]:
57-
return list(self._dag_infos_by_dag_id.values())
60+
def list_dags(self, retrieval_filter: Optional[AirflowFilter] = None) -> list[DagInfo]:
61+
retrieval_filter = retrieval_filter or AirflowFilter()
62+
# Very basic filtering for testing purposes
63+
dags_to_retrieve = list(self._dag_infos_by_dag_id.values())
64+
if retrieval_filter.dag_id_ilike:
65+
dags_to_retrieve = [
66+
dag_info
67+
for dag_info in dags_to_retrieve
68+
if retrieval_filter.dag_id_ilike in dag_info.dag_id
69+
]
70+
if retrieval_filter.airflow_tags:
71+
dags_to_retrieve = [
72+
dag_info
73+
for dag_info in dags_to_retrieve
74+
if all(
75+
tag in dag_info.metadata.get("tags", [])
76+
for tag in retrieval_filter.airflow_tags
77+
)
78+
]
79+
return dags_to_retrieve
5880

5981
def list_variables(self) -> list[dict[str, Any]]:
6082
return self._variables
@@ -150,11 +172,13 @@ def get_dag_source_code(self, file_token: str) -> str:
150172
return "indicates found source code"
151173

152174

153-
def make_dag_info(dag_id: str, file_token: Optional[str]) -> DagInfo:
175+
def make_dag_info(
176+
instance_name: str, dag_id: str, file_token: Optional[str], dag_props: Mapping[str, Any]
177+
) -> DagInfo:
154178
return DagInfo(
155179
webserver_url="http://dummy.domain",
156180
dag_id=dag_id,
157-
metadata={"file_token": file_token if file_token else "dummy_file_token"},
181+
metadata={"file_token": file_token if file_token else "dummy_file_token", **dag_props},
158182
)
159183

160184

@@ -220,6 +244,7 @@ def make_instance(
220244
task_deps: dict[str, list[str]] = {},
221245
instance_name: Optional[str] = None,
222246
max_runs_per_batch: Optional[int] = None,
247+
dag_props: dict[str, Any] = {},
223248
) -> AirflowInstanceFake:
224249
"""Constructs DagInfo, TaskInfo, and TaskInstance objects from provided data.
225250
@@ -231,7 +256,12 @@ def make_instance(
231256
dag_infos = []
232257
task_infos = []
233258
for dag_id, task_ids in dag_and_task_structure.items():
234-
dag_info = make_dag_info(dag_id=dag_id, file_token=dag_id)
259+
dag_info = make_dag_info(
260+
instance_name=instance_name or DEFAULT_FAKE_INSTANCE_NAME,
261+
dag_id=dag_id,
262+
file_token=dag_id,
263+
dag_props=dag_props.get(dag_id, {}),
264+
)
235265
dag_infos.append(dag_info)
236266
task_infos.extend(
237267
[

0 commit comments

Comments
 (0)