Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@
BATCH_EMIT_TAGS = from_conf("BATCH_EMIT_TAGS", False)
# Default tags to add to AWS Batch jobs. These are in addition to the defaults set when BATCH_EMIT_TAGS is true.
BATCH_DEFAULT_TAGS = from_conf("BATCH_DEFAULT_TAGS", {})
# Extra boto3 client kwargs for the Batch client (e.g. {"endpoint_url": "http://localhost:8000"} for local emulators).
BATCH_CLIENT_PARAMS = from_conf("BATCH_CLIENT_PARAMS", {})

###
# AWS Step Functions configuration
Expand All @@ -372,6 +374,10 @@
SFN_IAM_ROLE = from_conf("SFN_IAM_ROLE")
# AWS DynamoDb Table name (with partition key - `pathspec` of type string)
SFN_DYNAMO_DB_TABLE = from_conf("SFN_DYNAMO_DB_TABLE")
# Extra boto3 client kwargs for the Step Functions client (e.g. {"endpoint_url": "http://localhost:8082"}).
SFN_CLIENT_PARAMS = from_conf("SFN_CLIENT_PARAMS", {})
# Extra boto3 client kwargs for the DynamoDB client used by Step Functions (e.g. {"endpoint_url": "http://localhost:8765"}).
SFN_DYNAMO_DB_CLIENT_PARAMS = from_conf("SFN_DYNAMO_DB_CLIENT_PARAMS", {})
# IAM role for AWS Events with AWS Step Functions access
# https://docs.aws.amazon.com/eventbridge/latest/userguide/auth-and-access-control-eventbridge.html
EVENTS_SFN_ACCESS_IAM_ROLE = from_conf("EVENTS_SFN_ACCESS_IAM_ROLE")
Expand Down Expand Up @@ -489,6 +495,16 @@
AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT = from_conf(
"AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT"
)
# Airflow REST API endpoint, e.g. http://localhost:8090/api/v1
AIRFLOW_REST_API_URL = from_conf("AIRFLOW_REST_API_URL")
Comment on lines +498 to +499
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Default Airflow credentials are "admin"/"admin"

AIRFLOW_REST_API_USERNAME and AIRFLOW_REST_API_PASSWORD default to "admin" and "admin". While these are meant to be overridden for real deployments, the defaults could cause accidental authentication against a remote Airflow instance using well-known credentials. Consider defaulting to None and raising a clear error at call-time when they are not set.

AIRFLOW_REST_API_USERNAME = from_conf("AIRFLOW_REST_API_USERNAME", "admin")
AIRFLOW_REST_API_PASSWORD = from_conf("AIRFLOW_REST_API_PASSWORD", "admin")
# Path inside Airflow pods where DAG files are stored
AIRFLOW_KUBERNETES_DAGS_PATH = from_conf(
"AIRFLOW_KUBERNETES_DAGS_PATH", "/opt/airflow/dags"
)
# Kubernetes namespace where Airflow runs (for kubectl cp DAG upload)
AIRFLOW_KUBERNETES_NAMESPACE = from_conf("AIRFLOW_KUBERNETES_NAMESPACE", "default")


###
Expand Down
1 change: 1 addition & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
"step-functions",
".aws.step_functions.step_functions_deployer.StepFunctionsDeployer",
),
("airflow", ".airflow.airflow_deployer.AirflowDeployer"),
]

TL_PLUGINS_DESC = [
Expand Down
16 changes: 16 additions & 0 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,22 @@ def _to_job(self, node):
}
)

# Pass flow config values (--config-value overrides) to the pod so
# config_expr / @project decorators evaluate correctly at task runtime.
try:
from metaflow.flowspec import FlowStateItems

flow_configs = self.flow._flow_state[FlowStateItems.CONFIGS]
config_env = {
name: value
for name, (value, _is_plain) in flow_configs.items()
if value is not None
}
if config_env:
env["METAFLOW_FLOW_CONFIG_VALUE"] = json.dumps(config_env)
except Exception:
pass

# Extract the k8s decorators for constructing the arguments of the K8s Pod Operator on Airflow.
k8s_deco = [deco for deco in node.decorators if deco.name == "kubernetes"][0]
user_code_retries, _ = self._get_retries(node)
Expand Down
83 changes: 83 additions & 0 deletions metaflow/plugins/airflow/airflow_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import json
import os
import re
import sys
Expand Down Expand Up @@ -212,6 +213,13 @@ def airflow(obj, name=None):
show_default=True,
help="Worker pool for Airflow DAG execution.",
)
@click.option(
"--deployer-attribute-file",
default=None,
type=str,
hidden=True,
help="Write the DAG name and metadata to the file specified. Used internally for Metaflow's Deployer API.",
)
@click.pass_obj
def create(
obj,
Expand All @@ -225,6 +233,7 @@ def create(
max_workers=None,
workflow_timeout=None,
worker_pool=None,
deployer_attribute_file=None,
):
if os.path.abspath(sys.argv[0]) == os.path.abspath(file):
raise MetaflowException(
Expand Down Expand Up @@ -262,6 +271,17 @@ def create(
with open(file, "w") as f:
f.write(flow.compile())

if deployer_attribute_file:
with open(deployer_attribute_file, "w", encoding="utf-8") as f:
json.dump(
{
"name": obj.dag_name,
"flow_name": obj.flow.name,
"metadata": obj.metadata.metadata_str(),
},
f,
)

obj.echo(
"DAG *{dag_name}* "
"for flow *{name}* compiled to "
Expand All @@ -270,6 +290,69 @@ def create(
)


@airflow.command(help="Trigger a new run of this Airflow DAG.")
@click.option(
"--run-id-file",
default=None,
show_default=True,
type=str,
help="Write the DAG run ID to the file specified.",
)
@click.option(
"--deployer-attribute-file",
default=None,
type=str,
hidden=True,
help="Write the run metadata and pathspec to the file specified. Used internally for Metaflow's Deployer API.",
)
@click.pass_obj
def trigger(obj, run_id_file=None, deployer_attribute_file=None):
from metaflow.metaflow_config import (
AIRFLOW_REST_API_URL,
AIRFLOW_REST_API_USERNAME,
AIRFLOW_REST_API_PASSWORD,
)
from .airflow_client import AirflowClient

if not AIRFLOW_REST_API_URL:
raise MetaflowException(
"METAFLOW_AIRFLOW_REST_API_URL is not set. Cannot trigger Airflow DAG run."
)

client = AirflowClient(
AIRFLOW_REST_API_URL,
username=AIRFLOW_REST_API_USERNAME,
password=AIRFLOW_REST_API_PASSWORD,
)

dag_id = obj.dag_name
dag_run = client.trigger_dag_run(dag_id)
dag_run_id = dag_run.get("dag_run_id") or dag_run.get("run_id", "")

if run_id_file:
with open(run_id_file, "w") as f:
f.write(dag_run_id)

if deployer_attribute_file:
with open(deployer_attribute_file, "w", encoding="utf-8") as f:
json.dump(
{
"name": dag_run_id,
"dag_id": dag_id,
"metadata": obj.metadata.metadata_str(),
"pathspec": obj.flow.name,
},
f,
)

obj.echo(
"DAG *{dag_id}* triggered on Airflow (run-id *{run_id}*).".format(
dag_id=dag_id, run_id=dag_run_id
),
bold=True,
)


def make_flow(
obj,
dag_name,
Expand Down
187 changes: 187 additions & 0 deletions metaflow/plugins/airflow/airflow_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""
Thin wrapper around the Airflow 2.x REST API (api/v1).

All methods raise ``AirflowClientError`` on non-2xx responses so callers
don't have to inspect status codes themselves.
"""

import json
import time
import urllib.request
import urllib.error
import base64
from typing import Any, Dict, List, Optional

from .exception import AirflowException


class AirflowClientError(AirflowException):
headline = "Airflow REST API error"


class AirflowClient:
"""
Minimal Airflow REST API client (Airflow >= 2.0).

Parameters
----------
rest_api_url : str
Base URL of the Airflow REST API, e.g. ``http://localhost:8090/api/v1``.
username : str
Basic-auth username (default: ``"admin"``).
password : str
Basic-auth password (default: ``"admin"``).
"""

def __init__(
self,
rest_api_url: str,
username: str = "admin",
password: str = "admin",
):
self._base = rest_api_url.rstrip("/")
credentials = base64.b64encode(
("%s:%s" % (username, password)).encode()
).decode()
self._auth_header = "Basic %s" % credentials

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

def _request(
self,
method: str,
path: str,
body: Optional[Dict] = None,
) -> Any:
url = "%s/%s" % (self._base, path.lstrip("/"))
data = json.dumps(body).encode() if body is not None else None
req = urllib.request.Request(
url,
data=data,
method=method,
headers={
"Authorization": self._auth_header,
"Content-Type": "application/json",
"Accept": "application/json",
},
)
try:
with urllib.request.urlopen(req) as resp:
raw = resp.read()
return json.loads(raw) if raw else {}
except urllib.error.HTTPError as e:
body = e.read().decode(errors="replace")
raise AirflowClientError(
"Airflow API %s %s returned HTTP %d: %s" % (method, url, e.code, body)
)

# ------------------------------------------------------------------
# DAG operations
# ------------------------------------------------------------------

def get_dag(self, dag_id: str) -> Dict:
"""Return DAG metadata dict, or None if not found."""
try:
return self._request("GET", "dags/%s" % dag_id)
except AirflowClientError as e:
if "HTTP 404" in str(e):
return None
raise

def patch_dag(self, dag_id: str, **fields) -> Dict:
"""Patch DAG fields (e.g. ``is_paused=False``)."""
return self._request("PATCH", "dags/%s" % dag_id, body=fields)

def delete_dag(self, dag_id: str) -> bool:
"""Delete a DAG. Returns True on success."""
try:
self._request("DELETE", "dags/%s" % dag_id)
return True
except AirflowClientError:
return False

def list_dags(self, tags: Optional[List[str]] = None) -> List[Dict]:
"""List all visible DAGs, optionally filtered by tags."""
params = ""
if tags:
params = "?" + "&".join("tags=%s" % t for t in tags)
result = self._request("GET", "dags%s" % params)
return result.get("dags", [])

# ------------------------------------------------------------------
# DAG run operations
# ------------------------------------------------------------------

def trigger_dag_run(
self,
dag_id: str,
conf: Optional[Dict] = None,
run_id: Optional[str] = None,
) -> Dict:
"""Trigger a DAG run. Returns the dag_run dict."""
body: Dict[str, Any] = {}
if conf:
body["conf"] = conf
if run_id:
body["dag_run_id"] = run_id
return self._request("POST", "dags/%s/dagRuns" % dag_id, body=body)

def get_dag_run(self, dag_id: str, dag_run_id: str) -> Dict:
"""Return dag_run dict for a specific run."""
return self._request("GET", "dags/%s/dagRuns/%s" % (dag_id, dag_run_id))

def list_dag_runs(self, dag_id: str, limit: int = 25) -> List[Dict]:
"""List recent dag runs for a DAG."""
result = self._request(
"GET",
"dags/%s/dagRuns?limit=%d&order_by=-execution_date" % (dag_id, limit),
)
return result.get("dag_runs", [])

def patch_dag_run(self, dag_id: str, dag_run_id: str, **fields) -> Dict:
"""Patch a dag run (e.g. set state to 'failed' to terminate it)."""
return self._request(
"PATCH",
"dags/%s/dagRuns/%s" % (dag_id, dag_run_id),
body=fields,
)

# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------

def wait_for_dag(
self,
dag_id: str,
timeout: int = 120,
polling_interval: int = 5,
) -> Dict:
"""
Poll until the DAG is visible in Airflow (after kubectl-cp / file copy).

Returns the DAG metadata dict when found.

Raises
------
TimeoutError
If the DAG is not discovered within *timeout* seconds.
"""
deadline = time.time() + timeout
while time.time() < deadline:
try:
dag = self.get_dag(dag_id)
except OSError:
# Transient connection error (e.g. RemoteDisconnected) —
# the webserver may still be starting up. Retry silently.
time.sleep(polling_interval)
continue
if dag is not None:
return dag
time.sleep(polling_interval)
raise TimeoutError(
"DAG '%s' did not appear in Airflow within %d seconds. "
"Ensure the DAG file was copied to the dags folder and "
"that the Airflow scheduler is running." % (dag_id, timeout)
)
Loading
Loading