diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 237aacde975fb..1372901341891 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -136,14 +136,6 @@ def _eager_load_dag_run_for_validation() -> tuple[Load, Load]: ) -def _get_current_dag(dag_id: str, session: Session) -> SerializedDAG | None: - serdag = SerializedDagModel.get(dag_id=dag_id, session=session) # grabs the latest version - if not serdag: - return None - serdag.load_op_links = False - return serdag.dag - - class ConcurrencyMap: """ Dataclass to represent concurrency maps. @@ -248,6 +240,17 @@ def __init__( def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: Stats.incr("scheduler_heartbeat", 1, 1) + def _get_current_dag(self, dag_id: str, session: Session) -> SerializedDAG | None: + try: + serdag = SerializedDagModel.get(dag_id=dag_id, session=session) + if not serdag: + return None + serdag.load_op_links = False + return serdag.dag + except Exception: + self.log.exception("Failed to deserialize DAG '%s'", dag_id) + return None + def register_signals(self) -> ExitStack: """Register signals that stop child processes.""" resetter = ExitStack() @@ -1601,7 +1604,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - ) for dag_model in dag_models: - dag = _get_current_dag(dag_id=dag_model.dag_id, session=session) + dag = self._get_current_dag(dag_id=dag_model.dag_id, session=session) if not dag: self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue @@ -1664,7 +1667,7 @@ def _create_dag_runs_asset_triggered( } for dag_model in dag_models: - dag = _get_current_dag(dag_id=dag_model.dag_id, session=session) + dag = self._get_current_dag(dag_id=dag_model.dag_id, session=session) if not dag: self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index aba109f9d410a..430833fb1ca3e 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -4752,6 +4752,39 @@ def _clear_serdags(self, dag_id, session): session.delete(sdm) session.commit() + def test_scheduler_create_dag_runs_does_not_crash_on_deserialization_error(self, caplog, dag_maker): + """ + Test that scheduler._create_dag_runs does not crash when DAG deserialization fails. + This is a guardrail to ensure the scheduler continues processing other DAGs even if + one DAG has a deserialization error. + """ + with dag_maker(dag_id="test_scheduler_create_dag_runs_deserialization_error"): + EmptyOperator(task_id="dummy") + + scheduler_job = Job(executor=self.null_exec) + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + caplog.set_level("FATAL") + caplog.clear() + with ( + create_session() as session, + caplog.at_level( + "ERROR", + logger="airflow.jobs.scheduler_job_runner", + ), + patch( + "airflow.models.serialized_dag.SerializedDagModel.get", + side_effect=Exception("Simulated deserialization error"), + ), + ): + self.job_runner._create_dag_runs([dag_maker.dag_model], session) + scheduler_messages = [ + record.message for record in caplog.records if record.levelno >= logging.ERROR + ] + assert any("Failed to deserialize DAG" in msg for msg in scheduler_messages), ( + f"Expected deserialization error log, got: {scheduler_messages}" + ) + def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker, testing_dag_bundle): """ Test that externally triggered Dag Runs should not affect (by skipping) next