diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index e5e03a401a6..aa3e33c09d4 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -363,6 +363,8 @@ BATCH_EMIT_TAGS = from_conf("BATCH_EMIT_TAGS", False) # Default tags to add to AWS Batch jobs. These are in addition to the defaults set when BATCH_EMIT_TAGS is true. BATCH_DEFAULT_TAGS = from_conf("BATCH_DEFAULT_TAGS", {}) +# Extra boto3 client kwargs for the Batch client (e.g. {"endpoint_url": "http://localhost:8000"} for local emulators). +BATCH_CLIENT_PARAMS = from_conf("BATCH_CLIENT_PARAMS", {}) ### # AWS Step Functions configuration @@ -372,6 +374,10 @@ SFN_IAM_ROLE = from_conf("SFN_IAM_ROLE") # AWS DynamoDb Table name (with partition key - `pathspec` of type string) SFN_DYNAMO_DB_TABLE = from_conf("SFN_DYNAMO_DB_TABLE") +# Extra boto3 client kwargs for the Step Functions client (e.g. {"endpoint_url": "http://localhost:8082"}). +SFN_CLIENT_PARAMS = from_conf("SFN_CLIENT_PARAMS", {}) +# Extra boto3 client kwargs for the DynamoDB client used by Step Functions (e.g. {"endpoint_url": "http://localhost:8765"}). +SFN_DYNAMO_DB_CLIENT_PARAMS = from_conf("SFN_DYNAMO_DB_CLIENT_PARAMS", {}) # IAM role for AWS Events with AWS Step Functions access # https://docs.aws.amazon.com/eventbridge/latest/userguide/auth-and-access-control-eventbridge.html EVENTS_SFN_ACCESS_IAM_ROLE = from_conf("EVENTS_SFN_ACCESS_IAM_ROLE") @@ -489,6 +495,16 @@ AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT = from_conf( "AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT" ) +# Airflow REST API endpoint, e.g. http://localhost:8090/api/v1 +AIRFLOW_REST_API_URL = from_conf("AIRFLOW_REST_API_URL") +AIRFLOW_REST_API_USERNAME = from_conf("AIRFLOW_REST_API_USERNAME", "admin") +AIRFLOW_REST_API_PASSWORD = from_conf("AIRFLOW_REST_API_PASSWORD", "admin") +# Path inside Airflow pods where DAG files are stored +AIRFLOW_KUBERNETES_DAGS_PATH = from_conf( + "AIRFLOW_KUBERNETES_DAGS_PATH", "/opt/airflow/dags" +) +# Kubernetes namespace where Airflow runs (for kubectl cp DAG upload) +AIRFLOW_KUBERNETES_NAMESPACE = from_conf("AIRFLOW_KUBERNETES_NAMESPACE", "default") ### diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 3fc1d3f8db6..126d7acfaef 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -177,6 +177,7 @@ "step-functions", ".aws.step_functions.step_functions_deployer.StepFunctionsDeployer", ), + ("airflow", ".airflow.airflow_deployer.AirflowDeployer"), ] TL_PLUGINS_DESC = [ diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index a2c39899599..46deb04b6c7 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -353,6 +353,22 @@ def _to_job(self, node): } ) + # Pass flow config values (--config-value overrides) to the pod so + # config_expr / @project decorators evaluate correctly at task runtime. + try: + from metaflow.flowspec import FlowStateItems + + flow_configs = self.flow._flow_state[FlowStateItems.CONFIGS] + config_env = { + name: value + for name, (value, _is_plain) in flow_configs.items() + if value is not None + } + if config_env: + env["METAFLOW_FLOW_CONFIG_VALUE"] = json.dumps(config_env) + except Exception: + pass + # Extract the k8s decorators for constructing the arguments of the K8s Pod Operator on Airflow. k8s_deco = [deco for deco in node.decorators if deco.name == "kubernetes"][0] user_code_retries, _ = self._get_retries(node) diff --git a/metaflow/plugins/airflow/airflow_cli.py b/metaflow/plugins/airflow/airflow_cli.py index b3afba4189a..c994685db0a 100644 --- a/metaflow/plugins/airflow/airflow_cli.py +++ b/metaflow/plugins/airflow/airflow_cli.py @@ -1,4 +1,5 @@ import base64 +import json import os import re import sys @@ -212,6 +213,13 @@ def airflow(obj, name=None): show_default=True, help="Worker pool for Airflow DAG execution.", ) +@click.option( + "--deployer-attribute-file", + default=None, + type=str, + hidden=True, + help="Write the DAG name and metadata to the file specified. Used internally for Metaflow's Deployer API.", +) @click.pass_obj def create( obj, @@ -225,6 +233,7 @@ def create( max_workers=None, workflow_timeout=None, worker_pool=None, + deployer_attribute_file=None, ): if os.path.abspath(sys.argv[0]) == os.path.abspath(file): raise MetaflowException( @@ -262,6 +271,17 @@ def create( with open(file, "w") as f: f.write(flow.compile()) + if deployer_attribute_file: + with open(deployer_attribute_file, "w", encoding="utf-8") as f: + json.dump( + { + "name": obj.dag_name, + "flow_name": obj.flow.name, + "metadata": obj.metadata.metadata_str(), + }, + f, + ) + obj.echo( "DAG *{dag_name}* " "for flow *{name}* compiled to " @@ -270,6 +290,69 @@ def create( ) +@airflow.command(help="Trigger a new run of this Airflow DAG.") +@click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the DAG run ID to the file specified.", +) +@click.option( + "--deployer-attribute-file", + default=None, + type=str, + hidden=True, + help="Write the run metadata and pathspec to the file specified. Used internally for Metaflow's Deployer API.", +) +@click.pass_obj +def trigger(obj, run_id_file=None, deployer_attribute_file=None): + from metaflow.metaflow_config import ( + AIRFLOW_REST_API_URL, + AIRFLOW_REST_API_USERNAME, + AIRFLOW_REST_API_PASSWORD, + ) + from .airflow_client import AirflowClient + + if not AIRFLOW_REST_API_URL: + raise MetaflowException( + "METAFLOW_AIRFLOW_REST_API_URL is not set. Cannot trigger Airflow DAG run." + ) + + client = AirflowClient( + AIRFLOW_REST_API_URL, + username=AIRFLOW_REST_API_USERNAME, + password=AIRFLOW_REST_API_PASSWORD, + ) + + dag_id = obj.dag_name + dag_run = client.trigger_dag_run(dag_id) + dag_run_id = dag_run.get("dag_run_id") or dag_run.get("run_id", "") + + if run_id_file: + with open(run_id_file, "w") as f: + f.write(dag_run_id) + + if deployer_attribute_file: + with open(deployer_attribute_file, "w", encoding="utf-8") as f: + json.dump( + { + "name": dag_run_id, + "dag_id": dag_id, + "metadata": obj.metadata.metadata_str(), + "pathspec": obj.flow.name, + }, + f, + ) + + obj.echo( + "DAG *{dag_id}* triggered on Airflow (run-id *{run_id}*).".format( + dag_id=dag_id, run_id=dag_run_id + ), + bold=True, + ) + + def make_flow( obj, dag_name, diff --git a/metaflow/plugins/airflow/airflow_client.py b/metaflow/plugins/airflow/airflow_client.py new file mode 100644 index 00000000000..7daa6852f41 --- /dev/null +++ b/metaflow/plugins/airflow/airflow_client.py @@ -0,0 +1,187 @@ +""" +Thin wrapper around the Airflow 2.x REST API (api/v1). + +All methods raise ``AirflowClientError`` on non-2xx responses so callers +don't have to inspect status codes themselves. +""" + +import json +import time +import urllib.request +import urllib.error +import base64 +from typing import Any, Dict, List, Optional + +from .exception import AirflowException + + +class AirflowClientError(AirflowException): + headline = "Airflow REST API error" + + +class AirflowClient: + """ + Minimal Airflow REST API client (Airflow >= 2.0). + + Parameters + ---------- + rest_api_url : str + Base URL of the Airflow REST API, e.g. ``http://localhost:8090/api/v1``. + username : str + Basic-auth username (default: ``"admin"``). + password : str + Basic-auth password (default: ``"admin"``). + """ + + def __init__( + self, + rest_api_url: str, + username: str = "admin", + password: str = "admin", + ): + self._base = rest_api_url.rstrip("/") + credentials = base64.b64encode( + ("%s:%s" % (username, password)).encode() + ).decode() + self._auth_header = "Basic %s" % credentials + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _request( + self, + method: str, + path: str, + body: Optional[Dict] = None, + ) -> Any: + url = "%s/%s" % (self._base, path.lstrip("/")) + data = json.dumps(body).encode() if body is not None else None + req = urllib.request.Request( + url, + data=data, + method=method, + headers={ + "Authorization": self._auth_header, + "Content-Type": "application/json", + "Accept": "application/json", + }, + ) + try: + with urllib.request.urlopen(req) as resp: + raw = resp.read() + return json.loads(raw) if raw else {} + except urllib.error.HTTPError as e: + body = e.read().decode(errors="replace") + raise AirflowClientError( + "Airflow API %s %s returned HTTP %d: %s" % (method, url, e.code, body) + ) + + # ------------------------------------------------------------------ + # DAG operations + # ------------------------------------------------------------------ + + def get_dag(self, dag_id: str) -> Dict: + """Return DAG metadata dict, or None if not found.""" + try: + return self._request("GET", "dags/%s" % dag_id) + except AirflowClientError as e: + if "HTTP 404" in str(e): + return None + raise + + def patch_dag(self, dag_id: str, **fields) -> Dict: + """Patch DAG fields (e.g. ``is_paused=False``).""" + return self._request("PATCH", "dags/%s" % dag_id, body=fields) + + def delete_dag(self, dag_id: str) -> bool: + """Delete a DAG. Returns True on success.""" + try: + self._request("DELETE", "dags/%s" % dag_id) + return True + except AirflowClientError: + return False + + def list_dags(self, tags: Optional[List[str]] = None) -> List[Dict]: + """List all visible DAGs, optionally filtered by tags.""" + params = "" + if tags: + params = "?" + "&".join("tags=%s" % t for t in tags) + result = self._request("GET", "dags%s" % params) + return result.get("dags", []) + + # ------------------------------------------------------------------ + # DAG run operations + # ------------------------------------------------------------------ + + def trigger_dag_run( + self, + dag_id: str, + conf: Optional[Dict] = None, + run_id: Optional[str] = None, + ) -> Dict: + """Trigger a DAG run. Returns the dag_run dict.""" + body: Dict[str, Any] = {} + if conf: + body["conf"] = conf + if run_id: + body["dag_run_id"] = run_id + return self._request("POST", "dags/%s/dagRuns" % dag_id, body=body) + + def get_dag_run(self, dag_id: str, dag_run_id: str) -> Dict: + """Return dag_run dict for a specific run.""" + return self._request("GET", "dags/%s/dagRuns/%s" % (dag_id, dag_run_id)) + + def list_dag_runs(self, dag_id: str, limit: int = 25) -> List[Dict]: + """List recent dag runs for a DAG.""" + result = self._request( + "GET", + "dags/%s/dagRuns?limit=%d&order_by=-execution_date" % (dag_id, limit), + ) + return result.get("dag_runs", []) + + def patch_dag_run(self, dag_id: str, dag_run_id: str, **fields) -> Dict: + """Patch a dag run (e.g. set state to 'failed' to terminate it).""" + return self._request( + "PATCH", + "dags/%s/dagRuns/%s" % (dag_id, dag_run_id), + body=fields, + ) + + # ------------------------------------------------------------------ + # Utility + # ------------------------------------------------------------------ + + def wait_for_dag( + self, + dag_id: str, + timeout: int = 120, + polling_interval: int = 5, + ) -> Dict: + """ + Poll until the DAG is visible in Airflow (after kubectl-cp / file copy). + + Returns the DAG metadata dict when found. + + Raises + ------ + TimeoutError + If the DAG is not discovered within *timeout* seconds. + """ + deadline = time.time() + timeout + while time.time() < deadline: + try: + dag = self.get_dag(dag_id) + except OSError: + # Transient connection error (e.g. RemoteDisconnected) — + # the webserver may still be starting up. Retry silently. + time.sleep(polling_interval) + continue + if dag is not None: + return dag + time.sleep(polling_interval) + raise TimeoutError( + "DAG '%s' did not appear in Airflow within %d seconds. " + "Ensure the DAG file was copied to the dags folder and " + "that the Airflow scheduler is running." % (dag_id, timeout) + ) diff --git a/metaflow/plugins/airflow/airflow_deployer.py b/metaflow/plugins/airflow/airflow_deployer.py new file mode 100644 index 00000000000..7fb0dd0e053 --- /dev/null +++ b/metaflow/plugins/airflow/airflow_deployer.py @@ -0,0 +1,187 @@ +import os +import subprocess +import tempfile +from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING, Type + +from metaflow.exception import MetaflowException +from metaflow.runner.deployer_impl import DeployerImpl + +if TYPE_CHECKING: + import metaflow.plugins.airflow.airflow_deployer_objects + + +class AirflowDeployer(DeployerImpl): + """ + Deployer implementation for Apache Airflow. + + The DAG file produced by ``airflow create`` is copied to the Airflow + scheduler pod with ``kubectl cp``, and the deployer waits for Airflow to + discover it before returning. + + Parameters + ---------- + name : str, optional, default None + Airflow DAG name. The flow name is used instead if this option is + not specified. + """ + + TYPE: ClassVar[Optional[str]] = "airflow" + + def __init__(self, deployer_kwargs: Dict[str, str], **kwargs): + self._deployer_kwargs = deployer_kwargs + super().__init__(**kwargs) + + @property + def deployer_kwargs(self) -> Dict[str, Any]: + return self._deployer_kwargs + + @staticmethod + def deployed_flow_type() -> ( + Type["metaflow.plugins.airflow.airflow_deployer_objects.AirflowDeployedFlow"] + ): + from .airflow_deployer_objects import AirflowDeployedFlow + + return AirflowDeployedFlow + + def create( + self, **kwargs + ) -> "metaflow.plugins.airflow.airflow_deployer_objects.AirflowDeployedFlow": + """ + Compile and deploy this flow as an Airflow DAG. + + The DAG Python file is written locally by ``airflow create``, then + copied to the Airflow scheduler pod with ``kubectl cp``. The deployer + waits until Airflow discovers the DAG before returning. + + Parameters + ---------- + authorize : str, optional, default None + Authorize using this production token. + generate_new_token : bool, optional, default False + Generate a new production token for this flow. + given_token : str, optional, default None + Use the given production token for this flow. + tags : List[str], optional, default None + Annotate all objects produced by Airflow DAG runs with these tags. + user_namespace : str, optional, default None + Change the namespace from the default to the given tag. + is_paused_upon_creation : bool, optional, default False + Create the DAG in a paused state. + max_workers : int, optional, default 100 + Maximum number of parallel processes. + workflow_timeout : int, optional, default None + Workflow timeout in seconds. + worker_pool : str, optional, default None + Worker pool for Airflow DAG execution. + + Returns + ------- + AirflowDeployedFlow + The Flow deployed to Airflow. + """ + from metaflow.metaflow_config import ( + AIRFLOW_KUBERNETES_DAGS_PATH, + AIRFLOW_KUBERNETES_NAMESPACE, + ) + from .airflow_deployer_objects import AirflowDeployedFlow + + # Write the compiled DAG to a temp file; subprocess fills it in. + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as dag_file: + dag_file_path = dag_file.name + + try: + deployed_flow = self._create( + AirflowDeployedFlow, file=dag_file_path, **kwargs + ) + + dag_id = self.name + + # Copy DAG file to the Airflow scheduler pod. + self._kubectl_cp_dag( + dag_file_path, + dag_id, + AIRFLOW_KUBERNETES_NAMESPACE, + AIRFLOW_KUBERNETES_DAGS_PATH, + ) + + # Wait until Airflow discovers the DAG. + from .airflow_deployer_objects import _get_airflow_client + + client, _ = _get_airflow_client() + if client is not None: + client.wait_for_dag(dag_id) + finally: + try: + os.unlink(dag_file_path) + except OSError: + pass + + return deployed_flow + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_scheduler_pod(namespace: str) -> str: + """Return the name of the Airflow scheduler pod in *namespace*.""" + try: + result = subprocess.run( + [ + "kubectl", + "get", + "pods", + "-n", + namespace, + "-l", + "component=scheduler", + "-o", + "jsonpath={.items[0].metadata.name}", + ], + capture_output=True, + text=True, + check=True, + ) + pod_name = result.stdout.strip() + if not pod_name: + raise MetaflowException( + "No Airflow scheduler pod found in namespace '%s'. " + "Is Airflow running?" % namespace + ) + return pod_name + except subprocess.CalledProcessError as e: + raise MetaflowException("kubectl get pods failed: %s" % e.stderr) + + @classmethod + def _kubectl_cp_dag( + cls, + local_dag_path: str, + dag_id: str, + namespace: str, + dags_path: str, + ) -> None: + """Copy the compiled DAG file into the Airflow scheduler pod.""" + pod_name = cls._get_scheduler_pod(namespace) + remote_path = "%s/%s.py" % (dags_path.rstrip("/"), dag_id) + try: + subprocess.run( + [ + "kubectl", + "cp", + local_dag_path, + "%s:%s" % (pod_name, remote_path), + "-n", + namespace, + ], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + raise MetaflowException( + "kubectl cp failed while deploying DAG '%s' to pod '%s': %s" + % (dag_id, pod_name, e.stderr) + ) + + +_addl_stubgen_modules = ["metaflow.plugins.airflow.airflow_deployer_objects"] diff --git a/metaflow/plugins/airflow/airflow_deployer_objects.py b/metaflow/plugins/airflow/airflow_deployer_objects.py new file mode 100644 index 00000000000..fe1fad90903 --- /dev/null +++ b/metaflow/plugins/airflow/airflow_deployer_objects.py @@ -0,0 +1,283 @@ +import hashlib +import json +import sys +import time +import tempfile +import os +from typing import ClassVar, Optional + +from metaflow.client.core import get_metadata +from metaflow.exception import MetaflowException +from metaflow.runner.deployer import ( + DeployedFlow, + TriggeredRun, +) +from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo + + +def _get_airflow_client(): + """Return an (AirflowClient, url) pair, or (None, None) if unconfigured.""" + from metaflow.metaflow_config import ( + AIRFLOW_REST_API_URL, + AIRFLOW_REST_API_USERNAME, + AIRFLOW_REST_API_PASSWORD, + ) + from .airflow_client import AirflowClient + + if not AIRFLOW_REST_API_URL: + return None, None + client = AirflowClient( + AIRFLOW_REST_API_URL, + username=AIRFLOW_REST_API_USERNAME, + password=AIRFLOW_REST_API_PASSWORD, + ) + return client, AIRFLOW_REST_API_URL + + +def _compute_metaflow_run_id(dag_run_id, dag_id): + """Compute the Metaflow run ID from an Airflow DAG run ID. + + Mirrors AIRFLOW_MACROS.RUN_ID in airflow_utils.py: + run_id_creator([run_id, dag_id]) = md5(run_id + "-" + dag_id)[:12] + prefixed with "airflow-". + """ + run_hash = hashlib.md5( + ("%s-%s" % (dag_run_id, dag_id)).encode("utf-8") + ).hexdigest()[:12] + return "airflow-%s" % run_hash + + +class AirflowTriggeredRun(TriggeredRun): + """ + A class representing a triggered Airflow DAG run execution. + """ + + @property + def status(self) -> Optional[str]: + """ + Get the status of the triggered run via the Airflow REST API. + + Returns + ------- + str, optional + The Airflow dag_run state (e.g. ``"running"``, ``"success"``, + ``"failed"``), or None if it could not be retrieved. + """ + try: + client, _ = _get_airflow_client() + if client is None: + return None + content = json.loads(self.content) + dag_run_id = content.get("name") + dag_id = content.get("dag_id") or self.deployer.name + dag_run = client.get_dag_run(dag_id, dag_run_id) + state = dag_run.get("state") + # Map Airflow states to conventional casing used by other deployers + if state is not None: + return state.upper() + return None + except Exception: + return None + + @property + def is_running(self) -> bool: + """ + Check if the DAG run is currently running or queued. + + Returns + ------- + bool + """ + status = self.status + return status is not None and status in ("RUNNING", "QUEUED") + + def wait_for_completion( + self, check_interval: int = 5, timeout: Optional[int] = None + ): + """ + Wait for the DAG run to complete. + + Parameters + ---------- + check_interval : int, default 5 + Polling interval in seconds. + timeout : int, optional, default None + Maximum wait time in seconds. Waits indefinitely if None. + + Raises + ------ + TimeoutError + If the run does not complete within *timeout* seconds. + """ + start_time = time.time() + while self.is_running: + if timeout is not None and (time.time() - start_time) > timeout: + raise TimeoutError( + "Airflow DAG run did not complete within specified timeout." + ) + time.sleep(check_interval) + + +class AirflowDeployedFlow(DeployedFlow): + """ + A class representing a deployed Airflow DAG. + """ + + TYPE: ClassVar[Optional[str]] = "airflow" + + def delete(self, **kwargs) -> bool: + """ + Delete the deployed DAG via the Airflow REST API. + + Returns + ------- + bool + True if deletion succeeded, False otherwise. + """ + try: + client, _ = _get_airflow_client() + if client is None: + return False + return client.delete_dag(self.deployer.name) + except Exception: + return False + + def trigger(self, **kwargs) -> AirflowTriggeredRun: + """ + Trigger a new DAG run via the Airflow REST API. + + Parameters + ---------- + **kwargs : Any + Additional arguments. Flow ``Parameters`` can be passed as keyword + arguments and will be forwarded as DAG ``conf``. + + Returns + ------- + AirflowTriggeredRun + The triggered run instance. + + Raises + ------ + Exception + If there is an error during the trigger process. + """ + from .airflow_client import AirflowClientError + + client, _ = _get_airflow_client() + if client is None: + raise MetaflowException( + "METAFLOW_AIRFLOW_REST_API_URL is not set. " + "Cannot trigger Airflow DAG run." + ) + + dag_id = self.deployer.name + # Pass any flow parameters as DAG conf + conf = {k: v for k, v in kwargs.items() if v is not None} or None + + try: + dag_run = client.trigger_dag_run(dag_id, conf=conf) + except AirflowClientError as e: + raise Exception("Error triggering DAG %s on Airflow: %s" % (dag_id, str(e))) + + dag_run_id = dag_run.get("dag_run_id") or dag_run.get("run_id", "") + flow_name = self.deployer.flow_name + metaflow_run_id = _compute_metaflow_run_id(dag_run_id, dag_id) + pathspec = "%s/%s" % (flow_name, metaflow_run_id) + + content = json.dumps( + { + "name": dag_run_id, + "dag_id": dag_id, + "metadata": self.deployer.metadata, + "pathspec": pathspec, + } + ) + + return AirflowTriggeredRun(deployer=self.deployer, content=content) + + @classmethod + def from_deployment(cls, identifier: str, metadata: Optional[str] = None): + """ + Retrieve an ``AirflowDeployedFlow`` for an existing Airflow DAG. + + Parameters + ---------- + identifier : str + The Airflow DAG ID. + metadata : str, optional, default None + Optional metadata string. + + Returns + ------- + AirflowDeployedFlow + """ + from metaflow.runner.deployer import Deployer, generate_fake_flow_file_contents + + client, _ = _get_airflow_client() + if client is None: + raise MetaflowException("METAFLOW_AIRFLOW_REST_API_URL is not set.") + + dag = client.get_dag(identifier) + if dag is None: + raise MetaflowException("No deployed flow found for DAG: %s" % identifier) + + # Extract flow_name from DAG tags (set by `airflow create`). + # The Airflow DAG ID has the form "project.branch.FlowName" (dotted). + # The flow class name (CamelCase, no dots) is always the last component. + flow_name = identifier.split(".")[-1] # safe default + for tag in dag.get("tags", []): + tag_name = tag.get("name", "") + if tag_name.startswith("metaflow_flow_name:"): + flow_name = tag_name.split(":", 1)[1] + break + + fake_flow_file_contents = generate_fake_flow_file_contents( + flow_name=flow_name, param_info={}, project_name=None + ) + + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as fake_flow_file: + with open(fake_flow_file.name, "w") as fp: + fp.write(fake_flow_file_contents) + + d = Deployer(fake_flow_file.name).airflow(name=identifier) + d.name = identifier + d.flow_name = flow_name + d.metadata = metadata if metadata is not None else get_metadata() + + return cls(deployer=d) + + @classmethod + def get_triggered_run( + cls, identifier: str, run_id: str, metadata: Optional[str] = None + ): + """ + Retrieve an ``AirflowTriggeredRun`` for an existing DAG run. + + Parameters + ---------- + identifier : str + The Airflow DAG ID. + run_id : str + The Airflow DAG run ID. + metadata : str, optional, default None + + Returns + ------- + AirflowTriggeredRun + """ + deployed_flow_obj = cls.from_deployment(identifier, metadata) + metaflow_run_id = _compute_metaflow_run_id(run_id, identifier) + pathspec = "%s/%s" % (deployed_flow_obj.deployer.flow_name, metaflow_run_id) + content = json.dumps( + { + "name": run_id, + "dag_id": identifier, + "metadata": deployed_flow_obj.deployer.metadata, + "pathspec": pathspec, + } + ) + return AirflowTriggeredRun( + deployer=deployed_flow_obj.deployer, + content=content, + ) diff --git a/metaflow/plugins/airflow/airflow_utils.py b/metaflow/plugins/airflow/airflow_utils.py index d0574ad7401..a1cc50b0709 100644 --- a/metaflow/plugins/airflow/airflow_utils.py +++ b/metaflow/plugins/airflow/airflow_utils.py @@ -5,7 +5,6 @@ from collections import defaultdict from datetime import datetime, timedelta - TASK_ID_XCOM_KEY = "metaflow_task_id" FOREACH_CARDINALITY_XCOM_KEY = "metaflow_foreach_cardinality" FOREACH_XCOM_KEY = "metaflow_foreach_indexes" @@ -283,6 +282,13 @@ def parse_args(dd): data_dict[k] = v.isoformat() elif isinstance(v, timedelta): data_dict[k] = dict(seconds=v.total_seconds()) + elif isinstance(v, tuple): + # Airflow 2.x DAG.__init__ validates that `tags` (and + # similar args) is a list, not a tuple. Metaflow's CLI + # (Click) returns multi-value options as tuples, so + # normalise all tuples to lists here so that the + # serialised CONFIG dict round-trips correctly. + data_dict[k] = list(v) else: data_dict[k] = v return data_dict diff --git a/metaflow/plugins/airflow/sensors/external_task_sensor.py b/metaflow/plugins/airflow/sensors/external_task_sensor.py index 264e47f18bc..8492ea1ac67 100644 --- a/metaflow/plugins/airflow/sensors/external_task_sensor.py +++ b/metaflow/plugins/airflow/sensors/external_task_sensor.py @@ -3,7 +3,6 @@ from ..exception import AirflowException from datetime import timedelta - AIRFLOW_STATES = dict( QUEUED="queued", RUNNING="running", diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index 8675ad4a7a9..42fe856582b 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -12,14 +12,14 @@ basestring = str from metaflow.exception import MetaflowException -from metaflow.metaflow_config import AWS_SANDBOX_ENABLED +from metaflow.metaflow_config import AWS_SANDBOX_ENABLED, BATCH_CLIENT_PARAMS class BatchClient(object): def __init__(self): from ..aws_client import get_aws_client - self._client = get_aws_client("batch") + self._client = get_aws_client("batch", client_params=BATCH_CLIENT_PARAMS) def active_job_queues(self): paginator = self._client.get_paginator("describe_job_queues") diff --git a/metaflow/plugins/aws/step_functions/dynamo_db_client.py b/metaflow/plugins/aws/step_functions/dynamo_db_client.py index 2f4859e626d..b7b8cd6c76b 100644 --- a/metaflow/plugins/aws/step_functions/dynamo_db_client.py +++ b/metaflow/plugins/aws/step_functions/dynamo_db_client.py @@ -1,13 +1,15 @@ import time -from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE +from metaflow.metaflow_config import SFN_DYNAMO_DB_CLIENT_PARAMS, SFN_DYNAMO_DB_TABLE class DynamoDbClient(object): def __init__(self): from ..aws_client import get_aws_client - self._client = get_aws_client("dynamodb") + self._client = get_aws_client( + "dynamodb", client_params=SFN_DYNAMO_DB_CLIENT_PARAMS + ) self.name = SFN_DYNAMO_DB_TABLE def save_foreach_cardinality(self, foreach_split_task_id, foreach_cardinality, ttl): diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 3c36a97efe9..7273f56dffb 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -12,6 +12,7 @@ from metaflow.metaflow_config import ( EVENTS_SFN_ACCESS_IAM_ROLE, S3_ENDPOINT_URL, + SFN_CLIENT_PARAMS, SFN_DYNAMO_DB_TABLE, SFN_EXECUTION_LOG_GROUP_ARN, SFN_IAM_ROLE, @@ -82,6 +83,12 @@ def __init__( # https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/ self.use_distributed_map = use_distributed_map + # Detect sfn-local (local emulator) by checking if the SFN endpoint is + # localhost. sfn-local v2 does not support ProcessorConfig in Map states, + # so we omit it when targeting the local emulator. + _sfn_endpoint = (SFN_CLIENT_PARAMS or {}).get("endpoint_url", "") + self._is_sfn_local = any(h in _sfn_endpoint for h in ("localhost", "127.0.0.1")) + # S3 command upload configuration self.compress_state_machine = compress_state_machine @@ -142,6 +149,9 @@ def deploy(self, log_execution_history): def schedule(self): # Scheduling is currently enabled via AWS Event Bridge. + # If no cron schedule is defined, nothing to do. + if not self._cron: + return if EVENTS_SFN_ACCESS_IAM_ROLE is None: raise StepFunctionsSchedulingException( "No IAM role found for AWS " @@ -343,7 +353,7 @@ def _visit(node, workflow, exit_node=None): State(node.name) .batch(self._batch(node)) .output_path( - "$.['JobId', " "'Parameters', " "'Index', " "'SplitParentTaskId']" + "$['JobId', " "'Parameters', " "'Index', " "'SplitParentTaskId']" ) ) # End the (sub)workflow if we have reached the end of the flow or @@ -370,6 +380,13 @@ def _visit(node, workflow, exit_node=None): self.graph[n], Workflow(n).start_at(n), node.matching_join ) ) + # Add a ResultSelector that converts the Parallel output array into + # a named dict keyed by branch step name. This avoids array indexing + # ($[n].x) in downstream states, which is not supported by + # sfn-local v2.0.0. Instead, branches are accessed as $.step_name.x. + branch.result_selector( + {"%s.$" % n: "$[%d]" % i for i, n in enumerate(node.out_funcs)} + ) workflow.add_state(branch) # Continue the traversal from the matching_join. _visit(self.graph[node.matching_join], workflow, exit_node) @@ -388,7 +405,14 @@ def _visit(node, workflow, exit_node=None): workflow.add_state(cardinality_state.next(iterator_name)) workflow.add_state( Map(iterator_name) - .items_path("$.Result.Item.for_each_cardinality.NS") + # sfn-local serializes DynamoDB Number Set type as "Ns" (camelCase) + # instead of the standard "NS" (uppercase). Use Ns for sfn-local, + # NS for real AWS SFN. + .items_path( + "$.Result.Item.for_each_cardinality.Ns" + if self._is_sfn_local + else "$.Result.Item.for_each_cardinality.NS" + ) .parameter("JobId.$", "$.JobId") .parameter("SplitParentTaskId.$", "$.JobId") .parameter("Parameters.$", "$.Parameters") @@ -401,10 +425,18 @@ def _visit(node, workflow, exit_node=None): .iterator( _visit( self.graph[node.out_funcs[0]], - Workflow(node.out_funcs[0]) - .start_at(node.out_funcs[0]) - .mode( - "DISTRIBUTED" if self.use_distributed_map else "INLINE" + ( + Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]) + # sfn-local v2 does not support ProcessorConfig + # in Map states; omit it for the local emulator. + if self._is_sfn_local + else Workflow(node.out_funcs[0]) + .start_at(node.out_funcs[0]) + .mode( + "DISTRIBUTED" + if self.use_distributed_map + else "INLINE" + ) ), node.matching_join, ) @@ -432,7 +464,7 @@ def _visit(node, workflow, exit_node=None): else (None, None) ) ) - .output_path("$" if self.use_distributed_map else "$.[0]") + .output_path("$" if self.use_distributed_map else "$[0]") ) if self.use_distributed_map: workflow.add_state( @@ -464,7 +496,7 @@ def _visit(node, workflow, exit_node=None): .parameter("Bucket.$", "$.Body.DestinationBucket") .parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED[0].Key") ) - .output_path("$.[0]") + .output_path("$[0]") ) # Continue the traversal from the matching_join. @@ -676,7 +708,7 @@ def _batch(self, node): "$.Parameters.split_parent_task_id_%s" % node.split_parents[-1] ) # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']" + attrs["metaflow.run_id.$"] = "$.Parameters['metaflow.run_id']" else: # Set appropriate environment variables for runtime replacement. if len(node.in_funcs) == 1: @@ -686,7 +718,7 @@ def _batch(self, node): ) env["METAFLOW_PARENT_TASK_ID"] = "$.JobId" # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']" + attrs["metaflow.run_id.$"] = "$.Parameters['metaflow.run_id']" else: # Generate the input paths in a quasi-compressed format. # See util.decompress_list for why this is written the way @@ -696,13 +728,18 @@ def _batch(self, node): "${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx) for idx, _ in enumerate(node.in_funcs) ) - # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']" - for idx, _ in enumerate(node.in_funcs): - env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx - env["METAFLOW_PARENT_%s_STEP" % idx] = ( - "$.[%s].Parameters.step_name" % idx + # Use the context object reference for the run id. The Parallel + # state's ResultSelector transforms the output array to a named + # dict ($.step_name.x), so branch outputs are accessed via the + # step name rather than $[n].x array indexing (unsupported by + # sfn-local v2.0.0). + attrs["metaflow.run_id.$"] = "$$.Execution.Name" + for idx, branch_name in enumerate(node.in_funcs): + env["METAFLOW_PARENT_%s_TASK_ID" % idx] = ( + "$.%s.JobId" % branch_name ) + # Step name is known at compile time (it's the branch name). + env["METAFLOW_PARENT_%s_STEP" % idx] = branch_name env["METAFLOW_INPUT_PATHS"] = input_paths if node.is_inside_foreach: @@ -745,7 +782,7 @@ def _batch(self, node): for parent in node.split_parents: if self.graph[parent].type == "foreach": attrs["split_parent_task_id_%s.$" % parent] = ( - "$.[0].Parameters.split_parent_task_id_%s" % parent + "$[0].Parameters.split_parent_task_id_%s" % parent ) else: for parent in node.split_parents: @@ -1124,16 +1161,18 @@ def batch(self, job): # set retry strategy for AWS Batch job submission to account for the # measily 50 jobs / second queue admission limit which people can # run into very quickly. - self.retry_strategy( - { - "ErrorEquals": ["Batch.AWSBatchException"], - "BackoffRate": 2, - "IntervalSeconds": 2, - "MaxDelaySeconds": 60, - "MaxAttempts": 10, - "JitterStrategy": "FULL", - } - ) + retry = { + "ErrorEquals": ["Batch.AWSBatchException"], + "BackoffRate": 2, + "IntervalSeconds": 2, + "MaxAttempts": 10, + } + # sfn-local v2.0.0 does not support MaxDelaySeconds or JitterStrategy. + _sfn_endpoint = (SFN_CLIENT_PARAMS or {}).get("endpoint_url", "") + if not any(h in _sfn_endpoint for h in ("localhost", "127.0.0.1")): + retry["MaxDelaySeconds"] = 60 + retry["JitterStrategy"] = "FULL" + self.retry_strategy(retry) return self def dynamo_db(self, table_name, primary_key, values): @@ -1192,6 +1231,10 @@ def result_path(self, result_path): self.payload["ResultPath"] = result_path return self + def result_selector(self, selector): + self.payload["ResultSelector"] = selector + return self + class Map(object): def __init__(self, name): diff --git a/metaflow/plugins/aws/step_functions/step_functions_client.py b/metaflow/plugins/aws/step_functions/step_functions_client.py index ceec8e4d0ce..284250326df 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_client.py +++ b/metaflow/plugins/aws/step_functions/step_functions_client.py @@ -1,6 +1,7 @@ from metaflow.metaflow_config import ( AWS_SANDBOX_ENABLED, AWS_SANDBOX_REGION, + SFN_CLIENT_PARAMS, SFN_EXECUTION_LOG_GROUP_ARN, ) @@ -9,7 +10,7 @@ class StepFunctionsClient(object): def __init__(self): from ..aws_client import get_aws_client - self._client = get_aws_client("stepfunctions") + self._client = get_aws_client("stepfunctions", client_params=SFN_CLIENT_PARAMS) def search(self, name): paginator = self._client.get_paginator("list_state_machines") @@ -81,6 +82,12 @@ def list_executions(self, state_machine_arn, states): for execution in page["executions"] ) + def describe_execution(self, execution_arn): + try: + return self._client.describe_execution(executionArn=execution_arn) + except self._client.exceptions.ExecutionDoesNotExist: + return None + def terminate_execution(self, execution_arn): try: response = self._client.stop_execution(executionArn=execution_arn) diff --git a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py index 161eb91a638..e2466dc7bc8 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +++ b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py @@ -13,6 +13,43 @@ class StepFunctionsTriggeredRun(TriggeredRun): A class representing a triggered AWS Step Functions state machine execution. """ + @property + def status(self) -> Optional[str]: + """ + Get the status of the triggered execution. + + Returns + ------- + str, optional + One of RUNNING, SUCCEEDED, FAILED, TIMED_OUT, ABORTED, or None. + """ + try: + from metaflow.plugins.aws.step_functions.step_functions import ( + StepFunctions, + ) + from metaflow.plugins.aws.step_functions.step_functions_client import ( + StepFunctionsClient, + ) + + _, run_id = self.pathspec.split("/") + execution_name = run_id[4:] # strip "sfn-" + state_machine_name = self.deployer.name + + state_machine = StepFunctionsClient().get(state_machine_name) + if state_machine is None: + return None + sm_arn = state_machine["stateMachineArn"] + # Execution ARN: replace :stateMachine: with :execution: and append name + execution_arn = ( + sm_arn.replace(":stateMachine:", ":execution:") + ":" + execution_name + ) + result = StepFunctionsClient().describe_execution(execution_arn) + if result is None: + return None + return result.get("status") + except Exception: + return None + def terminate(self, **kwargs) -> bool: """ Terminate the running state machine execution. @@ -56,6 +93,71 @@ class StepFunctionsDeployedFlow(DeployedFlow): TYPE: ClassVar[Optional[str]] = "step-functions" + def _run_deployer_command(self, method, return_content=False, **kwargs): + """Run a deployer CLI command and return the result. + + Parameters + ---------- + method : str + CLI subcommand name (e.g. "trigger", "resume", "delete"). + return_content : bool + If True, read content from the attribute FIFO and return a + StepFunctionsTriggeredRun on success. If False, return bool. + **kwargs + Passed to the CLI subcommand. + """ + if return_content: + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + kwargs["deployer_attribute_file"] = attribute_file_path + command = getattr( + get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ), + method, + )(**kwargs) + + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + command_obj = self.deployer.spm.get(pid) + content = handle_timeout( + attribute_file_fd, command_obj, self.deployer.file_read_timeout + ) + command_obj.sync_wait() + if command_obj.process.returncode == 0: + return StepFunctionsTriggeredRun( + deployer=self.deployer, content=content + ) + raise Exception( + "Error running %s for %s on %s" + % (method, self.deployer.flow_file, self.deployer.TYPE) + ) + else: + command = getattr( + get_lower_level_group( + self.deployer.api, + self.deployer.top_level_kwargs, + self.deployer.TYPE, + self.deployer.deployer_kwargs, + ), + method, + )(**kwargs) + pid = self.deployer.spm.run_command( + [sys.executable, *command], + env=self.deployer.env_vars, + cwd=self.deployer.cwd, + show_output=self.deployer.show_output, + ) + command_obj = self.deployer.spm.get(pid) + command_obj.sync_wait() + return command_obj.process.returncode == 0 + @classmethod def list_deployed_flows(cls, flow_name: Optional[str] = None): """ @@ -73,31 +175,104 @@ def list_deployed_flows(cls, flow_name: Optional[str] = None): @classmethod def from_deployment(cls, identifier: str, metadata: Optional[str] = None): """ - This method is not currently implemented for Step Functions. + Retrieves a `StepFunctionsDeployedFlow` object from an identifier and optional + metadata. + + Parameters + ---------- + identifier : str + State machine name for the workflow to retrieve. + metadata : str, optional, default None + Optional deployer specific metadata. + + Returns + ------- + StepFunctionsDeployedFlow + A `StepFunctionsDeployedFlow` object representing the deployed flow. Raises ------ - NotImplementedError - This method is not implemented for Step Functions. + MetaflowException + If no deployed flow is found for the given identifier. """ - raise NotImplementedError( - "from_deployment is not implemented for StepFunctions" + import tempfile + from metaflow.exception import MetaflowException + from metaflow.runner.deployer import Deployer, generate_fake_flow_file_contents + from metaflow.client.core import get_metadata + from metaflow.plugins.aws.step_functions.step_functions_client import ( + StepFunctionsClient, ) + workflow = StepFunctionsClient().get(identifier) + if workflow is None: + raise MetaflowException("No deployed flow found for: %s" % identifier) + + # Extract flow metadata stored in the start state's Parameters. + try: + start = json.loads(workflow["definition"])["States"]["start"] + parameters = start["Parameters"]["Parameters"] + flow_name = parameters.get("metaflow.flow_name", "") + username = parameters.get("metaflow.owner", "") + except (KeyError, json.JSONDecodeError): + raise MetaflowException( + "Could not extract flow metadata from state machine: %s" % identifier + ) + + fake_flow_file_contents = generate_fake_flow_file_contents( + flow_name=flow_name, param_info={}, project_name=None + ) + + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as fake_flow_file: + with open(fake_flow_file.name, "w") as fp: + fp.write(fake_flow_file_contents) + + d = Deployer( + fake_flow_file.name, + env={"METAFLOW_USER": username}, + ).step_functions(name=identifier) + + d.name = identifier + d.flow_name = flow_name + if metadata is None: + d.metadata = get_metadata() + else: + d.metadata = metadata + + return cls(deployer=d) + @classmethod def get_triggered_run( cls, identifier: str, run_id: str, metadata: Optional[str] = None ): """ - This method is not currently implemented for Step Functions. + Retrieves a `StepFunctionsTriggeredRun` object from an identifier and run id. - Raises - ------ - NotImplementedError - This method is not implemented for Step Functions. + Parameters + ---------- + identifier : str + State machine name for the workflow. + run_id : str + Run ID for the triggered run. + metadata : str, optional, default None + Optional deployer specific metadata. + + Returns + ------- + StepFunctionsTriggeredRun + A `StepFunctionsTriggeredRun` object representing the triggered run. """ - raise NotImplementedError( - "get_triggered_run is not implemented for StepFunctions" + deployed_flow_obj = cls.from_deployment(identifier, metadata) + return StepFunctionsTriggeredRun( + deployer=deployed_flow_obj.deployer, + content=json.dumps( + { + "metadata": deployed_flow_obj.deployer.metadata, + "pathspec": "/".join( + (deployed_flow_obj.deployer.flow_name, run_id) + ), + "name": run_id, + } + ), ) @property @@ -189,23 +364,43 @@ def delete(self, **kwargs) -> bool: bool True if the command was successful, False otherwise. """ - command = get_lower_level_group( - self.deployer.api, - self.deployer.top_level_kwargs, - self.deployer.TYPE, - self.deployer.deployer_kwargs, - ).delete(**kwargs) + return self._run_deployer_command("delete", **kwargs) + + def resume( + self, + origin_run_id: str, + step_to_rerun: Optional[str] = None, + **kwargs, + ) -> StepFunctionsTriggeredRun: + """ + Resume a failed or stopped run on AWS Step Functions. - pid = self.deployer.spm.run_command( - [sys.executable, *command], - env=self.deployer.env_vars, - cwd=self.deployer.cwd, - show_output=self.deployer.show_output, - ) + Successful steps from the origin run will be cloned rather than + re-executed, unless they are downstream of *step_to_rerun*. - command_obj = self.deployer.spm.get(pid) - command_obj.sync_wait() - return command_obj.process.returncode == 0 + Parameters + ---------- + origin_run_id : str + Run ID of the run to resume (e.g., ``"sfn-"``). + step_to_rerun : str, optional + Name of a specific step from which to rerun. All downstream + steps will also be rerun. If not specified, only steps whose + origin task was not successful will be rerun. + **kwargs : Any + Additional arguments to pass to the resume command, + `Parameters` in particular. + + Returns + ------- + StepFunctionsTriggeredRun + The triggered run instance. + """ + resume_kwargs = dict(origin_run_id=origin_run_id, **kwargs) + if step_to_rerun is not None: + resume_kwargs["step_to_rerun"] = step_to_rerun + return self._run_deployer_command( + "resume", return_content=True, **resume_kwargs + ) def trigger(self, **kwargs) -> StepFunctionsTriggeredRun: """ @@ -221,44 +416,5 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun: ------- StepFunctionsTriggeredRun The triggered run instance. - - Raises - ------ - Exception - If there is an error during the trigger process. """ - with temporary_fifo() as (attribute_file_path, attribute_file_fd): - # every subclass needs to have `self.deployer_kwargs` - command = get_lower_level_group( - self.deployer.api, - self.deployer.top_level_kwargs, - self.deployer.TYPE, - self.deployer.deployer_kwargs, - ).trigger(deployer_attribute_file=attribute_file_path, **kwargs) - - pid = self.deployer.spm.run_command( - [sys.executable, *command], - env=self.deployer.env_vars, - cwd=self.deployer.cwd, - show_output=self.deployer.show_output, - ) - - command_obj = self.deployer.spm.get(pid) - content = handle_timeout( - attribute_file_fd, command_obj, self.deployer.file_read_timeout - ) - - command_obj.sync_wait() - if command_obj.process.returncode == 0: - return StepFunctionsTriggeredRun( - deployer=self.deployer, content=content - ) - - raise Exception( - "Error triggering %s on %s for %s" - % ( - self.deployer.name, - self.deployer.TYPE, - self.deployer.flow_file, - ) - ) + return self._run_deployer_command("trigger", return_content=True, **kwargs) diff --git a/metaflow/plugins/datastores/local_storage.py b/metaflow/plugins/datastores/local_storage.py index d00d6567a57..00968b2c9e7 100644 --- a/metaflow/plugins/datastores/local_storage.py +++ b/metaflow/plugins/datastores/local_storage.py @@ -42,7 +42,10 @@ def get_datastore_root_from_config(cls, echo, create_on_absent=True): "Creating %s datastore in current directory (%s)" % (cls.TYPE, orig_path) ) - os.mkdir(orig_path) + try: + os.mkdir(orig_path) + except FileExistsError: + pass # Another process created it concurrently result = orig_path else: return None diff --git a/metaflow/plugins/pypi/conda_decorator.py b/metaflow/plugins/pypi/conda_decorator.py index a21bf76bf18..6ad476bd25e 100644 --- a/metaflow/plugins/pypi/conda_decorator.py +++ b/metaflow/plugins/pypi/conda_decorator.py @@ -43,6 +43,7 @@ class CondaStepDecorator(StepDecorator): _metaflow_home = None _addl_env_vars = None + disabled = False # To define conda channels for the whole solve, users can specify # CONDA_CHANNELS in their environment. For pinning specific packages to specific @@ -248,6 +249,8 @@ def runtime_step_cli( ): if self.disabled: return + if self.__class__._metaflow_home is None: + return # Ensure local installation of Metaflow is visible to user code python_path = self.__class__._metaflow_home.name addl_env_vars = {} diff --git a/test/unit/localbatch/test_localbatch.py b/test/unit/localbatch/test_localbatch.py index 12bdfdc2147..8a9f138f347 100644 --- a/test/unit/localbatch/test_localbatch.py +++ b/test/unit/localbatch/test_localbatch.py @@ -333,13 +333,6 @@ def test_inject_env_is_visible_inside_container(self): # --------------------------------------------------------------------------- -_NEEDS_CORE_BATCH_PARAMS = pytest.mark.xfail( - reason="requires npow/core-deployer-changes: BATCH_CLIENT_PARAMS must be added " - "to metaflow_config.py so METAFLOW_BATCH_CLIENT_PARAMS env var is recognized", - strict=False, -) - - @pytest.mark.docker class TestMetaflowE2E: """ @@ -357,7 +350,6 @@ def _require_docker(self): except Exception: pytest.skip("Docker not available") - @_NEEDS_CORE_BATCH_PARAMS def test_batch_step_artifacts_are_persisted(self, simple_batch_run): """ The @batch step writes message='hello from localbatch' and value=42. @@ -367,11 +359,9 @@ def test_batch_step_artifacts_are_persisted(self, simple_batch_run): assert task["message"].data == "hello from localbatch" assert task["value"].data == 42 - @_NEEDS_CORE_BATCH_PARAMS - def test_run_succeeds(self, simple_batch_run): +def test_run_succeeds(self, simple_batch_run): assert simple_batch_run.successful - @_NEEDS_CORE_BATCH_PARAMS - def test_all_steps_have_tasks(self, simple_batch_run): +def test_all_steps_have_tasks(self, simple_batch_run): step_names = {s.id for s in simple_batch_run.steps()} assert {"start", "end"} <= step_names diff --git a/test/ux/core/test_sfn_compilation.py b/test/ux/core/test_sfn_compilation.py index c63ef3087a1..c8bddb5f577 100644 --- a/test/ux/core/test_sfn_compilation.py +++ b/test/ux/core/test_sfn_compilation.py @@ -201,11 +201,6 @@ def test_linear_flow(self): result.get("result", "OK") == "OK" ), f"Validation failed: {result.get('diagnostics', result)}" - @pytest.mark.xfail( - reason="requires npow/core-deployer-changes: step_functions.py must add " - "ResultSelector to Parallel states for sfn-local compatibility", - strict=False, - ) def test_branch_flow(self): """Parallel branch flow produces valid Parallel states with ResultSelector.""" definition = _compile_flow_to_json("dag/branch_flow.py") diff --git a/test/ux/ux_test_config.yaml b/test/ux/ux_test_config.yaml index 00b9e74cd1e..7cbc2077697 100644 --- a/test/ux/ux_test_config.yaml +++ b/test/ux/ux_test_config.yaml @@ -35,10 +35,6 @@ backends: cluster: null decospec: "batch:image=python:3.9" enabled: true - # xfail: requires core changes from npow/core-deployer-changes (step_functions.py - # sfn-local compatibility fixes: ResultSelector on Parallel states, no ProcessorConfig - # on Map states). Remove once that PR is merged into master. - xfail_reason: "requires npow/core-deployer-changes (sfn-local compatibility fixes in step_functions.py)" # Apache Airflow + Kubernetes (devstack: minikube + airflow helm chart) - name: airflow-kubernetes @@ -46,8 +42,4 @@ backends: cluster: default decospec: "kubernetes:image=python:3.9" enabled: true - # xfail: requires AirflowDeployer from npow/core-deployer-changes (__init__.py - # registration + airflow_deployer.py + airflow_client.py). - # Remove once that PR is merged into master. - xfail_reason: "requires npow/core-deployer-changes (AirflowDeployer not yet in core)"