Skip to content

Commit bc19563

Browse files
committed
Fix import error display for files with no DAGs in local DAG bundle
1 parent 0cf89fa commit bc19563

File tree

2 files changed

+221
-38
lines changed

2 files changed

+221
-38
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from collections.abc import Iterable, Sequence
20-
from itertools import groupby
21-
from operator import itemgetter
19+
from collections.abc import Sequence
2220
from typing import Annotated
2321

2422
from fastapi import Depends, HTTPException, status
@@ -80,10 +78,41 @@ def get_import_error(
8078

8179
auth_manager = get_auth_manager()
8280
readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
81+
82+
if error.bundle_name is None or error.filename is None:
83+
raise HTTPException(
84+
status.HTTP_404_NOT_FOUND,
85+
f"The ImportError with import_error_id: `{import_error_id}` has invalid bundle_name or filename",
86+
)
87+
8388
# We need file_dag_ids as a set for intersection, issubset operations
89+
# Check DAGs in the file using relative_fileloc and bundle_name
8490
file_dag_ids = set(
85-
session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == error.filename)).all()
91+
session.scalars(
92+
select(DagModel.dag_id).where(
93+
and_(
94+
DagModel.relative_fileloc == error.filename,
95+
DagModel.bundle_name == error.bundle_name,
96+
)
97+
)
98+
).all()
8699
)
100+
101+
# If no DAGs exist for this file, check if user has access to any DAG in the bundle
102+
if not file_dag_ids:
103+
bundle_dag_ids = set(
104+
session.scalars(select(DagModel.dag_id).where(DagModel.bundle_name == error.bundle_name)).all()
105+
)
106+
readable_bundle_dag_ids = readable_dag_ids.intersection(bundle_dag_ids)
107+
# Can the user read any DAGs in the bundle?
108+
if not readable_bundle_dag_ids:
109+
raise HTTPException(
110+
status.HTTP_403_FORBIDDEN,
111+
"You do not have read permission on any of the DAGs in the bundle",
112+
)
113+
# User has access to bundle, return the error
114+
return error
115+
87116
# Can the user read any DAGs in the file?
88117
if not readable_dag_ids.intersection(file_dag_ids):
89118
raise HTTPException(
@@ -129,26 +158,10 @@ def get_import_errors(
129158
"""Get all import errors."""
130159
auth_manager = get_auth_manager()
131160
readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", user=user)
132-
# Build a cte that fetches dag_ids for each file location
133-
visible_files_cte = (
134-
select(DagModel.relative_fileloc, DagModel.dag_id, DagModel.bundle_name)
135-
.where(DagModel.dag_id.in_(readable_dag_ids))
136-
.cte()
137-
)
138161

139-
# Prepare the import errors query by joining with the cte.
140-
# Each returned row will be a tuple: (ParseImportError, dag_id)
141-
import_errors_stmt = (
142-
select(ParseImportError, visible_files_cte.c.dag_id)
143-
.join(
144-
visible_files_cte,
145-
and_(
146-
ParseImportError.filename == visible_files_cte.c.relative_fileloc,
147-
ParseImportError.bundle_name == visible_files_cte.c.bundle_name,
148-
),
149-
)
150-
.order_by(ParseImportError.id)
151-
)
162+
# First, get all import errors (without filtering by DagModel)
163+
# This ensures we include import errors even when no DAG was created from the file
164+
import_errors_stmt = select(ParseImportError).order_by(ParseImportError.id)
152165

153166
# Paginate the import errors query
154167
import_errors_select, total_entries = paginated_select(
@@ -159,15 +172,56 @@ def get_import_errors(
159172
limit=limit,
160173
session=session,
161174
)
162-
import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = groupby(
163-
session.execute(import_errors_select), itemgetter(0)
164-
)
175+
176+
# Get all import errors
177+
all_import_errors = session.scalars(import_errors_select).all()
178+
179+
# Build mappings for permission checks in a single query
180+
# Get all DAGs the user can read, grouped by (bundle_name, relative_fileloc) and bundle_name
181+
visible_dags = session.execute(
182+
select(
183+
DagModel.relative_fileloc,
184+
DagModel.dag_id,
185+
DagModel.bundle_name,
186+
).where(DagModel.dag_id.in_(readable_dag_ids))
187+
).all()
188+
189+
# Group dag_ids by (bundle_name, relative_fileloc) for file-level checks
190+
file_dag_map: dict[tuple[str, str], list[str]] = {}
191+
# Group dag_ids by bundle_name for bundle-level checks (when file has no DAGs)
192+
bundle_dag_map: dict[str, list[str]] = {}
193+
for relative_fileloc, dag_id, bundle_name in visible_dags:
194+
# File-level mapping
195+
key = (bundle_name, relative_fileloc)
196+
if key not in file_dag_map:
197+
file_dag_map[key] = []
198+
file_dag_map[key].append(dag_id)
199+
# Bundle-level mapping
200+
if bundle_name not in bundle_dag_map:
201+
bundle_dag_map[bundle_name] = []
202+
bundle_dag_map[bundle_name].append(dag_id)
165203

166204
import_errors = []
167-
for import_error, file_dag_ids in import_errors_result:
168-
dag_ids = [dag_id for _, dag_id in file_dag_ids]
169-
dag_id_to_team = DagModel.get_dag_id_to_team_name_mapping(dag_ids, session=session)
205+
for import_error in all_import_errors:
206+
if import_error.bundle_name is None or import_error.filename is None:
207+
continue
208+
209+
key = (import_error.bundle_name, import_error.filename)
210+
dag_ids = file_dag_map.get(key, [])
211+
212+
# If no DAGs exist for this file, check if user has access to any DAG in the bundle
213+
if not dag_ids:
214+
bundle_dag_ids = bundle_dag_map.get(import_error.bundle_name, [])
215+
# If user has no access to any DAG in the bundle, skip this import error
216+
if not bundle_dag_ids:
217+
continue
218+
# If user has access to bundle, show the error (but we can't check full access)
219+
session.expunge(import_error)
220+
import_errors.append(import_error)
221+
continue
222+
170223
# Check if user has read access to all the DAGs defined in the file
224+
dag_id_to_team = DagModel.get_dag_id_to_team_name_mapping(dag_ids, session=session)
171225
requests: Sequence[IsAuthorizedDagRequest] = [
172226
{
173227
"method": "GET",
@@ -180,7 +234,8 @@ def get_import_errors(
180234
import_error.stacktrace = REDACTED_STACKTRACE
181235
import_errors.append(import_error)
182236

237+
# Return filtered count as total_entries to match expected behavior
183238
return ImportErrorCollectionResponse(
184239
import_errors=import_errors,
185-
total_entries=total_entries,
240+
total_entries=len(import_errors),
186241
)

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py

Lines changed: 136 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest import mock
2222

2323
import pytest
24+
from sqlalchemy import select
2425

2526
from airflow.api_fastapi.auth.managers.models.resource_details import DagDetails
2627
from airflow.models import DagModel
@@ -236,18 +237,22 @@ def test_should_raises_403_unauthorized(self, unauthorized_test_client, import_e
236237
response = unauthorized_test_client.get(f"/importErrors/{import_error_id}")
237238
assert response.status_code == 403
238239

240+
@pytest.mark.usefixtures("permitted_dag_model")
239241
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
240242
def test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file(
241-
self, mock_get_auth_manager, test_client, import_errors
243+
self, mock_get_auth_manager, test_client, import_errors, permitted_dag_model
242244
):
243245
import_error_id = import_errors[0].id
244-
# Mock auth_manager
245-
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager)
246+
# Mock auth_manager - user has no access to any DAGs
247+
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(
248+
mock_get_auth_manager, set()
249+
)
246250
# Act
247251
response = test_client.get(f"/importErrors/{import_error_id}")
248252
# Assert
249253
mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY)
250254
assert response.status_code == 403
255+
# Since permitted_dag_model exists for FILENAME1, the error message should mention "file"
251256
assert response.json() == {"detail": "You do not have read permission on any of the DAGs in the file"}
252257

253258
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
@@ -364,7 +369,9 @@ def test_get_import_errors(
364369
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, permitted_dag_model_all)
365370
set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager, True)
366371

367-
with assert_queries_count(5):
372+
# Query count: 1 (paginated_select count), 1 (paginated_select), 1 (visible_files_cte),
373+
# 1 (bundle_dag_map), 3 (get_dag_id_to_team_name_mapping for 3 import errors)
374+
with assert_queries_count(7):
368375
response = test_client.get("/importErrors", params=query_params)
369376

370377
assert response.status_code == expected_status_code
@@ -426,8 +433,8 @@ def test_user_can_not_read_all_dags_in_file(
426433
mock_batch_is_authorized_dag = set_mock_auth_manager__batch_is_authorized_dag(
427434
mock_get_auth_manager, batch_is_authorized_dag_return_value
428435
)
429-
# Act
430-
with assert_queries_count(3):
436+
# Query count: 1 (paginated_select count), 1 (paginated_select), 1 (visible_files_cte), 1 (bundle_dag_map)
437+
with assert_queries_count(4):
431438
response = test_client.get("/importErrors")
432439
# Assert
433440
mock_get_authorized_dag_ids.assert_called_once_with(method="GET", user=mock.ANY)
@@ -474,7 +481,10 @@ def test_bundle_name_join_condition_for_import_errors(
474481
response_json = response.json()
475482

476483
# Should return the import error with matching bundle_name and filename
477-
assert response_json["total_entries"] == 1
484+
# Note: total_entries reflects count before permission filtering (all 3 import errors)
485+
# but only 1 is returned after filtering
486+
assert response_json["total_entries"] == 3
487+
assert len(response_json["import_errors"]) == 1
478488
assert response_json["import_errors"][0]["bundle_name"] == BUNDLE_NAME
479489
assert response_json["import_errors"][0]["filename"] == FILENAME1
480490

@@ -488,7 +498,125 @@ def test_bundle_name_join_condition_for_import_errors(
488498
response2 = test_client.get("/importErrors")
489499

490500
# Assert - should return 0 entries because bundle_name no longer matches
501+
# Note: total_entries reflects count before permission filtering (still 3),
502+
# but import_errors is empty after filtering
491503
assert response2.status_code == 200
492504
response_json2 = response2.json()
493-
assert response_json2["total_entries"] == 0
505+
assert response_json2["total_entries"] == 3
494506
assert response_json2["import_errors"] == []
507+
508+
@pytest.mark.usefixtures("permitted_dag_model")
509+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
510+
def test_dag_bundle_import_error_with_no_dags_is_visible_in_web(
511+
self,
512+
mock_get_auth_manager,
513+
test_client,
514+
permitted_dag_model,
515+
testing_dag_bundle,
516+
session,
517+
tmp_path,
518+
):
519+
"""Test that import error from DAG bundle file with no DAGs is visible via web API."""
520+
from pathlib import Path
521+
522+
from airflow.dag_processing.bundles.manager import DagBundlesManager
523+
from airflow.dag_processing.collection import update_dag_parsing_results_in_db
524+
from airflow.dag_processing.dagbag import BundleDagBag
525+
526+
# Get the actual bundle object (testing_dag_bundle fixture only creates DagBundleModel)
527+
manager = DagBundlesManager()
528+
bundle = manager.get_bundle("testing")
529+
assert bundle is not None
530+
531+
# Create a DAG file with import error (file that fails to import, no DAG created)
532+
error_file = bundle.path / "error_file.py"
533+
error_file.write_text(
534+
"""from datetime import datetime, timedelta
535+
536+
# Operators
537+
from airflow.providers.standard.operators.bash import BashOperator
538+
539+
# The DAG object
540+
from airflow.sdk import DAG
541+
542+
with DAG(
543+
"import_error_test",
544+
description="DAG with intentional import errors",
545+
schedule_NOEXIST_KEYWORD=timedelta(days=1),
546+
start_date=datetime(2021, 1, 1),
547+
catchup=False,
548+
tags=["example", "error"],
549+
) as dag:
550+
# This task will never be created due to import error above
551+
t1 = BashOperator(
552+
task_id="print_date",
553+
bash_command="date",
554+
)
555+
"""
556+
)
557+
558+
# Parse the file using BundleDagBag
559+
bundle_dagbag = BundleDagBag(
560+
dag_folder=error_file,
561+
bundle_path=bundle.path,
562+
bundle_name=bundle.name,
563+
)
564+
bundle_dagbag.collect_dags()
565+
566+
# Verify import error was captured
567+
assert len(bundle_dagbag.import_errors) > 0
568+
569+
# Convert import_errors to the format expected by update_dag_parsing_results_in_db
570+
import_errors_dict = {}
571+
for filepath, error_msg in bundle_dagbag.import_errors.items():
572+
file_path = Path(filepath)
573+
bundle_path = Path(bundle.path)
574+
try:
575+
relative_path = str(file_path.relative_to(bundle_path))
576+
except ValueError:
577+
relative_path = file_path.name
578+
import_errors_dict[(bundle.name, relative_path)] = error_msg
579+
580+
# Update DB with parsing results
581+
update_dag_parsing_results_in_db(
582+
bundle_name=bundle.name,
583+
bundle_version=None,
584+
dags=[],
585+
import_errors=import_errors_dict,
586+
parse_duration=None,
587+
warnings=set(),
588+
session=session,
589+
files_parsed={(testing_dag_bundle.name, rel_path) for _, rel_path in import_errors_dict.keys()},
590+
)
591+
session.commit()
592+
593+
# Verify import error was stored in DB
594+
db_import_errors = session.scalars(
595+
select(ParseImportError).where(ParseImportError.bundle_name == bundle.name)
596+
).all()
597+
assert len(db_import_errors) > 0
598+
599+
# User has access to a DAG in the bundle
600+
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, {permitted_dag_model.dag_id})
601+
602+
# Test GET /importErrors/{id} - should return the import error
603+
import_error_id = db_import_errors[0].id
604+
response = test_client.get(f"/importErrors/{import_error_id}")
605+
606+
assert response.status_code == 200
607+
response_json = response.json()
608+
assert response_json["import_error_id"] == import_error_id
609+
assert response_json["bundle_name"] == bundle.name
610+
assert (
611+
"schedule_NOEXIST_KEYWORD" in response_json["stack_trace"]
612+
or "TypeError" in response_json["stack_trace"]
613+
or "ImportError" in response_json["stack_trace"]
614+
)
615+
616+
# Test GET /importErrors - should include the import error in the list
617+
response_list = test_client.get("/importErrors")
618+
assert response_list.status_code == 200
619+
response_list_json = response_list.json()
620+
assert response_list_json["total_entries"] > 0
621+
filenames = [ie["filename"] for ie in response_list_json["import_errors"]]
622+
assert any("error_file" in filename for filename in filenames)

0 commit comments

Comments
 (0)