Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 101 additions & 31 deletions composer/tools/composer_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ def pause_dag(
logger.info("Unable to pause DAG %s", dag_id)
logger.info(command_output[1])

@staticmethod
def pause_all_dags(
project_name: str,
environment: str,
location: str,
sdk_endpoint: str,
) -> None:
"""Pause all the DAGs in the given environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments"
f" run {environment} --project={project_name} --location={location}"
f" dags pause -- \"^(?!airflow_monitoring$).*\" --treat-dag-id-as-regex -y"
)
command_output = DAG._run_shell_command_locally_once(command=command)
logger.info(command_output[1])

@staticmethod
def unpause_dag(
project_name: str,
Expand All @@ -136,6 +152,22 @@ def unpause_dag(
logger.info("Unable to Unpause DAG %s", dag_id)
logger.info(command_output[1])

@staticmethod
def unpause_all_dags(
project_name: str,
environment: str,
location: str,
sdk_endpoint: str,
) -> None:
"""UnPause all the DAGs in the given environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments"
f" run {environment} --project={project_name} --location={location}"
f" dags unpause -- \".*\" --treat-dag-id-as-regex -y"
)
command_output = DAG._run_shell_command_locally_once(command=command)
logger.info(command_output[1])

@staticmethod
def describe_environment(
project_name: str, environment: str, location: str, sdk_endpoint: str
Expand All @@ -151,9 +183,74 @@ def describe_environment(
logger.info("Environment Info:\n %s", environment_json["name"])
return environment_json

@staticmethod
def pause_unpause_dags_individually(
project_name: str,
environment: str,
location: str,
sdk_endpoint: str,
airflow_version: tuple[int, int, int],
operation: str,
) -> None:
"""Pause or unpause DAGs individually."""
list_of_dags = DAG.get_list_of_dags(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
airflow_version=airflow_version,
)
logger.info("List of dags : %s", list_of_dags)

if operation == "pause":
for dag in list_of_dags:
if dag == "airflow_monitoring":
continue
DAG.pause_dag(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
dag_id=dag,
airflow_version=airflow_version,
)
else:
for dag in list_of_dags:
DAG.unpause_dag(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
dag_id=dag,
airflow_version=airflow_version,
)

@staticmethod
def pause_unpause_all_dags_at_once(
project_name: str,
environment: str,
location: str,
sdk_endpoint: str,
operation: str,
) -> None:
"""Pause or unpause all DAGs at once."""
if operation == "pause":
DAG.pause_all_dags(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
)
else:
DAG.unpause_all_dags(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
)

def main(
project_name: str, environment: str, location: str, operation: str, sdk_endpoint=str
project_name: str, environment: str, location: str, operation: str, sdk_endpoint: str
) -> int:
logger.info("DAG Pause/UnPause Script for Cloud Composer")
environment_info = DAG.describe_environment(
Expand All @@ -170,37 +267,10 @@ def main(
environment_info["config"]["softwareConfig"]["imageVersion"],
)
airflow_version = (int(versions[3]), int(versions[4]), int(versions[5]))
list_of_dags = DAG.get_list_of_dags(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
airflow_version=airflow_version,
)
logger.info("List of dags : %s", list_of_dags)

if operation == "pause":
for dag in list_of_dags:
if dag == "airflow_monitoring":
continue
DAG.pause_dag(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
dag_id=dag,
airflow_version=airflow_version,
)
if airflow_version < (2, 9, 0):
DAG.pause_unpause_dags_individually(project_name, environment, location, sdk_endpoint, airflow_version, operation)
else:
for dag in list_of_dags:
DAG.unpause_dag(
project_name=project_name,
environment=environment,
location=location,
sdk_endpoint=sdk_endpoint,
dag_id=dag,
airflow_version=airflow_version,
)
DAG.pause_unpause_all_dags_at_once(project_name, environment, location, sdk_endpoint, operation)
return 0


Expand Down