Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dagster-client): support forced run termination #28339

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions python_modules/dagster-graphql/dagster_graphql/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,17 +383,23 @@ def shutdown_repository_location(
else:
raise Exception(f"Unexpected query result type {query_result_type}")

def terminate_run(self, run_id: str):
def terminate_run(self, run_id: str, force: bool = False):
"""Terminates a pipeline run. This method it is useful when you would like to stop a pipeline run
based on a external event.

Args:
run_id (str): The run id of the pipeline run to terminate
force (bool, optional): if false, run will be terminated using terminatePolicy SAFE_TERMINATE.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm; not a huge fan of hiding terminatepolicy behind a flag like this. What if we down the road, add an additional terminatepolicy value? Seems like it would be hard to extend support here. How about instead we just take terminatePolicy directly?

If true, terminatePolicy is MARK_AS_CANCELED_IMMEDIATELY, Defaults to false.
"""
check.str_param(run_id, "run_id")

res_data: dict[str, dict[str, Any]] = self._execute(
TERMINATE_RUN_JOB_MUTATION, {"runId": run_id}
TERMINATE_RUN_JOB_MUTATION,
{
"runId": run_id,
"terminatePolicy": "MARK_AS_CANCELED_IMMEDIATELY" if force else "SAFE_TERMINATE",
},
)

query_result: dict[str, Any] = res_data["terminateRun"]
Expand All @@ -406,18 +412,23 @@ def terminate_run(self, run_id: str):
else:
raise DagsterGraphQLClientError(query_result_type, query_result["message"])

def terminate_runs(self, run_ids: list[str]):
def terminate_runs(self, run_ids: list[str], force: bool = False):
"""Terminates a list of pipeline runs. This method it is useful when you would like to stop a list of pipeline runs
based on a external event.

Args:
run_ids (List[str]): The list run ids of the pipeline runs to terminate
force (bool, optional): if false, run will be terminated using terminatePolicy SAFE_TERMINATE.
If true, terminatePolicy is MARK_AS_CANCELED_IMMEDIATELY, Defaults to false.
"""
check.list_param(run_ids, "run_ids", of_type=str)

res_data: dict[str, dict[str, Any]] = self._execute(
TERMINATE_RUNS_JOB_MUTATION,
{"runIds": run_ids},
{
"runIds": run_ids,
"terminatePolicy": "MARK_AS_CANCELED_IMMEDIATELY" if force else "SAFE_TERMINATE",
},
)

query_result: dict[str, Any] = res_data["terminateRuns"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,14 @@
"""

TERMINATE_RUN_JOB_MUTATION = """
mutation GraphQLClientTerminateRun($runId: String!) {
terminateRun(runId: $runId){
mutation GraphQLClientTerminateRun(
$runId: String!
$terminatePolicy: TerminateRunPolicy = SAFE_TERMINATE
) {
terminateRun(
runId: $runId
terminatePolicy: $terminatePolicy
) {
__typename
... on TerminateRunSuccess{
run {
Expand All @@ -151,8 +157,14 @@
"""

TERMINATE_RUNS_JOB_MUTATION = """
mutation GraphQLClientTerminateRuns($runIds: [String!]!) {
terminateRuns(runIds: $runIds) {
mutation GraphQLClientTerminateRuns(
$runIds: [String!]!
$terminatePolicy: TerminateRunPolicy = SAFE_TERMINATE
) {
terminateRuns(
runIds: $runIds
terminatePolicy: $terminatePolicy
) {
__typename
... on TerminateRunsResult {
terminateRunResults {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ def test_terminate_run_status_success(mock_client: MockClient):
assert actual_result == expected_result


@python_client_test_suite
def test_force_terminate_run_status_success(mock_client: MockClient):
expected_result = None
response = {"terminateRun": {"__typename": "TerminateRunSuccess", "run": expected_result}}
mock_client.mock_gql_client.execute.return_value = response

actual_result = mock_client.python_client.terminate_run(RUN_ID, True)
assert actual_result == expected_result


@python_client_test_suite
def test_terminate_run_not_failure(mock_client: MockClient):
error_type, error_message = "TerminateRunFailure", "Unable to terminate run"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ def test_successful_run_termination(mock_client: MockClient):
assert actual_result == expected_result


@python_client_test_suite
def test_successful_forced_run_termination(mock_client: MockClient):
expected_result = None
response = {
"terminateRuns": {
"terminateRunResults": [
{"__typename": "TerminateRunSuccess", "run": {"runId": run_id}}
for run_id in RUN_IDS
]
}
}
mock_client.mock_gql_client.execute.return_value = response

actual_result = mock_client.python_client.terminate_runs(RUN_IDS, True)
assert actual_result == expected_result


@python_client_test_suite
def test_complete_failure_run_not_found(mock_client: MockClient):
error_messages = [("RunNotFoundError", f"Run Id {run_id} not found") for run_id in RUN_IDS]
Expand Down