diff --git a/api/openapi-spec/swagger.json b/api/openapi-spec/swagger.json index f2d3315135..31aafb98f2 100644 --- a/api/openapi-spec/swagger.json +++ b/api/openapi-spec/swagger.json @@ -14122,6 +14122,23 @@ } } }, + "trainer.v1alpha1.Metric": { + "type": "object", + "required": [ + "name", + "value" + ], + "properties": { + "name": { + "description": "name is a user-defined label for the metric, e.g. \"loss\", \"eval_accuracy\".", + "type": "string" + }, + "value": { + "description": "value of the metric. Values must be serialized as a string.", + "type": "string" + } + } + }, "trainer.v1alpha1.ModelInitializer": { "description": "ModelInitializer represents the desired configuration to initialize pre-trained model. The ModelInitializer spec will override the runtime Job template which contains this label: `trainer.kubeflow.org/trainjob-ancestor-step: dataset-initializer`", "type": "object", @@ -14601,6 +14618,14 @@ "name" ], "x-kubernetes-list-type": "map" + }, + "trainerStatus": { + "description": "trainerStatus contains the latest observed runtime status of the Trainer step of the TrainJob. It reflects progress, remaining time, metrics, and the last update timestamp.\n\nThis field is nil if the TrainJob does not report trainer-level status, or if no status has been observed yet (for example, immediately after the TrainJob is created).\n\nThis is an alpha feature and requires enabling the TrainJobStatus feature gate.", + "allOf": [ + { + "$ref": "#/components/schemas/trainer.v1alpha1.TrainerStatus" + } + ] } } }, @@ -14666,6 +14691,43 @@ } } }, + "trainer.v1alpha1.TrainerStatus": { + "description": "TrainerStatus represents the latest known runtime status of the Trainer step of the TrainJob.", + "type": "object", + "properties": { + "estimatedRemainingSeconds": { + "description": "estimatedRemainingSeconds gives the estimated remaining training time in seconds before the train job is completed. The value will be empty if it is unknown.", + "type": "integer", + "format": "int32" + }, + "lastUpdatedTime": { + "description": "lastUpdatedTime is the timestamp when the runtime status was observed.", + "allOf": [ + { + "$ref": "#/components/schemas/io.k8s.apimachinery.pkg.apis.meta.v1.Time" + } + ] + }, + "metrics": { + "description": "metrics contains the current metrics for the model.", + "type": "array", + "items": { + "default": {}, + "allOf": [ + { + "$ref": "#/components/schemas/trainer.v1alpha1.Metric" + } + ] + }, + "x-kubernetes-list-type": "atomic" + }, + "progressPercentage": { + "description": "progressPercentage gives an estimate of how complete the TrainJob is as a percentage. The value will be between 0 and 100, or empty if unknown.", + "type": "integer", + "format": "int32" + } + } + }, "trainer.v1alpha1.TrainingRuntime": { "description": "TrainingRuntime represents a training runtime which can be referenced as part of `runtimeRef` API in TrainJob. This resource is a namespaced-scoped and can be referenced by TrainJob that created in the *same* namespace as the TrainingRuntime.", "type": "object", @@ -14781,6 +14843,20 @@ } } }, + "trainer.v1alpha1.UpdateTrainJobStatusRequest": { + "description": "UpdateTrainJobStatusRequest contains the current runtime status (e.g. progress and metrics) for the different stages of the TrainJob.", + "type": "object", + "properties": { + "trainerStatus": { + "description": "trainerStatus contains the latest observed runtime status of the Trainer step of the TrainJob. It reflects progress, remaining time, metrics, and the last update timestamp.\n\nThis field is nil if the TrainJob does not report trainer-level status, or if no status has been observed yet (for example, immediately after the TrainJob is created).\n\nThis is an alpha feature and requires enabling the TrainJobStatus feature gate.", + "allOf": [ + { + "$ref": "#/components/schemas/trainer.v1alpha1.TrainerStatus" + } + ] + } + } + }, "trainer.v1alpha1.VolcanoPodGroupPolicySource": { "description": "VolcanoPodGroupPolicySource represents configuration for the Volcano gang-scheduler.", "type": "object", diff --git a/api/python_api/kubeflow_trainer_api/models/__init__.py b/api/python_api/kubeflow_trainer_api/models/__init__.py index b64fbad192..1c05b8cc7c 100644 --- a/api/python_api/kubeflow_trainer_api/models/__init__.py +++ b/api/python_api/kubeflow_trainer_api/models/__init__.py @@ -372,6 +372,7 @@ from kubeflow_trainer_api.models.trainer_v1alpha1_ml_policy import TrainerV1alpha1MLPolicy from kubeflow_trainer_api.models.trainer_v1alpha1_ml_policy_source import TrainerV1alpha1MLPolicySource from kubeflow_trainer_api.models.trainer_v1alpha1_mpiml_policy_source import TrainerV1alpha1MPIMLPolicySource +from kubeflow_trainer_api.models.trainer_v1alpha1_metric import TrainerV1alpha1Metric from kubeflow_trainer_api.models.trainer_v1alpha1_model_initializer import TrainerV1alpha1ModelInitializer from kubeflow_trainer_api.models.trainer_v1alpha1_pod_group_policy import TrainerV1alpha1PodGroupPolicy from kubeflow_trainer_api.models.trainer_v1alpha1_pod_group_policy_source import TrainerV1alpha1PodGroupPolicySource @@ -385,8 +386,10 @@ from kubeflow_trainer_api.models.trainer_v1alpha1_train_job_spec import TrainerV1alpha1TrainJobSpec from kubeflow_trainer_api.models.trainer_v1alpha1_train_job_status import TrainerV1alpha1TrainJobStatus from kubeflow_trainer_api.models.trainer_v1alpha1_trainer import TrainerV1alpha1Trainer +from kubeflow_trainer_api.models.trainer_v1alpha1_trainer_status import TrainerV1alpha1TrainerStatus from kubeflow_trainer_api.models.trainer_v1alpha1_training_runtime import TrainerV1alpha1TrainingRuntime from kubeflow_trainer_api.models.trainer_v1alpha1_training_runtime_list import TrainerV1alpha1TrainingRuntimeList from kubeflow_trainer_api.models.trainer_v1alpha1_training_runtime_spec import TrainerV1alpha1TrainingRuntimeSpec from kubeflow_trainer_api.models.trainer_v1alpha1_training_runtime_spec_patch import TrainerV1alpha1TrainingRuntimeSpecPatch +from kubeflow_trainer_api.models.trainer_v1alpha1_update_train_job_status_request import TrainerV1alpha1UpdateTrainJobStatusRequest from kubeflow_trainer_api.models.trainer_v1alpha1_volcano_pod_group_policy_source import TrainerV1alpha1VolcanoPodGroupPolicySource diff --git a/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_metric.py b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_metric.py new file mode 100644 index 0000000000..a532002389 --- /dev/null +++ b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_metric.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +""" + Kubeflow Trainer OpenAPI Spec + + No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + + The version of the OpenAPI document: unversioned + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" # noqa: E501 + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + +from pydantic import BaseModel, ConfigDict, Field, StrictStr +from typing import Any, ClassVar, Dict, List +from typing import Optional, Set +from typing_extensions import Self + +class TrainerV1alpha1Metric(BaseModel): + """ + TrainerV1alpha1Metric + """ # noqa: E501 + name: StrictStr = Field(description="name is a user-defined label for the metric, e.g. \"loss\", \"eval_accuracy\".") + value: StrictStr = Field(description="value of the metric. Values must be serialized as a string.") + __properties: ClassVar[List[str]] = ["name", "value"] + + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + protected_namespaces=(), + ) + + + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias""" + # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> Optional[Self]: + """Create an instance of TrainerV1alpha1Metric from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: Set[str] = set([ + ]) + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + return _dict + + @classmethod + def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: + """Create an instance of TrainerV1alpha1Metric from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + _obj = cls.model_validate({ + "name": obj.get("name"), + "value": obj.get("value") + }) + return _obj + + diff --git a/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_train_job_status.py b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_train_job_status.py index d190778d85..46908e6628 100644 --- a/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_train_job_status.py +++ b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_train_job_status.py @@ -21,6 +21,7 @@ from typing import Any, ClassVar, Dict, List, Optional from kubeflow_trainer_api.models.io_k8s_apimachinery_pkg_apis_meta_v1_condition import IoK8sApimachineryPkgApisMetaV1Condition from kubeflow_trainer_api.models.trainer_v1alpha1_job_status import TrainerV1alpha1JobStatus +from kubeflow_trainer_api.models.trainer_v1alpha1_trainer_status import TrainerV1alpha1TrainerStatus from typing import Optional, Set from typing_extensions import Self @@ -30,7 +31,8 @@ class TrainerV1alpha1TrainJobStatus(BaseModel): """ # noqa: E501 conditions: Optional[List[IoK8sApimachineryPkgApisMetaV1Condition]] = Field(default=None, description="conditions for the TrainJob.") jobs_status: Optional[List[TrainerV1alpha1JobStatus]] = Field(default=None, description="jobsStatus tracks the child Jobs in TrainJob.", alias="jobsStatus") - __properties: ClassVar[List[str]] = ["conditions", "jobsStatus"] + trainer_status: Optional[TrainerV1alpha1TrainerStatus] = Field(default=None, description="trainerStatus contains the latest observed runtime status of the Trainer step of the TrainJob. It reflects progress, remaining time, metrics, and the last update timestamp. This field is nil if the TrainJob does not report trainer-level status, or if no status has been observed yet (for example, immediately after the TrainJob is created). This is an alpha feature and requires enabling the TrainJobStatus feature gate.", alias="trainerStatus") + __properties: ClassVar[List[str]] = ["conditions", "jobsStatus", "trainerStatus"] model_config = ConfigDict( populate_by_name=True, @@ -85,6 +87,9 @@ def to_dict(self) -> Dict[str, Any]: if _item_jobs_status: _items.append(_item_jobs_status.to_dict()) _dict['jobsStatus'] = _items + # override the default output from pydantic by calling `to_dict()` of trainer_status + if self.trainer_status: + _dict['trainerStatus'] = self.trainer_status.to_dict() return _dict @classmethod @@ -98,7 +103,8 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: _obj = cls.model_validate({ "conditions": [IoK8sApimachineryPkgApisMetaV1Condition.from_dict(_item) for _item in obj["conditions"]] if obj.get("conditions") is not None else None, - "jobsStatus": [TrainerV1alpha1JobStatus.from_dict(_item) for _item in obj["jobsStatus"]] if obj.get("jobsStatus") is not None else None + "jobsStatus": [TrainerV1alpha1JobStatus.from_dict(_item) for _item in obj["jobsStatus"]] if obj.get("jobsStatus") is not None else None, + "trainerStatus": TrainerV1alpha1TrainerStatus.from_dict(obj["trainerStatus"]) if obj.get("trainerStatus") is not None else None }) return _obj diff --git a/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_trainer_status.py b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_trainer_status.py new file mode 100644 index 0000000000..66e0852994 --- /dev/null +++ b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_trainer_status.py @@ -0,0 +1,102 @@ +# coding: utf-8 + +""" + Kubeflow Trainer OpenAPI Spec + + No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + + The version of the OpenAPI document: unversioned + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" # noqa: E501 + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + +from datetime import datetime +from pydantic import BaseModel, ConfigDict, Field, StrictInt +from typing import Any, ClassVar, Dict, List, Optional +from kubeflow_trainer_api.models.trainer_v1alpha1_metric import TrainerV1alpha1Metric +from typing import Optional, Set +from typing_extensions import Self + +class TrainerV1alpha1TrainerStatus(BaseModel): + """ + TrainerStatus represents the latest known runtime status of the Trainer step of the TrainJob. + """ # noqa: E501 + estimated_remaining_seconds: Optional[StrictInt] = Field(default=None, description="estimatedRemainingSeconds gives the estimated remaining training time in seconds before the train job is completed. The value will be empty if it is unknown.", alias="estimatedRemainingSeconds") + last_updated_time: Optional[datetime] = Field(default=None, description="lastUpdatedTime is the timestamp when the runtime status was observed.", alias="lastUpdatedTime") + metrics: Optional[List[TrainerV1alpha1Metric]] = Field(default=None, description="metrics contains the current metrics for the model.") + progress_percentage: Optional[StrictInt] = Field(default=None, description="progressPercentage gives an estimate of how complete the TrainJob is as a percentage. The value will be between 0 and 100, or empty if unknown.", alias="progressPercentage") + __properties: ClassVar[List[str]] = ["estimatedRemainingSeconds", "lastUpdatedTime", "metrics", "progressPercentage"] + + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + protected_namespaces=(), + ) + + + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias""" + # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> Optional[Self]: + """Create an instance of TrainerV1alpha1TrainerStatus from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: Set[str] = set([ + ]) + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + # override the default output from pydantic by calling `to_dict()` of each item in metrics (list) + _items = [] + if self.metrics: + for _item_metrics in self.metrics: + if _item_metrics: + _items.append(_item_metrics.to_dict()) + _dict['metrics'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: + """Create an instance of TrainerV1alpha1TrainerStatus from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + _obj = cls.model_validate({ + "estimatedRemainingSeconds": obj.get("estimatedRemainingSeconds"), + "lastUpdatedTime": obj.get("lastUpdatedTime"), + "metrics": [TrainerV1alpha1Metric.from_dict(_item) for _item in obj["metrics"]] if obj.get("metrics") is not None else None, + "progressPercentage": obj.get("progressPercentage") + }) + return _obj + + diff --git a/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_update_train_job_status_request.py b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_update_train_job_status_request.py new file mode 100644 index 0000000000..f2ec58f719 --- /dev/null +++ b/api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_update_train_job_status_request.py @@ -0,0 +1,91 @@ +# coding: utf-8 + +""" + Kubeflow Trainer OpenAPI Spec + + No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + + The version of the OpenAPI document: unversioned + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" # noqa: E501 + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + +from pydantic import BaseModel, ConfigDict, Field +from typing import Any, ClassVar, Dict, List, Optional +from kubeflow_trainer_api.models.trainer_v1alpha1_trainer_status import TrainerV1alpha1TrainerStatus +from typing import Optional, Set +from typing_extensions import Self + +class TrainerV1alpha1UpdateTrainJobStatusRequest(BaseModel): + """ + UpdateTrainJobStatusRequest contains the current runtime status (e.g. progress and metrics) for the different stages of the TrainJob. + """ # noqa: E501 + trainer_status: Optional[TrainerV1alpha1TrainerStatus] = Field(default=None, description="trainerStatus contains the latest observed runtime status of the Trainer step of the TrainJob. It reflects progress, remaining time, metrics, and the last update timestamp. This field is nil if the TrainJob does not report trainer-level status, or if no status has been observed yet (for example, immediately after the TrainJob is created). This is an alpha feature and requires enabling the TrainJobStatus feature gate.", alias="trainerStatus") + __properties: ClassVar[List[str]] = ["trainerStatus"] + + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + protected_namespaces=(), + ) + + + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias""" + # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> Optional[Self]: + """Create an instance of TrainerV1alpha1UpdateTrainJobStatusRequest from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: Set[str] = set([ + ]) + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + # override the default output from pydantic by calling `to_dict()` of trainer_status + if self.trainer_status: + _dict['trainerStatus'] = self.trainer_status.to_dict() + return _dict + + @classmethod + def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: + """Create an instance of TrainerV1alpha1UpdateTrainJobStatusRequest from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + _obj = cls.model_validate({ + "trainerStatus": TrainerV1alpha1TrainerStatus.from_dict(obj["trainerStatus"]) if obj.get("trainerStatus") is not None else None + }) + return _obj + + diff --git a/charts/kubeflow-trainer/README.md b/charts/kubeflow-trainer/README.md index aac6ad1546..2b71b5d4d8 100644 --- a/charts/kubeflow-trainer/README.md +++ b/charts/kubeflow-trainer/README.md @@ -136,7 +136,10 @@ manager: | manager.volumeMounts | list | `[]` | Volume mounts for manager containers. | | manager.resources | object | `{}` | Pod resource requests and limits for manager containers. | | manager.securityContext | object | `{"allowPrivilegeEscalation":false,"capabilities":{"drop":["ALL"]},"runAsNonRoot":true,"seccompProfile":{"type":"RuntimeDefault"}}` | Security context for manager containers. | -| manager.config | object | `{"certManagement":{"enable":true,"webhookSecretName":"","webhookServiceName":""},"controller":{"groupKindConcurrency":{"clusterTrainingRuntime":1,"trainJob":5,"trainingRuntime":1}},"featureGates":{},"health":{"healthProbeBindAddress":":8081","livenessEndpointName":"healthz","readinessEndpointName":"readyz"},"leaderElection":{"leaderElect":true,"leaseDuration":"15s","renewDeadline":"10s","resourceName":"trainer.kubeflow.org","resourceNamespace":"","retryPeriod":"2s"},"metrics":{"bindAddress":":8443","secureServing":true},"webhook":{"host":"","port":9443}}` | Controller manager configuration. This configuration is used to generate the ConfigMap for the controller manager. | +| manager.config | object | `{"certManagement":{"enable":true,"webhookSecretName":"","webhookServiceName":""},"controller":{"groupKindConcurrency":{"clusterTrainingRuntime":1,"trainJob":5,"trainingRuntime":1}},"featureGates":{},"health":{"healthProbeBindAddress":":8081","livenessEndpointName":"healthz","readinessEndpointName":"readyz"},"leaderElection":{"leaderElect":true,"leaseDuration":"15s","renewDeadline":"10s","resourceName":"trainer.kubeflow.org","resourceNamespace":"","retryPeriod":"2s"},"metrics":{"bindAddress":":8443","secureServing":true},"statusServer":{"burst":10,"port":10443,"qps":5},"webhook":{"host":"","port":9443}}` | Controller manager configuration. This configuration is used to generate the ConfigMap for the controller manager. | +| manager.config.statusServer.port | int | `10443` | Port that the TrainJob status server serves on. | +| manager.config.statusServer.qps | int | `5` | QPS rate limit for the TrainJob Status Server api client | +| manager.config.statusServer.burst | int | `10` | Burst rate limit for the TrainJob Status Server api client | | webhook.failurePolicy | string | `"Fail"` | Specifies how unrecognized errors are handled. Available options are `Ignore` or `Fail`. | | dataCache.enabled | bool | `false` | Enable/disable data cache support (LWS dependency, ClusterRole). Set to `true` to install data cache components. | | dataCache.lws.install | bool | `true` | Whether to install LeaderWorkerSet as a dependency. Set to `false` if LeaderWorkerSet is already installed in the cluster. | diff --git a/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainjobs.yaml b/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainjobs.yaml index 9197d96300..acac158aa8 100644 --- a/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainjobs.yaml +++ b/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainjobs.yaml @@ -5949,6 +5949,66 @@ spec: x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map + trainerStatus: + description: |- + trainerStatus contains the latest observed runtime status of the + Trainer step of the TrainJob. It reflects progress, remaining time, + metrics, and the last update timestamp. + + This field is nil if the TrainJob does not report trainer-level + status, or if no status has been observed yet (for example, + immediately after the TrainJob is created). + + This is an alpha feature and requires enabling the TrainJobStatus feature gate. + properties: + estimatedRemainingSeconds: + description: |- + estimatedRemainingSeconds gives the estimated remaining training time in seconds + before the train job is completed. + The value will be empty if it is unknown. + format: int32 + minimum: 0 + type: integer + lastUpdatedTime: + description: lastUpdatedTime is the timestamp when the runtime + status was observed. + format: date-time + type: string + metrics: + description: metrics contains the current metrics for the model. + items: + properties: + name: + description: name is a user-defined label for the metric, + e.g. "loss", "eval_accuracy". + maxLength: 64 + minLength: 1 + type: string + value: + description: value of the metric. Values must be serialized + as a string. + maxLength: 64 + minLength: 1 + type: string + required: + - name + - value + type: object + maxItems: 256 + type: array + x-kubernetes-list-type: atomic + progressPercentage: + description: |- + progressPercentage gives an estimate of how complete the TrainJob is as a percentage. + The value will be between 0 and 100, or empty if unknown. + format: int32 + maximum: 100 + minimum: 0 + type: integer + type: object + x-kubernetes-validations: + - message: lastUpdatedTime is required when trainerStatus is present + rule: has(self.lastUpdatedTime) type: object type: object x-kubernetes-validations: diff --git a/charts/kubeflow-trainer/templates/manager/configmap.yaml b/charts/kubeflow-trainer/templates/manager/configmap.yaml index f32cbe3610..a9f76c1ff1 100644 --- a/charts/kubeflow-trainer/templates/manager/configmap.yaml +++ b/charts/kubeflow-trainer/templates/manager/configmap.yaml @@ -63,6 +63,11 @@ data: webhookServiceName: {{ if .Values.manager.config.certManagement.webhookServiceName }}{{ .Values.manager.config.certManagement.webhookServiceName }}{{ else }}{{ include "trainer.webhook.service.name" . }}{{ end }} webhookSecretName: {{ if .Values.manager.config.certManagement.webhookSecretName }}{{ .Values.manager.config.certManagement.webhookSecretName }}{{ else }}{{ include "trainer.webhook.secret.name" . }}{{ end }} + statusServer: + port: {{ .Values.manager.config.statusServer.port }} + qps: {{ .Values.manager.config.statusServer.qps }} + burst: {{ .Values.manager.config.statusServer.burst }} + {{ with .Values.manager.config.featureGates }} # Feature gates featureGates: {{ toYaml . | nindent 6 }} diff --git a/charts/kubeflow-trainer/templates/manager/deployment.yaml b/charts/kubeflow-trainer/templates/manager/deployment.yaml index 9cf019fc67..6f445a4d6a 100644 --- a/charts/kubeflow-trainer/templates/manager/deployment.yaml +++ b/charts/kubeflow-trainer/templates/manager/deployment.yaml @@ -64,6 +64,9 @@ spec: - name: webhook containerPort: 9443 protocol: TCP + - name: status-server + containerPort: 10443 + protocol: TCP {{- with .Values.manager.env }} env: diff --git a/charts/kubeflow-trainer/templates/manager/service.yaml b/charts/kubeflow-trainer/templates/manager/service.yaml index 9c935cdf8d..dfce02a6df 100644 --- a/charts/kubeflow-trainer/templates/manager/service.yaml +++ b/charts/kubeflow-trainer/templates/manager/service.yaml @@ -32,3 +32,7 @@ spec: port: 443 protocol: TCP targetPort: webhook + - name: status-server + port: 10443 + targetPort: status-server + protocol: TCP diff --git a/charts/kubeflow-trainer/values.yaml b/charts/kubeflow-trainer/values.yaml index e779d1f573..6d32e47214 100644 --- a/charts/kubeflow-trainer/values.yaml +++ b/charts/kubeflow-trainer/values.yaml @@ -138,6 +138,13 @@ manager: # webhookServiceName and webhookSecretName are auto-generated if not specified webhookServiceName: "" webhookSecretName: "" + statusServer: + # -- Port that the TrainJob status server serves on. + port: 10443 + # -- QPS rate limit for the TrainJob Status Server api client + qps: 5 + # -- Burst rate limit for the TrainJob Status Server api client + burst: 10 featureGates: {} webhook: diff --git a/cmd/trainer-controller-manager/main.go b/cmd/trainer-controller-manager/main.go index 686de0d60b..6fee6f5b6f 100644 --- a/cmd/trainer-controller-manager/main.go +++ b/cmd/trainer-controller-manager/main.go @@ -40,8 +40,10 @@ import ( trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/config" "github.com/kubeflow/trainer/v2/pkg/controller" + "github.com/kubeflow/trainer/v2/pkg/features" "github.com/kubeflow/trainer/v2/pkg/runtime" runtimecore "github.com/kubeflow/trainer/v2/pkg/runtime/core" + "github.com/kubeflow/trainer/v2/pkg/statusserver" "github.com/kubeflow/trainer/v2/pkg/util/cert" "github.com/kubeflow/trainer/v2/pkg/webhooks" ) @@ -67,6 +69,7 @@ func init() { func main() { var configFile string var enableHTTP2 bool + var featureGates string flag.StringVar(&configFile, "config", "", "The controller will load its initial configuration from this file. "+ @@ -80,6 +83,9 @@ func main() { // - https://github.com/advisories/GHSA-4374-p667-p6c8 flag.BoolVar(&enableHTTP2, "enable-http2", false, "If set, HTTP/2 will be enabled for the metrics and webhook servers") + flag.StringVar(&featureGates, "feature-gates", "", + "A comma-separated list of key=value pairs that describe feature gates. "+ + "Command-line feature gates override those specified in the config file.") zapOpts := zap.Options{ TimeEncoder: zapcore.RFC3339NanoTimeEncoder, @@ -97,11 +103,20 @@ func main() { os.Exit(1) } + // Set feature gates from config file first if err := utilfeature.DefaultMutableFeatureGate.SetFromMap(cfg.FeatureGates); err != nil { - setupLog.Error(err, "Unable to set flag gates for known features") + setupLog.Error(err, "Unable to set feature gates from config file") os.Exit(1) } + // Command-line feature gates override config file settings + if featureGates != "" { + if err := utilfeature.DefaultMutableFeatureGate.Set(featureGates); err != nil { + setupLog.Error(err, "Unable to set feature gates from command line") + os.Exit(1) + } + } + setupLog.Info("Creating manager") mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), options) if err != nil { @@ -128,13 +143,13 @@ func main() { ctx := ctrl.SetupSignalHandler() setupProbeEndpoints(mgr, certsReady) - runtimes, err := runtimecore.New(ctx, mgr.GetClient(), mgr.GetFieldIndexer()) + runtimes, err := runtimecore.New(ctx, mgr.GetClient(), mgr.GetFieldIndexer(), &cfg) if err != nil { setupLog.Error(err, "Could not initialize runtimes") os.Exit(1) } - // Set up controllers using goroutines to start the manager quickly. - go setupControllers(mgr, runtimes, certsReady) + // Set up controllers and other components using goroutines to start the manager quickly. + go setupManagerComponents(mgr, runtimes, &cfg, certsReady, enableHTTP2) setupLog.Info("Starting manager") if err = mgr.Start(ctx); err != nil { @@ -143,7 +158,7 @@ func main() { } } -func setupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime, certsReady <-chan struct{}) { +func setupManagerComponents(mgr ctrl.Manager, runtimes map[string]runtime.Runtime, cfg *configapi.Configuration, certsReady <-chan struct{}, enableHTTP2 bool) { setupLog.Info("Waiting for certificate generation to complete") <-certsReady setupLog.Info("Certs ready") @@ -156,6 +171,13 @@ func setupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime, cer setupLog.Error(err, "Could not create webhook", "webhook", failedWebhook) os.Exit(1) } + + if features.Enabled(features.TrainJobStatus) { + if err := statusserver.SetupServer(mgr, cfg.StatusServer, enableHTTP2); err != nil { + setupLog.Error(err, "Could not create runtime status server") + os.Exit(1) + } + } } func setupProbeEndpoints(mgr ctrl.Manager, certsReady <-chan struct{}) { diff --git a/docs/proposals/2779-trainjob-progress/README.md b/docs/proposals/2779-trainjob-progress/README.md index 31265059ec..aaaef4f4dc 100644 --- a/docs/proposals/2779-trainjob-progress/README.md +++ b/docs/proposals/2779-trainjob-progress/README.md @@ -63,7 +63,7 @@ We propose an approach with the following high-level **push-based** design: Users can choose not to instrument their runtime, in which case no progress and metrics will be available on the TrainJob. The feature is therefore optional and opt-in. -The feature will have an associated feature gate `TrainJobProgress`, defaulting to "disabled". Disabling the gate will disable the http service. +The feature will have an associated feature gate `TrainJobStatus`, defaulting to "disabled". Disabling the gate will disable the http service. ### CRD changes @@ -322,9 +322,9 @@ Users will need to instrument their train jobs so that they periodically send tr To make it easier for training pods to update the training status, if the feature gate is enabled the control plane will inject the following environment variables into all containers of all pods of the training job: ```shell -KUBEFLOW_TRAINER_STATUS_URL=https://kubeflow-trainer-controller-manager.kubeflow:8082/apis/trainer.kubeflow.org/v1alpha1/namespaces/{namespace}/trainjobs/{name}/status -KUBEFLOW_TRAINER_STATUS_CA_CERT=/var/run/secrets/kubeflow/trainer/ca.crt -KUBEFLOW_TRAINER_STATUS_TOKEN=/var/run/secrets/kubeflow/trainer/token +KUBEFLOW_TRAINER_SERVER_URL=https://kubeflow-trainer-controller-manager.kubeflow:8082/apis/trainer.kubeflow.org/v1alpha1/namespaces/{namespace}/trainjobs/{name}/status +KUBEFLOW_TRAINER_SERVER_CA_CERT=/var/run/secrets/kubeflow/trainer/ca.crt +KUBEFLOW_TRAINER_SERVER_TOKEN=/var/run/secrets/kubeflow/trainer/token ``` These environment variables make it easy for any pod to report the runtime code for submitting status updates, e.g.: @@ -335,9 +335,9 @@ import ssl def update_training_status(payload): try: - url = os.environ["KUBEFLOW_TRAINER_STATUS_URL"] - ca_file = os.environ["KUBEFLOW_TRAINER_STATUS_CA_CERT"] - token = open(os.environ["KUBEFLOW_TRAINER_STATUS_TOKEN"]).read() + url = os.environ["KUBEFLOW_TRAINER_SERVER_URL"] + ca_file = os.environ["KUBEFLOW_TRAINER_SERVER_CA_CERT"] + token = open(os.environ["KUBEFLOW_TRAINER_SERVER_TOKEN"]).read() ssl_context = ssl.create_default_context(cafile=ca_file) req = request.Request(url, data=payload, headers={"Authorization": f"Bearer {token}"}) request.urlopen(req, ssl_context=ssl_context) diff --git a/go.mod b/go.mod index 667d060118..3ec6767dbf 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kubeflow/trainer/v2 go 1.25.0 require ( + github.com/coreos/go-oidc/v3 v3.17.0 github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 github.com/onsi/ginkgo/v2 v2.28.1 @@ -40,6 +41,7 @@ require ( github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.1 // indirect github.com/go-openapi/jsonreference v0.21.0 // indirect diff --git a/go.sum b/go.sum index bb7e8210e3..00e07ff2f1 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -31,6 +33,8 @@ github.com/gkampitakis/go-diff v1.3.2 h1:Qyn0J9XJSDTgnsgHRdz9Zp24RaJeKMUHg2+PDZZ github.com/gkampitakis/go-diff v1.3.2/go.mod h1:LLgOrpqleQe26cte8s36HTWcTmMEur6OPYerdAAS9tk= github.com/gkampitakis/go-snaps v0.5.15 h1:amyJrvM1D33cPHwVrjo9jQxX8g/7E2wYdZ+01KS3zGE= github.com/gkampitakis/go-snaps v0.5.15/go.mod h1:HNpx/9GoKisdhw9AFOBT1N7DBs9DiHo/hGheFGBZ+mc= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= diff --git a/hack/e2e-setup-cluster.sh b/hack/e2e-setup-cluster.sh index e17962a86f..f13c34bd24 100755 --- a/hack/e2e-setup-cluster.sh +++ b/hack/e2e-setup-cluster.sh @@ -63,6 +63,15 @@ cat <"${E2E_MANIFESTS_DIR}/kustomization.yaml" images: - name: "${CONTROLLER_MANAGER_CI_IMAGE_NAME}" newTag: "${CI_IMAGE_TAG}" + patches: + - patch: |- + # enable feature flags + - op: add + path: /spec/template/spec/containers/0/args/- + value: --feature-gates=TrainJobStatus=true + target: + kind: Deployment + name: kubeflow-trainer-controller-manager EOF kubectl apply --server-side -k "${E2E_MANIFESTS_DIR}" diff --git a/hack/e2e-setup-gpu-cluster.sh b/hack/e2e-setup-gpu-cluster.sh index a569027550..d3b5adfb35 100755 --- a/hack/e2e-setup-gpu-cluster.sh +++ b/hack/e2e-setup-gpu-cluster.sh @@ -125,6 +125,15 @@ cat <"${E2E_MANIFESTS_DIR}/kustomization.yaml" images: - name: "${CONTROLLER_MANAGER_CI_IMAGE_NAME}" newTag: "${CI_IMAGE_TAG}" + patches: + - patch: |- + # enable feature flags + - op: add + path: /spec/template/spec/containers/0/args/- + value: --feature-gates=TrainJobStatus=true + target: + kind: Deployment + name: kubeflow-trainer-controller-manager EOF kubectl apply --server-side -k "${E2E_MANIFESTS_DIR}" diff --git a/manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml b/manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml index 9197d96300..acac158aa8 100644 --- a/manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml +++ b/manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml @@ -5949,6 +5949,66 @@ spec: x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map + trainerStatus: + description: |- + trainerStatus contains the latest observed runtime status of the + Trainer step of the TrainJob. It reflects progress, remaining time, + metrics, and the last update timestamp. + + This field is nil if the TrainJob does not report trainer-level + status, or if no status has been observed yet (for example, + immediately after the TrainJob is created). + + This is an alpha feature and requires enabling the TrainJobStatus feature gate. + properties: + estimatedRemainingSeconds: + description: |- + estimatedRemainingSeconds gives the estimated remaining training time in seconds + before the train job is completed. + The value will be empty if it is unknown. + format: int32 + minimum: 0 + type: integer + lastUpdatedTime: + description: lastUpdatedTime is the timestamp when the runtime + status was observed. + format: date-time + type: string + metrics: + description: metrics contains the current metrics for the model. + items: + properties: + name: + description: name is a user-defined label for the metric, + e.g. "loss", "eval_accuracy". + maxLength: 64 + minLength: 1 + type: string + value: + description: value of the metric. Values must be serialized + as a string. + maxLength: 64 + minLength: 1 + type: string + required: + - name + - value + type: object + maxItems: 256 + type: array + x-kubernetes-list-type: atomic + progressPercentage: + description: |- + progressPercentage gives an estimate of how complete the TrainJob is as a percentage. + The value will be between 0 and 100, or empty if unknown. + format: int32 + maximum: 100 + minimum: 0 + type: integer + type: object + x-kubernetes-validations: + - message: lastUpdatedTime is required when trainerStatus is present + rule: has(self.lastUpdatedTime) type: object type: object x-kubernetes-validations: diff --git a/manifests/base/manager/controller_manager_config.yaml b/manifests/base/manager/controller_manager_config.yaml index 7028a844bf..397b2a6106 100644 --- a/manifests/base/manager/controller_manager_config.yaml +++ b/manifests/base/manager/controller_manager_config.yaml @@ -42,3 +42,8 @@ certManagement: clientConnection: qps: 50 burst: 100 + +statusServer: + port: 10443 + qps: 5 + burst: 10 diff --git a/manifests/base/manager/manager.yaml b/manifests/base/manager/manager.yaml index 8f5dc3056c..209ee07c09 100644 --- a/manifests/base/manager/manager.yaml +++ b/manifests/base/manager/manager.yaml @@ -43,6 +43,9 @@ spec: - name: webhook containerPort: 9443 protocol: TCP + - name: status-server + containerPort: 10443 + protocol: TCP volumeMounts: - mountPath: /tmp/k8s-webhook-server/serving-certs @@ -85,5 +88,9 @@ spec: port: 443 protocol: TCP targetPort: webhook + - name: status-server + port: 10443 + protocol: TCP + targetPort: status-server selector: app.kubernetes.io/component: manager diff --git a/pkg/apis/config/v1alpha1/configuration_types.go b/pkg/apis/config/v1alpha1/configuration_types.go index e178590435..9be1961be3 100644 --- a/pkg/apis/config/v1alpha1/configuration_types.go +++ b/pkg/apis/config/v1alpha1/configuration_types.go @@ -60,6 +60,10 @@ type Configuration struct { // +optional ClientConnection *ClientConnection `json:"clientConnection,omitempty"` + // statusServer provides configuration options for the Runtime Status Server. + // +optional + StatusServer *StatusServer `json:"statusServer,omitempty"` + // featureGates is a map of feature names to bools that allows to override the // default enablement status of a feature. // +optional @@ -196,3 +200,25 @@ type ClientConnection struct { // +kubebuilder:default=100 Burst *int32 `json:"burst,omitempty"` } + +type StatusServer struct { + // port is the port that the status server serves at. + // Defaults to 10443. + // +optional + // +kubebuilder:default=10443 + Port *int32 `json:"port,omitempty"` + + // qps controls the number of queries per second allowed for the status server's + // Kubernetes client before client-side throttling. + // Defaults to 5. + // +optional + // +kubebuilder:default=5 + QPS *float32 `json:"qps,omitempty"` + + // burst allows extra queries to accumulate when the status server client is not + // using its full QPS allocation. + // Defaults to 10. + // +optional + // +kubebuilder:default=10 + Burst *int32 `json:"burst,omitempty"` +} diff --git a/pkg/apis/config/v1alpha1/defaults.go b/pkg/apis/config/v1alpha1/defaults.go index 5437fe6a4b..19a3a3beee 100644 --- a/pkg/apis/config/v1alpha1/defaults.go +++ b/pkg/apis/config/v1alpha1/defaults.go @@ -61,4 +61,16 @@ func SetDefaults_Configuration(cfg *Configuration) { if cfg.ClientConnection.Burst == nil { cfg.ClientConnection.Burst = ptr.To[int32](100) } + if cfg.StatusServer == nil { + cfg.StatusServer = &StatusServer{} + } + if cfg.StatusServer.Port == nil { + cfg.StatusServer.Port = ptr.To[int32](10443) + } + if cfg.StatusServer.QPS == nil { + cfg.StatusServer.QPS = ptr.To[float32](5) + } + if cfg.StatusServer.Burst == nil { + cfg.StatusServer.Burst = ptr.To[int32](10) + } } diff --git a/pkg/apis/config/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/config/v1alpha1/zz_generated.deepcopy.go index d64aa00baf..7e3493bd42 100644 --- a/pkg/apis/config/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/config/v1alpha1/zz_generated.deepcopy.go @@ -95,6 +95,11 @@ func (in *Configuration) DeepCopyInto(out *Configuration) { *out = new(ClientConnection) (*in).DeepCopyInto(*out) } + if in.StatusServer != nil { + in, out := &in.StatusServer, &out.StatusServer + *out = new(StatusServer) + (*in).DeepCopyInto(*out) + } if in.FeatureGates != nil { in, out := &in.FeatureGates, &out.FeatureGates *out = make(map[string]bool, len(*in)) @@ -203,3 +208,33 @@ func (in *ControllerWebhook) DeepCopy() *ControllerWebhook { in.DeepCopyInto(out) return out } + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *StatusServer) DeepCopyInto(out *StatusServer) { + *out = *in + if in.Port != nil { + in, out := &in.Port, &out.Port + *out = new(int32) + **out = **in + } + if in.QPS != nil { + in, out := &in.QPS, &out.QPS + *out = new(float32) + **out = **in + } + if in.Burst != nil { + in, out := &in.Burst, &out.Burst + *out = new(int32) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StatusServer. +func (in *StatusServer) DeepCopy() *StatusServer { + if in == nil { + return nil + } + out := new(StatusServer) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/apis/trainer/v1alpha1/trainjob_types.go b/pkg/apis/trainer/v1alpha1/trainjob_types.go index bc73fe4a99..86e378c7e3 100644 --- a/pkg/apis/trainer/v1alpha1/trainjob_types.go +++ b/pkg/apis/trainer/v1alpha1/trainjob_types.go @@ -496,6 +496,18 @@ type TrainJobStatus struct { // +kubebuilder:validation:MaxItems=100 // +optional JobsStatus []JobStatus `json:"jobsStatus,omitempty"` + + // trainerStatus contains the latest observed runtime status of the + // Trainer step of the TrainJob. It reflects progress, remaining time, + // metrics, and the last update timestamp. + // + // This field is nil if the TrainJob does not report trainer-level + // status, or if no status has been observed yet (for example, + // immediately after the TrainJob is created). + // + // This is an alpha feature and requires enabling the TrainJobStatus feature gate. + // +optional + TrainerStatus *TrainerStatus `json:"trainerStatus,omitempty"` } type JobStatus struct { @@ -533,6 +545,68 @@ type JobStatus struct { Suspended *int32 `json:"suspended,omitempty"` } +// TrainerStatus represents the latest known runtime status of the Trainer step of the TrainJob. +// +kubebuilder:validation:XValidation:rule="has(self.lastUpdatedTime)",message="lastUpdatedTime is required when trainerStatus is present" +type TrainerStatus struct { + + // progressPercentage gives an estimate of how complete the TrainJob is as a percentage. + // The value will be between 0 and 100, or empty if unknown. + // + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=100 + // +optional + ProgressPercentage *int32 `json:"progressPercentage,omitempty"` + + // estimatedRemainingSeconds gives the estimated remaining training time in seconds + // before the train job is completed. + // The value will be empty if it is unknown. + // + // +kubebuilder:validation:Minimum=0 + // +optional + EstimatedRemainingSeconds *int32 `json:"estimatedRemainingSeconds,omitempty"` + + // metrics contains the current metrics for the model. + // + // +kubebuilder:validation:MaxItems=256 + // +listType=atomic + // +optional + Metrics []Metric `json:"metrics,omitempty"` + + // lastUpdatedTime is the timestamp when the runtime status was observed. + // +optional + LastUpdatedTime metav1.Time `json:"lastUpdatedTime,omitempty"` +} + +type Metric struct { + // name is a user-defined label for the metric, e.g. "loss", "eval_accuracy". + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=64 + // +required + Name string `json:"name,omitempty"` + + // value of the metric. Values must be serialized as a string. + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=64 + // +required + Value string `json:"value,omitempty"` +} + +// UpdateTrainJobStatusRequest contains the current runtime status (e.g. progress and metrics) for the different stages of the +// TrainJob. +type UpdateTrainJobStatusRequest struct { + // trainerStatus contains the latest observed runtime status of the + // Trainer step of the TrainJob. It reflects progress, remaining time, + // metrics, and the last update timestamp. + // + // This field is nil if the TrainJob does not report trainer-level + // status, or if no status has been observed yet (for example, + // immediately after the TrainJob is created). + // + // This is an alpha feature and requires enabling the TrainJobStatus feature gate. + // +optional + TrainerStatus *TrainerStatus `json:"trainerStatus,omitempty"` +} + func init() { SchemeBuilder.Register(&TrainJob{}, &TrainJobList{}) } diff --git a/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go index 5e9ca23912..d6fe1376a7 100644 --- a/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go @@ -492,6 +492,22 @@ func (in *MPIMLPolicySource) DeepCopy() *MPIMLPolicySource { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Metric) DeepCopyInto(out *Metric) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Metric. +func (in *Metric) DeepCopy() *Metric { + if in == nil { + return nil + } + out := new(Metric) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ModelInitializer) DeepCopyInto(out *ModelInitializer) { *out = *in @@ -880,6 +896,11 @@ func (in *TrainJobStatus) DeepCopyInto(out *TrainJobStatus) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.TrainerStatus != nil { + in, out := &in.TrainerStatus, &out.TrainerStatus + *out = new(TrainerStatus) + (*in).DeepCopyInto(*out) + } return } @@ -946,6 +967,38 @@ func (in *Trainer) DeepCopy() *Trainer { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *TrainerStatus) DeepCopyInto(out *TrainerStatus) { + *out = *in + if in.ProgressPercentage != nil { + in, out := &in.ProgressPercentage, &out.ProgressPercentage + *out = new(int32) + **out = **in + } + if in.EstimatedRemainingSeconds != nil { + in, out := &in.EstimatedRemainingSeconds, &out.EstimatedRemainingSeconds + *out = new(int32) + **out = **in + } + if in.Metrics != nil { + in, out := &in.Metrics, &out.Metrics + *out = make([]Metric, len(*in)) + copy(*out, *in) + } + in.LastUpdatedTime.DeepCopyInto(&out.LastUpdatedTime) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TrainerStatus. +func (in *TrainerStatus) DeepCopy() *TrainerStatus { + if in == nil { + return nil + } + out := new(TrainerStatus) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *TrainingRuntime) DeepCopyInto(out *TrainingRuntime) { *out = *in @@ -1054,6 +1107,27 @@ func (in *TrainingRuntimeSpecPatch) DeepCopy() *TrainingRuntimeSpecPatch { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *UpdateTrainJobStatusRequest) DeepCopyInto(out *UpdateTrainJobStatusRequest) { + *out = *in + if in.TrainerStatus != nil { + in, out := &in.TrainerStatus, &out.TrainerStatus + *out = new(TrainerStatus) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new UpdateTrainJobStatusRequest. +func (in *UpdateTrainJobStatusRequest) DeepCopy() *UpdateTrainJobStatusRequest { + if in == nil { + return nil + } + out := new(UpdateTrainJobStatusRequest) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *VolcanoPodGroupPolicySource) DeepCopyInto(out *VolcanoPodGroupPolicySource) { *out = *in diff --git a/pkg/apis/trainer/v1alpha1/zz_generated.openapi.go b/pkg/apis/trainer/v1alpha1/zz_generated.openapi.go index 5dd559f35f..45818f2c24 100644 --- a/pkg/apis/trainer/v1alpha1/zz_generated.openapi.go +++ b/pkg/apis/trainer/v1alpha1/zz_generated.openapi.go @@ -50,6 +50,7 @@ func GetOpenAPIDefinitions(ref common.ReferenceCallback) map[string]common.OpenA "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.MLPolicy": schema_pkg_apis_trainer_v1alpha1_MLPolicy(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.MLPolicySource": schema_pkg_apis_trainer_v1alpha1_MLPolicySource(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.MPIMLPolicySource": schema_pkg_apis_trainer_v1alpha1_MPIMLPolicySource(ref), + "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.Metric": schema_pkg_apis_trainer_v1alpha1_Metric(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.ModelInitializer": schema_pkg_apis_trainer_v1alpha1_ModelInitializer(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.PodGroupPolicy": schema_pkg_apis_trainer_v1alpha1_PodGroupPolicy(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.PodGroupPolicySource": schema_pkg_apis_trainer_v1alpha1_PodGroupPolicySource(ref), @@ -64,10 +65,12 @@ func GetOpenAPIDefinitions(ref common.ReferenceCallback) map[string]common.OpenA "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainJobSpec": schema_pkg_apis_trainer_v1alpha1_TrainJobSpec(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainJobStatus": schema_pkg_apis_trainer_v1alpha1_TrainJobStatus(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.Trainer": schema_pkg_apis_trainer_v1alpha1_Trainer(ref), + "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainerStatus": schema_pkg_apis_trainer_v1alpha1_TrainerStatus(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainingRuntime": schema_pkg_apis_trainer_v1alpha1_TrainingRuntime(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainingRuntimeList": schema_pkg_apis_trainer_v1alpha1_TrainingRuntimeList(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainingRuntimeSpec": schema_pkg_apis_trainer_v1alpha1_TrainingRuntimeSpec(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainingRuntimeSpecPatch": schema_pkg_apis_trainer_v1alpha1_TrainingRuntimeSpecPatch(ref), + "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.UpdateTrainJobStatusRequest": schema_pkg_apis_trainer_v1alpha1_UpdateTrainJobStatusRequest(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.VolcanoPodGroupPolicySource": schema_pkg_apis_trainer_v1alpha1_VolcanoPodGroupPolicySource(ref), "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.XGBoostMLPolicySource": schema_pkg_apis_trainer_v1alpha1_XGBoostMLPolicySource(ref), v2.ContainerResourceMetricSource{}.OpenAPIModelName(): schema_k8sio_api_autoscaling_v2_ContainerResourceMetricSource(ref), @@ -1065,6 +1068,33 @@ func schema_pkg_apis_trainer_v1alpha1_MPIMLPolicySource(ref common.ReferenceCall } } +func schema_pkg_apis_trainer_v1alpha1_Metric(ref common.ReferenceCallback) common.OpenAPIDefinition { + return common.OpenAPIDefinition{ + Schema: spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"object"}, + Properties: map[string]spec.Schema{ + "name": { + SchemaProps: spec.SchemaProps{ + Description: "name is a user-defined label for the metric, e.g. \"loss\", \"eval_accuracy\".", + Type: []string{"string"}, + Format: "", + }, + }, + "value": { + SchemaProps: spec.SchemaProps{ + Description: "value of the metric. Values must be serialized as a string.", + Type: []string{"string"}, + Format: "", + }, + }, + }, + Required: []string{"name", "value"}, + }, + }, + } +} + func schema_pkg_apis_trainer_v1alpha1_ModelInitializer(ref common.ReferenceCallback) common.OpenAPIDefinition { return common.OpenAPIDefinition{ Schema: spec.Schema{ @@ -1715,11 +1745,17 @@ func schema_pkg_apis_trainer_v1alpha1_TrainJobStatus(ref common.ReferenceCallbac }, }, }, + "trainerStatus": { + SchemaProps: spec.SchemaProps{ + Description: "trainerStatus contains the latest observed runtime status of the Trainer step of the TrainJob. It reflects progress, remaining time, metrics, and the last update timestamp.\n\nThis field is nil if the TrainJob does not report trainer-level status, or if no status has been observed yet (for example, immediately after the TrainJob is created).\n\nThis is an alpha feature and requires enabling the TrainJobStatus feature gate.", + Ref: ref("github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainerStatus"), + }, + }, }, }, }, Dependencies: []string{ - "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.JobStatus", metav1.Condition{}.OpenAPIModelName()}, + "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.JobStatus", "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainerStatus", metav1.Condition{}.OpenAPIModelName()}, } } @@ -1827,6 +1863,60 @@ func schema_pkg_apis_trainer_v1alpha1_Trainer(ref common.ReferenceCallback) comm } } +func schema_pkg_apis_trainer_v1alpha1_TrainerStatus(ref common.ReferenceCallback) common.OpenAPIDefinition { + return common.OpenAPIDefinition{ + Schema: spec.Schema{ + SchemaProps: spec.SchemaProps{ + Description: "TrainerStatus represents the latest known runtime status of the Trainer step of the TrainJob.", + Type: []string{"object"}, + Properties: map[string]spec.Schema{ + "progressPercentage": { + SchemaProps: spec.SchemaProps{ + Description: "progressPercentage gives an estimate of how complete the TrainJob is as a percentage. The value will be between 0 and 100, or empty if unknown.", + Type: []string{"integer"}, + Format: "int32", + }, + }, + "estimatedRemainingSeconds": { + SchemaProps: spec.SchemaProps{ + Description: "estimatedRemainingSeconds gives the estimated remaining training time in seconds before the train job is completed. The value will be empty if it is unknown.", + Type: []string{"integer"}, + Format: "int32", + }, + }, + "metrics": { + VendorExtensible: spec.VendorExtensible{ + Extensions: spec.Extensions{ + "x-kubernetes-list-type": "atomic", + }, + }, + SchemaProps: spec.SchemaProps{ + Description: "metrics contains the current metrics for the model.", + Type: []string{"array"}, + Items: &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Default: map[string]interface{}{}, + Ref: ref("github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.Metric"), + }, + }, + }, + }, + }, + "lastUpdatedTime": { + SchemaProps: spec.SchemaProps{ + Description: "lastUpdatedTime is the timestamp when the runtime status was observed.", + Ref: ref(metav1.Time{}.OpenAPIModelName()), + }, + }, + }, + }, + }, + Dependencies: []string{ + "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.Metric", metav1.Time{}.OpenAPIModelName()}, + } +} + func schema_pkg_apis_trainer_v1alpha1_TrainingRuntime(ref common.ReferenceCallback) common.OpenAPIDefinition { return common.OpenAPIDefinition{ Schema: spec.Schema{ @@ -1976,6 +2066,27 @@ func schema_pkg_apis_trainer_v1alpha1_TrainingRuntimeSpecPatch(ref common.Refere } } +func schema_pkg_apis_trainer_v1alpha1_UpdateTrainJobStatusRequest(ref common.ReferenceCallback) common.OpenAPIDefinition { + return common.OpenAPIDefinition{ + Schema: spec.Schema{ + SchemaProps: spec.SchemaProps{ + Description: "UpdateTrainJobStatusRequest contains the current runtime status (e.g. progress and metrics) for the different stages of the TrainJob.", + Type: []string{"object"}, + Properties: map[string]spec.Schema{ + "trainerStatus": { + SchemaProps: spec.SchemaProps{ + Description: "trainerStatus contains the latest observed runtime status of the Trainer step of the TrainJob. It reflects progress, remaining time, metrics, and the last update timestamp.\n\nThis field is nil if the TrainJob does not report trainer-level status, or if no status has been observed yet (for example, immediately after the TrainJob is created).\n\nThis is an alpha feature and requires enabling the TrainJobStatus feature gate.", + Ref: ref("github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainerStatus"), + }, + }, + }, + }, + }, + Dependencies: []string{ + "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1.TrainerStatus"}, + } +} + func schema_pkg_apis_trainer_v1alpha1_VolcanoPodGroupPolicySource(ref common.ReferenceCallback) common.OpenAPIDefinition { return common.OpenAPIDefinition{ Schema: spec.Schema{ diff --git a/pkg/client/applyconfiguration/trainer/v1alpha1/metric.go b/pkg/client/applyconfiguration/trainer/v1alpha1/metric.go new file mode 100644 index 0000000000..330e56ebe1 --- /dev/null +++ b/pkg/client/applyconfiguration/trainer/v1alpha1/metric.go @@ -0,0 +1,48 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by applyconfiguration-gen. DO NOT EDIT. + +package v1alpha1 + +// MetricApplyConfiguration represents a declarative configuration of the Metric type for use +// with apply. +type MetricApplyConfiguration struct { + // name is a user-defined label for the metric, e.g. "loss", "eval_accuracy". + Name *string `json:"name,omitempty"` + // value of the metric. Values must be serialized as a string. + Value *string `json:"value,omitempty"` +} + +// MetricApplyConfiguration constructs a declarative configuration of the Metric type for use with +// apply. +func Metric() *MetricApplyConfiguration { + return &MetricApplyConfiguration{} +} + +// WithName sets the Name field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the Name field is set to the value of the last call. +func (b *MetricApplyConfiguration) WithName(value string) *MetricApplyConfiguration { + b.Name = &value + return b +} + +// WithValue sets the Value field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the Value field is set to the value of the last call. +func (b *MetricApplyConfiguration) WithValue(value string) *MetricApplyConfiguration { + b.Value = &value + return b +} diff --git a/pkg/client/applyconfiguration/trainer/v1alpha1/trainerstatus.go b/pkg/client/applyconfiguration/trainer/v1alpha1/trainerstatus.go new file mode 100644 index 0000000000..9a0327015b --- /dev/null +++ b/pkg/client/applyconfiguration/trainer/v1alpha1/trainerstatus.go @@ -0,0 +1,82 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by applyconfiguration-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// TrainerStatusApplyConfiguration represents a declarative configuration of the TrainerStatus type for use +// with apply. +// +// TrainerStatus represents the latest known runtime status of the Trainer step of the TrainJob. +type TrainerStatusApplyConfiguration struct { + // progressPercentage gives an estimate of how complete the TrainJob is as a percentage. + // The value will be between 0 and 100, or empty if unknown. + ProgressPercentage *int32 `json:"progressPercentage,omitempty"` + // estimatedRemainingSeconds gives the estimated remaining training time in seconds + // before the train job is completed. + // The value will be empty if it is unknown. + EstimatedRemainingSeconds *int32 `json:"estimatedRemainingSeconds,omitempty"` + // metrics contains the current metrics for the model. + Metrics []MetricApplyConfiguration `json:"metrics,omitempty"` + // lastUpdatedTime is the timestamp when the runtime status was observed. + LastUpdatedTime *v1.Time `json:"lastUpdatedTime,omitempty"` +} + +// TrainerStatusApplyConfiguration constructs a declarative configuration of the TrainerStatus type for use with +// apply. +func TrainerStatus() *TrainerStatusApplyConfiguration { + return &TrainerStatusApplyConfiguration{} +} + +// WithProgressPercentage sets the ProgressPercentage field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the ProgressPercentage field is set to the value of the last call. +func (b *TrainerStatusApplyConfiguration) WithProgressPercentage(value int32) *TrainerStatusApplyConfiguration { + b.ProgressPercentage = &value + return b +} + +// WithEstimatedRemainingSeconds sets the EstimatedRemainingSeconds field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the EstimatedRemainingSeconds field is set to the value of the last call. +func (b *TrainerStatusApplyConfiguration) WithEstimatedRemainingSeconds(value int32) *TrainerStatusApplyConfiguration { + b.EstimatedRemainingSeconds = &value + return b +} + +// WithMetrics adds the given value to the Metrics field in the declarative configuration +// and returns the receiver, so that objects can be build by chaining "With" function invocations. +// If called multiple times, values provided by each call will be appended to the Metrics field. +func (b *TrainerStatusApplyConfiguration) WithMetrics(values ...*MetricApplyConfiguration) *TrainerStatusApplyConfiguration { + for i := range values { + if values[i] == nil { + panic("nil value passed to WithMetrics") + } + b.Metrics = append(b.Metrics, *values[i]) + } + return b +} + +// WithLastUpdatedTime sets the LastUpdatedTime field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the LastUpdatedTime field is set to the value of the last call. +func (b *TrainerStatusApplyConfiguration) WithLastUpdatedTime(value v1.Time) *TrainerStatusApplyConfiguration { + b.LastUpdatedTime = &value + return b +} diff --git a/pkg/client/applyconfiguration/trainer/v1alpha1/trainjobstatus.go b/pkg/client/applyconfiguration/trainer/v1alpha1/trainjobstatus.go index 4a1b92e775..734a9a91bf 100644 --- a/pkg/client/applyconfiguration/trainer/v1alpha1/trainjobstatus.go +++ b/pkg/client/applyconfiguration/trainer/v1alpha1/trainjobstatus.go @@ -29,6 +29,16 @@ type TrainJobStatusApplyConfiguration struct { Conditions []v1.ConditionApplyConfiguration `json:"conditions,omitempty"` // jobsStatus tracks the child Jobs in TrainJob. JobsStatus []JobStatusApplyConfiguration `json:"jobsStatus,omitempty"` + // trainerStatus contains the latest observed runtime status of the + // Trainer step of the TrainJob. It reflects progress, remaining time, + // metrics, and the last update timestamp. + // + // This field is nil if the TrainJob does not report trainer-level + // status, or if no status has been observed yet (for example, + // immediately after the TrainJob is created). + // + // This is an alpha feature and requires enabling the TrainJobStatus feature gate. + TrainerStatus *TrainerStatusApplyConfiguration `json:"trainerStatus,omitempty"` } // TrainJobStatusApplyConfiguration constructs a declarative configuration of the TrainJobStatus type for use with @@ -62,3 +72,11 @@ func (b *TrainJobStatusApplyConfiguration) WithJobsStatus(values ...*JobStatusAp } return b } + +// WithTrainerStatus sets the TrainerStatus field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the TrainerStatus field is set to the value of the last call. +func (b *TrainJobStatusApplyConfiguration) WithTrainerStatus(value *TrainerStatusApplyConfiguration) *TrainJobStatusApplyConfiguration { + b.TrainerStatus = value + return b +} diff --git a/pkg/client/applyconfiguration/utils.go b/pkg/client/applyconfiguration/utils.go index 8b474d3639..3b80a9978d 100644 --- a/pkg/client/applyconfiguration/utils.go +++ b/pkg/client/applyconfiguration/utils.go @@ -54,6 +54,8 @@ func ForKind(kind schema.GroupVersionKind) interface{} { return &trainerv1alpha1.JobStatusApplyConfiguration{} case v1alpha1.SchemeGroupVersion.WithKind("JobTemplatePatch"): return &trainerv1alpha1.JobTemplatePatchApplyConfiguration{} + case v1alpha1.SchemeGroupVersion.WithKind("Metric"): + return &trainerv1alpha1.MetricApplyConfiguration{} case v1alpha1.SchemeGroupVersion.WithKind("MLPolicy"): return &trainerv1alpha1.MLPolicyApplyConfiguration{} case v1alpha1.SchemeGroupVersion.WithKind("MLPolicySource"): @@ -78,6 +80,8 @@ func ForKind(kind schema.GroupVersionKind) interface{} { return &trainerv1alpha1.RuntimeRefApplyConfiguration{} case v1alpha1.SchemeGroupVersion.WithKind("Trainer"): return &trainerv1alpha1.TrainerApplyConfiguration{} + case v1alpha1.SchemeGroupVersion.WithKind("TrainerStatus"): + return &trainerv1alpha1.TrainerStatusApplyConfiguration{} case v1alpha1.SchemeGroupVersion.WithKind("TrainingRuntime"): return &trainerv1alpha1.TrainingRuntimeApplyConfiguration{} case v1alpha1.SchemeGroupVersion.WithKind("TrainingRuntimeSpec"): diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index edc5f61359..b043230ae2 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -155,6 +155,18 @@ health: t.Fatal(err) } + statusServerConfig := filepath.Join(tmpDir, "statusServer.yaml") + if err := os.WriteFile(statusServerConfig, []byte(` +apiVersion: config.trainer.kubeflow.org/v1alpha1 +kind: Configuration +statusServer: + port: 12443 + qps: 1 + burst: 2 +`), os.FileMode(0600)); err != nil { + t.Fatal(err) + } + completeConfig := filepath.Join(tmpDir, "complete.yaml") if err := os.WriteFile(completeConfig, []byte(` apiVersion: config.trainer.kubeflow.org/v1alpha1 @@ -188,6 +200,10 @@ certManagement: clientConnection: qps: 50 burst: 100 +statusServer: + port: 12443 + qps: 1 + burst: 2 `), os.FileMode(0600)); err != nil { t.Fatal(err) } @@ -264,6 +280,12 @@ this is not: valid: yaml: content Burst: ptr.To[int32](100), } + defaultStatusServer := &configapi.StatusServer{ + Port: ptr.To[int32](10443), + QPS: ptr.To[float32](5), + Burst: ptr.To[int32](10), + } + defaultWebhook := configapi.ControllerWebhook{ Port: ptr.To[int32](9443), } @@ -317,6 +339,7 @@ this is not: valid: yaml: content Health: defaultHealth, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, }, wantOptions: defaultOptions, }, @@ -330,6 +353,7 @@ this is not: valid: yaml: content Health: defaultHealth, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, }, wantOptions: defaultOptions, }, @@ -370,6 +394,7 @@ this is not: valid: yaml: content QPS: ptr.To[float32](100), Burst: ptr.To[int32](200), }, + StatusServer: defaultStatusServer, }, wantOptions: ctrl.Options{ HealthProbeBindAddress: ":8082", @@ -394,6 +419,7 @@ this is not: valid: yaml: content Health: defaultHealth, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, LeaderElection: &componentconfigv1alpha1.LeaderElectionConfiguration{ LeaderElect: ptr.To(true), ResourceName: "trainer-leader", @@ -434,6 +460,7 @@ this is not: valid: yaml: content Health: defaultHealth, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, Controller: &configapi.ControllerConfigurationSpec{ GroupKindConcurrency: map[string]int32{ "TrainJob.trainer.kubeflow.org": 10, @@ -471,6 +498,7 @@ this is not: valid: yaml: content Metrics: defaultMetrics, Health: defaultHealth, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, CertManagement: &configapi.CertManagement{ Enable: ptr.To(true), WebhookServiceName: "custom-webhook-service", @@ -488,6 +516,7 @@ this is not: valid: yaml: content Metrics: defaultMetrics, Health: defaultHealth, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, CertManagement: &configapi.CertManagement{ Enable: ptr.To(false), WebhookServiceName: "kubeflow-trainer-controller-manager", @@ -509,6 +538,7 @@ this is not: valid: yaml: content Health: defaultHealth, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, }, wantOptions: ctrl.Options{ HealthProbeBindAddress: ":8081", @@ -536,6 +566,7 @@ this is not: valid: yaml: content Health: defaultHealth, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, }, wantOptions: ctrl.Options{ HealthProbeBindAddress: ":8081", @@ -551,6 +582,35 @@ this is not: valid: yaml: content }, }, }, + { + name: "status server config", + configFile: statusServerConfig, + wantConfiguration: configapi.Configuration{ + TypeMeta: typeMeta, + Webhook: defaultWebhook, + Metrics: defaultMetrics, + Health: defaultHealth, + CertManagement: defaultCertManagement, + ClientConnection: defaultClientConnection, + StatusServer: &configapi.StatusServer{ + Port: ptr.To[int32](12443), + QPS: ptr.To[float32](1), + Burst: ptr.To[int32](2), + }, + }, + wantOptions: ctrl.Options{ + HealthProbeBindAddress: ":8081", + Metrics: metricsserver.Options{ + BindAddress: ":8443", + SecureServing: true, + }, + WebhookServer: &webhook.DefaultServer{ + Options: webhook.Options{ + Port: 9443, + }, + }, + }, + }, { name: "health config with custom endpoints", configFile: healthConfig, @@ -565,6 +625,7 @@ this is not: valid: yaml: content }, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: defaultStatusServer, }, wantOptions: ctrl.Options{ HealthProbeBindAddress: ":9090", @@ -607,6 +668,11 @@ this is not: valid: yaml: content }, CertManagement: defaultCertManagement, ClientConnection: defaultClientConnection, + StatusServer: &configapi.StatusServer{ + Port: ptr.To[int32](12443), + QPS: ptr.To[float32](1), + Burst: ptr.To[int32](2), + }, }, wantOptions: ctrl.Options{ HealthProbeBindAddress: ":8081", diff --git a/pkg/config/validation.go b/pkg/config/validation.go index b0e82fe610..13ee95d31d 100644 --- a/pkg/config/validation.go +++ b/pkg/config/validation.go @@ -50,5 +50,18 @@ func validate(cfg *configapi.Configuration) field.ErrorList { } } + // Validate status server config + if cfg.StatusServer != nil { + if cfg.StatusServer.Port != nil && (*cfg.StatusServer.Port < 1 || *cfg.StatusServer.Port > 65535) { + allErrs = append(allErrs, field.Invalid(field.NewPath("statusServer", "port"), *cfg.StatusServer.Port, "must be between 1 and 65535")) + } + if cfg.StatusServer.QPS != nil && *cfg.StatusServer.QPS < 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("statusServer", "qps"), *cfg.StatusServer.QPS, "must be greater than or equal to 0")) + } + if cfg.StatusServer.Burst != nil && *cfg.StatusServer.Burst < 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("statusServer", "burst"), *cfg.StatusServer.Burst, "must be greater than or equal to 0")) + } + } + return allErrs } diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index ed222e41a0..4c5e3529a5 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -141,7 +141,7 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c } if deadlineResult, deadlineErr := r.reconcileDeadline(ctx, &trainJob); deadlineErr != nil || deadlineResult.RequeueAfter > 0 { - if !equality.Semantic.DeepEqual(&trainJob.Status, prevTrainJob.Status) { + if !equality.Semantic.DeepEqual(&trainJob.Status, &prevTrainJob.Status) { return deadlineResult, errors.Join(err, r.client.Status().Patch(ctx, &trainJob, client.MergeFrom(prevTrainJob))) } return deadlineResult, errors.Join(err, deadlineErr) diff --git a/pkg/features/features.go b/pkg/features/features.go index 2db2e560d6..adba63fc03 100644 --- a/pkg/features/features.go +++ b/pkg/features/features.go @@ -30,13 +30,23 @@ func init() { runtime.Must(utilfeature.DefaultMutableFeatureGate.Add(defaultFeatureGates)) } +const ( + // owner: robert-bell + // kep: https://github.com/kubeflow/trainer/blob/main/docs/proposals/2779-trainjob-progress/README.md + // + // Enables status server allowing TrainJob pods to update their status. + TrainJobStatus featuregate.Feature = "TrainJobStatus" +) + // defaultFeatureGates consists of all known Trainer-specific feature keys. // To add a new feature, define a key for it above and add it here. The features will be // available throughout Trainer binaries. // // Entries are separated from each other with blank lines to avoid sweeping gofmt changes // when adding or removing one entry. -var defaultFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{} +var defaultFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{ + TrainJobStatus: {Default: false, PreRelease: featuregate.Alpha}, +} func SetFeatureGateDuringTest(tb testing.TB, f featuregate.Feature, value bool) { featuregatetesting.SetFeatureGateDuringTest(tb, utilfeature.DefaultFeatureGate, f, value) diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index d4a2ac8ab6..02548977b7 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -27,6 +27,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/constants" "github.com/kubeflow/trainer/v2/pkg/runtime" @@ -48,7 +49,7 @@ var ClusterTrainingRuntimeGroupKind = schema.GroupKind{ Kind: trainer.ClusterTrainingRuntimeKind, }.String() -func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndexer) (runtime.Runtime, error) { +func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndexer, *configapi.Configuration) (runtime.Runtime, error) { return &ClusterTrainingRuntime{ TrainingRuntime: trainingRuntimeFactory, }, nil diff --git a/pkg/runtime/core/clustertrainingruntime_test.go b/pkg/runtime/core/clustertrainingruntime_test.go index c20f3d3712..fac63a32c2 100644 --- a/pkg/runtime/core/clustertrainingruntime_test.go +++ b/pkg/runtime/core/clustertrainingruntime_test.go @@ -151,7 +151,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { } c := clientBuilder.Build() - trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder)) + trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } @@ -161,7 +161,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { t.Fatal("Failed type assertion from Runtime interface to TrainingRuntime") } - clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder)) + clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } diff --git a/pkg/runtime/core/core.go b/pkg/runtime/core/core.go index aad5c51e2e..9d46d96c0a 100644 --- a/pkg/runtime/core/core.go +++ b/pkg/runtime/core/core.go @@ -22,13 +22,14 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/runtime" ) // +kubebuilder:rbac:groups=trainer.kubeflow.org,resources=trainingruntimes,verbs=get;list;watch // +kubebuilder:rbac:groups=trainer.kubeflow.org,resources=clustertrainingruntimes,verbs=get;list;watch -func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) (map[string]runtime.Runtime, error) { +func New(ctx context.Context, client client.Client, indexer client.FieldIndexer, cfg *configapi.Configuration) (map[string]runtime.Runtime, error) { registry := NewRuntimeRegistry() runtimes := make(map[string]runtime.Runtime, len(registry)) for name, registrar := range registry { @@ -36,7 +37,7 @@ func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) depRegistrar, depExist := registry[dep] _, depRegistered := runtimes[dep] if depExist && !depRegistered { - r, err := depRegistrar.factory(ctx, client, indexer) + r, err := depRegistrar.factory(ctx, client, indexer, cfg) if err != nil { return nil, fmt.Errorf("initializing runtime %q on which %q depends: %w", dep, name, err) } @@ -44,7 +45,7 @@ func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) } } if _, ok := runtimes[name]; !ok { - r, err := registrar.factory(ctx, client, indexer) + r, err := registrar.factory(ctx, client, indexer, cfg) if err != nil { return nil, fmt.Errorf("initializing runtime %q: %w", name, err) } diff --git a/pkg/runtime/core/registry.go b/pkg/runtime/core/registry.go index 5831885180..de500a70cc 100644 --- a/pkg/runtime/core/registry.go +++ b/pkg/runtime/core/registry.go @@ -21,12 +21,13 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/runtime" ) type Registry map[string]RuntimeRegistrar type RuntimeRegistrar struct { - factory func(ctx context.Context, client client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) + factory func(ctx context.Context, client client.Client, indexer client.FieldIndexer, cfg *configapi.Configuration) (runtime.Runtime, error) dependencies []string } diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index 42cd88cc5c..807d9847e6 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -35,6 +35,7 @@ import ( jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -62,14 +63,14 @@ var _ runtime.Runtime = (*TrainingRuntime)(nil) var trainingRuntimeFactory *TrainingRuntime -func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) { +func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer, cfg *configapi.Configuration) (runtime.Runtime, error) { if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil { return nil, fmt.Errorf("setting index on TrainingRuntime for TrainJob: %w", err) } if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobClusterRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil { return nil, fmt.Errorf("setting index on ClusterTrainingRuntime for TrainJob: %w", err) } - fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer) + fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer, cfg) if err != nil { return nil, err } diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index 6848e17f8a..e933ec844f 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -2107,7 +2107,7 @@ test-job-node-0-1.test-job slots=8 } c := clientBuilder.Build() - trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder)) + trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index c878e54508..cd7394a8f2 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -25,6 +25,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/runtime" "github.com/kubeflow/trainer/v2/pkg/runtime/framework" @@ -46,7 +47,7 @@ type Framework struct { trainJobStatusPlugin framework.TrainJobStatusPlugin } -func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { +func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer, cfg *configapi.Configuration) (*Framework, error) { f := &Framework{ registry: r, } @@ -56,7 +57,7 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl } for name, factory := range r { - plugin, err := factory(ctx, c, indexer) + plugin, err := factory(ctx, c, indexer, cfg) if err != nil { return nil, err } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 6b4e82bb5e..b4c5a17fab 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -42,6 +42,7 @@ import ( schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -176,7 +177,7 @@ func TestNew(t *testing.T) { }) } clientBuilder := testingutil.NewClientBuilder() - fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder), nil) if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected errors (-want,+got):\n%s", diff) } @@ -336,7 +337,7 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) { t.Cleanup(cancel) clientBuilder := testingutil.NewClientBuilder() - fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } @@ -431,7 +432,7 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) { t.Cleanup(cancel) clientBuilder := testingutil.NewClientBuilder() - fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } @@ -471,7 +472,7 @@ func TestRunCustomValidationPlugins(t *testing.T) { t.Cleanup(cancel) clientBuildr := testingutil.NewClientBuilder() - fwk, err := New(ctx, clientBuildr.Build(), tc.registry, testingutil.AsIndex(clientBuildr)) + fwk, err := New(ctx, clientBuildr.Build(), tc.registry, testingutil.AsIndex(clientBuildr), nil) if err != nil { t.Fatal(err) } @@ -2203,7 +2204,7 @@ test-job-node-0-1.test-job slots=1 clientBuilder := testingutil.NewClientBuilder() c := clientBuilder.Build() - fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } @@ -2292,7 +2293,7 @@ func TestWatchExtensionPlugins(t *testing.T) { t.Cleanup(cancel) clientBuilder := testingutil.NewClientBuilder() - fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } @@ -2308,7 +2309,7 @@ type fakeTrainJobStatusPlugin struct{} var _ framework.TrainJobStatusPlugin = (*fakeTrainJobStatusPlugin)(nil) -func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { +func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer, *configapi.Configuration) (framework.Plugin, error) { return &fakeTrainJobStatusPlugin{}, nil } @@ -2526,7 +2527,7 @@ func TestTrainJobStatusPlugins(t *testing.T) { } c := clientBuilder.Build() - fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder), nil) if err != nil { if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) @@ -2642,7 +2643,7 @@ func TestPodNetworkPlugins(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cliBuilder := testingutil.NewClientBuilder() - fwk, err := New(ctx, cliBuilder.Build(), tc.registry, testingutil.AsIndex(cliBuilder)) + fwk, err := New(ctx, cliBuilder.Build(), tc.registry, testingutil.AsIndex(cliBuilder), nil) if err != nil { t.Fatal(err) } diff --git a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go index ec725787d8..e91ea14525 100644 --- a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go @@ -42,6 +42,7 @@ import ( schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" schedulerpluginsv1alpha1ac "sigs.k8s.io/scheduler-plugins/pkg/generated/applyconfiguration/scheduling/v1alpha1" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/runtime" "github.com/kubeflow/trainer/v2/pkg/runtime/framework" @@ -65,7 +66,7 @@ const Name = "CoScheduling" // +kubebuilder:rbac:groups=node.k8s.io,resources=runtimeclasses,verbs=get;list;watch // +kubebuilder:rbac:groups="",resources=limitranges,verbs=get;list;watch -func New(_ context.Context, client client.Client, _ client.FieldIndexer) (framework.Plugin, error) { +func New(_ context.Context, client client.Client, _ client.FieldIndexer, _ *configapi.Configuration) (framework.Plugin, error) { return &CoScheduling{ client: client, restMapper: client.RESTMapper(), diff --git a/pkg/runtime/framework/plugins/coscheduling/coscheduling_test.go b/pkg/runtime/framework/plugins/coscheduling/coscheduling_test.go index 5302e7e551..f6fb659942 100644 --- a/pkg/runtime/framework/plugins/coscheduling/coscheduling_test.go +++ b/pkg/runtime/framework/plugins/coscheduling/coscheduling_test.go @@ -554,7 +554,7 @@ func TestCoScheduling(t *testing.T) { }, }) cli := clientBuilder.Build() - plugin, err := New(ctx, cli, utiltesting.AsIndex(clientBuilder)) + plugin, err := New(ctx, cli, utiltesting.AsIndex(clientBuilder), nil) if err != nil { t.Fatalf("Failed to create plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/flux/flux.go b/pkg/runtime/framework/plugins/flux/flux.go index fd2136e672..da52152dae 100644 --- a/pkg/runtime/framework/plugins/flux/flux.go +++ b/pkg/runtime/framework/plugins/flux/flux.go @@ -37,6 +37,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" jobsetapply "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -89,7 +90,7 @@ type Flux struct { scheme *apiruntime.Scheme } -func New(_ context.Context, client client.Client, _ client.FieldIndexer) (framework.Plugin, error) { +func New(_ context.Context, client client.Client, _ client.FieldIndexer, _ *configapi.Configuration) (framework.Plugin, error) { return &Flux{ client: client, scheme: client.Scheme(), diff --git a/pkg/runtime/framework/plugins/flux/flux_test.go b/pkg/runtime/framework/plugins/flux/flux_test.go index 376ba88322..02aec04d30 100644 --- a/pkg/runtime/framework/plugins/flux/flux_test.go +++ b/pkg/runtime/framework/plugins/flux/flux_test.go @@ -106,7 +106,7 @@ func TestFlux(t *testing.T) { t.Run(name, func(t *testing.T) { _, ctx := ktesting.NewTestContext(t) cli := utiltesting.NewClientBuilder().Build() - p, _ := New(ctx, cli, nil) + p, _ := New(ctx, cli, nil, nil) err := p.(framework.EnforceMLPolicyPlugin).EnforceMLPolicy(tc.info, tc.trainJob) if err != nil { @@ -176,7 +176,7 @@ func TestValidate(t *testing.T) { for name, tc := range cases { t.Run(name, func(t *testing.T) { _, ctx := ktesting.NewTestContext(t) - p, _ := New(ctx, utiltesting.NewClientBuilder().Build(), nil) + p, _ := New(ctx, utiltesting.NewClientBuilder().Build(), nil, nil) _, errs := p.(framework.CustomValidationPlugin).Validate(ctx, tc.info, nil, tc.newObj) if len(errs) > 0 { diff --git a/pkg/runtime/framework/plugins/jax/jax.go b/pkg/runtime/framework/plugins/jax/jax.go index 64b53cf684..4c43506e58 100644 --- a/pkg/runtime/framework/plugins/jax/jax.go +++ b/pkg/runtime/framework/plugins/jax/jax.go @@ -26,6 +26,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -40,7 +41,7 @@ var _ framework.CustomValidationPlugin = (*Jax)(nil) const Name = "JAX" -func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { +func New(context.Context, client.Client, client.FieldIndexer, *configapi.Configuration) (framework.Plugin, error) { return &Jax{}, nil } diff --git a/pkg/runtime/framework/plugins/jax/jax_test.go b/pkg/runtime/framework/plugins/jax/jax_test.go index 030d3ff839..4654d10b83 100644 --- a/pkg/runtime/framework/plugins/jax/jax_test.go +++ b/pkg/runtime/framework/plugins/jax/jax_test.go @@ -337,7 +337,7 @@ func TestJAXEnforceMLPolicy(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cliBuilder := utiltesting.NewClientBuilder() - p, err := New(ctx, cliBuilder.Build(), nil) + p, err := New(ctx, cliBuilder.Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize JAX plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index ebe9fc8624..a629c6fcc8 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -42,6 +42,7 @@ import ( jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -72,7 +73,7 @@ const Name = constants.JobSetKind // +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets,verbs=create;delete;get;list;watch;update;patch -func New(ctx context.Context, client client.Client, _ client.FieldIndexer) (framework.Plugin, error) { +func New(ctx context.Context, client client.Client, _ client.FieldIndexer, _ *configapi.Configuration) (framework.Plugin, error) { return &JobSet{ client: client, restMapper: client.RESTMapper(), diff --git a/pkg/runtime/framework/plugins/jobset/jobset_test.go b/pkg/runtime/framework/plugins/jobset/jobset_test.go index 443db5e2bb..2e8e5d3968 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset_test.go +++ b/pkg/runtime/framework/plugins/jobset/jobset_test.go @@ -315,7 +315,7 @@ func TestJobSet(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cli := utiltesting.NewClientBuilder().Build() - p, err := New(ctx, cli, nil) + p, err := New(ctx, cli, nil, nil) if err != nil { t.Fatalf("Failed to initialize JobSet plugin: %v", err) } @@ -1571,7 +1571,7 @@ func TestValidate(t *testing.T) { } cli := clientBuilder.Build() - p, err := New(ctx, cli, nil) + p, err := New(ctx, cli, nil, nil) if err != nil { t.Fatalf("Failed to initialize JobSet plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index 02e3436d20..ff36bab316 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -40,6 +40,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -64,7 +65,7 @@ const Name = "MPI" // +kubebuilder:rbac:groups="",resources=secrets,verbs=create;get;list;watch;update;patch // +kubebuilder:rbac:groups="",resources=configmaps,verbs=create;get;list;watch;update;patch -func New(_ context.Context, client client.Client, _ client.FieldIndexer) (framework.Plugin, error) { +func New(_ context.Context, client client.Client, _ client.FieldIndexer, _ *configapi.Configuration) (framework.Plugin, error) { return &MPI{ client: client, scheme: client.Scheme(), diff --git a/pkg/runtime/framework/plugins/mpi/mpi_test.go b/pkg/runtime/framework/plugins/mpi/mpi_test.go index 06a857aa59..db5628a5b8 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi_test.go +++ b/pkg/runtime/framework/plugins/mpi/mpi_test.go @@ -847,7 +847,7 @@ trainJob-node-1-0.trainJob slots=1 }) cli := b.Build() - p, err := New(ctx, cli, nil) + p, err := New(ctx, cli, nil, nil) if err != nil { t.Fatalf("Failed to initialize MPI plugin: %v", err) } @@ -994,7 +994,7 @@ func TestValidate(t *testing.T) { var cancel func() ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) - p, err := New(ctx, utiltesting.NewClientBuilder().Build(), nil) + p, err := New(ctx, utiltesting.NewClientBuilder().Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize MPI plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/plainml/plainml.go b/pkg/runtime/framework/plugins/plainml/plainml.go index 8bd82fe746..9859861e0f 100644 --- a/pkg/runtime/framework/plugins/plainml/plainml.go +++ b/pkg/runtime/framework/plugins/plainml/plainml.go @@ -21,6 +21,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -34,7 +35,7 @@ type PlainML struct{} const Name = "PlainML" -func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { +func New(context.Context, client.Client, client.FieldIndexer, *configapi.Configuration) (framework.Plugin, error) { return &PlainML{}, nil } diff --git a/pkg/runtime/framework/plugins/plainml/plainml_test.go b/pkg/runtime/framework/plugins/plainml/plainml_test.go index 9e26edb766..0261b0aa8a 100644 --- a/pkg/runtime/framework/plugins/plainml/plainml_test.go +++ b/pkg/runtime/framework/plugins/plainml/plainml_test.go @@ -190,7 +190,7 @@ func TestPlainML(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cliBuilder := utiltesting.NewClientBuilder() - p, err := New(ctx, cliBuilder.Build(), nil) + p, err := New(ctx, cliBuilder.Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize PlainML plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/registry.go b/pkg/runtime/framework/plugins/registry.go index 24ad9a9517..c53fba8b53 100644 --- a/pkg/runtime/framework/plugins/registry.go +++ b/pkg/runtime/framework/plugins/registry.go @@ -21,6 +21,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/features" "github.com/kubeflow/trainer/v2/pkg/runtime/framework" "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/coscheduling" "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/flux" @@ -29,14 +31,15 @@ import ( "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/mpi" "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/plainml" "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/torch" + "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/trainjobstatus" "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/volcano" "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/xgboost" ) -type Registry map[string]func(ctx context.Context, client client.Client, indexer client.FieldIndexer) (framework.Plugin, error) +type Registry map[string]func(ctx context.Context, client client.Client, indexer client.FieldIndexer, cfg *configapi.Configuration) (framework.Plugin, error) func NewRegistry() Registry { - return Registry{ + registry := Registry{ coscheduling.Name: coscheduling.New, flux.Name: flux.New, volcano.Name: volcano.New, @@ -47,4 +50,10 @@ func NewRegistry() Registry { jax.Name: jax.New, xgboost.Name: xgboost.New, } + + if features.Enabled(features.TrainJobStatus) { + registry[trainjobstatus.Name] = trainjobstatus.New + } + + return registry } diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 5540709347..abd8fb9592 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -30,6 +30,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -44,7 +45,7 @@ var _ framework.CustomValidationPlugin = (*Torch)(nil) const Name = "Torch" -func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { +func New(context.Context, client.Client, client.FieldIndexer, *configapi.Configuration) (framework.Plugin, error) { return &Torch{}, nil } diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index a6a5d6cde2..aeb13c6ee6 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1374,7 +1374,7 @@ func TestTorchEnforceMLPolicy(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cliBuilder := utiltesting.NewClientBuilder() - p, err := New(ctx, cliBuilder.Build(), nil) + p, err := New(ctx, cliBuilder.Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize Torch plugin: %v", err) } @@ -1682,7 +1682,7 @@ func TestTorchValidate(t *testing.T) { var cancel func() ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) - p, err := New(ctx, utiltesting.NewClientBuilder().Build(), nil) + p, err := New(ctx, utiltesting.NewClientBuilder().Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize Torch plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/trainjobstatus/trainjobstatus.go b/pkg/runtime/framework/plugins/trainjobstatus/trainjobstatus.go new file mode 100644 index 0000000000..d7722d978d --- /dev/null +++ b/pkg/runtime/framework/plugins/trainjobstatus/trainjobstatus.go @@ -0,0 +1,209 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trainjobstatus + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" + metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/apply" + "github.com/kubeflow/trainer/v2/pkg/constants" + "github.com/kubeflow/trainer/v2/pkg/runtime" + "github.com/kubeflow/trainer/v2/pkg/runtime/framework" + "github.com/kubeflow/trainer/v2/pkg/statusserver" + "github.com/kubeflow/trainer/v2/pkg/util/cert" +) + +const ( + Name = "TrainJobStatus" + + // Environment variable names + envNameStatusURL = "KUBEFLOW_TRAINER_SERVER_URL" + envNameCACert = "KUBEFLOW_TRAINER_SERVER_CA_CERT" + envNameToken = "KUBEFLOW_TRAINER_SERVER_TOKEN" + + // Volume and mount configuration + configMountPath = "/var/run/secrets/kubeflow/trainer" + caCertFileName = "ca.crt" + tokenFileName = "token" + tokenVolumeName = "kubeflow-trainer-token" + + // Service account token configuration + tokenExpirySeconds = 3600 + + // Server tls config + caCertKey = "ca.crt" +) + +type Status struct { + client client.Client + cfg *configapi.Configuration +} + +var _ framework.ComponentBuilderPlugin = (*Status)(nil) +var _ framework.EnforceMLPolicyPlugin = (*Status)(nil) + +func New(_ context.Context, c client.Client, _ client.FieldIndexer, cfg *configapi.Configuration) (framework.Plugin, error) { + return &Status{client: c, cfg: cfg}, nil +} + +func (p *Status) Name() string { + return Name +} + +func (p *Status) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { + if info == nil || trainJob == nil { + return nil + } + + envVars, err := p.createEnvVars(trainJob) + if err != nil { + return err + } + volumeMount := createTokenVolumeMount() + volume := createTokenVolume(trainJob) + + // Inject into all trainer containers + trainerPS := info.FindPodSetByAncestor(constants.AncestorTrainer) + if trainerPS != nil { + for i := range trainerPS.Containers { + apply.UpsertEnvVars(&trainerPS.Containers[i].Env, envVars...) + apply.UpsertVolumeMounts(&trainerPS.Containers[i].VolumeMounts, volumeMount) + } + apply.UpsertVolumes(&trainerPS.Volumes, volume) + } + + return nil +} + +func (p *Status) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]apiruntime.ApplyConfiguration, error) { + if info == nil || trainJob == nil { + return nil, nil + } + + configMap, err := p.buildStatusServerCaCrtConfigMap(ctx, trainJob) + if err != nil { + return nil, err + } + + return []apiruntime.ApplyConfiguration{configMap}, nil +} + +func (p *Status) createEnvVars(trainJob *trainer.TrainJob) ([]corev1ac.EnvVarApplyConfiguration, error) { + if p.cfg.StatusServer.Port == nil { + return nil, fmt.Errorf("missing status server port") + } + // TODO: consider renaming the CertManagement.WebhookServiceName name? + svc := fmt.Sprintf("https://%s.%s.svc:%d", p.cfg.CertManagement.WebhookServiceName, cert.GetOperatorNamespace(), *p.cfg.StatusServer.Port) + path := statusserver.StatusUrl(trainJob.Namespace, trainJob.Name) + statusURL := svc + path + + return []corev1ac.EnvVarApplyConfiguration{ + *corev1ac.EnvVar(). + WithName(envNameStatusURL). + WithValue(statusURL), + *corev1ac.EnvVar(). + WithName(envNameCACert). + WithValue(fmt.Sprintf("%s/%s", configMountPath, caCertFileName)), + *corev1ac.EnvVar(). + WithName(envNameToken). + WithValue(fmt.Sprintf("%s/%s", configMountPath, tokenFileName)), + }, nil +} + +func createTokenVolumeMount() corev1ac.VolumeMountApplyConfiguration { + return *corev1ac.VolumeMount(). + WithName(tokenVolumeName). + WithMountPath(configMountPath). + WithReadOnly(true) +} + +func createTokenVolume(trainJob *trainer.TrainJob) corev1ac.VolumeApplyConfiguration { + configMapName := fmt.Sprintf("%s-tls-config", trainJob.Name) + + return *corev1ac.Volume(). + WithName(tokenVolumeName). + WithProjected( + corev1ac.ProjectedVolumeSource(). + WithSources( + corev1ac.VolumeProjection(). + WithServiceAccountToken( + corev1ac.ServiceAccountTokenProjection(). + WithAudience(statusserver.TokenAudience(trainJob.Namespace, trainJob.Name)). + WithExpirationSeconds(tokenExpirySeconds). + WithPath(tokenFileName), + ), + corev1ac.VolumeProjection(). + WithConfigMap( + corev1ac.ConfigMapProjection(). + WithName(configMapName). + WithItems( + corev1ac.KeyToPath(). + WithKey(caCertKey). + WithPath(caCertFileName), + ), + ), + ), + ) +} + +// buildStatusServerCaCrtConfigMap creates a ConfigMap that will copy the ca.crt from the webhook secret +func (p *Status) buildStatusServerCaCrtConfigMap(ctx context.Context, trainJob *trainer.TrainJob) (*corev1ac.ConfigMapApplyConfiguration, error) { + configMapName := fmt.Sprintf("%s-tls-config", trainJob.Name) + + // Get the CA cert from the webhook secret + secret := &corev1.Secret{} + secretKey := client.ObjectKey{ + Namespace: cert.GetOperatorNamespace(), + Name: p.cfg.CertManagement.WebhookSecretName, + } + + var caCertData string + if err := p.client.Get(ctx, secretKey, secret); err == nil { + if caCert, ok := secret.Data[caCertKey]; ok && len(caCert) > 0 { + caCertData = string(caCert) + } else { + return nil, fmt.Errorf("failed to find status server ca.crt in tls secret") + } + } else { + return nil, fmt.Errorf("failed to look up status server tls secret: %w", err) + } + + configMap := corev1ac.ConfigMap(configMapName, trainJob.Namespace). + WithData(map[string]string{ + caCertKey: caCertData, + }). + WithOwnerReferences( + metav1ac.OwnerReference(). + WithAPIVersion(trainer.GroupVersion.String()). + WithKind(trainer.TrainJobKind). + WithName(trainJob.Name). + WithUID(trainJob.UID). + WithController(true). + WithBlockOwnerDeletion(true), + ) + + return configMap, nil +} diff --git a/pkg/runtime/framework/plugins/trainjobstatus/trainjobstatus_test.go b/pkg/runtime/framework/plugins/trainjobstatus/trainjobstatus_test.go new file mode 100644 index 0000000000..a8d2a8802e --- /dev/null +++ b/pkg/runtime/framework/plugins/trainjobstatus/trainjobstatus_test.go @@ -0,0 +1,506 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trainjobstatus + +import ( + "cmp" + "context" + "fmt" + "testing" + + gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" + "k8s.io/klog/v2/ktesting" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/constants" + "github.com/kubeflow/trainer/v2/pkg/runtime" + "github.com/kubeflow/trainer/v2/pkg/runtime/framework" + utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing" +) + +func TestEnforceMLPolicy(t *testing.T) { + cases := map[string]struct { + info *runtime.Info + trainJob *trainer.TrainJob + wantInfo *runtime.Info + wantError error + }{ + "does nothing if no trainer pods": { + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "launcher", + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + {Name: "launcher"}, + }, + }, + }, + }, + }, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(1).Obj()). + Obj(), + wantInfo: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "launcher", + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + {Name: "launcher"}, + }, + }, + }, + }, + }, + }, + "injects runtime configuration into trainer containers": { + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](2), + Containers: []runtime.Container{ + {Name: constants.Node}, + }, + }, + }, + }, + }, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("test-uid"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(2).Obj()). + Obj(), + wantInfo: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](2), + Containers: []runtime.Container{ + { + Name: constants.Node, + Env: []corev1ac.EnvVarApplyConfiguration{ + *corev1ac.EnvVar(). + WithName(envNameStatusURL). + WithValue("https://kubeflow-trainer-controller-manager.kubeflow-system.svc:10443/apis/trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status"), + *corev1ac.EnvVar(). + WithName(envNameCACert). + WithValue(fmt.Sprintf("%s/%s", configMountPath, caCertFileName)), + *corev1ac.EnvVar(). + WithName(envNameToken). + WithValue(fmt.Sprintf("%s/%s", configMountPath, tokenFileName)), + }, + VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{ + *corev1ac.VolumeMount(). + WithName(tokenVolumeName). + WithMountPath(configMountPath). + WithReadOnly(true), + }, + }, + }, + Volumes: []corev1ac.VolumeApplyConfiguration{ + *corev1ac.Volume(). + WithName(tokenVolumeName). + WithProjected( + corev1ac.ProjectedVolumeSource(). + WithSources( + corev1ac.VolumeProjection(). + WithServiceAccountToken( + corev1ac.ServiceAccountTokenProjection(). + WithAudience("trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status"). + WithExpirationSeconds(tokenExpirySeconds). + WithPath(tokenFileName), + ), + corev1ac.VolumeProjection(). + WithConfigMap( + corev1ac.ConfigMapProjection(). + WithName("test-job-tls-config"). + WithItems( + corev1ac.KeyToPath(). + WithKey(caCertKey). + WithPath(caCertFileName), + ), + ), + ), + ), + }, + }, + }, + }, + }, + }, + "injects runtime configuration into multiple containers": { + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + {Name: constants.Node}, + {Name: "sidecar"}, + }, + }, + }, + }, + }, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("test-uid"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(1).Obj()). + Obj(), + wantInfo: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + { + Name: constants.Node, + Env: []corev1ac.EnvVarApplyConfiguration{ + *corev1ac.EnvVar(). + WithName(envNameStatusURL). + WithValue("https://kubeflow-trainer-controller-manager.kubeflow-system.svc:10443/apis/trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status"), + *corev1ac.EnvVar(). + WithName(envNameCACert). + WithValue(fmt.Sprintf("%s/%s", configMountPath, caCertFileName)), + *corev1ac.EnvVar(). + WithName(envNameToken). + WithValue(fmt.Sprintf("%s/%s", configMountPath, tokenFileName)), + }, + VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{ + *corev1ac.VolumeMount(). + WithName(tokenVolumeName). + WithMountPath(configMountPath). + WithReadOnly(true), + }, + }, + { + Name: "sidecar", + Env: []corev1ac.EnvVarApplyConfiguration{ + *corev1ac.EnvVar(). + WithName(envNameStatusURL). + WithValue("https://kubeflow-trainer-controller-manager.kubeflow-system.svc:10443/apis/trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status"), + *corev1ac.EnvVar(). + WithName(envNameCACert). + WithValue(fmt.Sprintf("%s/%s", configMountPath, caCertFileName)), + *corev1ac.EnvVar(). + WithName(envNameToken). + WithValue(fmt.Sprintf("%s/%s", configMountPath, tokenFileName)), + }, + VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{ + *corev1ac.VolumeMount(). + WithName(tokenVolumeName). + WithMountPath(configMountPath). + WithReadOnly(true), + }, + }, + }, + Volumes: []corev1ac.VolumeApplyConfiguration{ + *corev1ac.Volume(). + WithName(tokenVolumeName). + WithProjected( + corev1ac.ProjectedVolumeSource(). + WithSources( + corev1ac.VolumeProjection(). + WithServiceAccountToken( + corev1ac.ServiceAccountTokenProjection(). + WithAudience("trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status"). + WithExpirationSeconds(tokenExpirySeconds). + WithPath(tokenFileName), + ), + corev1ac.VolumeProjection(). + WithConfigMap( + corev1ac.ConfigMapProjection(). + WithName("test-job-tls-config"). + WithItems( + corev1ac.KeyToPath(). + WithKey(caCertKey). + WithPath(caCertFileName), + ), + ), + ), + ), + }, + }, + }, + }, + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) + var cancel func() + ctx, cancel = context.WithCancel(ctx) + t.Cleanup(cancel) + + cli := utiltesting.NewClientBuilder().Build() + cfg := &configapi.Configuration{ + CertManagement: &configapi.CertManagement{ + WebhookServiceName: "kubeflow-trainer-controller-manager", + WebhookSecretName: "kubeflow-trainer-webhook-cert", + }, + StatusServer: &configapi.StatusServer{ + Port: ptr.To[int32](10443), + QPS: ptr.To[float32](5), + Burst: ptr.To[int32](10), + }, + } + + p, err := New(ctx, cli, nil, cfg) + if err != nil { + t.Fatalf("Failed to initialize Status plugin: %v", err) + } + + err = p.(framework.EnforceMLPolicyPlugin).EnforceMLPolicy(tc.info, tc.trainJob) + if diff := gocmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from EnforceMLPolicy (-want, +got): %s", diff) + } + + if diff := gocmp.Diff(tc.wantInfo, tc.info, + cmpopts.SortSlices(func(a, b string) bool { return a < b }), + cmpopts.SortMaps(func(a, b int) bool { return a < b }), + ); len(diff) != 0 { + t.Errorf("Unexpected info from EnforceMLPolicy (-want, +got): %s", diff) + } + }) + } +} + +func TestBuild(t *testing.T) { + objCmpOpts := []gocmp.Option{ + cmpopts.SortSlices(func(a, b apiruntime.Object) int { + return cmp.Compare(a.GetObjectKind().GroupVersionKind().String(), b.GetObjectKind().GroupVersionKind().String()) + }), + } + + cases := map[string]struct { + info *runtime.Info + trainJob *trainer.TrainJob + objs []client.Object + wantObjs []apiruntime.Object + wantError string + }{ + "no action when info is nil": { + info: nil, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").Obj(), + wantObjs: nil, + }, + "no action when trainJob is nil": { + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + {Name: constants.Node}, + }, + }, + }, + }, + }, + trainJob: nil, + wantObjs: nil, + }, + "creates ConfigMap with CA cert from secret": { + objs: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kubeflow-trainer-webhook-cert", + Namespace: "kubeflow-system", // Webhook secret is in operator namespace + }, + Data: map[string][]byte{ + caCertKey: []byte("test-ca-cert-data"), + }, + }, + }, + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](2), + Containers: []runtime.Container{ + {Name: constants.Node}, + }, + }, + }, + }, + }, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("test-uid"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(2).Obj()). + Obj(), + wantObjs: []apiruntime.Object{ + &corev1.ConfigMap{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "ConfigMap", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-job-tls-config", + Namespace: metav1.NamespaceDefault, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: trainer.GroupVersion.String(), + Kind: trainer.TrainJobKind, + Name: "test-job", + UID: "test-uid", + Controller: ptr.To(true), + BlockOwnerDeletion: ptr.To(true), + }, + }, + }, + Data: map[string]string{ + caCertKey: "test-ca-cert-data", + }, + }, + }, + }, + "returns error when webhook secret not found": { + objs: []client.Object{}, + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + {Name: constants.Node}, + }, + }, + }, + }, + }, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("test-uid"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(1).Obj()). + Obj(), + wantError: "failed to look up status server tls secret", + }, + "returns error when CA cert is missing in secret": { + objs: []client.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kubeflow-trainer-webhook-cert", + Namespace: "kubeflow-system", + }, + Data: map[string][]byte{ + // No ca.crt key + "other-key": []byte("other-data"), + }, + }, + }, + info: &runtime.Info{ + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: "trainer", + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + Containers: []runtime.Container{ + {Name: constants.Node}, + }, + }, + }, + }, + }, + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("test-uid"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(1).Obj()). + Obj(), + wantError: "failed to find status server ca.crt in tls secret", + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) + var cancel func() + ctx, cancel = context.WithCancel(ctx) + t.Cleanup(cancel) + + b := utiltesting.NewClientBuilder().WithObjects(tc.objs...) + cli := b.Build() + + cfg := &configapi.Configuration{ + CertManagement: &configapi.CertManagement{ + WebhookServiceName: "kubeflow-trainer-controller-manager", + WebhookSecretName: "kubeflow-trainer-webhook-cert", + }, + StatusServer: &configapi.StatusServer{ + Port: ptr.To[int32](10443), + QPS: ptr.To[float32](5), + Burst: ptr.To[int32](10), + }, + } + + p, err := New(ctx, cli, nil, cfg) + if err != nil { + t.Fatalf("Failed to initialize Status plugin: %v", err) + } + + var objs []apiruntime.ApplyConfiguration + objs, err = p.(framework.ComponentBuilderPlugin).Build(ctx, tc.info, tc.trainJob) + + if tc.wantError != "" { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tc.wantError) + } else if len(err.Error()) < len(tc.wantError) || err.Error()[:len(tc.wantError)] != tc.wantError { + t.Errorf("Expected error containing %q, got %q", tc.wantError, err.Error()) + } + return + } + + if err != nil { + t.Errorf("Unexpected error from Build: %v", err) + } + + var typedObjs []apiruntime.Object + typedObjs, err = utiltesting.ToObject(cli.Scheme(), objs...) + if err != nil { + t.Errorf("Failed to convert object: %v", err) + } + + if diff := gocmp.Diff(tc.wantObjs, typedObjs, objCmpOpts...); len(diff) != 0 { + t.Errorf("Unexpected objects from Build (-want, +got): %s", diff) + } + }) + } +} diff --git a/pkg/runtime/framework/plugins/volcano/volcano.go b/pkg/runtime/framework/plugins/volcano/volcano.go index 2727480403..5468cfb923 100644 --- a/pkg/runtime/framework/plugins/volcano/volcano.go +++ b/pkg/runtime/framework/plugins/volcano/volcano.go @@ -48,6 +48,7 @@ import ( volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanov1beta1ac "volcano.sh/apis/pkg/client/applyconfiguration/scheduling/v1beta1" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/runtime" "github.com/kubeflow/trainer/v2/pkg/runtime/framework" @@ -71,7 +72,7 @@ const Name = "Volcano" // +kubebuilder:rbac:groups=node.k8s.io,resources=runtimeclasses,verbs=get;list;watch // +kubebuilder:rbac:groups="",resources=limitranges,verbs=get;list;watch -func New(_ context.Context, client client.Client, _ client.FieldIndexer) (framework.Plugin, error) { +func New(_ context.Context, client client.Client, _ client.FieldIndexer, _ *configapi.Configuration) (framework.Plugin, error) { return &Volcano{ client: client, restMapper: client.RESTMapper(), diff --git a/pkg/runtime/framework/plugins/volcano/volcano_test.go b/pkg/runtime/framework/plugins/volcano/volcano_test.go index 70c761b6c1..4150aee439 100644 --- a/pkg/runtime/framework/plugins/volcano/volcano_test.go +++ b/pkg/runtime/framework/plugins/volcano/volcano_test.go @@ -359,7 +359,7 @@ func TestVolcano(t *testing.T) { } cli := clientBuilder.Build() - plugin, err := New(ctx, cli, utiltesting.AsIndex(clientBuilder)) + plugin, err := New(ctx, cli, utiltesting.AsIndex(clientBuilder), nil) if err != nil { t.Fatalf("Failed to create plugin: %v", err) } @@ -507,7 +507,7 @@ func TestValidate(t *testing.T) { clientBuilder := utiltesting.NewClientBuilder().WithObjects(tc.objs...) cli := clientBuilder.Build() - v, err := New(ctx, cli, nil) + v, err := New(ctx, cli, nil, nil) if err != nil { t.Fatalf("failed to init Volcano plugin: %v", err) } diff --git a/pkg/runtime/framework/plugins/xgboost/xgboost.go b/pkg/runtime/framework/plugins/xgboost/xgboost.go index 1278a04646..4e78c74d74 100644 --- a/pkg/runtime/framework/plugins/xgboost/xgboost.go +++ b/pkg/runtime/framework/plugins/xgboost/xgboost.go @@ -27,6 +27,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/v2/pkg/apply" "github.com/kubeflow/trainer/v2/pkg/constants" @@ -41,7 +42,7 @@ var _ framework.CustomValidationPlugin = (*XGBoost)(nil) const Name = "XGBoost" -func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { +func New(context.Context, client.Client, client.FieldIndexer, *configapi.Configuration) (framework.Plugin, error) { return &XGBoost{}, nil } diff --git a/pkg/runtime/framework/plugins/xgboost/xgboost_test.go b/pkg/runtime/framework/plugins/xgboost/xgboost_test.go index b927b6455c..0c91847c6c 100644 --- a/pkg/runtime/framework/plugins/xgboost/xgboost_test.go +++ b/pkg/runtime/framework/plugins/xgboost/xgboost_test.go @@ -201,7 +201,7 @@ func TestXGBoostValidate(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cliBuilder := utiltesting.NewClientBuilder() - p, err := New(ctx, cliBuilder.Build(), nil) + p, err := New(ctx, cliBuilder.Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize XGBoost plugin: %v", err) } @@ -931,7 +931,7 @@ func TestXGBoostEnforceMLPolicy(t *testing.T) { ctx, cancel = context.WithCancel(ctx) t.Cleanup(cancel) cliBuilder := utiltesting.NewClientBuilder() - p, err := New(ctx, cliBuilder.Build(), nil) + p, err := New(ctx, cliBuilder.Build(), nil, nil) if err != nil { t.Fatalf("Failed to initialize XGBoost plugin: %v", err) } diff --git a/pkg/statusserver/auth.go b/pkg/statusserver/auth.go new file mode 100644 index 0000000000..603cf76b91 --- /dev/null +++ b/pkg/statusserver/auth.go @@ -0,0 +1,154 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package statusserver + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "k8s.io/client-go/rest" +) + +type TokenAuthorizer interface { + Init(ctx context.Context) error + Authorize(ctx context.Context, rawIDToken, namespace, trainJobName string) (bool, error) +} + +type projectedServiceAccountTokenAuthorizer struct { + oidcProvider *oidc.Provider + config *rest.Config +} + +var _ TokenAuthorizer = &projectedServiceAccountTokenAuthorizer{} + +// projectedToken is the decoded JWT claims of a k8s projected service account token. +// Note: we only decode the subset of claims we actually need. +type projectedToken struct { + Issuer string `json:"iss"` + Kubernetes struct { + Namespace string `json:"namespace"` + } `json:"kubernetes.io"` +} + +// NewProjectedServiceAccountTokenAuthorizer creates a validator for checking a bearer token has permission to +// update the requested train job. +func NewProjectedServiceAccountTokenAuthorizer(config *rest.Config) TokenAuthorizer { + return &projectedServiceAccountTokenAuthorizer{ + config: config, + } +} + +func (p *projectedServiceAccountTokenAuthorizer) Init(ctx context.Context) error { + issuerURL, err := getClusterOIDCIssuerURL() + if err != nil { + return fmt.Errorf("failed to discover issuer URL: %w", err) + } + + // Create an authenticated HTTP client using the provided rest config + httpClient, err := rest.HTTPClientFor(p.config) + if err != nil { + return fmt.Errorf("failed to create HTTP client: %w", err) + } + + // Create context with the authenticated HTTP client + ctx = oidc.ClientContext(ctx, httpClient) + + provider, err := oidc.NewProvider(ctx, issuerURL) + if err != nil { + return fmt.Errorf("failed to create OIDC provider: %w", err) + } + p.oidcProvider = provider + + return nil +} + +func (p *projectedServiceAccountTokenAuthorizer) Authorize(ctx context.Context, authHeader, namespace, trainJobName string) (bool, error) { + if p.oidcProvider == nil { + return false, fmt.Errorf("OIDC provider has not been initialized") + } + + rawToken := extractRawToken(authHeader) + + // Create authorizer with TrainJob-specific audience + expectedAudience := TokenAudience(namespace, trainJobName) + verifier := p.oidcProvider.Verifier(&oidc.Config{ + ClientID: expectedAudience, + }) + + // Check token signature, expiry, and audience + idToken, err := verifier.Verify(ctx, rawToken) + if err != nil { + return false, nil + } + + // Check token is bound to a pod in the same namespace as the train job + parsedToken := projectedToken{} + err = idToken.Claims(&parsedToken) + if err != nil { + return false, nil + } + if parsedToken.Kubernetes.Namespace != namespace { + return false, nil + } + + return true, nil +} + +func extractRawToken(authHeader string) string { + parts := strings.Split(authHeader, " ") + + if len(parts) != 2 || parts[0] != "Bearer" { + return "" + } + + return parts[1] +} + +// getClusterOIDCIssuerURL tries to look up the cluster token issuer from the in-cluster service account token +// Different clusters may use different issuers. This is a reliable way of discovering the issuer. +func getClusterOIDCIssuerURL() (string, error) { + tokenBytes, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/token") + if err != nil { + return "", err + } + + parts := strings.Split(strings.TrimSpace(string(tokenBytes)), ".") + if len(parts) != 3 { + return "", fmt.Errorf("serviceaccount token is not a jwt") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("serviceaccount token is not a jwt: %w", err) + } + + var token projectedToken + if err := json.Unmarshal(payload, &token); err != nil { + return "", fmt.Errorf("serviceaccount token is not a jwt: %w", err) + } + + if token.Issuer == "" { + return "", fmt.Errorf("serviceaccount token missing issuer claim") + } + + return token.Issuer, nil +} diff --git a/pkg/statusserver/middleware.go b/pkg/statusserver/middleware.go new file mode 100644 index 0000000000..b6d5a768e9 --- /dev/null +++ b/pkg/statusserver/middleware.go @@ -0,0 +1,64 @@ +package statusserver + +import ( + "fmt" + "net/http" + + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type Middleware func(http.Handler) http.Handler + +// chain applies middleware in order: first middleware wraps second, etc. +func chain(h http.Handler, middlewares ...Middleware) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + h = middlewares[i](h) + } + return h +} + +// recoveryMiddleware recovers from panics in HTTP handlers to prevent Server crashes. +func recoveryMiddleware(log logr.Logger) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Error(fmt.Errorf("panic: %v", err), "Panic in HTTP handler", + "path", r.URL.Path, "method", r.Method) + badRequest(w, log, "Internal Server Error", v1.StatusReasonInternalError, http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) + } +} + +// loggingMiddleware logs incoming HTTP requests. +func loggingMiddleware(log logr.Logger) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.V(5).Info("HTTP request", "method", r.Method, "path", r.URL.Path, "remote", r.RemoteAddr) + next.ServeHTTP(w, r) + }) + } +} + +// bodySizeLimitMiddleware enforces a maximum request body size. +func bodySizeLimitMiddleware(log logr.Logger, maxBytes int64) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Reject based on Content-Length header if present + if r.ContentLength > maxBytes { + badRequest(w, log, "Payload too large", + v1.StatusReasonRequestEntityTooLarge, + http.StatusRequestEntityTooLarge) + return + } + + // Wrap body to enforce limit for chunked/streaming requests + r.Body = http.MaxBytesReader(w, r.Body, maxBytes) + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/statusserver/middleware_test.go b/pkg/statusserver/middleware_test.go new file mode 100644 index 0000000000..0b2dd99d1b --- /dev/null +++ b/pkg/statusserver/middleware_test.go @@ -0,0 +1,31 @@ +package statusserver + +import ( + "net/http" + "net/http/httptest" + "testing" + + "k8s.io/klog/v2/ktesting" +) + +func TestRecoveryMiddleware(t *testing.T) { + logger, _ := ktesting.NewTestContext(t) + + // Create a handler that panics + panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + // Wrap with recovery middleware + handler := recoveryMiddleware(logger)(panicHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + // Should not panic, should return 500 + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %v, want %v", rec.Code, http.StatusInternalServerError) + } +} diff --git a/pkg/statusserver/server.go b/pkg/statusserver/server.go new file mode 100644 index 0000000000..d5a7ee9494 --- /dev/null +++ b/pkg/statusserver/server.go @@ -0,0 +1,263 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package statusserver + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/go-logr/logr" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" + + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + trainerv1alpha1ac "github.com/kubeflow/trainer/v2/pkg/client/applyconfiguration/trainer/v1alpha1" +) + +const ( + shutdownTimeout = 5 * time.Second + + // HTTP Server timeouts to prevent resource exhaustion + readTimeout = 10 * time.Second + writeTimeout = 10 * time.Second + idleTimeout = 120 * time.Second + + // Maximum request body size (64kB) + maxBodySize = 1 << 16 +) + +// Server for collecting runtime status updates. +type Server struct { + log logr.Logger + httpServer *http.Server + client client.Client + authorizer TokenAuthorizer +} + +var ( + _ manager.Runnable = &Server{} + _ manager.LeaderElectionRunnable = &Server{} +) + +// NewServer creates a new Server for collecting runtime status updates. +func NewServer(c client.Client, cfg *configapi.StatusServer, tlsConfig *tls.Config, authorizer TokenAuthorizer) (*Server, error) { + if cfg == nil || cfg.Port == nil { + return nil, fmt.Errorf("cfg info is required") + } + if tlsConfig == nil { + return nil, fmt.Errorf("tlsConfig is required") + } + if authorizer == nil { + return nil, fmt.Errorf("authorizer is required") + } + + log := ctrl.Log.WithName("runtime-status") + + s := &Server{ + log: log, + client: c, + authorizer: authorizer, + } + + mux := http.NewServeMux() + mux.HandleFunc("POST "+StatusUrl("{namespace}", "{name}"), s.handleTrainJobRuntimeStatus) + mux.HandleFunc("/", s.handleDefault) + + // Apply middleware (authentication happens in handler) + handler := chain(mux, + recoveryMiddleware(log), + loggingMiddleware(log), + bodySizeLimitMiddleware(log, maxBodySize), + ) + + httpServer := http.Server{ + Addr: fmt.Sprintf(":%d", *cfg.Port), + Handler: handler, + TLSConfig: tlsConfig, + ErrorLog: slog.NewLogLogger(logr.ToSlogHandler(log), slog.LevelInfo), + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + IdleTimeout: idleTimeout, + } + + s.httpServer = &httpServer + + return s, nil +} + +// Start implements manager.Runnable and starts the HTTPS Server. +// It blocks until the Server stops, either due to an error or graceful shutdown. +func (s *Server) Start(ctx context.Context) error { + s.log.Info("Initializing token authorizer") + if err := s.authorizer.Init(ctx); err != nil { + return fmt.Errorf("token authorizer initialization failed: %w", err) + } + + // Handle graceful shutdown in background + serverShutdown := make(chan struct{}) + go func() { + <-ctx.Done() + s.log.Info("Shutting down runtime status server") + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := s.httpServer.Shutdown(shutdownCtx); err != nil { + s.log.Error(err, "Error shutting down runtime status server") + } + }() + + s.log.Info("Starting runtime status server with TLS", "address", s.httpServer.Addr) + if err := s.httpServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("runtime status server failed: %w", err) + } + + <-serverShutdown + return nil +} + +func (s *Server) NeedLeaderElection() bool { + // server needs to run on all replicas + return false +} + +// handleTrainJobRuntimeStatus handles POST requests to update TrainJob status. +// Expected URL format: /apis/trainer.kubeflow.org/v1alpha1/namespaces/{namespace}/trainjobs/{name}/status +func (s *Server) handleTrainJobRuntimeStatus(w http.ResponseWriter, r *http.Request) { + + namespace := r.PathValue("namespace") + trainJobName := r.PathValue("name") + + authorized, err := s.authorizer.Authorize(r.Context(), r.Header.Get("Authorization"), namespace, trainJobName) + if err != nil { + badRequest(w, s.log, "Internal error", metav1.StatusReasonInternalError, http.StatusInternalServerError) + return + } + if !authorized { + badRequest(w, s.log, "Forbidden", metav1.StatusReasonForbidden, http.StatusForbidden) + return + } + + // Parse request body + var updateRequest trainer.UpdateTrainJobStatusRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&updateRequest); err != nil { + s.log.V(5).Error(err, "Failed to parse runtime status", "namespace", namespace, "trainJobName", trainJobName) + badRequest(w, s.log, "Invalid payload", metav1.StatusReasonInvalid, http.StatusUnprocessableEntity) + return + } + + // If the update request is empty (no trainer status), return success without applying + if updateRequest.TrainerStatus == nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(updateRequest); err != nil { + s.log.Error(err, "Failed to write response", "namespace", namespace, "name", trainJobName) + } + return + } + + var trainJob = trainerv1alpha1ac.TrainJob(trainJobName, namespace).WithStatus(toApplyConfig(updateRequest)) + + if err := s.client.Status().Apply(r.Context(), trainJob, client.ForceOwnership, client.FieldOwner("trainer-status")); err != nil { + // Check if the error is due to validation failure + if apierrors.IsInvalid(err) || apierrors.IsBadRequest(err) { + // Extract the validation error message for the user + statusErr, ok := err.(*apierrors.StatusError) + if ok && statusErr.ErrStatus.Message != "" { + badRequest(w, s.log, statusErr.ErrStatus.Message, metav1.StatusReasonInvalid, http.StatusUnprocessableEntity) + } else { + badRequest(w, s.log, "Validation failed: "+err.Error(), metav1.StatusReasonInvalid, http.StatusUnprocessableEntity) + } + return + } + + // Check if the error is due to missing object + if apierrors.IsNotFound(err) { + badRequest(w, s.log, "Train job not found", metav1.StatusReasonNotFound, http.StatusNotFound) + return + } + + // For other errors, return internal server error + s.log.Error(err, "Failed to update TrainJob", "namespace", namespace, "name", trainJobName) + badRequest(w, s.log, "Internal error", metav1.StatusReasonInternalError, http.StatusInternalServerError) + return + } + + // Return the parsed payload + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(updateRequest); err != nil { + s.log.Error(err, "Failed to write TrainJob status", "namespace", namespace, "name", trainJobName) + } +} + +// handleDefault is the default handler for unknown requests. +func (s *Server) handleDefault(w http.ResponseWriter, _ *http.Request) { + badRequest(w, s.log, "Not found", metav1.StatusReasonNotFound, http.StatusNotFound) +} + +// badRequest sends a kubernetes Status response with the error message +func badRequest(w http.ResponseWriter, log logr.Logger, message string, reason metav1.StatusReason, code int32) { + status := metav1.Status{ + Status: metav1.StatusFailure, + Message: message, + Reason: reason, + Code: code, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(int(code)) + if err := json.NewEncoder(w).Encode(status); err != nil { + log.Error(err, "Failed to write bad request details") + } +} + +func toApplyConfig(updateRequest trainer.UpdateTrainJobStatusRequest) *trainerv1alpha1ac.TrainJobStatusApplyConfiguration { + var s = trainerv1alpha1ac.TrainJobStatus() + + trainerStatus := updateRequest.TrainerStatus + if trainerStatus != nil { + var ts = trainerv1alpha1ac.TrainerStatus() + + if trainerStatus.ProgressPercentage != nil { + ts = ts.WithProgressPercentage(*trainerStatus.ProgressPercentage) + } + if trainerStatus.EstimatedRemainingSeconds != nil { + ts = ts.WithEstimatedRemainingSeconds(*trainerStatus.EstimatedRemainingSeconds) + } + for _, m := range trainerStatus.Metrics { + ts.WithMetrics( + trainerv1alpha1ac.Metric(). + WithName(m.Name). + WithValue(m.Value), + ) + } + + ts = ts.WithLastUpdatedTime(trainerStatus.LastUpdatedTime) + s.WithTrainerStatus(ts) + } + return s +} diff --git a/pkg/statusserver/server_test.go b/pkg/statusserver/server_test.go new file mode 100644 index 0000000000..04d6a47ac3 --- /dev/null +++ b/pkg/statusserver/server_test.go @@ -0,0 +1,156 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package statusserver + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing" +) + +type fakeAuthorizer struct { + authorized bool +} + +func (f fakeAuthorizer) Init(_ context.Context) error { + return nil +} + +func (f fakeAuthorizer) Authorize(_ context.Context, _, _, _ string) (bool, error) { + return f.authorized, nil +} + +func newTestServer(t *testing.T, cfg *configapi.StatusServer, authorizer TokenAuthorizer, objs ...client.Object) *httptest.Server { + t.Helper() + + fakeClient := utiltesting.NewClientBuilder(). + WithObjects(objs...). + WithStatusSubresource(objs...). + Build() + + srv, err := NewServer(fakeClient, cfg, &tls.Config{}, authorizer) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + return httptest.NewServer(srv.httpServer.Handler) +} + +func TestServerErrorResponses(t *testing.T) { + cases := map[string]struct { + url string + body string + authorized bool + wantResponse *metav1.Status + }{ + "unauthorized fails with 403 unauthorized": { + url: "/apis/trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status", + authorized: false, + wantResponse: &metav1.Status{ + Status: metav1.StatusFailure, + Message: "Forbidden", + Reason: metav1.StatusReasonForbidden, + Code: http.StatusForbidden, + }, + }, + "invalid payload fails with 422 unprocessable entity": { + url: "/apis/trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status", + body: `invalid payload`, + authorized: true, + wantResponse: &metav1.Status{ + Status: metav1.StatusFailure, + Message: "Invalid payload", + Reason: metav1.StatusReasonInvalid, + Code: http.StatusUnprocessableEntity, + }, + }, + "oversized payload fails with 413 payload too large error": { + url: "/apis/trainer.kubeflow.org/v1alpha1/namespaces/default/trainjobs/test-job/status", + // Generate ~1MB payload (exceeds 64kB limit) + body: `{"trainerStatus": {"metrics": [` + strings.Repeat(`{"name":"m","value":"0.5"},`, 40000) + `]}}`, + authorized: true, + wantResponse: &metav1.Status{ + Status: metav1.StatusFailure, + Message: "Payload too large", + Reason: metav1.StatusReasonRequestEntityTooLarge, + Code: http.StatusRequestEntityTooLarge, + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + existingTrainJob := &trainer.TrainJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-job", + Namespace: "default", + }, + } + ts := newTestServer( + t, + &configapi.StatusServer{Port: ptr.To[int32](8080)}, + fakeAuthorizer{authorized: tc.authorized}, + existingTrainJob, + ) + defer ts.Close() + + // Make actual HTTP request + req, err := http.NewRequest("POST", ts.URL+tc.url, bytes.NewReader([]byte(tc.body))) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("HTTP POST failed: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close() }) + + if resp.StatusCode != int(tc.wantResponse.Code) { + t.Errorf("status = %v, want %v", resp.StatusCode, tc.wantResponse.Code) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var got metav1.Status + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if diff := cmp.Diff(tc.wantResponse, &got); diff != "" { + t.Errorf("response mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/statusserver/setup.go b/pkg/statusserver/setup.go new file mode 100644 index 0000000000..700809f5a2 --- /dev/null +++ b/pkg/statusserver/setup.go @@ -0,0 +1,73 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package statusserver + +import ( + "fmt" + + "k8s.io/client-go/rest" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/util/cert" +) + +func SetupServer(mgr ctrl.Manager, cfg *configapi.StatusServer, enableHTTP2 bool) error { + tlsConfig, err := cert.SetupTLSConfig(mgr, enableHTTP2) + if err != nil { + return err + } + + // Create a separate client with its own QPS/Burst limits + // to avoid impacting the main reconciler's rate limits + cli, err := createClient(mgr, cfg) + if err != nil { + return err + } + + // Initialize OIDC provider for token authentication + // The provider will be used to create verifiers with TrainJob-specific audiences + authorizer := NewProjectedServiceAccountTokenAuthorizer(mgr.GetConfig()) + + server, err := NewServer(cli, cfg, tlsConfig, authorizer) + if err != nil { + return err + } + return mgr.Add(server) +} + +func createClient(mgr ctrl.Manager, cfg *configapi.StatusServer) (client.Client, error) { + // Copy the manager's rest config and override rate limits + mgrCfg := rest.CopyConfig(mgr.GetConfig()) + if cfg.QPS != nil { + mgrCfg.QPS = *cfg.QPS + } + if cfg.Burst != nil { + mgrCfg.Burst = int(*cfg.Burst) + } + + cli, err := client.New(mgrCfg, client.Options{ + Scheme: mgr.GetScheme(), + Mapper: mgr.GetRESTMapper(), + }) + if err != nil { + return nil, fmt.Errorf("failed to create status server client: %w", err) + } + + return cli, nil +} diff --git a/pkg/statusserver/utils.go b/pkg/statusserver/utils.go new file mode 100644 index 0000000000..47ce55a2ca --- /dev/null +++ b/pkg/statusserver/utils.go @@ -0,0 +1,34 @@ +/* +Copyright 2026 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package statusserver + +import "fmt" + +const ( + // TokenAudiencePrefix is the prefix for projected service account token audiences + TokenAudiencePrefix = "trainer.kubeflow.org" +) + +// TokenAudience returns the required audience for a TrainJob's status endpoint. +func TokenAudience(namespace, name string) string { + return fmt.Sprintf("%s/v1alpha1/namespaces/%s/trainjobs/%s/status", TokenAudiencePrefix, namespace, name) +} + +// StatusUrl is the path of the endpoint for receiving status updates +func StatusUrl(namespace, name string) string { + return fmt.Sprintf("/apis/trainer.kubeflow.org/v1alpha1/namespaces/%s/trainjobs/%s/status", namespace, name) +} diff --git a/pkg/util/cert/cert.go b/pkg/util/cert/cert.go index 041d8ebb5c..321121024c 100644 --- a/pkg/util/cert/cert.go +++ b/pkg/util/cert/cert.go @@ -17,6 +17,7 @@ limitations under the License. package cert import ( + "crypto/tls" "fmt" "os" "strings" @@ -24,6 +25,7 @@ import ( cert "github.com/open-policy-agent/cert-controller/pkg/rotator" "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" ) const ( @@ -33,7 +35,7 @@ const ( defaultNamespace = "kubeflow-system" ) -func getOperatorNamespace() string { +func GetOperatorNamespace() string { if data, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"); err == nil { if ns := strings.TrimSpace(string(data)); len(ns) > 0 { return ns @@ -54,7 +56,7 @@ type Config struct { // ManageCerts creates all certs for webhooks. func ManageCerts(mgr ctrl.Manager, cfg Config, setupFinished chan struct{}) error { - ns := getOperatorNamespace() + ns := GetOperatorNamespace() // DNSName is ..svc dnsName := fmt.Sprintf("%s.%s.svc", cfg.WebhookServiceName, ns) @@ -77,3 +79,28 @@ func ManageCerts(mgr ctrl.Manager, cfg Config, setupFinished chan struct{}) erro RequireLeaderElection: false, }) } + +// SetupTLSConfig creates a TLS config with automatic certificate rotation support. +// It creates a cert watcher, adds it to the manager, and returns a TLS config +// that will automatically pick up rotated certificates. +func SetupTLSConfig(mgr ctrl.Manager, enableHTTP2 bool) (*tls.Config, error) { + certWatcher, err := certwatcher.New(certDir+"/tls.crt", certDir+"/tls.key") + if err != nil { + return nil, fmt.Errorf("error creating cert watcher: %w", err) + } + + if err := mgr.Add(certWatcher); err != nil { + return nil, fmt.Errorf("error adding cert watcher to manager: %w", err) + } + + tlsConfig := &tls.Config{ + GetCertificate: certWatcher.GetCertificate, + } + + // Disable HTTP/2 unless explicitly enabled (CVE-2023-44487, CVE-2023-39325) + if !enableHTTP2 { + tlsConfig.NextProtos = []string{"http/1.1"} + } + + return tlsConfig, nil +} diff --git a/pkg/webhooks/trainjob_webhook_test.go b/pkg/webhooks/trainjob_webhook_test.go index 01a24c8829..e4aa26c1b0 100644 --- a/pkg/webhooks/trainjob_webhook_test.go +++ b/pkg/webhooks/trainjob_webhook_test.go @@ -105,7 +105,8 @@ func TestValidateCreate(t *testing.T) { if tc.clusterTrainingRuntime != nil { clientBuilder = clientBuilder.WithObjects(tc.clusterTrainingRuntime) } - runtimes, err := runtimecore.New(context.Background(), clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + + runtimes, err := runtimecore.New(context.Background(), clientBuilder.Build(), testingutil.AsIndex(clientBuilder), nil) if err != nil { t.Fatal(err) } diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index e51feaf146..a72c76b6bf 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -14,6 +14,8 @@ import ( "github.com/kubeflow/trainer/v2/pkg/constants" testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing" "github.com/kubeflow/trainer/v2/test/util" + + _ "embed" ) const ( @@ -23,6 +25,9 @@ const ( xgboostRuntime = "xgboost-distributed" ) +//go:embed testdata/status_update.py +var statusUpdateScript string + var _ = ginkgo.Describe("TrainJob e2e", func() { // Each test runs in a separate namespace. var ns *corev1.Namespace @@ -394,4 +399,64 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) + + ginkgo.When("Creating TrainJob with runtime status server instrumentation", func() { + ginkgo.It("should inject runtime configuration which allows the runtime status endpoint to be called", func() { + // Create a TrainJob that sends a single runtime status update and exits + trainJob := testingutil.MakeTrainJobWrapper(ns.Name, "e2e-test-runtime-status"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), torchRuntime). + Trainer(&trainer.Trainer{ + Command: []string{"python3", "-c"}, + Args: []string{statusUpdateScript}, + }). + Obj() + + ginkgo.By("Create a TrainJob that will call the runtime-status endpoint", func() { + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + }) + + ginkgo.By("Verify trainerStatus is updated with runtime status information", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + + // Verify trainerStatus is not nil + g.Expect(gotTrainJob.Status.TrainerStatus).ShouldNot(gomega.BeNil()) + + // Verify progress percentage + g.Expect(gotTrainJob.Status.TrainerStatus.ProgressPercentage).ShouldNot(gomega.BeNil()) + g.Expect(*gotTrainJob.Status.TrainerStatus.ProgressPercentage).Should(gomega.Equal(int32(42))) + + // Verify estimated remaining seconds + g.Expect(gotTrainJob.Status.TrainerStatus.EstimatedRemainingSeconds).ShouldNot(gomega.BeNil()) + g.Expect(*gotTrainJob.Status.TrainerStatus.EstimatedRemainingSeconds).Should(gomega.Equal(int32(120))) + + // Verify metrics + g.Expect(gotTrainJob.Status.TrainerStatus.Metrics).Should(gomega.HaveLen(2)) + g.Expect(gotTrainJob.Status.TrainerStatus.Metrics[0].Name).Should(gomega.Equal("loss")) + g.Expect(gotTrainJob.Status.TrainerStatus.Metrics[0].Value).Should(gomega.Equal("0.123")) + g.Expect(gotTrainJob.Status.TrainerStatus.Metrics[1].Name).Should(gomega.Equal("accuracy")) + g.Expect(gotTrainJob.Status.TrainerStatus.Metrics[1].Value).Should(gomega.Equal("0.95")) + + // Verify lastUpdatedTime is set + g.Expect(gotTrainJob.Status.TrainerStatus.LastUpdatedTime.IsZero()).Should(gomega.BeFalse()) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Wait for TrainJob to be in Succeeded status", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{ + { + Type: trainer.TrainJobComplete, + Status: metav1.ConditionTrue, + Reason: jobsetconsts.AllJobsCompletedReason, + Message: jobsetconsts.AllJobsCompletedMessage, + }, + }, util.IgnoreConditions)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + }) + }) }) diff --git a/test/e2e/testdata/status_update.py b/test/e2e/testdata/status_update.py new file mode 100644 index 0000000000..28bf8d68ec --- /dev/null +++ b/test/e2e/testdata/status_update.py @@ -0,0 +1,53 @@ +""" +status_update.py + +E2E test script for TrainJob that updates the runtime status using the status server endpoint. + +This script validates that the runtime status endpoint can be called from within +a training container. It reads the status server URL, CA certificate, and service +account token from environment variables injected by the Status plugin, then sends +a single status update with test metrics to verify the TrainJob status is updated. + +Environment variables required: +- KUBEFLOW_TRAINER_SERVER_URL: HTTPS URL for the status server endpoint +- KUBEFLOW_TRAINER_SERVER_CA_CERT: Path to CA certificate file for TLS verification +- KUBEFLOW_TRAINER_SERVER_TOKEN: Path to service account token file for authentication +""" + +import json +import os +import ssl +from datetime import datetime, timezone +from urllib import error, request + +url = os.environ["KUBEFLOW_TRAINER_SERVER_URL"] +ca_file = os.environ["KUBEFLOW_TRAINER_SERVER_CA_CERT"] +token = open(os.environ["KUBEFLOW_TRAINER_SERVER_TOKEN"]).read().strip() +ssl_context = ssl.create_default_context(cafile=ca_file) + +# Send a single status update +payload = { + "trainerStatus": { + "progressPercentage": 42, + "estimatedRemainingSeconds": 120, + "metrics": [ + {"name": "loss", "value": "0.123"}, + {"name": "accuracy", "value": "0.95"}, + ], + "lastUpdatedTime": datetime.now(timezone.utc).isoformat(), + } +} +data = json.dumps(payload) +req = request.Request( + url, data=data.encode("utf-8"), headers={"Authorization": f"Bearer {token}"} +) + +try: + resp = request.urlopen(req, context=ssl_context) +except error.HTTPError as ex: + body = ex.read().decode("utf-8", errors="replace") + print(f"Failed to update trainer status. {ex} {body}") + raise +else: + body = resp.read().decode("utf-8") + print(f"Success updating trainer status: {resp.getcode()} {body}") diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index 768cf68427..0185a5b70d 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -101,7 +101,7 @@ func (f *Framework) RunManager(cfg *rest.Config, startControllers bool) (context }) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "failed to create manager") - runtimes, err := runtimecore.New(ctx, mgr.GetClient(), mgr.GetFieldIndexer()) + runtimes, err := runtimecore.New(ctx, mgr.GetClient(), mgr.GetFieldIndexer(), nil) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) gomega.ExpectWithOffset(1, runtimes).NotTo(gomega.BeNil())