Skip to content

Commit 01db99b

Browse files
committed
Fix hardcoded waiter logic in EmrCreateJobFlowOperator
The operator was incorrectly ignoring the `wait_policy` argument and always defaulting to waiting for cluster completion. This change ensures the `wait_policy` is correctly persisted and used to select the appropriate waiter (e.g., for step completion), fixing the hardcoded behavior.
1 parent e916980 commit 01db99b

File tree

3 files changed

+35
-19
lines changed

3 files changed

+35
-19
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
from __future__ import annotations
1919

2020
import ast
21-
import warnings
2221
from collections.abc import Sequence
2322
from datetime import timedelta
2423
from typing import TYPE_CHECKING, Any
2524
from uuid import uuid4
2625

27-
from airflow.exceptions import AirflowProviderDeprecationWarning
2826
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
2927
from airflow.providers.amazon.aws.links.emr import (
3028
EmrClusterLink,
@@ -657,7 +655,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
657655
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
658656
completion (True)
659657
(default: None)
660-
:param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether to finish the task immediately after creation (None) or:
658+
:param wait_policy: Whether to finish the task immediately after creation (None) or:
661659
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
662660
- wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
663661
(default: None)
@@ -701,19 +699,13 @@ def __init__(
701699
self.waiter_max_attempts = waiter_max_attempts or 60
702700
self.waiter_delay = waiter_delay or 60
703701
self.deferrable = deferrable
702+
self.wait_policy = wait_policy
704703

705704
if wait_policy is not None:
706-
warnings.warn(
707-
"`wait_policy` parameter is deprecated and will be removed in a future release; "
708-
"please use `wait_for_completion` (bool) instead.",
709-
AirflowProviderDeprecationWarning,
710-
stacklevel=2,
711-
)
712-
713705
if wait_for_completion is not None:
714706
raise ValueError(
715-
"Cannot specify both `wait_for_completion` and deprecated `wait_policy`. "
716-
"Please use `wait_for_completion` (bool)."
707+
"Cannot specify both `wait_for_completion` and `wait_policy`. "
708+
"Use `wait_policy` if you need to control the specific waiting behavior."
717709
)
718710

719711
self.wait_for_completion = wait_policy in (
@@ -758,15 +750,24 @@ def execute(self, context: Context) -> str | None:
758750
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
759751
)
760752
if self.wait_for_completion:
761-
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
753+
# Determine which waiter to use. Prefer explicit wait_policy when provided,
754+
# otherwise default to WAIT_FOR_COMPLETION.
755+
wp = self.wait_policy
756+
if wp is not None:
757+
waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
758+
else:
759+
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
762760

763761
if self.deferrable:
762+
# Pass the selected waiter_name to the trigger so deferrable mode waits
763+
# according to the requested policy as well.
764764
self.defer(
765765
trigger=EmrCreateJobFlowTrigger(
766766
job_flow_id=self._job_flow_id,
767767
aws_conn_id=self.aws_conn_id,
768768
waiter_delay=self.waiter_delay,
769769
waiter_max_attempts=self.waiter_max_attempts,
770+
waiter_name=waiter_name,
770771
),
771772
method_name="execute_complete",
772773
# timeout is set to ensure that if a trigger dies, the timeout does not restart

providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@ def __init__(
8282
aws_conn_id: str | None = None,
8383
waiter_delay: int = 30,
8484
waiter_max_attempts: int = 60,
85+
waiter_name: str = "job_flow_waiting",
8586
):
8687
super().__init__(
8788
serialized_fields={"job_flow_id": job_flow_id},
88-
waiter_name="job_flow_waiting",
89+
waiter_name=waiter_name,
8990
waiter_args={"ClusterId": job_flow_id},
9091
failure_message="JobFlow creation failed",
9192
status_message="JobFlow creation in progress",

providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from botocore.waiter import Waiter
2727
from jinja2 import StrictUndefined
2828

29-
from airflow.exceptions import AirflowProviderDeprecationWarning
3029
from airflow.models import DAG, DagRun, TaskInstance
3130
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
3231
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
@@ -254,10 +253,25 @@ def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client):
254253
def test_template_fields(self):
255254
validate_template_fields(self.operator)
256255

257-
def test_wait_policy_deprecation_warning(self):
258-
"""Test that using wait_policy raises a deprecation warning."""
259-
with pytest.warns(AirflowProviderDeprecationWarning, match="`wait_policy` parameter is deprecated"):
256+
def test_wait_policy_behavior(self):
257+
"""Test that using wait_policy sets the operator attributes correctly."""
258+
op = EmrCreateJobFlowOperator(
259+
task_id=TASK_ID,
260+
wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
261+
)
262+
# wait_policy should be stored on the instance
263+
assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_COMPLETION
264+
# passing WAIT_FOR_COMPLETION should enable wait_for_completion
265+
assert op.wait_for_completion is True
266+
267+
def test_cannot_specify_both_wait_for_completion_and_wait_policy(self):
268+
"""Passing both wait_for_completion and wait_policy should raise ValueError."""
269+
with pytest.raises(
270+
ValueError,
271+
match=r"Cannot specify both `wait_for_completion` and `wait_policy`",
272+
):
260273
EmrCreateJobFlowOperator(
261274
task_id=TASK_ID,
262-
wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
275+
wait_for_completion=True,
276+
wait_policy=WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
263277
)

0 commit comments

Comments
 (0)