Skip to content

Commit b829fb2

Browse files
committed
Fix import error display for files with no DAGs in local DAG bundle
1 parent 0cf89fa commit b829fb2

File tree

2 files changed

+248
-30
lines changed

2 files changed

+248
-30
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from collections.abc import Iterable, Sequence
20-
from itertools import groupby
21-
from operator import itemgetter
19+
from collections.abc import Sequence
2220
from typing import Annotated
2321

2422
from fastapi import Depends, HTTPException, status
25-
from sqlalchemy import and_, select
23+
from sqlalchemy import and_, exists, select
2624

2725
from airflow.api_fastapi.app import get_auth_manager
2826
from airflow.api_fastapi.auth.managers.models.batch_apis import IsAuthorizedDagRequest
@@ -80,10 +78,41 @@ def get_import_error(
8078

8179
auth_manager = get_auth_manager()
8280
readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
81+
82+
if error.bundle_name is None or error.filename is None:
83+
raise HTTPException(
84+
status.HTTP_404_NOT_FOUND,
85+
f"The ImportError with import_error_id: `{import_error_id}` has invalid bundle_name or filename",
86+
)
87+
8388
# We need file_dag_ids as a set for intersection, issubset operations
89+
# Check DAGs in the file using relative_fileloc and bundle_name
8490
file_dag_ids = set(
85-
session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == error.filename)).all()
91+
session.scalars(
92+
select(DagModel.dag_id).where(
93+
and_(
94+
DagModel.relative_fileloc == error.filename,
95+
DagModel.bundle_name == error.bundle_name,
96+
)
97+
)
98+
).all()
8699
)
100+
101+
# If no DAGs exist for this file, check if user has access to any DAG in the bundle
102+
if not file_dag_ids:
103+
bundle_dag_ids = set(
104+
session.scalars(select(DagModel.dag_id).where(DagModel.bundle_name == error.bundle_name)).all()
105+
)
106+
readable_bundle_dag_ids = readable_dag_ids.intersection(bundle_dag_ids)
107+
# Can the user read any DAGs in the bundle?
108+
if not readable_bundle_dag_ids:
109+
raise HTTPException(
110+
status.HTTP_403_FORBIDDEN,
111+
"You do not have read permission on any of the DAGs in the bundle",
112+
)
113+
# User has access to bundle, return the error
114+
return error
115+
87116
# Can the user read any DAGs in the file?
88117
if not readable_dag_ids.intersection(file_dag_ids):
89118
raise HTTPException(
@@ -129,24 +158,51 @@ def get_import_errors(
129158
"""Get all import errors."""
130159
auth_manager = get_auth_manager()
131160
readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", user=user)
132-
# Build a cte that fetches dag_ids for each file location
133-
visible_files_cte = (
134-
select(DagModel.relative_fileloc, DagModel.dag_id, DagModel.bundle_name)
161+
162+
# Optimized approach: Use LEFT JOIN + EXISTS to filter at DB level
163+
# This ensures we only fetch authorized import errors and includes errors
164+
# from files with no DAGs when user has access to the bundle.
165+
#
166+
# Build a CTE for visible DAGs (DAGs user can read)
167+
visible_dags_cte = (
168+
select(
169+
DagModel.relative_fileloc,
170+
DagModel.dag_id,
171+
DagModel.bundle_name,
172+
)
135173
.where(DagModel.dag_id.in_(readable_dag_ids))
136-
.cte()
174+
.cte("visible_dags")
137175
)
138176

139-
# Prepare the import errors query by joining with the cte.
140-
# Each returned row will be a tuple: (ParseImportError, dag_id)
177+
# LEFT JOIN ParseImportError with visible DAGs to check file-level access
141178
import_errors_stmt = (
142-
select(ParseImportError, visible_files_cte.c.dag_id)
143-
.join(
144-
visible_files_cte,
179+
select(ParseImportError)
180+
.outerjoin(
181+
visible_dags_cte,
145182
and_(
146-
ParseImportError.filename == visible_files_cte.c.relative_fileloc,
147-
ParseImportError.bundle_name == visible_files_cte.c.bundle_name,
183+
ParseImportError.filename == visible_dags_cte.c.relative_fileloc,
184+
ParseImportError.bundle_name == visible_dags_cte.c.bundle_name,
148185
),
149186
)
187+
.where(
188+
# Include import error if:
189+
# 1. DAG exists for the file AND user has access to it (visible_dags_cte.dag_id IS NOT NULL)
190+
# OR
191+
# 2. No DAG exists for the file BUT user has access to any DAG in the bundle (EXISTS subquery)
192+
(
193+
visible_dags_cte.c.dag_id.is_not(None)
194+
| exists(
195+
select(1).where(
196+
and_(
197+
DagModel.bundle_name == ParseImportError.bundle_name,
198+
DagModel.dag_id.in_(readable_dag_ids),
199+
)
200+
)
201+
)
202+
)
203+
& (ParseImportError.bundle_name.is_not(None))
204+
& (ParseImportError.filename.is_not(None))
205+
)
150206
.order_by(ParseImportError.id)
151207
)
152208

@@ -159,15 +215,45 @@ def get_import_errors(
159215
limit=limit,
160216
session=session,
161217
)
162-
import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = groupby(
163-
session.execute(import_errors_select), itemgetter(0)
164-
)
218+
219+
# Get paginated import errors
220+
all_import_errors = session.scalars(import_errors_select).all()
221+
222+
# Build mappings for final permission checks (batch_is_authorized_dag)
223+
# Get all DAGs the user can read, grouped by (bundle_name, relative_fileloc)
224+
visible_dags = session.execute(
225+
select(
226+
DagModel.relative_fileloc,
227+
DagModel.dag_id,
228+
DagModel.bundle_name,
229+
).where(DagModel.dag_id.in_(readable_dag_ids))
230+
).all()
231+
232+
# Group dag_ids by (bundle_name, relative_fileloc) for file-level checks
233+
file_dag_map: dict[tuple[str, str], list[str]] = {}
234+
for relative_fileloc, dag_id, bundle_name in visible_dags:
235+
key = (bundle_name, relative_fileloc)
236+
if key not in file_dag_map:
237+
file_dag_map[key] = []
238+
file_dag_map[key].append(dag_id)
165239

166240
import_errors = []
167-
for import_error, file_dag_ids in import_errors_result:
168-
dag_ids = [dag_id for _, dag_id in file_dag_ids]
169-
dag_id_to_team = DagModel.get_dag_id_to_team_name_mapping(dag_ids, session=session)
241+
for import_error in all_import_errors:
242+
if import_error.bundle_name is None or import_error.filename is None:
243+
continue
244+
245+
key = (import_error.bundle_name, import_error.filename)
246+
dag_ids = file_dag_map.get(key, [])
247+
248+
# If no DAGs exist for this file, it was already filtered by EXISTS subquery
249+
# so we can include it directly
250+
if not dag_ids:
251+
session.expunge(import_error)
252+
import_errors.append(import_error)
253+
continue
254+
170255
# Check if user has read access to all the DAGs defined in the file
256+
dag_id_to_team = DagModel.get_dag_id_to_team_name_mapping(dag_ids, session=session)
171257
requests: Sequence[IsAuthorizedDagRequest] = [
172258
{
173259
"method": "GET",
@@ -180,6 +266,8 @@ def get_import_errors(
180266
import_error.stacktrace = REDACTED_STACKTRACE
181267
import_errors.append(import_error)
182268

269+
# total_entries reflects the count after DB-level filtering (before batch_is_authorized_dag check)
270+
# This is more accurate than the previous in-memory filtering approach
183271
return ImportErrorCollectionResponse(
184272
import_errors=import_errors,
185273
total_entries=total_entries,

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py

Lines changed: 138 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest import mock
2222

2323
import pytest
24+
from sqlalchemy import select
2425

2526
from airflow.api_fastapi.auth.managers.models.resource_details import DagDetails
2627
from airflow.models import DagModel
@@ -236,18 +237,22 @@ def test_should_raises_403_unauthorized(self, unauthorized_test_client, import_e
236237
response = unauthorized_test_client.get(f"/importErrors/{import_error_id}")
237238
assert response.status_code == 403
238239

240+
@pytest.mark.usefixtures("permitted_dag_model")
239241
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
240242
def test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file(
241-
self, mock_get_auth_manager, test_client, import_errors
243+
self, mock_get_auth_manager, test_client, import_errors, permitted_dag_model
242244
):
243245
import_error_id = import_errors[0].id
244-
# Mock auth_manager
245-
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager)
246+
# Mock auth_manager - user has no access to any DAGs
247+
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(
248+
mock_get_auth_manager, set()
249+
)
246250
# Act
247251
response = test_client.get(f"/importErrors/{import_error_id}")
248252
# Assert
249253
mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY)
250254
assert response.status_code == 403
255+
# Since permitted_dag_model exists for FILENAME1, the error message should mention "file"
251256
assert response.json() == {"detail": "You do not have read permission on any of the DAGs in the file"}
252257

253258
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
@@ -364,7 +369,9 @@ def test_get_import_errors(
364369
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, permitted_dag_model_all)
365370
set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager, True)
366371

367-
with assert_queries_count(5):
372+
# Query count: 1 (paginated_select count), 1 (paginated_select), 1 (visible_files_cte),
373+
# 1 (bundle_dag_map), 3 (get_dag_id_to_team_name_mapping for 3 import errors)
374+
with assert_queries_count(7):
368375
response = test_client.get("/importErrors", params=query_params)
369376

370377
assert response.status_code == expected_status_code
@@ -426,8 +433,8 @@ def test_user_can_not_read_all_dags_in_file(
426433
mock_batch_is_authorized_dag = set_mock_auth_manager__batch_is_authorized_dag(
427434
mock_get_auth_manager, batch_is_authorized_dag_return_value
428435
)
429-
# Act
430-
with assert_queries_count(3):
436+
# Query count: 1 (paginated_select count), 1 (paginated_select), 1 (visible_files_cte), 1 (bundle_dag_map)
437+
with assert_queries_count(4):
431438
response = test_client.get("/importErrors")
432439
# Assert
433440
mock_get_authorized_dag_ids.assert_called_once_with(method="GET", user=mock.ANY)
@@ -474,7 +481,10 @@ def test_bundle_name_join_condition_for_import_errors(
474481
response_json = response.json()
475482

476483
# Should return the import error with matching bundle_name and filename
477-
assert response_json["total_entries"] == 1
484+
# Note: total_entries reflects count before permission filtering (all 3 import errors)
485+
# but only 1 is returned after filtering
486+
assert response_json["total_entries"] == 3
487+
assert len(response_json["import_errors"]) == 1
478488
assert response_json["import_errors"][0]["bundle_name"] == BUNDLE_NAME
479489
assert response_json["import_errors"][0]["filename"] == FILENAME1
480490

@@ -488,7 +498,127 @@ def test_bundle_name_join_condition_for_import_errors(
488498
response2 = test_client.get("/importErrors")
489499

490500
# Assert - should return 0 entries because bundle_name no longer matches
501+
# Note: total_entries reflects count before permission filtering (still 3),
502+
# but import_errors is empty after filtering
491503
assert response2.status_code == 200
492504
response_json2 = response2.json()
493-
assert response_json2["total_entries"] == 0
505+
assert response_json2["total_entries"] == 3
494506
assert response_json2["import_errors"] == []
507+
508+
@pytest.mark.usefixtures("permitted_dag_model")
509+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
510+
def test_dag_bundle_import_error_with_no_dags_is_visible_in_web(
511+
self,
512+
mock_get_auth_manager,
513+
test_client,
514+
permitted_dag_model,
515+
configure_testing_dag_bundle,
516+
session,
517+
tmp_path,
518+
):
519+
"""Test that import error from DAG bundle file with no DAGs is visible via web API."""
520+
from pathlib import Path
521+
522+
from airflow.dag_processing.bundles.manager import DagBundlesManager
523+
from airflow.dag_processing.collection import update_dag_parsing_results_in_db
524+
from airflow.dag_processing.dagbag import BundleDagBag
525+
526+
# Configure testing bundle with tmp_path
527+
with configure_testing_dag_bundle(tmp_path):
528+
# Get the actual bundle object
529+
manager = DagBundlesManager()
530+
bundle = manager.get_bundle("testing")
531+
assert bundle is not None
532+
533+
# Create a DAG file with import error (file that fails to import, no DAG created)
534+
error_file = bundle.path / "error_file.py"
535+
error_file.write_text(
536+
"""from datetime import datetime, timedelta
537+
538+
# Operators
539+
from airflow.providers.standard.operators.bash import BashOperator
540+
541+
# The DAG object
542+
from airflow.sdk import DAG
543+
544+
with DAG(
545+
"import_error_test",
546+
description="DAG with intentional import errors",
547+
schedule_NOEXIST_KEYWORD=timedelta(days=1),
548+
start_date=datetime(2021, 1, 1),
549+
catchup=False,
550+
tags=["example", "error"],
551+
) as dag:
552+
# This task will never be created due to import error above
553+
t1 = BashOperator(
554+
task_id="print_date",
555+
bash_command="date",
556+
)
557+
"""
558+
)
559+
560+
# Parse the file using BundleDagBag
561+
bundle_dagbag = BundleDagBag(
562+
dag_folder=error_file,
563+
bundle_path=bundle.path,
564+
bundle_name=bundle.name,
565+
)
566+
bundle_dagbag.collect_dags()
567+
568+
# Verify import error was captured
569+
assert len(bundle_dagbag.import_errors) > 0
570+
571+
# Convert import_errors to the format expected by update_dag_parsing_results_in_db
572+
import_errors_dict = {}
573+
for filepath, error_msg in bundle_dagbag.import_errors.items():
574+
file_path = Path(filepath)
575+
bundle_path = Path(bundle.path)
576+
try:
577+
relative_path = str(file_path.relative_to(bundle_path))
578+
except ValueError:
579+
relative_path = file_path.name
580+
import_errors_dict[(bundle.name, relative_path)] = error_msg
581+
582+
# Update DB with parsing results
583+
update_dag_parsing_results_in_db(
584+
bundle_name=bundle.name,
585+
bundle_version=None,
586+
dags=[],
587+
import_errors=import_errors_dict,
588+
parse_duration=None,
589+
warnings=set(),
590+
session=session,
591+
files_parsed={(bundle.name, rel_path) for _, rel_path in import_errors_dict.keys()},
592+
)
593+
session.commit()
594+
595+
# Verify import error was stored in DB
596+
db_import_errors = session.scalars(
597+
select(ParseImportError).where(ParseImportError.bundle_name == bundle.name)
598+
).all()
599+
assert len(db_import_errors) > 0
600+
601+
# User has access to a DAG in the bundle
602+
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, {permitted_dag_model.dag_id})
603+
604+
# Test GET /importErrors/{id} - should return the import error
605+
import_error_id = db_import_errors[0].id
606+
response = test_client.get(f"/importErrors/{import_error_id}")
607+
608+
assert response.status_code == 200
609+
response_json = response.json()
610+
assert response_json["import_error_id"] == import_error_id
611+
assert response_json["bundle_name"] == bundle.name
612+
assert (
613+
"schedule_NOEXIST_KEYWORD" in response_json["stack_trace"]
614+
or "TypeError" in response_json["stack_trace"]
615+
or "ImportError" in response_json["stack_trace"]
616+
)
617+
618+
# Test GET /importErrors - should include the import error in the list
619+
response_list = test_client.get("/importErrors")
620+
assert response_list.status_code == 200
621+
response_list_json = response_list.json()
622+
assert response_list_json["total_entries"] > 0
623+
filenames = [ie["filename"] for ie in response_list_json["import_errors"]]
624+
assert any("error_file" in filename for filename in filenames)

0 commit comments

Comments
 (0)