diff --git a/python_modules/dagster-graphql/dagster_graphql/client/client.py b/python_modules/dagster-graphql/dagster_graphql/client/client.py index 08d443e6d019f..e7af0cfa1c4df 100644 --- a/python_modules/dagster-graphql/dagster_graphql/client/client.py +++ b/python_modules/dagster-graphql/dagster_graphql/client/client.py @@ -383,17 +383,23 @@ def shutdown_repository_location( else: raise Exception(f"Unexpected query result type {query_result_type}") - def terminate_run(self, run_id: str): + def terminate_run(self, run_id: str, force: bool = False): """Terminates a pipeline run. This method it is useful when you would like to stop a pipeline run based on a external event. Args: run_id (str): The run id of the pipeline run to terminate + force (bool, optional): if false, run will be terminated using terminatePolicy SAFE_TERMINATE. + If true, terminatePolicy is MARK_AS_CANCELED_IMMEDIATELY, Defaults to false. """ check.str_param(run_id, "run_id") res_data: dict[str, dict[str, Any]] = self._execute( - TERMINATE_RUN_JOB_MUTATION, {"runId": run_id} + TERMINATE_RUN_JOB_MUTATION, + { + "runId": run_id, + "terminatePolicy": "MARK_AS_CANCELED_IMMEDIATELY" if force else "SAFE_TERMINATE", + }, ) query_result: dict[str, Any] = res_data["terminateRun"] @@ -406,18 +412,23 @@ def terminate_run(self, run_id: str): else: raise DagsterGraphQLClientError(query_result_type, query_result["message"]) - def terminate_runs(self, run_ids: list[str]): + def terminate_runs(self, run_ids: list[str], force: bool = False): """Terminates a list of pipeline runs. This method it is useful when you would like to stop a list of pipeline runs based on a external event. Args: run_ids (List[str]): The list run ids of the pipeline runs to terminate + force (bool, optional): if false, run will be terminated using terminatePolicy SAFE_TERMINATE. + If true, terminatePolicy is MARK_AS_CANCELED_IMMEDIATELY, Defaults to false. """ check.list_param(run_ids, "run_ids", of_type=str) res_data: dict[str, dict[str, Any]] = self._execute( TERMINATE_RUNS_JOB_MUTATION, - {"runIds": run_ids}, + { + "runIds": run_ids, + "terminatePolicy": "MARK_AS_CANCELED_IMMEDIATELY" if force else "SAFE_TERMINATE", + }, ) query_result: dict[str, Any] = res_data["terminateRuns"] diff --git a/python_modules/dagster-graphql/dagster_graphql/client/client_queries.py b/python_modules/dagster-graphql/dagster_graphql/client/client_queries.py index 7f36bd0e82c35..f0122da781e9e 100644 --- a/python_modules/dagster-graphql/dagster_graphql/client/client_queries.py +++ b/python_modules/dagster-graphql/dagster_graphql/client/client_queries.py @@ -128,8 +128,14 @@ """ TERMINATE_RUN_JOB_MUTATION = """ -mutation GraphQLClientTerminateRun($runId: String!) { - terminateRun(runId: $runId){ +mutation GraphQLClientTerminateRun( + $runId: String! + $terminatePolicy: TerminateRunPolicy = SAFE_TERMINATE +) { + terminateRun( + runId: $runId + terminatePolicy: $terminatePolicy + ) { __typename ... on TerminateRunSuccess{ run { @@ -151,8 +157,14 @@ """ TERMINATE_RUNS_JOB_MUTATION = """ -mutation GraphQLClientTerminateRuns($runIds: [String!]!) { - terminateRuns(runIds: $runIds) { +mutation GraphQLClientTerminateRuns( + $runIds: [String!]! + $terminatePolicy: TerminateRunPolicy = SAFE_TERMINATE +) { + terminateRuns( + runIds: $runIds + terminatePolicy: $terminatePolicy + ) { __typename ... on TerminateRunsResult { terminateRunResults { diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_run.py b/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_run.py index 0bee502cac529..cee870565810d 100644 --- a/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_run.py +++ b/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_run.py @@ -17,6 +17,16 @@ def test_terminate_run_status_success(mock_client: MockClient): assert actual_result == expected_result +@python_client_test_suite +def test_force_terminate_run_status_success(mock_client: MockClient): + expected_result = None + response = {"terminateRun": {"__typename": "TerminateRunSuccess", "run": expected_result}} + mock_client.mock_gql_client.execute.return_value = response + + actual_result = mock_client.python_client.terminate_run(RUN_ID, True) + assert actual_result == expected_result + + @python_client_test_suite def test_terminate_run_not_failure(mock_client: MockClient): error_type, error_message = "TerminateRunFailure", "Unable to terminate run" diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_runs.py b/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_runs.py index 7120754cfef28..0f29ba6caef69 100644 --- a/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_runs.py +++ b/python_modules/dagster-graphql/dagster_graphql_tests/client_tests/test_terminate_runs.py @@ -24,6 +24,23 @@ def test_successful_run_termination(mock_client: MockClient): assert actual_result == expected_result +@python_client_test_suite +def test_successful_forced_run_termination(mock_client: MockClient): + expected_result = None + response = { + "terminateRuns": { + "terminateRunResults": [ + {"__typename": "TerminateRunSuccess", "run": {"runId": run_id}} + for run_id in RUN_IDS + ] + } + } + mock_client.mock_gql_client.execute.return_value = response + + actual_result = mock_client.python_client.terminate_runs(RUN_IDS, True) + assert actual_result == expected_result + + @python_client_test_suite def test_complete_failure_run_not_found(mock_client: MockClient): error_messages = [("RunNotFoundError", f"Run Id {run_id} not found") for run_id in RUN_IDS]