diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index e5e03a401a6..f3f313e445b 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -437,6 +437,10 @@ KUBERNETES_DISK = from_conf("KUBERNETES_DISK", None) # Default kubernetes QoS class KUBERNETES_QOS = from_conf("KUBERNETES_QOS", "burstable") +# Default container security context (JSON) for kubernetes pods +KUBERNETES_SECURITY_CONTEXT = from_conf("KUBERNETES_SECURITY_CONTEXT", "") +# Default pod security context (JSON) for kubernetes pods +KUBERNETES_POD_SECURITY_CONTEXT = from_conf("KUBERNETES_POD_SECURITY_CONTEXT", "") # Architecture of kubernetes nodes - used for @conda/@pypi in metaflow-dev KUBERNETES_CONDA_ARCH = from_conf("KUBERNETES_CONDA_ARCH") diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b7b25c8c69d..c9e33ee2f2d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -2581,6 +2581,15 @@ def _container_templates(self): ) } + pod_security_context = resources.get("pod_security_context", None) + _pod_security_context = {} + if pod_security_context is not None and len(pod_security_context) > 0: + _pod_security_context = { + "security_context": kubernetes_sdk.V1PodSecurityContext( + **pod_security_context + ) + } + # Create a ContainerTemplate for this node. Ideally, we would have # liked to inline this ContainerTemplate and avoid scanning the workflow # twice, but due to issues with variable substitution, we will have to @@ -2639,6 +2648,7 @@ def _container_templates(self): port=port, qos=resources["qos"], security_context=security_context, + pod_security_context=pod_security_context, ) for k, v in env.items(): @@ -2804,6 +2814,16 @@ def _container_templates(self): if resources["image_pull_secrets"] else None ) + # Set pod security context via pod_spec_patch + .pod_spec_patch( + { + "securityContext": kubernetes_sdk.V1PodSecurityContext( + **pod_security_context + ).to_dict() + } + if pod_security_context + else None + ) # Set container .container( # TODO: Unify the logic with kubernetes.py @@ -4503,7 +4523,12 @@ def pod_spec_patch(self, pod_spec_patch=None): if pod_spec_patch is None: return self - self.payload["podSpecPatch"] = json.dumps(pod_spec_patch) + if "podSpecPatch" in self.payload: + existing = json.loads(self.payload["podSpecPatch"]) + existing.update(pod_spec_patch) + self.payload["podSpecPatch"] = json.dumps(existing) + else: + self.payload["podSpecPatch"] = json.dumps(pod_spec_patch) return self diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index c19b3efe3b9..515bab0b152 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -198,6 +198,7 @@ def create_jobset( num_parallel=None, qos=None, security_context=None, + pod_security_context=None, ): name = "js-%s" % str(uuid4())[:6] jobset = ( @@ -233,6 +234,7 @@ def create_jobset( num_parallel=num_parallel, qos=qos, security_context=security_context, + pod_security_context=pod_security_context, ) .environment_variable("METAFLOW_CODE_METADATA", code_package_metadata) .environment_variable("METAFLOW_CODE_SHA", code_package_sha) @@ -499,6 +501,7 @@ def create_job_object( qos=None, annotations=None, security_context=None, + pod_security_context=None, ): if env is None: env = {} @@ -544,6 +547,7 @@ def create_job_object( port=port, qos=qos, security_context=security_context, + pod_security_context=pod_security_context, ) .environment_variable("METAFLOW_CODE_METADATA", code_package_metadata) .environment_variable("METAFLOW_CODE_SHA", code_package_sha) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index e15f7b06cb9..06332ebaaf3 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -158,6 +158,12 @@ def kubernetes(): type=JSONTypeClass(), multiple=False, ) +@click.option( + "--pod-security-context", + default=None, + type=JSONTypeClass(), + multiple=False, +) @click.pass_context def step( ctx, @@ -192,6 +198,7 @@ def step( labels=None, annotations=None, security_context=None, + pod_security_context=None, **kwargs ): def echo(msg, stream="stderr", job_id=None, **kwargs): @@ -338,6 +345,7 @@ def _sync_metadata(): labels=labels, annotations=annotations, security_context=security_context, + pod_security_context=pod_security_context, ) except Exception: traceback.print_exc(chain=False) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index bd3ae7e12c4..e9034572138 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -27,6 +27,8 @@ KUBERNETES_NODE_SELECTOR, KUBERNETES_PERSISTENT_VOLUME_CLAIMS, KUBERNETES_PORT, + KUBERNETES_SECURITY_CONTEXT, + KUBERNETES_POD_SECURITY_CONTEXT, KUBERNETES_SERVICE_ACCOUNT, KUBERNETES_SHARED_MEMORY, KUBERNETES_TOLERATIONS, @@ -136,6 +138,17 @@ class KubernetesDecorator(StepDecorator): - run_as_user: int, optional, default None - run_as_group: int, optional, default None - run_as_non_root: bool, optional, default None + - read_only_root_filesystem: bool, optional, default None + - capabilities: Dict[str, List[str]], optional, default None + Can also be set via METAFLOW_KUBERNETES_SECURITY_CONTEXT (JSON). + pod_security_context: Dict[str, Any], optional, default None + Pod-level security context. Applies to all containers in the pod. Allows the following keys: + - run_as_user: int, optional, default None + - run_as_group: int, optional, default None + - run_as_non_root: bool, optional, default None + - fs_group: int, optional, default None + - supplemental_groups: List[int], optional, default None + Can also be set via METAFLOW_KUBERNETES_POD_SECURITY_CONTEXT (JSON). """ name = "kubernetes" @@ -168,6 +181,7 @@ class KubernetesDecorator(StepDecorator): "hostname_resolution_timeout": 10 * 60, "qos": KUBERNETES_QOS, "security_context": None, + "pod_security_context": None, } package_metadata = None package_url = None @@ -310,6 +324,19 @@ def init(self): if not self.attributes["port"]: self.attributes["port"] = KUBERNETES_PORT + # Security context: decorator takes precedence over env var + if not self.attributes["security_context"] and KUBERNETES_SECURITY_CONTEXT: + self.attributes["security_context"] = json.loads( + KUBERNETES_SECURITY_CONTEXT + ) + if ( + not self.attributes["pod_security_context"] + and KUBERNETES_POD_SECURITY_CONTEXT + ): + self.attributes["pod_security_context"] = json.loads( + KUBERNETES_POD_SECURITY_CONTEXT + ) + # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger): # Executing Kubernetes jobs requires a non-local datastore. @@ -500,6 +527,7 @@ def runtime_step_cli( "labels", "annotations", "security_context", + "pod_security_context", ]: cli_args.command_options[k] = json.dumps(v) else: diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index b81777bcc7b..95b86f13b05 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -93,6 +93,13 @@ def create_job_spec(self): "security_context": client.V1SecurityContext(**security_context) } + pod_security_context = self._kwargs.get("pod_security_context", {}) + _pod_security_context = {} + if pod_security_context is not None and len(pod_security_context) > 0: + _pod_security_context = { + "security_context": client.V1PodSecurityContext(**pod_security_context) + } + return client.V1JobSpec( # Retries are handled by Metaflow when it is responsible for # executing the flow. The responsibility is moved to Kubernetes @@ -277,6 +284,7 @@ def create_job_spec(self): if self._kwargs["persistent_volume_claims"] is not None else [] ), + **_pod_security_context, ), ), ) diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index da0f0fc3130..474fbe4cd39 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -569,6 +569,13 @@ def dump(self): _security_context = { "security_context": client.V1SecurityContext(**security_context) } + + pod_security_context = self._kwargs.get("pod_security_context", {}) + _pod_security_context = {} + if pod_security_context is not None and len(pod_security_context) > 0: + _pod_security_context = { + "security_context": client.V1PodSecurityContext(**pod_security_context) + } return dict( name=self.name, template=client.api_client.ApiClient().sanitize_for_serialization( @@ -784,6 +791,7 @@ def dump(self): is not None else [] ), + **_pod_security_context, ), ), ), diff --git a/test/unit/test_kubernetes_security_context.py b/test/unit/test_kubernetes_security_context.py new file mode 100644 index 00000000000..17e60a9de00 --- /dev/null +++ b/test/unit/test_kubernetes_security_context.py @@ -0,0 +1,374 @@ +"""Tests for Kubernetes security context support (container and pod level).""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_kubernetes_client(): + """Create a mock Kubernetes client that tracks calls to V1SecurityContext and V1PodSecurityContext.""" + with patch("metaflow.plugins.kubernetes.kubernetes_job.KubernetesJob") as _: + from kubernetes import client + + yield client + + +class TestContainerSecurityContext: + """Tests for container-level security context in KubernetesJob.""" + + def test_security_context_applied_to_container(self): + """Verify that security_context dict is passed to V1SecurityContext.""" + from unittest.mock import MagicMock + + mock_client_wrapper = MagicMock() + from kubernetes import client as real_client + + mock_client_wrapper.get.return_value = real_client + + from metaflow.plugins.kubernetes.kubernetes_job import KubernetesJob + + job = KubernetesJob( + client=mock_client_wrapper, + step_name="test_step", + command=["echo", "hello"], + namespace="default", + service_account="default", + image="python:3.9", + image_pull_policy="Always", + image_pull_secrets=[], + cpu="1", + memory="4096", + disk="10240", + gpu=None, + gpu_vendor="nvidia", + timeout_in_seconds=300, + retries=0, + port=None, + use_tmpfs=False, + tmpfs_size=None, + tmpfs_path="/metaflow_temp", + persistent_volume_claims=None, + shared_memory=None, + tolerations=[], + labels={}, + annotations={}, + qos="Burstable", + security_context={"run_as_user": 1000, "run_as_non_root": True}, + pod_security_context=None, + ) + + spec = job.create_job_spec() + container = spec.template.spec.containers[0] + assert container.security_context is not None + assert container.security_context.run_as_user == 1000 + assert container.security_context.run_as_non_root is True + + def test_empty_security_context_not_applied(self): + """Verify that an empty security_context does not set anything.""" + from unittest.mock import MagicMock + + mock_client_wrapper = MagicMock() + from kubernetes import client as real_client + + mock_client_wrapper.get.return_value = real_client + + from metaflow.plugins.kubernetes.kubernetes_job import KubernetesJob + + job = KubernetesJob( + client=mock_client_wrapper, + step_name="test_step", + command=["echo", "hello"], + namespace="default", + service_account="default", + image="python:3.9", + image_pull_policy="Always", + image_pull_secrets=[], + cpu="1", + memory="4096", + disk="10240", + gpu=None, + gpu_vendor="nvidia", + timeout_in_seconds=300, + retries=0, + port=None, + use_tmpfs=False, + tmpfs_size=None, + tmpfs_path="/metaflow_temp", + persistent_volume_claims=None, + shared_memory=None, + tolerations=[], + labels={}, + annotations={}, + qos="Burstable", + security_context=None, + pod_security_context=None, + ) + + spec = job.create_job_spec() + container = spec.template.spec.containers[0] + assert container.security_context is None + + def test_security_context_with_capabilities(self): + """Verify that capabilities can be set in security_context.""" + from unittest.mock import MagicMock + from kubernetes import client as real_client + + mock_client_wrapper = MagicMock() + mock_client_wrapper.get.return_value = real_client + + from metaflow.plugins.kubernetes.kubernetes_job import KubernetesJob + + caps = real_client.V1Capabilities(drop=["ALL"], add=["NET_BIND_SERVICE"]) + job = KubernetesJob( + client=mock_client_wrapper, + step_name="test_step", + command=["echo", "hello"], + namespace="default", + service_account="default", + image="python:3.9", + image_pull_policy="Always", + image_pull_secrets=[], + cpu="1", + memory="4096", + disk="10240", + gpu=None, + gpu_vendor="nvidia", + timeout_in_seconds=300, + retries=0, + port=None, + use_tmpfs=False, + tmpfs_size=None, + tmpfs_path="/metaflow_temp", + persistent_volume_claims=None, + shared_memory=None, + tolerations=[], + labels={}, + annotations={}, + qos="Burstable", + security_context={ + "capabilities": caps, + "read_only_root_filesystem": True, + }, + pod_security_context=None, + ) + + spec = job.create_job_spec() + container = spec.template.spec.containers[0] + assert container.security_context.read_only_root_filesystem is True + assert container.security_context.capabilities is not None + + +class TestPodSecurityContext: + """Tests for pod-level security context in KubernetesJob.""" + + def test_pod_security_context_applied(self): + """Verify that pod_security_context dict is passed to V1PodSecurityContext.""" + from unittest.mock import MagicMock + from kubernetes import client as real_client + + mock_client_wrapper = MagicMock() + mock_client_wrapper.get.return_value = real_client + + from metaflow.plugins.kubernetes.kubernetes_job import KubernetesJob + + job = KubernetesJob( + client=mock_client_wrapper, + step_name="test_step", + command=["echo", "hello"], + namespace="default", + service_account="default", + image="python:3.9", + image_pull_policy="Always", + image_pull_secrets=[], + cpu="1", + memory="4096", + disk="10240", + gpu=None, + gpu_vendor="nvidia", + timeout_in_seconds=300, + retries=0, + port=None, + use_tmpfs=False, + tmpfs_size=None, + tmpfs_path="/metaflow_temp", + persistent_volume_claims=None, + shared_memory=None, + tolerations=[], + labels={}, + annotations={}, + qos="Burstable", + security_context=None, + pod_security_context={"fs_group": 2000, "run_as_non_root": True}, + ) + + spec = job.create_job_spec() + pod_spec = spec.template.spec + assert pod_spec.security_context is not None + assert pod_spec.security_context.fs_group == 2000 + assert pod_spec.security_context.run_as_non_root is True + + def test_empty_pod_security_context_not_applied(self): + """Verify that empty pod_security_context does not set anything.""" + from unittest.mock import MagicMock + from kubernetes import client as real_client + + mock_client_wrapper = MagicMock() + mock_client_wrapper.get.return_value = real_client + + from metaflow.plugins.kubernetes.kubernetes_job import KubernetesJob + + job = KubernetesJob( + client=mock_client_wrapper, + step_name="test_step", + command=["echo", "hello"], + namespace="default", + service_account="default", + image="python:3.9", + image_pull_policy="Always", + image_pull_secrets=[], + cpu="1", + memory="4096", + disk="10240", + gpu=None, + gpu_vendor="nvidia", + timeout_in_seconds=300, + retries=0, + port=None, + use_tmpfs=False, + tmpfs_size=None, + tmpfs_path="/metaflow_temp", + persistent_volume_claims=None, + shared_memory=None, + tolerations=[], + labels={}, + annotations={}, + qos="Burstable", + security_context=None, + pod_security_context=None, + ) + + spec = job.create_job_spec() + pod_spec = spec.template.spec + assert pod_spec.security_context is None + + def test_both_security_contexts_applied(self): + """Verify that both container and pod security contexts can be set simultaneously.""" + from unittest.mock import MagicMock + from kubernetes import client as real_client + + mock_client_wrapper = MagicMock() + mock_client_wrapper.get.return_value = real_client + + from metaflow.plugins.kubernetes.kubernetes_job import KubernetesJob + + job = KubernetesJob( + client=mock_client_wrapper, + step_name="test_step", + command=["echo", "hello"], + namespace="default", + service_account="default", + image="python:3.9", + image_pull_policy="Always", + image_pull_secrets=[], + cpu="1", + memory="4096", + disk="10240", + gpu=None, + gpu_vendor="nvidia", + timeout_in_seconds=300, + retries=0, + port=None, + use_tmpfs=False, + tmpfs_size=None, + tmpfs_path="/metaflow_temp", + persistent_volume_claims=None, + shared_memory=None, + tolerations=[], + labels={}, + annotations={}, + qos="Burstable", + security_context={"run_as_user": 1000, "allow_privilege_escalation": False}, + pod_security_context={ + "fs_group": 2000, + "run_as_group": 3000, + "supplemental_groups": [4000], + }, + ) + + spec = job.create_job_spec() + + # Check container-level + container = spec.template.spec.containers[0] + assert container.security_context.run_as_user == 1000 + assert container.security_context.allow_privilege_escalation is False + + # Check pod-level + pod_spec = spec.template.spec + assert pod_spec.security_context.fs_group == 2000 + assert pod_spec.security_context.run_as_group == 3000 + assert pod_spec.security_context.supplemental_groups == [4000] + + +class TestSecurityContextEnvVarDefaults: + """Tests for environment variable-based security context defaults.""" + + def test_env_var_security_context_parsed(self): + """Verify METAFLOW_KUBERNETES_SECURITY_CONTEXT env var is parsed as JSON.""" + with patch( + "metaflow.plugins.kubernetes.kubernetes_decorator.KUBERNETES_SECURITY_CONTEXT", + '{"run_as_user": 1000}', + ), patch( + "metaflow.plugins.kubernetes.kubernetes_decorator.KUBERNETES_POD_SECURITY_CONTEXT", + "", + ): + from metaflow.plugins.kubernetes.kubernetes_decorator import ( + KubernetesDecorator, + ) + + deco = KubernetesDecorator.__new__(KubernetesDecorator) + deco.attributes = dict(KubernetesDecorator.defaults) + deco.init() + assert deco.attributes["security_context"] == {"run_as_user": 1000} + + def test_env_var_pod_security_context_parsed(self): + """Verify METAFLOW_KUBERNETES_POD_SECURITY_CONTEXT env var is parsed as JSON.""" + with patch( + "metaflow.plugins.kubernetes.kubernetes_decorator.KUBERNETES_SECURITY_CONTEXT", + "", + ), patch( + "metaflow.plugins.kubernetes.kubernetes_decorator.KUBERNETES_POD_SECURITY_CONTEXT", + '{"fs_group": 2000}', + ): + from metaflow.plugins.kubernetes.kubernetes_decorator import ( + KubernetesDecorator, + ) + + deco = KubernetesDecorator.__new__(KubernetesDecorator) + deco.attributes = dict(KubernetesDecorator.defaults) + deco.init() + assert deco.attributes["pod_security_context"] == {"fs_group": 2000} + + def test_decorator_overrides_env_var(self): + """Verify decorator value takes precedence over env var.""" + with patch( + "metaflow.plugins.kubernetes.kubernetes_decorator.KUBERNETES_SECURITY_CONTEXT", + '{"run_as_user": 1000}', + ), patch( + "metaflow.plugins.kubernetes.kubernetes_decorator.KUBERNETES_POD_SECURITY_CONTEXT", + '{"fs_group": 2000}', + ): + from metaflow.plugins.kubernetes.kubernetes_decorator import ( + KubernetesDecorator, + ) + + deco = KubernetesDecorator.__new__(KubernetesDecorator) + deco.attributes = dict(KubernetesDecorator.defaults) + # Simulate decorator explicitly setting the values + deco.attributes["security_context"] = {"run_as_user": 5000} + deco.attributes["pod_security_context"] = {"fs_group": 6000} + deco.init() + # Decorator values should be preserved + assert deco.attributes["security_context"] == {"run_as_user": 5000} + assert deco.attributes["pod_security_context"] == {"fs_group": 6000}