Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ deploy-rhoai: ## Deploy operator using RHOAI manifests with kustomize
crd/trainingruntimes.trainer.kubeflow.org \
crd/trainjobs.trainer.kubeflow.org 2>/dev/null || sleep 5
@echo "Applying operator and resources..."
@$(KUBECTL) apply --server-side=true -k $(RHOAI_MANIFESTS_DIR)
@$(KUBECTL) apply -k $(RHOAI_MANIFESTS_DIR) --server-side=true --force-conflicts
@echo "Waiting for deployment to be ready..."
@$(KUBECTL) wait --for=condition=available --timeout=300s \
deployment/kubeflow-trainer-controller-manager -n $(NAMESPACE) 2>/dev/null || true
Expand Down
7 changes: 7 additions & 0 deletions manifests/rhoai/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ patches:
# through a ComponentConfig type
- path: manager_config_patch.yaml
- path: manager_metrics_patch.yaml
# RHAI-specific RBAC for progression tracking
- path: rbac_progression_patch.yaml
target:
group: rbac.authorization.k8s.io
version: v1
kind: ClusterRole
name: kubeflow-trainer-controller-manager
- patch: |-
- op: remove
path: /spec/ports/0
Expand Down
3 changes: 3 additions & 0 deletions manifests/rhoai/manager_config_patch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ spec:
containers:
- name: manager
image: $(image)
env:
- name: ENABLE_RHAI_FEATURES
value: "true"
args:
- --config=/controller_manager_config.yaml
- --zap-log-level=2
Expand Down
13 changes: 13 additions & 0 deletions manifests/rhoai/rbac_progression_patch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# RHAI-specific: Permissions for progression tracking
# Allows the controller to list/get pods to read IPs for HTTP metrics polling
# Note: watch not needed - controller polls during TrainJob reconciliation, not on pod events
- op: add
path: /rules/-
value:
apiGroups:
- ""
resources:
- pods
verbs:
- get
- list
1 change: 1 addition & 0 deletions pkg/controller/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime, opt
}
if err := NewTrainJobReconciler(
mgr.GetClient(),
mgr.GetAPIReader(),
mgr.GetEventRecorderFor("trainer-trainjob-controller"),
runtimes,
WithWatchers(runtimeRec, clRuntimeRec),
Expand Down
30 changes: 18 additions & 12 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (

trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/v2/pkg/constants"
"github.com/kubeflow/trainer/v2/pkg/rhai/progression"
jobruntimes "github.com/kubeflow/trainer/v2/pkg/runtime"
)

Expand All @@ -51,11 +52,12 @@ type TrainJobWatcher interface {
}

type TrainJobReconciler struct {
log logr.Logger
client client.Client
recorder record.EventRecorder
runtimes map[string]jobruntimes.Runtime
watchers iter.Seq[TrainJobWatcher]
log logr.Logger
client client.Client
apiReader client.Reader
recorder record.EventRecorder
runtimes map[string]jobruntimes.Runtime
watchers iter.Seq[TrainJobWatcher]
}

type TrainJobReconcilerOptions struct {
Expand All @@ -73,17 +75,18 @@ func WithWatchers(watchers ...TrainJobWatcher) TrainJobReconcilerOption {
var _ reconcile.Reconciler = (*TrainJobReconciler)(nil)
var _ predicate.TypedPredicate[*trainer.TrainJob] = (*TrainJobReconciler)(nil)

func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runtimes map[string]jobruntimes.Runtime, opts ...TrainJobReconcilerOption) *TrainJobReconciler {
func NewTrainJobReconciler(client client.Client, apiReader client.Reader, recorder record.EventRecorder, runtimes map[string]jobruntimes.Runtime, opts ...TrainJobReconcilerOption) *TrainJobReconciler {
options := &TrainJobReconcilerOptions{}
for _, opt := range opts {
opt(options)
}
return &TrainJobReconciler{
log: ctrl.Log.WithName("trainjob-controller"),
client: client,
recorder: recorder,
runtimes: runtimes,
watchers: options.Watchers,
log: ctrl.Log.WithName("trainjob-controller"),
client: client,
apiReader: apiReader,
recorder: recorder,
runtimes: runtimes,
watchers: options.Watchers,
}
}

Expand Down Expand Up @@ -139,7 +142,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) {
return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob))
}
return ctrl.Result{}, err

// RHAI progression tracking (use APIReader to avoid pod watches)
result, progressionErr := progression.ReconcileProgression(ctx, r.client, r.apiReader, log, &trainJob)
return result, errors.Join(err, progressionErr)
}

func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
Expand Down
69 changes: 69 additions & 0 deletions pkg/rhai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# RHAI (Red Hat AI) Extensions

This directory contains RHAI-specific extensions for the Kubeflow Trainer operator.

## Purpose

The `rhai/` package provides midstream-specific features that are not part of upstream Kubeflow:
- **Progression tracking**: Real-time training metrics polling and status updates
- **Custom annotations**: RHAI-specific metadata for training jobs
- **Extended RBAC**: Additional permissions for pod access

## Structure

```
pkg/rhai/
├── constants/ # RHAI-specific constants and annotations
├── progression/ # Core progression tracking logic and tests
└── test/ # End-to-end tests
```

## Integration

Progression tracking is integrated into `TrainJobReconciler` with 2 lines:

```go
result, progressionErr := progression.ReconcileProgression(ctx, r.client, log, &trainJob)
return result, errors.Join(err, progressionErr)
```

- All RHAI logic isolated in `pkg/rhai/`
- Enabled per-TrainJob via annotation
- No-op when disabled

## Usage

Enable progression tracking via annotation:

```yaml
metadata:
annotations:
trainer.opendatahub.io/progression-tracking: "enabled"
trainer.opendatahub.io/metrics-port: "28080" # optional
trainer.opendatahub.io/metrics-poll-interval: "30s" # optional
```

Your training container exposes metrics at `http://localhost:28080/metrics`:

```json
{
"progressPercentage": 45,
"currentStep": 450,
"totalSteps": 1000,
"trainMetrics": {"loss": 0.235}
}
```

Controller updates `trainer.opendatahub.io/trainerStatus` annotation with progress.

User can monitor progress in realtime with training parameters using watch command below :
```
watch -n 2 'kubectl get trainjob <job-name> -n <namespace> -o jsonpath="{.metadata.annotations.trainer\.opendatahub\.io/trainerStatus}" | jq'

```

## Development

```bash
go test ./pkg/rhai/test/... -v -timeout 30m -ginkgo.v -ginkgo.progress
```
63 changes: 63 additions & 0 deletions pkg/rhai/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright 2024 The Kubeflow Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package constants contains shared constants for all RHAI features.
package constants

const (
// Progression tracking feature annotations

// AnnotationProgressionTracking enables/disables progression tracking for a TrainJob.
// Value: "enabled" to enable tracking, any other value or absence disables it.
// Example: trainer.opendatahub.io/progression-tracking: "enabled"
AnnotationProgressionTracking string = "trainer.opendatahub.io/progression-tracking"

// AnnotationTrainerStatus stores the JSON-encoded training status/progress.
// This annotation is automatically updated by the controller with real-time metrics.
// Example: trainer.opendatahub.io/trainerStatus: '{"status":"training","progress":{"percent":45.2},...}'
AnnotationTrainerStatus string = "trainer.opendatahub.io/trainerStatus"

// AnnotationMetricsPort specifies the port where the training pod exposes metrics.
// Default: 28080
// Example: trainer.opendatahub.io/metrics-port: "8080"
AnnotationMetricsPort string = "trainer.opendatahub.io/metrics-port"

// AnnotationMetricsPollInterval specifies how often to poll metrics (supports duration format).
// Accepts: "30s", "1m", or integer seconds "30" (min: 5s, max: 300s)
// Default: 30s
// Example: trainer.opendatahub.io/metrics-poll-interval: "45s"
AnnotationMetricsPollInterval string = "trainer.opendatahub.io/metrics-poll-interval"

// DefaultMetricsPort is the default port for metrics endpoints in training pods.
DefaultMetricsPort string = "28080"

// DefaultMetricsPollIntervalSecs is the default interval (in seconds) for polling training metrics.
DefaultMetricsPollIntervalSecs int = 30

// MinMetricsPollIntervalSecs is the minimum allowed poll interval to prevent excessive controller load.
MinMetricsPollIntervalSecs int = 5

// MaxMetricsPollIntervalSecs is the maximum allowed poll interval to keep tracking responsive.
MaxMetricsPollIntervalSecs int = 300

// PreStopBufferSecs is added to (2 × poll interval) for preStop hook duration.
// This ensures at least 2 poll opportunities after training completion.
PreStopBufferSecs int = 10

// TerminationGraceBufferSecs is added to preStop duration for pod termination grace period.
// This allows time for graceful process shutdown after preStop hook completes.
TerminationGraceBufferSecs int = 30
)
Loading
Loading