Skip to content

Commit 5946854

Browse files
Pass security context for Intel GPU based training jobs (#157)
1 parent daaad01 commit 5946854

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

interactive_ai/workflows/geti_domain/common/jobs_common/k8s_helpers/trainer_image_info.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class TrainerImageInfo:
4848

4949
train_image_name: str
5050
sidecar_image_name: str
51+
render_gid: int = 0 # Should be non-zero value when training with Intel GPUs
5152

5253
@classmethod
5354
def create(cls, training_framework: TrainingFramework) -> "TrainerImageInfo":
@@ -64,6 +65,8 @@ def create(cls, training_framework: TrainingFramework) -> "TrainerImageInfo":
6465

6566
configmap = asyncio.run(get_config_map(namespace=namespace, name=name))
6667

68+
render_gid = 0
69+
6770
msg = "Cannot get `{0}` field from config map `{1}/{2}`"
6871

6972
# This information is from `impt-configuration` config map in the namespace `impt`
@@ -73,7 +76,8 @@ def create(cls, training_framework: TrainingFramework) -> "TrainerImageInfo":
7376
raise ValueError(msg.format("ote_image", namespace, name))
7477
if (otx2_image := configmap.data.get("otx2_image")) is None:
7578
raise ValueError(msg.format("otx2_image", namespace, name))
76-
79+
if render_gid_value := configmap.data.get("render_gid"):
80+
render_gid = int(render_gid_value)
7781
if FeatureFlagProvider.is_enabled(FeatureFlag.FEATURE_FLAG_OTX_VERSION_SELECTION):
7882
if parse_version(training_framework.version) < parse_version("2.0.0"):
7983
image_name = ote_image
@@ -86,7 +90,7 @@ def create(cls, training_framework: TrainingFramework) -> "TrainerImageInfo":
8690
f"Trainer image has been selected {image_name}, where a model has trainer "
8791
f"identification for {training_framework.version}."
8892
)
89-
return cls(train_image_name=image_name, sidecar_image_name=mlflow_sidecar_image)
93+
return cls(train_image_name=image_name, sidecar_image_name=mlflow_sidecar_image, render_gid=render_gid)
9094

9195
def to_primary_image_full_name(self) -> str:
9296
"""Get primary image full name.

interactive_ai/workflows/geti_domain/common/jobs_common/k8s_helpers/trainer_pod_definition.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88

99
from flytekit import ContainerTask, PodTemplate, current_context
10-
from kubernetes.client import V1PodSpec
10+
from kubernetes.client import V1Capabilities, V1PodSpec, V1SecurityContext
1111
from kubernetes.client.models import (
1212
V1ConfigMapEnvSource,
1313
V1ConfigMapKeySelector,
@@ -200,6 +200,16 @@ def create_flyte_container_task( # noqa: PLR0913
200200
runtime_class_name = "nvidia" if accelerator_name == "nvidia.com/gpu" else None
201201
logger.info(f"Create runtime_class_name={runtime_class_name}")
202202

203+
security_context = None
204+
if trainer_image_info.render_gid != 0:
205+
security_context = V1SecurityContext(
206+
run_as_group=trainer_image_info.render_gid,
207+
allow_privilege_escalation=False,
208+
read_only_root_filesystem=False,
209+
run_as_non_root=True,
210+
run_as_user=10001,
211+
capabilities=V1Capabilities(drop=["ALL"]),
212+
)
203213
role = "flyte_workflows"
204214

205215
pod_spec = V1PodSpec(
@@ -296,6 +306,7 @@ def create_flyte_container_task( # noqa: PLR0913
296306
V1VolumeMount(mount_path="/dev/shm", name="shared-memory"), # noqa : S108 # nosec: B108
297307
V1VolumeMount(mount_path="/shard_files", name="shard-files-dir"),
298308
],
309+
security_context=security_context,
299310
)
300311
],
301312
init_containers=[

interactive_ai/workflows/geti_domain/common/tests/unit/k8s_helpers/test_trainer_image_info.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,36 @@
1313
@pytest.mark.JobsComponent
1414
class TestTrainerImageInfo:
1515
@pytest.mark.parametrize(
16-
"training_framework, feature_flag_otx_version_selection, primary_image_full_name, sidecar_image_full_name",
16+
"training_framework, feature_flag_otx_version_selection, "
17+
"primary_image_full_name, sidecar_image_full_name, render_gid",
1718
[
1819
(
1920
TrainingFramework(type=TrainingFrameworkType.OTX, version="2.1.0"),
2021
"false",
2122
"otx2_image",
2223
"mlflow_sidecar_image",
24+
0,
2325
),
2426
(
2527
TrainingFramework(type=TrainingFrameworkType.OTX, version="1.6.0"),
2628
"false",
2729
"otx2_image",
2830
"mlflow_sidecar_image",
31+
0,
2932
),
3033
(
3134
TrainingFramework(type=TrainingFrameworkType.OTX, version="2.1.0"),
3235
"true",
3336
"otx2_image",
3437
"mlflow_sidecar_image",
38+
992,
3539
),
3640
(
3741
TrainingFramework(type=TrainingFrameworkType.OTX, version="1.6.0"),
3842
"true",
3943
"ote_image",
4044
"mlflow_sidecar_image",
45+
992,
4146
),
4247
],
4348
)
@@ -49,6 +54,7 @@ def test_create(
4954
feature_flag_otx_version_selection,
5055
primary_image_full_name,
5156
sidecar_image_full_name,
57+
render_gid,
5258
):
5359
os.environ.update(
5460
{
@@ -62,6 +68,7 @@ def test_create(
6268
"mlflow_sidecar_image": "mlflow_sidecar_image",
6369
"ote_image": "ote_image",
6470
"otx2_image": "otx2_image",
71+
"render_gid": str(render_gid),
6572
}
6673
mock_get_config_map.return_value = MagicMock()
6774
mock_get_config_map.return_value.data.get.side_effect = lambda key: configmap_data.get(key)
@@ -70,3 +77,4 @@ def test_create(
7077

7178
assert trainer_image_info.to_primary_image_full_name() == primary_image_full_name
7279
assert trainer_image_info.to_sidecar_image_full_name() == sidecar_image_full_name
80+
assert trainer_image_info.render_gid == render_gid

0 commit comments

Comments
 (0)