Skip to content

Commit

Permalink
fix: k8s spec not read correctly (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
vijayvammi authored Feb 15, 2025
1 parent fedf29b commit 5bf10c6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
24 changes: 12 additions & 12 deletions examples/11-jobs/k8s-job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ job-executor:
config:
pvc_claim_name: runnable
config_path:
mock: false
mock: true
namespace: enterprise-mlops
jobSpec:
# activeDeadlineSeconds: Optional[int]
activeDeadlineSeconds: 32000
# selector: Optional[LabelSelector]
# ttlSecondsAfterFinished: Optional[int]
template:
# metadata:
# annotations: Optional[Dict[str, str]]
# generate_name: Optional[str] = run_id
spec:
# activeDeadlineSeconds: Optional[int]
activeDeadlineSeconds: 86400
# nodeSelector: Optional[Dict[str, str]]
# tolerations: Optional[List[Toleration]]
# volumes:
Expand All @@ -30,15 +30,15 @@ job-executor:
# value: str
image: harbor.csis.astrazeneca.net/mlops/runnable:latest
# imagePullPolicy: Optional[str] = choose from [Always, Never, IfNotPresent]
# resources:
# limits:
# cpu: str
# memory: str
# gpu: str
# requests:
# cpu: str
# memory: str
# gpu: str
resources:
limits:
cpu: "$"
memory: "3"
gpu: "1"
# requests:
# cpu: "3"
# memory: "3"
# gpu: "1"
# volumeMounts:
# - name: str
# mountPath: str
61 changes: 37 additions & 24 deletions extensions/job_executor/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,42 @@ class TolerationOperator(str, Enum):
EQUAL = "Equal"


class Toleration(BaseModel):
class BaseModelWIthConfig(BaseModel, use_enum_values=True):
model_config = ConfigDict(
extra="forbid",
alias_generator=to_camel,
populate_by_name=True,
from_attributes=True,
validate_default=True,
)


class Toleration(BaseModelWIthConfig):
key: str
operator: TolerationOperator = TolerationOperator.EQUAL
value: Optional[str]
effect: str
toleration_seconds: Optional[int] = Field(default=None)


class LabelSelectorRequirement(BaseModel):
class LabelSelectorRequirement(BaseModelWIthConfig):
key: str
operator: Operator
values: list[str]


class LabelSelector(BaseModel):
class LabelSelector(BaseModelWIthConfig):
match_expressions: list[LabelSelectorRequirement]
match_labels: dict[str, str]


class ObjectMetaData(BaseModel):
class ObjectMetaData(BaseModelWIthConfig):
generate_name: Optional[str]
annotations: Optional[dict[str, str]]
namespace: Optional[str] = "default"


class EnvVar(BaseModel):
class EnvVar(BaseModelWIthConfig):
name: str
value: str

Expand All @@ -75,7 +85,7 @@ class EnvVar(BaseModel):
]


class Request(BaseModel):
class Request(BaseModelWIthConfig):
"""
The default requests
"""
Expand All @@ -85,7 +95,7 @@ class Request(BaseModel):
gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")


class Limit(BaseModel):
class Limit(BaseModelWIthConfig):
"""
The default limits
"""
Expand All @@ -95,34 +105,34 @@ class Limit(BaseModel):
gpu: VendorGPU = Field(default=None, serialization_alias="nvidia.com/gpu")


class Resources(BaseModel):
class Resources(BaseModelWIthConfig):
limits: Limit = Limit()
requests: Request = Request()
requests: Optional[Request] = Field(default=None)


class VolumeMount(BaseModel):
class VolumeMount(BaseModelWIthConfig):
name: str
mount_path: str


class Container(BaseModel):
class Container(BaseModelWIthConfig):
image: str
env: list[EnvVar] = Field(default_factory=list)
image_pull_policy: ImagePullPolicy = ImagePullPolicy.NEVER
image_pull_policy: ImagePullPolicy = Field(default=ImagePullPolicy.NEVER)
resources: Resources = Resources()
volume_mounts: Optional[list[VolumeMount]] = Field(default_factory=lambda: [])


class HostPath(BaseModel):
class HostPath(BaseModelWIthConfig):
path: str


class HostPathVolume(BaseModel):
class HostPathVolume(BaseModelWIthConfig):
name: str
host_path: HostPath


class PVCClaim(BaseModel):
class PVCClaim(BaseModelWIthConfig):
claim_name: str

model_config = ConfigDict(
Expand All @@ -132,12 +142,12 @@ class PVCClaim(BaseModel):
)


class PVCVolume(BaseModel):
class PVCVolume(BaseModelWIthConfig):
name: str
persistent_volume_claim: PVCClaim


class K8sTemplateSpec(BaseModel):
class K8sTemplateSpec(BaseModelWIthConfig):
active_deadline_seconds: int = Field(default=60 * 60 * 2) # 2 hours
node_selector: Optional[dict[str, str]] = None
tolerations: Optional[list[Toleration]] = None
Expand All @@ -149,12 +159,12 @@ class K8sTemplateSpec(BaseModel):
container: Container


class K8sTemplate(BaseModel):
class K8sTemplate(BaseModelWIthConfig):
spec: K8sTemplateSpec
metadata: Optional[ObjectMetaData] = None


class Spec(BaseModel):
class Spec(BaseModelWIthConfig):
active_deadline_seconds: Optional[int] = Field(default=60 * 60 * 2) # 2 hours
backoff_limit: int = 6
selector: Optional[LabelSelector] = None
Expand Down Expand Up @@ -251,7 +261,7 @@ def submit_k8s_job(self, task: BaseTaskType):
command = utils.get_job_execution_command()

container_env = [
self._client.V1EnvVar(**env.model_dump(by_alias=True))
self._client.V1EnvVar(**env.model_dump())
for env in self.job_spec.template.spec.container.env
]

Expand All @@ -260,23 +270,26 @@ def submit_k8s_job(self, task: BaseTaskType):
env=container_env,
name="default",
volume_mounts=container_volume_mounts,
resources=self.job_spec.template.spec.container.resources.model_dump(
by_alias=True, exclude_none=True
),
**self.job_spec.template.spec.container.model_dump(
exclude_none=True, exclude={"volume_mounts", "command", "env"}
exclude_none=True,
exclude={"volume_mounts", "command", "env", "resources"},
),
)

if self.job_spec.template.spec.volumes:
self._volumes += self.job_spec.template.spec.volumes

spec_volumes = [
self._client.V1Volume(**vol.model_dump(by_alias=True))
for vol in self._volumes
self._client.V1Volume(**vol.model_dump()) for vol in self._volumes
]

tolerations = None
if self.job_spec.template.spec.tolerations:
tolerations = [
self._client.V1Toleration(**toleration.model_dump(by_alias=True))
self._client.V1Toleration(**toleration.model_dump())
for toleration in self.job_spec.template.spec.tolerations
]

Expand Down

0 comments on commit 5bf10c6

Please sign in to comment.