feat: add support for tracking TrainJob progress and training metrics#3227
feat: add support for tracking TrainJob progress and training metrics#3227andreyvelich merged 65 commits intokubeflow:masterfrom
Conversation
|
cc @andreyvelich @tenzen-y @astefanutti @akshaychitneni The progress tracking implementation is ready for an initial review once you have bandwidth. I've still some bits to work through (I'm updating the task list as I go along), but I'd appreciate any early feedback on the approach I've taken. Particular areas I'd be keen for feedback on -
|
9914660 to
0181d6a
Compare
There was a problem hiding this comment.
Pull request overview
Implements the TrainJobProgress KEP by adding a progress/metrics reporting surface to TrainJob and wiring an authenticated HTTPS “progress server” plus a runtime plugin that injects the required runtime configuration into training pods.
Changes:
- Add
status.trainerStatus(progress %, ETA, metrics, timestamp) to the TrainJob API and generated clients/specs. - Introduce
TrainJobProgressalpha feature gate and a controller-side HTTPS progress server with authn/authz + rate-limited client. - Add a runtime framework “progress” plugin that injects env vars, a projected SA token, and a CA bundle ConfigMap mount into trainer pods; add e2e coverage.
Reviewed changes
Copilot reviewed 74 out of 75 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/integration/framework/framework.go | Pass default controller configuration into runtime initialization for integration tests. |
| test/e2e/testdata/progress.py | E2E helper script that posts a progress payload from inside a training container. |
| test/e2e/e2e_test.go | Adds an e2e that validates trainerStatus gets updated via the progress endpoint. |
| pkg/webhooks/trainjob_webhook_test.go | Updates webhook tests for new runtime initialization signature (config param). |
| pkg/util/testing/config.go | Adds a test helper to build a defaulted Configuration object. |
| pkg/util/runtime/runtime.go | Adds shared helper to detect operator namespace (with local default). |
| pkg/util/cert/cert.go | Refactors namespace lookup and adds cert-watcher-backed TLS config helper. |
| pkg/runtime/framework/plugins/volcano/volcano_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/volcano/volcano.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/plugins/torch/torch_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/torch/torch.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/plugins/registry.go | Extends plugin factory signature to pass config; conditionally registers Progress plugin behind feature gate. |
| pkg/runtime/framework/plugins/progress/progress_test.go | Unit tests for Progress plugin injection and ConfigMap creation behavior. |
| pkg/runtime/framework/plugins/progress/progress.go | New Progress plugin that injects env/token/CA mount + creates CA ConfigMap per TrainJob. |
| pkg/runtime/framework/plugins/plainml/plainml_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/plainml/plainml.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/plugins/mpi/mpi_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/mpi/mpi.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/plugins/jobset/jobset_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/jobset/jobset.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/plugins/jax/jax_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/jax/jax.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/plugins/coscheduling/coscheduling_test.go | Updates plugin constructor signature in tests (adds cfg param). |
| pkg/runtime/framework/plugins/coscheduling/coscheduling.go | Updates plugin constructor signature to accept config. |
| pkg/runtime/framework/core/framework_test.go | Updates framework construction in tests to pass default config. |
| pkg/runtime/framework/core/framework.go | Threads config through plugin factories during framework initialization. |
| pkg/runtime/core/trainingruntime_test.go | Updates runtime tests to pass config into runtime construction. |
| pkg/runtime/core/trainingruntime.go | Threads config into framework initialization. |
| pkg/runtime/core/registry.go | Updates runtime registrar factory signature to accept config. |
| pkg/runtime/core/core.go | Threads config through runtime registry initialization. |
| pkg/runtime/core/clustertrainingruntime_test.go | Updates cluster runtime tests to pass config into runtime construction. |
| pkg/runtime/core/clustertrainingruntime.go | Updates cluster runtime constructor signature to accept config. |
| pkg/progress/setup.go | Adds manager wiring for the progress server (TLS watcher, verifier, separate client). |
| pkg/progress/server_test.go | Adds server tests for success + error responses and auth behavior. |
| pkg/progress/server.go | Implements HTTPS progress server endpoint that patches status.trainerStatus with authz based on pod label + token claims. |
| pkg/progress/middleware_test.go | Adds recovery middleware test. |
| pkg/progress/middleware.go | Adds middleware chain: recovery, logging, authn, body-size limits. |
| pkg/progress/constants.go | Adds shared label/audience constants and status URL helper. |
| pkg/progress/auth.go | Adds projected SA token OIDC verification + issuer discovery from in-cluster token. |
| pkg/features/features.go | Introduces TrainJobProgress alpha feature gate defaulting to disabled. |
| pkg/controller/trainjob_controller.go | Avoids reconciler overwriting externally-managed status.trainerStatus updates. |
| pkg/config/validation.go | Adds validation for ProgressServer config (port/qps/burst). |
| pkg/config/config_test.go | Extends config tests to cover ProgressServer defaults and file loading. |
| pkg/client/applyconfiguration/utils.go | Adds applyconfiguration mappings for Metric and TrainJobTrainerStatus. |
| pkg/client/applyconfiguration/trainer/v1alpha1/trainjobtrainerstatus.go | Generated applyconfiguration for TrainJobTrainerStatus. |
| pkg/client/applyconfiguration/trainer/v1alpha1/trainjobstatus.go | Adds TrainerStatus to applyconfiguration for TrainJobStatus. |
| pkg/client/applyconfiguration/trainer/v1alpha1/metric.go | Generated applyconfiguration for Metric. |
| pkg/apis/trainer/v1alpha1/zz_generated.openapi.go | Updates OpenAPI defs for Metric/ProgressStatus/TrainJobTrainerStatus and status field. |
| pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go | Adds deepcopy support for new API types/fields. |
| pkg/apis/trainer/v1alpha1/trainjob_types.go | Adds status.trainerStatus, Metric, ProgressStatus, and TrainJobTrainerStatus types + validations. |
| pkg/apis/config/v1alpha1/zz_generated.deepcopy.go | Adds deepcopy support for ProgressServer config. |
| pkg/apis/config/v1alpha1/defaults.go | Adds defaulting for ProgressServer config. |
| pkg/apis/config/v1alpha1/configuration_types.go | Adds ProgressServer config API type and Configuration field. |
| manifests/base/rbac/role.yaml | Grants controller permission to get pods (required for progress server authz). |
| manifests/base/manager/manager.yaml | Exposes the progress-server container port and service port. |
| manifests/base/manager/controller_manager_config.yaml | Adds progressServer config defaults to the base config. |
| manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml | Updates CRD schema to include status.trainerStatus. |
| hack/e2e-setup-gpu-cluster.sh | Enables TrainJobProgress feature gate for e2e cluster setup. |
| hack/e2e-setup-cluster.sh | Enables TrainJobProgress feature gate for e2e cluster setup. |
| go.sum | Adds module checksums for OIDC dependencies. |
| go.mod | Adds go-oidc dependency (and go-jose indirect). |
| cmd/trainer-controller-manager/main.go | Adds --feature-gates, threads config into runtimes, and conditionally starts progress server. |
| charts/kubeflow-trainer/values.yaml | Adds progressServer config values. |
| charts/kubeflow-trainer/templates/manager/service.yaml | Adds Service port for progress-server. |
| charts/kubeflow-trainer/templates/manager/deployment.yaml | Adds progress-server container port. |
| charts/kubeflow-trainer/templates/manager/configmap.yaml | Renders progressServer config into manager configmap. |
| charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainjobs.yaml | Updates Helm CRD schema to include status.trainerStatus. |
| charts/kubeflow-trainer/README.md | Updates values documentation to include progressServer config. |
| api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_train_job_trainer_status.py | Adds generated Python model for TrainJobTrainerStatus. |
| api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_train_job_status.py | Adds TrainerStatus field to generated Python TrainJobStatus model. |
| api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_progress_status.py | Adds generated Python model for ProgressStatus. |
| api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_metric.py | Adds generated Python model for Metric. |
| api/python_api/kubeflow_trainer_api/models/init.py | Exports newly generated Python models. |
| api/openapi-spec/swagger.json | Updates published OpenAPI spec with new schemas/fields. |
| Makefile | Includes pkg/progress in controller-gen targets (CRDs/RBAC generation). |
astefanutti
left a comment
There was a problem hiding this comment.
Thanks @robert-bell!
I would suggest to use "TrainJobStatus" or "TrainJobRuntimeStatus" instead of TrainJobProgress to be more future-proof.
pkg/status/server.go
Outdated
| TrainerStatus: progressStatus.TrainerStatus, | ||
| }, | ||
| } | ||
| if err := s.client.Status().Patch(r.Context(), &trainJob, client.Merge); err != nil { |
There was a problem hiding this comment.
Use SSA with a dedicated manager.
There was a problem hiding this comment.
I think I've updated to SSA, but I'm not sure if I used the right approach - would you mind double checking? I've updated the client to use apply and the generated applyconfiguration client code, and reverted the change in the trainjob reconciler that was unsetting the trainer status.
I couldn't get the existing unit tests to work with SSA - the client apply is failing with Operation cannot be fulfilled on trainjobs.trainer.kubeflow.org "test-job": object was modified. Is that expected - is it what the TODO on the trainjob controller is referring to?
// TODO(astefanutti): Consider using SSA once controller-runtime client has SSA support
// for sub-resources. See: kubernetes-sigs/controller-runtime#3183
There was a problem hiding this comment.
Yes you've taken the right approach on the status server side with:
s.client.Status().Apply(r.Context(), trainJob, client.ForceOwnership,
Now you can apply it to the controller side, which is indeed what that TODO is about.
There was a problem hiding this comment.
Thanks for confirming. If it's OK can we address that TODO in a separate PR to keep this one focused. This PR doesn't depend on that TODO being fixed.
pkg/progress/setup.go
Outdated
| "github.com/kubeflow/trainer/v2/pkg/util/cert" | ||
| ) | ||
|
|
||
| func SetupServer(mgr ctrl.Manager, cfg *configapi.ProgressServer) error { |
There was a problem hiding this comment.
Should there be a network policy that only authorizes ingress from TrainJob Pods?
There was a problem hiding this comment.
Yeah that'd be a nice extra layer though the controller pod is also serving metrics and the webhook so I think the netpol would also need to explicitly allow ingress into those ports from any IP. I also don't think I can disable the netpol with the feature-gate so it'd always be active.
Are we OK with that?
There was a problem hiding this comment.
The ingress rules for metrics and webhooks could probably be restricted to allowed incoming namespaces.
Also the NWP could be dynamically applied at start-time.
This can be tackled in a follow-up PR to keep the scope of the PR focused in the KEP.
pkg/status/server.go
Outdated
| } | ||
|
|
||
| // Verify the pod has the label identifying it belongs to this TrainJob | ||
| trainJobNameFromLabel, ok := pod.Labels[LabelTrainJobName] |
There was a problem hiding this comment.
Would using a finer-grained audience avoid getting Pods?
There was a problem hiding this comment.
Are you thinking something like putting the train job name in the audience? That would let us avoid looking up the pod. It'd also mean I wouldn't need to decode the sa token which is cleaner.
Did you have any thoughts about the token format? It might be worth including the whole path in there for future extensibility - something like this? Or is this overkill?
trainer.kubeflow.org/v1alpha1/namespace/{namespace}/trainjobs/{name}/status
There was a problem hiding this comment.
just to add - I think encoding information in the audience is non-conventional. It's normally just supposed to be a static string identifying the target service.
I think it works for us because the pods should only ever be updating their parent train job, but just flagging because it could introduce brittleness in the future.
There was a problem hiding this comment.
Are you thinking something like putting the train job name in the audience?
Yes
I think encoding information in the audience is non-conventional. It's normally just supposed to be a static string identifying the target service.
That's a fair point I agree, though I'm not sure if "conventional" usage is very well defined either.
There was a problem hiding this comment.
Curious if we would lose the pod binding check? would it mean that the tokens used with deleted pods will be allowed to update status?
There was a problem hiding this comment.
I've pushed an update with that uses a different audience per train job. @astefanutti please take a look.
Curious if we would lose the pod binding check? would it mean that the tokens used with deleted pods will be allowed to update status?
Yeah, we do lose this, but I think the practical impact is fairly limited because the endpoint itself has fairly limited capabilities and the token expiry is fairly short (1hr). Happy to discuss alternatives though.
There was a problem hiding this comment.
@robert-bell That looks great, I agree with you the impact is limited in practice.
38221bb to
f4b4a30
Compare
pkg/status/auth.go
Outdated
| } | ||
|
|
||
| verifier := provider.Verifier(&oidc.Config{ | ||
| ClientID: TokenAudience, |
There was a problem hiding this comment.
Does it also verify issuer?
There was a problem hiding this comment.
Yep - the issuer is configured when the provider is created, and the verifier only accepts that issuer.
pkg/status/server.go
Outdated
| // authorizeRequest checks whether the service account token bearer token used by this request comes from | ||
| // a pod that is part of the TrainJob that is being updated. | ||
| func (s *Server) authorizeRequest(r *http.Request, namespace, trainJobName string) bool { | ||
| token, ok := serviceAccountTokenFromContext(r.Context()) |
There was a problem hiding this comment.
Any reason to have auth in the handler with saving token in the context vs using middleware
There was a problem hiding this comment.
No strong reason. I kept it in the handler mainly because the authz logic depends on the job name/namespace from the request path. Keeping it in the handler makes that dependency more explicit, whereas in middleware it would be a bit more implicit and coupled to path parsing.
Btw, in the latest version I'm not storing the parsed token in the context. The authn has also moved to the handler because it now depends on job name/namespace too.
akshaychitneni
left a comment
There was a problem hiding this comment.
Thanks @robert-bell
andreyvelich
left a comment
There was a problem hiding this comment.
Thank you for this awesome work @robert-bell! Overall looks great!
I left a few comments, will check tests soon as well.
| } | ||
|
|
||
| mux := http.NewServeMux() | ||
| mux.HandleFunc("POST "+StatusUrl("{namespace}", "{name}"), s.handleTrainJobRuntimeStatus) |
There was a problem hiding this comment.
Do we need to implement health check endpoint, so our instrumentation can check if server is healthy before reporting metrics?
Similar to /readyz for Kubernetes API server: https://kubernetes.io/docs/reference/using-api/health-checks/
There was a problem hiding this comment.
good shout - yes it's needed. It needs to be implemented in the existing health checks because the server is in the same process as the controllers and webhooks, but it looks like the controller Manager will let me add extra checks to the probes.
It looks like it might need a bit of rejigging. Leave it with me and try implement it this week.
There was a problem hiding this comment.
@robert-bell Sounds good, we can also do that in the followup PR. Let's just create an issue to track it.
| WithTrainerStatus(toApplyConfig(runtimeStatus.TrainerStatus)), | ||
| ) | ||
|
|
||
| if err := s.client.Status().Apply(r.Context(), trainJob, client.ForceOwnership, client.FieldOwner("trainer-status")); err != nil { |
There was a problem hiding this comment.
What if client calls server multiple times?
Do we need to have any guardrails for client calls?
There was a problem hiding this comment.
Client side rate-limiting gets triggered which rate limits the api calls - I've tested it works manually. FYI - I deliberately made a separate client for the server to prevent any rate-limiting affecting the reconcile loop.
We could look at additional protection, but I wonder whether it'd be more effective to put that protection in the instrumentation code to go in the sdk? wdyt?
There was a problem hiding this comment.
Probably. Let's try to discuss it once we have initial implementation.
|
@andreyvelich @astefanutti @akshaychitneni just a quick update - I'm still working through the changes from your comments. I've tried to remember to resolve the threads I've addressed so you could take a look if you wanted. Otherwise I'm hoping I can have the existing comments addressed before the end of the week.
@andreyvelich just fyi - some of the tests still need updating/have been temporarily removed. I'll cycle back to them once the implementation settles down. Do feel free to check out the existing tests though, I'd really welcome any feedback. |
24e6f31 to
475a596
Compare
|
Hey @andreyvelich, @astefanutti, @akshaychitneni, I've pushed another update and rebased and I think this is ready for re-review. I think I've addressed all the comments bar the ones that sounded like more discussion was needed -- I'm happy to update these once folk are happy with the way forward. Please can you take another look at the threads.
I've kept track of the follow on issues that've been requested - I'll open them once the PR is merged if that's ok
It'd be nice to have envtest integration tests, but I haven't included it. Could we leave it for a follow on PR? It'll be fairly chunky. |
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
Signed-off-by: Rob Bell <robell@redhat.com>
7bab1fa to
0ef0b91
Compare
|
/lgtm |
astefanutti
left a comment
There was a problem hiding this comment.
Thanks @robert-bell!
/lgtm
/approve
|
[APPROVALNOTIFIER] This PR is APPROVED This pull-request has been approved by: astefanutti The full list of commands accepted by this bot can be found here. The pull request process is described here DetailsNeeds approval from an approver in each of these files:
Approvers can indicate their approval by writing |
|
@robert-bell here's the prototype: https://github.com/krishdef7/kubeflow-llm-trainer-prototype Still a work in progress and actively evolving, but it covers the core instrumentation. The The callback reports loss, reward/KL divergence, and step/epoch progress on each logging step, covering the metrics I mentioned in my earlier comment. Two things I noticed from the client side while building this:
Also flagging a known gap: QLoRA with TRL requires a The prototype also includes TorchTune and Unsloth backends with the same callback wired in. Happy to iterate on any of this once #3227 lands. |
|
Thanks @andreyvelich, and thank you and @astefanutti and @akshaychitneni for your reviews. I'm really excited to have this over the line. 🎉 |
What this PR does / why we need it:
This PR implements TrainJobProgress (#2905), enabling real-time progress and metrics tracking for TrainJobs. Training pods can now push status updates (progress %, estimated time remaining, custom metrics) which are exposed via
status.trainerStatusin the TrainJob CR.It's still a WIP, and there's a few bits that still need working out, but I'd be keen for any feedback on the current approach.
Headline changes:
trainerStatusfield to the TrainJob status, using the spec from feat(docs): KEP-2779: Track TrainJob progress and expose training metrics #2905TrainJobProgress, defaults to disabled. Everything is disabled if the gate is disabled.kubeflow-trainer-controller-managerservicetrainerStatusfield.The implementation mostly follows what we agreed in #2905, with a few changes worth pointing out -
metav1.Statusobject to align with the k8s api server. I'd be happy to take any input on this.There's a few TODOs left before this is ready for actually merging - I'll work through these but please do start reviewing these changes as I'd like to check folk are happy with the general approach.
Which issue(s) this PR fixes (optional, in
Fixes #<issue number>, #<issue number>, ...format, will close the issue(s) when PR gets merged):Part of #2779
Checklist: