Skip to content

Commit b22f653

Browse files
feat: Add training progression tracking feature for experimental implementation
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 9212662 commit b22f653

20 files changed

Lines changed: 3122 additions & 5 deletions

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ deploy-rhoai: ## Deploy operator using RHOAI manifests with kustomize
254254
crd/trainingruntimes.trainer.kubeflow.org \
255255
crd/trainjobs.trainer.kubeflow.org 2>/dev/null || sleep 5
256256
@echo "Applying operator and resources..."
257-
@$(KUBECTL) apply --server-side=true -k $(RHOAI_MANIFESTS_DIR)
257+
@$(KUBECTL) apply -k $(RHOAI_MANIFESTS_DIR) --server-side=true --force-conflicts
258258
@echo "Waiting for deployment to be ready..."
259259
@$(KUBECTL) wait --for=condition=available --timeout=300s \
260260
deployment/kubeflow-trainer-controller-manager -n $(NAMESPACE) 2>/dev/null || true

cmd/trainer-controller-manager/main.go

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
4141
"github.com/kubeflow/trainer/v2/pkg/config"
4242
"github.com/kubeflow/trainer/v2/pkg/controller"
43+
rhaisetup "github.com/kubeflow/trainer/v2/pkg/rhai"
4344
"github.com/kubeflow/trainer/v2/pkg/runtime"
4445
runtimecore "github.com/kubeflow/trainer/v2/pkg/runtime/core"
4546
"github.com/kubeflow/trainer/v2/pkg/util/cert"
@@ -150,9 +151,47 @@ func setupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime, cer
150151
<-certsReady
151152
setupLog.Info("Certs ready")
152153

153-
if failedCtrlName, err := controller.SetupControllers(mgr, runtimes, ctrlpkg.Options{}); err != nil {
154-
setupLog.Error(err, "Could not create controller", "controller", failedCtrlName)
155-
os.Exit(1)
154+
// Setup controllers with optional RHAI features (decorator pattern for midstream extensions)
155+
rhaiEnabled := os.Getenv("ENABLE_RHAI_FEATURES") == "true"
156+
if rhaiEnabled {
157+
setupLog.Info("RHAI features enabled")
158+
159+
// Setup TrainingRuntime controller
160+
runtimeRec := controller.NewTrainingRuntimeReconciler(
161+
mgr.GetClient(),
162+
mgr.GetEventRecorderFor("trainer-trainingruntime-controller"),
163+
)
164+
if err := runtimeRec.SetupWithManager(mgr, ctrlpkg.Options{}); err != nil {
165+
setupLog.Error(err, "Could not create controller", "controller", "TrainingRuntime")
166+
os.Exit(1)
167+
}
168+
169+
// Setup ClusterTrainingRuntime controller
170+
clRuntimeRec := controller.NewClusterTrainingRuntimeReconciler(
171+
mgr.GetClient(),
172+
mgr.GetEventRecorderFor("trainer-clustertrainingruntime-controller"),
173+
)
174+
if err := clRuntimeRec.SetupWithManager(mgr, ctrlpkg.Options{}); err != nil {
175+
setupLog.Error(err, "Could not create controller", "controller", "ClusterTrainingRuntime")
176+
os.Exit(1)
177+
}
178+
179+
// Wrap base TrainJob reconciler with RHAI progression tracking
180+
baseReconciler := controller.NewTrainJobReconciler(
181+
mgr.GetClient(),
182+
mgr.GetEventRecorderFor("trainer-trainjob-controller"),
183+
runtimes,
184+
controller.WithWatchers(runtimeRec, clRuntimeRec),
185+
)
186+
if err := rhaisetup.SetupWithManager(mgr, baseReconciler); err != nil {
187+
setupLog.Error(err, "Could not setup RHAI features")
188+
os.Exit(1)
189+
}
190+
} else {
191+
if failedCtrlName, err := controller.SetupControllers(mgr, runtimes, ctrlpkg.Options{}); err != nil {
192+
setupLog.Error(err, "Could not create controller", "controller", failedCtrlName)
193+
os.Exit(1)
194+
}
156195
}
157196
if failedWebhook, err := webhooks.Setup(mgr, runtimes); err != nil {
158197
setupLog.Error(err, "Could not create webhook", "webhook", failedWebhook)

manifests/rhoai/kustomization.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ patches:
5454
# through a ComponentConfig type
5555
- path: manager_config_patch.yaml
5656
- path: manager_metrics_patch.yaml
57+
# RHAI-specific RBAC for progression tracking
58+
- path: rbac_progression_patch.yaml
59+
target:
60+
group: rbac.authorization.k8s.io
61+
version: v1
62+
kind: ClusterRole
63+
name: kubeflow-trainer-controller-manager
5764
- patch: |-
5865
- op: remove
5966
path: /spec/ports/0

manifests/rhoai/manager_config_patch.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ spec:
88
containers:
99
- name: manager
1010
image: $(image)
11+
env:
12+
- name: ENABLE_RHAI_FEATURES
13+
value: "true"
1114
args:
1215
- --config=/controller_manager_config.yaml
1316
- --zap-log-level=2

manifests/rhoai/params.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
odh-kubeflow-trainer-controller-image=quay.io/opendatahub/trainer:v2.1.0
1+
odh-kubeflow-trainer-controller-image=quay.io/abdhumal/trainer:v2.1.0-rhai-progression-18Nov-2
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# RHAI-specific: Permissions for progression tracking
2+
# Allows the controller to read pod IPs and execute commands for metrics polling
3+
- op: add
4+
path: /rules/-
5+
value:
6+
apiGroups:
7+
- ""
8+
resources:
9+
- pods
10+
verbs:
11+
- create
12+
- delete
13+
- get
14+
- list
15+
- patch
16+
- update
17+
- watch
18+
- op: add
19+
path: /rules/-
20+
value:
21+
apiGroups:
22+
- ""
23+
resources:
24+
- pods/exec
25+
verbs:
26+
- create

pkg/rhai/README.md

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# RHAI (Red Hat AI) Extensions
2+
3+
This directory contains RHAI-specific extensions for the Kubeflow Trainer operator.
4+
5+
## Purpose
6+
7+
The `rhai/` package provides midstream-specific features that are not part of upstream Kubeflow:
8+
- **Progression tracking**: Real-time training metrics polling and status updates
9+
- **Custom annotations**: RHAI-specific metadata for training jobs
10+
- **Extended RBAC**: Additional permissions
11+
12+
## Structure
13+
14+
```
15+
pkg/rhai/
16+
├── README.md # This file
17+
├── setup.go # RHAI feature registration
18+
├── controller/
19+
│ └── progression_controller.go # Wraps base controller with progression tracking
20+
└── progression/
21+
├── progression.go # Core progression tracking logic
22+
└── progression_test.go # Tests for progression tracking
23+
```
24+
25+
## How It Works
26+
27+
### 1. Controller Wrapping
28+
29+
The `ProgressionReconciler` wraps the base `TrainJobReconciler` and adds:
30+
- Metrics polling from training pods
31+
- Progress annotation updates
32+
- Automatic requeuing for ongoing polling
33+
34+
### 2. Progression Tracking
35+
36+
When enabled via annotation `trainer.opendatahub.io/progression-tracking: "enabled"`:
37+
- Controller polls training pod's metrics endpoint (default port: 28080)
38+
- Updates TrainJob annotations with real-time progress
39+
- Captures final metrics on job completion/failure
40+
41+
### 3. Manifest Integration
42+
43+
RHAI-specific manifests in `manifests/rhoai/`:
44+
- `rbac_progression_patch.yaml`: Additional RBAC for pod access
45+
- `manager_config_patch.yaml`: ConfigMap mounting for feature flags
46+
47+
## Enabling RHAI Features
48+
49+
RHAI features are controlled via the `ENABLE_RHAI_FEATURES` environment variable. When enabled, the operator uses a wrapping controller that adds progression tracking to the base upstream functionality.
50+
51+
### Deployment Configuration
52+
53+
**For Kubernetes/OpenShift deployments**, set the environment variable in your deployment manifest:
54+
55+
```yaml
56+
# manifests/rhoai/manager_config_patch.yaml
57+
apiVersion: apps/v1
58+
kind: Deployment
59+
spec:
60+
template:
61+
spec:
62+
containers:
63+
- name: manager
64+
env:
65+
- name: ENABLE_RHAI_FEATURES
66+
value: "true"
67+
```
68+
69+
**For local development**, export the variable before running:
70+
71+
```bash
72+
export ENABLE_RHAI_FEATURES=true
73+
go run ./cmd/trainer-controller-manager/main.go
74+
```
75+
76+
## Usage Example
77+
78+
Create a TrainJob with progression tracking:
79+
80+
```yaml
81+
apiVersion: trainer.kubeflow.org/v1alpha1
82+
kind: TrainJob
83+
metadata:
84+
name: pytorch-example
85+
annotations:
86+
trainer.opendatahub.io/progression-tracking: "enabled"
87+
trainer.opendatahub.io/metrics-port: "28080" # optional, default: 28080
88+
trainer.opendatahub.io/metrics-poll-interval: "30s" # optional, default: 30s
89+
spec:
90+
# ... your training job spec ...
91+
```
92+
93+
The controller will:
94+
1. Poll the primary pod's metrics endpoint every 30s(default) - configurable
95+
2. Update the `trainer.opendatahub.io/trainerStatus` annotation with:
96+
- Progress percentage
97+
- Current step/epoch
98+
- Loss and learning rate
99+
- Time elapsed/remaining
100+
- Custom metrics
101+
3. Capture final status when job completes
102+
103+
## Development
104+
105+
### Running Tests
106+
107+
```bash
108+
go test ./pkg/rhai/...
109+
```
110+
111+
### Adding New RHAI Features
112+
113+
1. Create new package under `pkg/rhai/yourfeature/`
114+
2. Add controller wrapper if needed in `pkg/rhai/controller/`
115+
3. Update `pkg/rhai/setup.go` to register the feature
116+
4. Add manifest patches in `manifests/rhoai/`
117+
5. Document in this README
118+
119+
## Maintenance
120+
121+
When rebasing from upstream:
122+
1. Pull upstream changes: `git pull upstream master`
123+
2. Rebase: `git rebase upstream/master`
124+
3. `pkg/rhai/` should auto-merge (no conflicts expected)
125+
4. Review controller integration in `main.go` if base controller changed
126+
5. Run tests: `make test`

pkg/rhai/constants/constants.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
Copyright 2024 The Kubeflow Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
// Package constants contains shared constants for all RHAI features.
18+
package constants
19+
20+
const (
21+
// Progression tracking feature annotations
22+
23+
// AnnotationProgressionTracking enables/disables progression tracking for a TrainJob.
24+
// Value: "enabled" to enable tracking, any other value or absence disables it.
25+
// Example: trainer.opendatahub.io/progression-tracking: "enabled"
26+
AnnotationProgressionTracking string = "trainer.opendatahub.io/progression-tracking"
27+
28+
// AnnotationTrainerStatus stores the JSON-encoded training status/progress.
29+
// This annotation is automatically updated by the controller with real-time metrics.
30+
// Example: trainer.opendatahub.io/trainerStatus: '{"status":"training","progress":{"percent":45.2},...}'
31+
AnnotationTrainerStatus string = "trainer.opendatahub.io/trainerStatus"
32+
33+
// AnnotationMetricsPort specifies the port where the training pod exposes metrics.
34+
// Default: 28080
35+
// Example: trainer.opendatahub.io/metrics-port: "8080"
36+
AnnotationMetricsPort string = "trainer.opendatahub.io/metrics-port"
37+
38+
// AnnotationMetricsPollInterval specifies how often to poll metrics (supports duration format).
39+
// Accepts: "30s", "1m", or integer seconds "30" (min: 5s, max: 300s)
40+
// Default: 30s
41+
// Example: trainer.opendatahub.io/metrics-poll-interval: "45s"
42+
AnnotationMetricsPollInterval string = "trainer.opendatahub.io/metrics-poll-interval"
43+
44+
// AnnotationFramework specifies the training framework (for framework-specific handling).
45+
// Example: trainer.opendatahub.io/framework: "pytorch"
46+
AnnotationFramework string = "trainer.opendatahub.io/framework"
47+
48+
// DefaultMetricsPort is the default port for metrics endpoints in training pods.
49+
DefaultMetricsPort string = "28080"
50+
51+
// DefaultMetricsPollIntervalSecs is the default interval (in seconds) for polling training metrics.
52+
DefaultMetricsPollIntervalSecs int = 30
53+
)

0 commit comments

Comments
 (0)