From 6f7b02cf7b36e83d3856cab40169f695d91783ff Mon Sep 17 00:00:00 2001 From: MStokluska Date: Thu, 6 Nov 2025 10:23:12 +0100 Subject: [PATCH 1/5] test: initial implementation of SDK e2e --- .gitignore | 1 + tests/trainer/kubeflow_sdk_test.go | 30 ++ tests/trainer/resources/mnist.ipynb | 275 ++++++++++++++++++ .../trainer/sdk_tests/fashion_mnist_tests.go | 86 ++++++ tests/trainer/utils/utils_cluster_prep.go | 83 ++++++ tests/trainer/utils/utils_notebook.go | 85 ++++++ 6 files changed, 560 insertions(+) create mode 100644 tests/trainer/kubeflow_sdk_test.go create mode 100644 tests/trainer/resources/mnist.ipynb create mode 100644 tests/trainer/sdk_tests/fashion_mnist_tests.go create mode 100644 tests/trainer/utils/utils_cluster_prep.go create mode 100644 tests/trainer/utils/utils_notebook.go diff --git a/.gitignore b/.gitignore index 36f971e32..dc4d504ac 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ bin/* +.vscode/* \ No newline at end of file diff --git a/tests/trainer/kubeflow_sdk_test.go b/tests/trainer/kubeflow_sdk_test.go new file mode 100644 index 000000000..cd5a75b59 --- /dev/null +++ b/tests/trainer/kubeflow_sdk_test.go @@ -0,0 +1,30 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trainer + +import ( + "testing" + + . "github.com/opendatahub-io/distributed-workloads/tests/common" + sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests" +) + +func TestKubeflowSDK_Sanity(t *testing.T) { + Tags(t, Sanity) + sdktests.RunFashionMnistCpuDistributedTraining(t) + // ADD MORE SANITY TESTS HERE +} diff --git a/tests/trainer/resources/mnist.ipynb b/tests/trainer/resources/mnist.ipynb new file mode 100644 index 000000000..56e7351ea --- /dev/null +++ b/tests/trainer/resources/mnist.ipynb @@ -0,0 +1,275 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T13:19:46.917723Z", + "iopub.status.busy": "2025-09-03T13:19:46.917308Z", + "iopub.status.idle": "2025-09-03T13:19:46.935181Z", + "shell.execute_reply": "2025-09-03T13:19:46.934697Z", + "shell.execute_reply.started": "2025-09-03T13:19:46.917698Z" + } + }, + "outputs": [], + "source": [ + "def train_fashion_mnist():\n", + " import os\n", + "\n", + " import torch\n", + " import torch.distributed as dist\n", + " import torch.nn.functional as F\n", + " from torch import nn\n", + " from torch.utils.data import DataLoader, DistributedSampler\n", + " from torchvision import datasets, transforms\n", + "\n", + " # Define the PyTorch CNN model to be trained\n", + " class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", + " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", + " self.fc1 = nn.Linear(4 * 4 * 50, 500)\n", + " self.fc2 = nn.Linear(500, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.conv1(x))\n", + " x = F.max_pool2d(x, 2, 2)\n", + " x = F.relu(self.conv2(x))\n", + " x = F.max_pool2d(x, 2, 2)\n", + " x = x.view(-1, 4 * 4 * 50)\n", + " x = F.relu(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n", + " device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n", + " print(f\"Using Device: {device}, Backend: {backend}\")\n", + "\n", + " # Setup PyTorch distributed.\n", + " local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n", + " dist.init_process_group(backend=backend)\n", + " print(\n", + " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", + " dist.get_world_size(),\n", + " dist.get_rank(),\n", + " local_rank,\n", + " )\n", + " )\n", + "\n", + " # Create the model and load it into the device.\n", + " device = torch.device(f\"{device}:{local_rank}\")\n", + " model = nn.parallel.DistributedDataParallel(Net().to(device))\n", + " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", + "\n", + " \n", + " # Download FashionMNIST dataset only on local_rank=0 process.\n", + " if local_rank == 0:\n", + " dataset = datasets.FashionMNIST(\n", + " \"./data\",\n", + " train=True,\n", + " download=True,\n", + " transform=transforms.Compose([transforms.ToTensor()]),\n", + " )\n", + " dist.barrier()\n", + " dataset = datasets.FashionMNIST(\n", + " \"./data\",\n", + " train=True,\n", + " download=False,\n", + " transform=transforms.Compose([transforms.ToTensor()]),\n", + " )\n", + "\n", + "\n", + " # Shard the dataset accross workers.\n", + " train_loader = DataLoader(\n", + " dataset,\n", + " batch_size=100,\n", + " sampler=DistributedSampler(dataset)\n", + " )\n", + "\n", + " # TODO(astefanutti): add parameters to the training function\n", + " dist.barrier()\n", + " for epoch in range(1, 3):\n", + " model.train()\n", + "\n", + " # Iterate over mini-batches from the training set\n", + " for batch_idx, (inputs, labels) in enumerate(train_loader):\n", + " # Copy the data to the GPU device if available\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " # Forward pass\n", + " outputs = model(inputs)\n", + " loss = F.nll_loss(outputs, labels)\n", + " # Backward pass\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch_idx % 10 == 0 and dist.get_rank() == 0:\n", + " print(\n", + " \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n", + " epoch,\n", + " batch_idx * len(inputs),\n", + " len(train_loader.dataset),\n", + " 100.0 * batch_idx / len(train_loader),\n", + " loss.item(),\n", + " )\n", + " )\n", + "\n", + " # Wait for the distributed training to complete\n", + " dist.barrier()\n", + " if dist.get_rank() == 0:\n", + " print(\"Training is finished\")\n", + "\n", + " # Finally clean up PyTorch distributed\n", + " dist.destroy_process_group()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T13:19:49.832393Z", + "iopub.status.busy": "2025-09-03T13:19:49.832117Z", + "iopub.status.idle": "2025-09-03T13:19:51.924613Z", + "shell.execute_reply": "2025-09-03T13:19:51.924264Z", + "shell.execute_reply.started": "2025-09-03T13:19:49.832371Z" + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from kubeflow.trainer import CustomTrainer, TrainerClient\n", + "\n", + "client = TrainerClient()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for runtime in client.list_runtimes():\n", + " print(runtime)\n", + " if runtime.name == \"universal\": # Update to actual universal image runtime once available\n", + " torch_runtime = runtime" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T13:19:56.525591Z", + "iopub.status.busy": "2025-09-03T13:19:56.524936Z", + "iopub.status.idle": "2025-09-03T13:19:56.721404Z", + "shell.execute_reply": "2025-09-03T13:19:56.720565Z", + "shell.execute_reply.started": "2025-09-03T13:19:56.525536Z" + } + }, + "outputs": [], + "source": [ + "job_name = client.train(\n", + " trainer=CustomTrainer(\n", + " func=train_fashion_mnist,\n", + " num_nodes=2,\n", + " resources_per_node={\n", + " \"cpu\": 2,\n", + " \"memory\": \"8Gi\",\n", + " },\n", + " packages_to_install=[\"torchvision\"],\n", + " ),\n", + " runtime=torch_runtime,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T13:20:01.378158Z", + "iopub.status.busy": "2025-09-03T13:20:01.377707Z", + "iopub.status.idle": "2025-09-03T13:20:12.713960Z", + "shell.execute_reply": "2025-09-03T13:20:12.713295Z", + "shell.execute_reply.started": "2025-09-03T13:20:01.378130Z" + } + }, + "outputs": [], + "source": [ + "# Wait for the running status.\n", + "client.wait_for_job_status(name=job_name, status={\"Running\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T13:20:24.045774Z", + "iopub.status.busy": "2025-09-03T13:20:24.045480Z", + "iopub.status.idle": "2025-09-03T13:20:24.772877Z", + "shell.execute_reply": "2025-09-03T13:20:24.772178Z", + "shell.execute_reply.started": "2025-09-03T13:20:24.045755Z" + } + }, + "outputs": [], + "source": [ + "for c in client.get_job(name=job_name).steps:\n", + " print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T13:20:26.729486Z", + "iopub.status.busy": "2025-09-03T13:20:26.728951Z", + "iopub.status.idle": "2025-09-03T13:20:29.596510Z", + "shell.execute_reply": "2025-09-03T13:20:29.594741Z", + "shell.execute_reply.started": "2025-09-03T13:20:26.729446Z" + } + }, + "outputs": [], + "source": [ + "for logline in client.get_job_logs(job_name, follow=True):\n", + " print(logline)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client.delete_job(job_name)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/trainer/sdk_tests/fashion_mnist_tests.go b/tests/trainer/sdk_tests/fashion_mnist_tests.go new file mode 100644 index 000000000..16d357114 --- /dev/null +++ b/tests/trainer/sdk_tests/fashion_mnist_tests.go @@ -0,0 +1,86 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sdk_tests + +import ( + "fmt" + "os" + "testing" + + . "github.com/onsi/gomega" + + corev1 "k8s.io/api/core/v1" + + common "github.com/opendatahub-io/distributed-workloads/tests/common" + support "github.com/opendatahub-io/distributed-workloads/tests/common/support" + trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils" +) + +const ( + notebookName = "mnist.ipynb" + notebookPath = "resources/" + notebookName +) + +// CPU Only - Distributed Training +func RunFashionMnistCpuDistributedTraining(t *testing.T) { + test := support.With(t) + + // Create a new test namespace + namespace := test.NewTestNamespace() + + // Ensure pre-requisites to run the test are met + trainerutils.EnsureTrainerClusterReady(t, test) + + // Ensure Notebook SA and RBACs are set for this namespace + trainerutils.EnsureNotebookRBAC(t, test, namespace.Name) + + // RBACs setup + userName := common.GetNotebookUserName(test) + userToken := common.GetNotebookUserToken(test) + support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin") + + // Read notebook from directory + localPath := notebookPath + nb, err := os.ReadFile(localPath) + test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath)) + + // Create ConfigMap with notebook + cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: nb}) + + // Build command + marker := "/opt/app-root/src/notebook_completion_marker" + shellCmd := trainerutils.BuildPapermillShellCmd(notebookName, marker, nil) + command := []string{"/bin/sh", "-c", shellCmd} + + // Create Notebook CR (with default 10Gi PVC) + pvc := support.CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", support.AccessModes(corev1.ReadWriteOnce)) + common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, pvc, common.ContainerSizeSmall) + + // Cleanup + defer func() { + common.DeleteNotebook(test, namespace) + test.Eventually(common.Notebooks(test, namespace), support.TestTimeoutLong).Should(HaveLen(0)) + }() + + // Wait for the Notebook Pod and get pod/container names + podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name) + + // Poll marker file to check if the notebook execution completed successfully + if err := trainerutils.PollNotebookCompletionMarker(test, namespace.Name, podName, containerName, marker, support.TestTimeoutDouble); err != nil { + test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE") + } +} diff --git a/tests/trainer/utils/utils_cluster_prep.go b/tests/trainer/utils/utils_cluster_prep.go new file mode 100644 index 000000000..235c86d0c --- /dev/null +++ b/tests/trainer/utils/utils_cluster_prep.go @@ -0,0 +1,83 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trainer + +import ( + "fmt" + "os/exec" + "testing" + + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + . "github.com/opendatahub-io/distributed-workloads/tests/common/support" +) + +// EnsureTrainerClusterReady verifies cluster dependencies required by Kubeflow Trainer tests. +func EnsureTrainerClusterReady(t *testing.T, test Test) { + t.Helper() + // JobSet CRD present + // TODO: Remove once trainer is part of installation + if out, err := exec.Command("kubectl", "get", "crd", "jobsets.jobset.x-k8s.io").CombinedOutput(); err != nil { + t.Fatalf("JobSet CRD missing: %v\n%s", err, string(out)) + } + // Trainer controller deployment available + // TODO: Remove once trainer is part of installation + if out, err := exec.Command("kubectl", "-n", "opendatahub", "wait", "--for=condition=available", "--timeout=180s", "deploy/kubeflow-trainer-controller-manager").CombinedOutput(); err != nil { + t.Fatalf("Trainer controller not available: %v\n%s", err, string(out)) + } + // Required ClusterTrainingRuntimes present + runtimes, err := test.Client().Trainer().TrainerV1alpha1().ClusterTrainingRuntimes().List(test.Ctx(), metav1.ListOptions{}) + test.Expect(err).NotTo(HaveOccurred(), "Failed to list ClusterTrainingRuntimes") + found := map[string]bool{} + for _, rt := range runtimes.Items { + found[rt.Name] = true + } + // TODO: Extend / tweak with universal image runtime once available + for _, name := range []string{"torch-cuda-241", "torch-cuda-251", "torch-rocm-241", "torch-rocm-251"} { + test.Expect(found[name]).To(BeTrue(), fmt.Sprintf("Expected ClusterTrainingRuntime '%s' not found", name)) + } +} + +// EnsureNotebookRBAC sets up the Notebook ServiceAccount and RBAC so that notebooks can +// read ClusterTrainingRuntimes (cluster-scoped), and create/read TrainJobs and pod logs in the namespace. +func EnsureNotebookRBAC(t *testing.T, test Test, namespace string) { + t.Helper() + + // Ensure ServiceAccount exists + saName := "jupyter-nb-kube-3aadmin" + sa := &corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: saName, Namespace: namespace}} + _, _ = test.Client().Core().CoreV1().ServiceAccounts(namespace).Create(test.Ctx(), sa, metav1.CreateOptions{}) + // Get current SA (created or existing) + saObj, err := test.Client().Core().CoreV1().ServiceAccounts(namespace).Get(test.Ctx(), saName, metav1.GetOptions{}) + test.Expect(err).NotTo(HaveOccurred()) + + // Cluster-scoped read for ClusterTrainingRuntimes + ctrRead := CreateClusterRole(test, []rbacv1.PolicyRule{ + {APIGroups: []string{"trainer.kubeflow.org"}, Resources: []string{"clustertrainingruntimes"}, Verbs: []string{"get", "list", "watch"}}, + }) + CreateClusterRoleBinding(test, saObj, ctrRead) + + // Namespace Role for TrainJobs and pods/log access + role := CreateRole(test, namespace, []rbacv1.PolicyRule{ + {APIGroups: []string{"trainer.kubeflow.org"}, Resources: []string{"trainjobs", "trainjobs/status"}, Verbs: []string{"get", "list", "watch", "create", "update", "patch", "delete"}}, + {APIGroups: []string{""}, Resources: []string{"pods", "pods/log"}, Verbs: []string{"get", "list", "watch"}}, + }) + CreateRoleBinding(test, namespace, saObj, role) +} diff --git a/tests/trainer/utils/utils_notebook.go b/tests/trainer/utils/utils_notebook.go new file mode 100644 index 000000000..92438e730 --- /dev/null +++ b/tests/trainer/utils/utils_notebook.go @@ -0,0 +1,85 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trainer + +import ( + "fmt" + "os/exec" + "strings" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + . "github.com/onsi/gomega" + + common "github.com/opendatahub-io/distributed-workloads/tests/common" + . "github.com/opendatahub-io/distributed-workloads/tests/common/support" +) + +// BuildPapermillShellCmd builds a shell command to execute a notebook with papermill and write a SUCCESS/FAILURE marker. +// extraPipPackages, if provided, are installed alongside papermill. +func BuildPapermillShellCmd(notebookName string, marker string, extraPipPackages []string) string { + pipLine := "pip install --quiet --no-cache-dir papermill" + if len(extraPipPackages) > 0 { + pipLine = pipLine + " " + strings.Join(extraPipPackages, " ") + } + return fmt.Sprintf( + "set -e; %s; if papermill -k python3 /opt/app-root/notebooks/%s /opt/app-root/src/out.ipynb --log-output; "+ + "then echo 'SUCCESS' > %s; else echo 'FAILURE' > %s; fi; sleep infinity", + pipLine, notebookName, marker, marker, + ) +} + +// CreateNotebookFromBytes creates a ConfigMap with the notebook and a Notebook CR to run it. +func CreateNotebookFromBytes(test Test, namespace *corev1.Namespace, userToken string, notebookName string, notebookBytes []byte, command []string, numGpus int, containerSize common.ContainerSize) *corev1.ConfigMap { + cm := CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: notebookBytes}) + common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, numGpus, CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", AccessModes(corev1.ReadWriteOnce)), containerSize) + return cm +} + +// WaitForNotebookPodRunning waits for the Notebook pod (identified by the template's label) to be Running and returns pod/container names. +func WaitForNotebookPodRunning(test Test, namespace string) (string, string) { + labelSelector := fmt.Sprintf("notebook-name=%s", common.NOTEBOOK_CONTAINER_NAME) + test.Eventually(func() []corev1.Pod { + return GetPods(test, namespace, metav1.ListOptions{LabelSelector: labelSelector, FieldSelector: "status.phase=Running"}) + }, TestTimeoutLong).Should(HaveLen(1), "Expected exactly one notebook pod") + + pods := GetPods(test, namespace, metav1.ListOptions{LabelSelector: labelSelector, FieldSelector: "status.phase=Running"}) + return pods[0].Name, pods[0].Spec.Containers[0].Name +} + +// PollNotebookCompletionMarker polls the given marker file inside the notebook pod until SUCCESS/FAILURE or timeout. +func PollNotebookCompletionMarker(test Test, namespace, podName, containerName, marker string, timeout time.Duration) error { + var finalErr error + test.Eventually(func() bool { + out, err := exec.Command("kubectl", "-n", namespace, "exec", podName, "-c", containerName, "--", "cat", marker).CombinedOutput() + if err != nil { + return false + } + switch strings.TrimSpace(string(out)) { + case "SUCCESS": + return true + case "FAILURE": + finalErr = fmt.Errorf("Notebook execution failed") + return true + default: + return false + } + }, timeout).Should(BeTrue(), "Notebook did not reach definitive state") + return finalErr +} From aa1afeb65ccd3e82ebf975e9b3f25732e9ec4010 Mon Sep 17 00:00:00 2001 From: MStokluska Date: Thu, 6 Nov 2025 10:37:14 +0100 Subject: [PATCH 2/5] cleanup imports --- tests/trainer/utils/utils_cluster_prep.go | 1 + tests/trainer/utils/utils_notebook.go | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/trainer/utils/utils_cluster_prep.go b/tests/trainer/utils/utils_cluster_prep.go index 235c86d0c..4fda01d60 100644 --- a/tests/trainer/utils/utils_cluster_prep.go +++ b/tests/trainer/utils/utils_cluster_prep.go @@ -22,6 +22,7 @@ import ( "testing" . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/tests/trainer/utils/utils_notebook.go b/tests/trainer/utils/utils_notebook.go index 92438e730..6d76b1ece 100644 --- a/tests/trainer/utils/utils_notebook.go +++ b/tests/trainer/utils/utils_notebook.go @@ -22,11 +22,11 @@ import ( "strings" "time" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - . "github.com/onsi/gomega" - common "github.com/opendatahub-io/distributed-workloads/tests/common" . "github.com/opendatahub-io/distributed-workloads/tests/common/support" ) From f7b02856205cd058cac47722c70bbedfc19761ab Mon Sep 17 00:00:00 2001 From: MStokluska Date: Tue, 18 Nov 2025 16:33:28 +0100 Subject: [PATCH 3/5] address comments --- tests/trainer/kubeflow_sdk_test.go | 1 - tests/trainer/resources/mnist.ipynb | 374 ++++++++++++++++-- .../trainer/sdk_tests/fashion_mnist_tests.go | 62 ++- tests/trainer/utils/utils_cluster_prep.go | 61 +-- tests/trainer/utils/utils_notebook.go | 27 +- 5 files changed, 404 insertions(+), 121 deletions(-) diff --git a/tests/trainer/kubeflow_sdk_test.go b/tests/trainer/kubeflow_sdk_test.go index cd5a75b59..f1de0753a 100644 --- a/tests/trainer/kubeflow_sdk_test.go +++ b/tests/trainer/kubeflow_sdk_test.go @@ -26,5 +26,4 @@ import ( func TestKubeflowSDK_Sanity(t *testing.T) { Tags(t, Sanity) sdktests.RunFashionMnistCpuDistributedTraining(t) - // ADD MORE SANITY TESTS HERE } diff --git a/tests/trainer/resources/mnist.ipynb b/tests/trainer/resources/mnist.ipynb index 56e7351ea..c44fe2953 100644 --- a/tests/trainer/resources/mnist.ipynb +++ b/tests/trainer/resources/mnist.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2025-09-03T13:19:46.917723Z", @@ -21,8 +21,9 @@ " import torch.distributed as dist\n", " import torch.nn.functional as F\n", " from torch import nn\n", - " from torch.utils.data import DataLoader, DistributedSampler\n", - " from torchvision import datasets, transforms\n", + " from torch.utils.data import DataLoader, DistributedSampler, Dataset\n", + " import numpy as np\n", + " import struct\n", "\n", " # Define the PyTorch CNN model to be trained\n", " class Net(nn.Module):\n", @@ -48,7 +49,7 @@ " print(f\"Using Device: {device}, Backend: {backend}\")\n", "\n", " # Setup PyTorch distributed.\n", - " local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n", + " local_rank = int(os.getenv(\"PET_NODE_RANK\", 0))\n", " dist.init_process_group(backend=backend)\n", " print(\n", " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", @@ -64,31 +65,61 @@ " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", "\n", " \n", - " # Download FashionMNIST dataset only on local_rank=0 process.\n", - " if local_rank == 0:\n", - " dataset = datasets.FashionMNIST(\n", - " \"./data\",\n", - " train=True,\n", - " download=True,\n", - " transform=transforms.Compose([transforms.ToTensor()]),\n", - " )\n", - " dist.barrier()\n", - " dataset = datasets.FashionMNIST(\n", - " \"./data\",\n", - " train=True,\n", - " download=False,\n", - " transform=transforms.Compose([transforms.ToTensor()]),\n", - " )\n", + " # Prefer shared PVC if present; else fallback to internet download (rank 0 only)\n", + " from urllib.parse import urlparse\n", + " import gzip, shutil\n", + " \n", + " pvc_root = \"/mnt/shared\"\n", + " pvc_raw = os.path.join(pvc_root, \"FashionMNIST\", \"raw\")\n", + "\n", + "\n", + " use_pvc = os.path.isdir(pvc_raw) and any(os.scandir(pvc_raw))\n", + "\n", + " if not use_pvc:\n", + " raise RuntimeError(\"Shared PVC not mounted or empty at /mnt/shared/FashionMNIST/raw; this test requires a pre-populated RWX PVC\")\n", "\n", + " print(\"Using dataset from shared PVC at /mnt/shared\")\n", "\n", - " # Shard the dataset accross workers.\n", + " def _read_idx_images(path):\n", + " with open(path, \"rb\") as f:\n", + " magic, num, rows, cols = struct.unpack(\">IIII\", f.read(16))\n", + " if magic != 2051:\n", + " raise RuntimeError(f\"Unexpected images magic: {magic}\")\n", + " data = f.read()\n", + " return np.frombuffer(data, dtype=np.uint8).reshape(num, rows, cols)\n", + "\n", + " def _read_idx_labels(path):\n", + " with open(path, \"rb\") as f:\n", + " magic, num = struct.unpack(\">II\", f.read(8))\n", + " if magic != 2049:\n", + " raise RuntimeError(f\"Unexpected labels magic: {magic}\")\n", + " data = f.read()\n", + " return np.frombuffer(data, dtype=np.uint8)\n", + "\n", + " class MnistIdxDataset(Dataset):\n", + " def __init__(self, images_path: str, labels_path: str):\n", + " self.images = _read_idx_images(images_path)\n", + " self.labels = _read_idx_labels(labels_path)\n", + " if len(self.images) != len(self.labels):\n", + " raise RuntimeError(\"Images and labels count mismatch\")\n", + " def __len__(self):\n", + " return len(self.labels)\n", + " def __getitem__(self, idx: int):\n", + " import torch as _torch\n", + " img = _torch.from_numpy(self.images[idx][None, ...].astype(\"float32\") / 255.0)\n", + " label = int(self.labels[idx])\n", + " return img, label\n", + "\n", + " images_path = os.path.join(pvc_root, \"FashionMNIST\", \"raw\", \"train-images-idx3-ubyte\")\n", + " labels_path = os.path.join(pvc_root, \"FashionMNIST\", \"raw\", \"train-labels-idx1-ubyte\")\n", + "\n", + " dataset = MnistIdxDataset(images_path, labels_path)\n", " train_loader = DataLoader(\n", " dataset,\n", " batch_size=100,\n", " sampler=DistributedSampler(dataset)\n", " )\n", "\n", - " # TODO(astefanutti): add parameters to the training function\n", " dist.barrier()\n", " for epoch in range(1, 3):\n", " model.train()\n", @@ -125,6 +156,231 @@ " dist.destroy_process_group()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import gzip\n", + "import shutil\n", + "import socket\n", + "import time\n", + "\n", + "import boto3\n", + "from botocore.config import Config as BotoConfig\n", + "from botocore.exceptions import ClientError\n", + "\n", + "# --- Global networking safety net: cap all socket operations ---\n", + "# Any blocking socket operation (like reading S3 object data) will raise\n", + "# socket.timeout after this many seconds instead of hanging forever.\n", + "socket.setdefaulttimeout(10) # seconds\n", + "\n", + "# Notebook's PVC mount path (per Notebook CR). Training pods will mount the same PVC at /mnt/shared\n", + "PVC_NOTEBOOK_PATH = \"/opt/app-root/src\"\n", + "DATASET_ROOT_NOTEBOOK = PVC_NOTEBOOK_PATH # place FashionMNIST under this root\n", + "FASHION_RAW_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, \"FashionMNIST\", \"raw\")\n", + "os.makedirs(FASHION_RAW_DIR, exist_ok=True)\n", + "\n", + "# Env config for S3/MinIO\n", + "s3_endpoint = os.getenv(\"AWS_DEFAULT_ENDPOINT\", \"\")\n", + "s3_access_key = os.getenv(\"AWS_ACCESS_KEY_ID\", \"\")\n", + "s3_secret_key = os.getenv(\"AWS_SECRET_ACCESS_KEY\", \"\")\n", + "s3_bucket = os.getenv(\"AWS_STORAGE_BUCKET\", \"\")\n", + "s3_prefix = os.getenv(\"AWS_STORAGE_BUCKET_MNIST_DIR\", \"\") # e.g. \"data\"\n", + "\n", + "def stream_download(s3, bucket, key, dst):\n", + " \"\"\"\n", + " Download an object from S3/MinIO using get_object and streaming reads.\n", + " This bypasses boto3's TransferManager / download_file entirely.\n", + " Returns True on success, False on any error (including timeouts).\n", + " \"\"\"\n", + " print(f\"[notebook] STREAM download s3://{bucket}/{key} -> {dst}\")\n", + " t0 = time.time()\n", + "\n", + " try:\n", + " # Metadata / headers fetch — should be quick or fail clearly\n", + " resp = s3.get_object(Bucket=bucket, Key=key)\n", + " except ClientError as e:\n", + " err = e.response.get(\"Error\", {})\n", + " print(f\"[notebook] CLIENT ERROR (get_object) for {key}: {err}\")\n", + " return False\n", + " except Exception as e:\n", + " print(f\"[notebook] OTHER ERROR (get_object) for {key}: {e}\")\n", + " return False\n", + "\n", + " body = resp[\"Body\"]\n", + " try:\n", + " with open(dst, \"wb\") as f:\n", + " while True:\n", + " try:\n", + " # Each read is bounded by socket.setdefaulttimeout(...)\n", + " chunk = body.read(1024 * 1024) # 1MB per chunk\n", + " except socket.timeout as e:\n", + " print(f\"[notebook] socket.timeout while reading {key}: {e}\")\n", + " return False\n", + " if not chunk:\n", + " break\n", + " f.write(chunk)\n", + " except Exception as e:\n", + " print(f\"[notebook] ERROR writing to {dst} for {key}: {e}\")\n", + " return False\n", + "\n", + " t1 = time.time()\n", + " print(f\"[notebook] DONE stream {key} in {t1 - t0:.2f}s\")\n", + " return True\n", + "\n", + "\n", + "if s3_endpoint and s3_bucket:\n", + " try:\n", + " # Normalize endpoint URL\n", + " endpoint_url = (\n", + " s3_endpoint\n", + " if s3_endpoint.startswith(\"http\")\n", + " else f\"https://{s3_endpoint}\"\n", + " )\n", + " prefix = (s3_prefix or \"\").strip(\"/\")\n", + "\n", + " print(\n", + " f\"S3 configured (boto3, notebook): \"\n", + " f\"endpoint={endpoint_url}, bucket={s3_bucket}, prefix={prefix or ''}\"\n", + " )\n", + "\n", + " # Boto config: single attempt, reasonable connect/read timeouts.\n", + " boto_cfg = BotoConfig(\n", + " signature_version=\"s3v4\",\n", + " s3={\"addressing_style\": \"path\"},\n", + " retries={\"max_attempts\": 1, \"mode\": \"standard\"},\n", + " connect_timeout=5,\n", + " read_timeout=10,\n", + " )\n", + "\n", + " # Create S3/MinIO client\n", + " s3 = boto3.client(\n", + " \"s3\",\n", + " endpoint_url=endpoint_url,\n", + " aws_access_key_id=s3_access_key,\n", + " aws_secret_access_key=s3_secret_key,\n", + " config=boto_cfg,\n", + " verify=False,\n", + " )\n", + "\n", + " # Optional: quick debug HEAD of the problematic key\n", + " # (will just log if there's an access or existence problem)\n", + " test_key = \"data/t10k-labels-idx1-ubyte.gz\"\n", + " try:\n", + " print(f\"[debug] HEAD s3://{s3_bucket}/{test_key}\")\n", + " meta = s3.head_object(Bucket=s3_bucket, Key=test_key)\n", + " print(f\"[debug] HEAD OK: size={meta.get('ContentLength')}\")\n", + " except ClientError as e:\n", + " print(f\"[debug] HEAD ERROR for {test_key}: {e.response.get('Error')}\")\n", + "\n", + " # List and download all objects under the prefix\n", + " paginator = s3.get_paginator(\"list_objects_v2\")\n", + " pulled_any = False\n", + "\n", + " for page in paginator.paginate(Bucket=s3_bucket, Prefix=prefix or \"\"):\n", + " contents = page.get(\"Contents\", [])\n", + " if not contents:\n", + " continue\n", + "\n", + " for obj in contents:\n", + " key = obj[\"Key\"]\n", + "\n", + " # Skip \"directory markers\"\n", + " if key.endswith(\"/\"):\n", + " continue\n", + "\n", + " # Determine relative path under prefix for local storage\n", + " rel = key[len(prefix):].lstrip(\"/\") if prefix else key\n", + " dst = os.path.join(FASHION_RAW_DIR, rel)\n", + " os.makedirs(os.path.dirname(dst), exist_ok=True)\n", + "\n", + " # Download only if missing\n", + " if not os.path.exists(dst):\n", + " ok = stream_download(s3, s3_bucket, key, dst)\n", + " if not ok:\n", + " # Skip decompression and move on to next object\n", + " continue\n", + " else:\n", + " print(f\"[notebook] Skipping existing file {dst}\")\n", + "\n", + " # If the file is .gz, decompress and remove the .gz\n", + " if dst.endswith(\".gz\") and os.path.exists(dst):\n", + " out_path = os.path.splitext(dst)[0]\n", + " if not os.path.exists(out_path):\n", + " print(f\"[notebook] Decompressing {dst} -> {out_path}\")\n", + " try:\n", + " with gzip.open(dst, \"rb\") as f_in, open(out_path, \"wb\") as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " except Exception as e:\n", + " print(f\"[notebook] Failed to decompress {dst}: {e}\")\n", + " else:\n", + " try:\n", + " os.remove(dst)\n", + " except Exception:\n", + " # Not critical if we can't delete the gzip\n", + " pass\n", + "\n", + " pulled_any = True\n", + "\n", + " print(f\"[notebook] S3 pulled_any={pulled_any}\")\n", + "\n", + " except Exception as e:\n", + " print(f\"[notebook] S3 fetch failed: {e}\")\n", + "else:\n", + " print(\"[notebook] S3 not configured: missing endpoint or bucket env vars\")\n", + "\n", + "# Check if we have data; if not, try downloading from internet\n", + "files_needed = [\n", + " \"train-images-idx3-ubyte\",\n", + " \"train-labels-idx1-ubyte\",\n", + " \"t10k-images-idx3-ubyte\",\n", + " \"t10k-labels-idx1-ubyte\",\n", + "]\n", + "files_present = all(os.path.exists(os.path.join(FASHION_RAW_DIR, f)) for f in files_needed)\n", + "\n", + "if not files_present:\n", + " print(\"[notebook] Dataset not complete, attempting internet download...\")\n", + " try:\n", + " import urllib.request\n", + " \n", + " base_url = \"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/\"\n", + " gz_files = [\n", + " \"train-images-idx3-ubyte.gz\",\n", + " \"train-labels-idx1-ubyte.gz\",\n", + " \"t10k-images-idx3-ubyte.gz\",\n", + " \"t10k-labels-idx1-ubyte.gz\",\n", + " ]\n", + " \n", + " for gz_file in gz_files:\n", + " url = base_url + gz_file\n", + " dst_gz = os.path.join(FASHION_RAW_DIR, gz_file)\n", + " dst = os.path.splitext(dst_gz)[0]\n", + " \n", + " if os.path.exists(dst):\n", + " print(f\"[notebook] Already have {dst}, skipping download\")\n", + " continue\n", + " \n", + " print(f\"[notebook] Downloading {url} ...\")\n", + " urllib.request.urlretrieve(url, dst_gz)\n", + " \n", + " print(f\"[notebook] Decompressing {dst_gz} ...\")\n", + " with gzip.open(dst_gz, \"rb\") as f_in, open(dst, \"wb\") as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " \n", + " os.remove(dst_gz)\n", + " print(f\"[notebook] Done: {dst}\")\n", + " \n", + " print(\"[notebook] Internet download completed successfully\")\n", + " except Exception as e:\n", + " print(f\"[notebook] Internet download failed: {e}\")\n", + " print(\"[notebook] WARNING: Dataset may be incomplete!\")\n", + "else:\n", + " print(\"[notebook] Dataset files already present, skipping download\")\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -142,9 +398,27 @@ }, "outputs": [], "source": [ - "from kubeflow.trainer import CustomTrainer, TrainerClient\n", + "# Init SDK client with user token/API URL (no Backend types import)\n", + "import os\n", + "from kubernetes import client as k8s_client\n", + "from kubeflow.trainer import TrainerClient\n", + "from kubeflow.common.types import KubernetesBackendConfig\n", + "\n", + "openshift_api_url = os.getenv(\"OPENSHIFT_API_URL\", \"\")\n", + "token = os.getenv(\"NOTEBOOK_TOKEN\", \"\")\n", + "\n", + "cfg = k8s_client.Configuration()\n", + "cfg.host = openshift_api_url\n", + "cfg.verify_ssl = False\n", + "cfg.api_key = {\"authorization\": f\"Bearer {token}\"}\n", + "\n", + "api_client = k8s_client.ApiClient(cfg)\n", + "\n", + "backend_cfg = KubernetesBackendConfig(\n", + " client_configuration=api_client.configuration,\n", + ")\n", "\n", - "client = TrainerClient()\n" + "client = TrainerClient(backend_cfg)" ] }, { @@ -153,10 +427,16 @@ "metadata": {}, "outputs": [], "source": [ - "for runtime in client.list_runtimes():\n", - " print(runtime)\n", - " if runtime.name == \"universal\": # Update to actual universal image runtime once available\n", - " torch_runtime = runtime" + "\n", + "try:\n", + " torch_runtime = client.get_runtime(\"torch-distributed\")\n", + "except Exception:\n", + " torch_runtime = next(\n", + " (r for r in client.list_runtimes() if r.name == \"torch-distributed\"),\n", + " None,\n", + " )\n", + " if torch_runtime is None:\n", + " raise RuntimeError(\"Runtime 'torch-distributed' not found\")" ] }, { @@ -173,6 +453,13 @@ }, "outputs": [], "source": [ + "import os\n", + "from kubeflow.trainer import CustomTrainer\n", + "from kubeflow.trainer.options import PodTemplateOverrides, PodTemplateOverride, PodSpecOverride, ContainerOverride\n", + "\n", + "pvc_name = os.getenv(\"SHARED_PVC_NAME\", \"\")\n", + "print(f\"[notebook] Using PVC: {pvc_name}\")\n", + "\n", "job_name = client.train(\n", " trainer=CustomTrainer(\n", " func=train_fashion_mnist,\n", @@ -181,10 +468,34 @@ " \"cpu\": 2,\n", " \"memory\": \"8Gi\",\n", " },\n", - " packages_to_install=[\"torchvision\"],\n", " ),\n", " runtime=torch_runtime,\n", - ")" + " options=[\n", + " PodTemplateOverrides(\n", + " PodTemplateOverride(\n", + " target_jobs=[\"node\"],\n", + " spec=PodSpecOverride(\n", + " volumes=[\n", + " {\n", + " \"name\": \"work\",\n", + " \"persistentVolumeClaim\": {\"claimName\": pvc_name},\n", + " }\n", + " ],\n", + " containers=[\n", + " ContainerOverride(\n", + " name=\"node\",\n", + " volume_mounts=[\n", + " {\"name\": \"work\", \"mountPath\": \"/mnt/shared\", \"readOnly\": False}\n", + " ],\n", + " )\n", + " ],\n", + " )\n", + " )\n", + " )\n", + " ],\n", + ")\n", + "\n", + "print(f\"[notebook] Job submitted: {job_name}\") " ] }, { @@ -201,8 +512,9 @@ }, "outputs": [], "source": [ - "# Wait for the running status.\n", - "client.wait_for_job_status(name=job_name, status={\"Running\"})" + "# Wait for the running status, then completion.\n", + "client.wait_for_job_status(name=job_name, status={\"Running\"})\n", + "client.wait_for_job_status(name=job_name, status={\"Complete\"})" ] }, { diff --git a/tests/trainer/sdk_tests/fashion_mnist_tests.go b/tests/trainer/sdk_tests/fashion_mnist_tests.go index 16d357114..0cfef91ce 100644 --- a/tests/trainer/sdk_tests/fashion_mnist_tests.go +++ b/tests/trainer/sdk_tests/fashion_mnist_tests.go @@ -42,33 +42,62 @@ func RunFashionMnistCpuDistributedTraining(t *testing.T) { // Create a new test namespace namespace := test.NewTestNamespace() - // Ensure pre-requisites to run the test are met - trainerutils.EnsureTrainerClusterReady(t, test) - - // Ensure Notebook SA and RBACs are set for this namespace - trainerutils.EnsureNotebookRBAC(t, test, namespace.Name) + // Ensure Notebook ServiceAccount exists (no extra RBAC) + trainerutils.EnsureNotebookServiceAccount(t, test, namespace.Name) // RBACs setup userName := common.GetNotebookUserName(test) userToken := common.GetNotebookUserToken(test) support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin") - // Read notebook from directory + // Create ConfigMap with notebook localPath := notebookPath nb, err := os.ReadFile(localPath) test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath)) - - // Create ConfigMap with notebook cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: nb}) - // Build command - marker := "/opt/app-root/src/notebook_completion_marker" - shellCmd := trainerutils.BuildPapermillShellCmd(notebookName, marker, nil) + // Build command with parameters and pinned deps, and print definitive status line to logs + endpoint, endpointOK := support.GetStorageBucketDefaultEndpoint() + accessKey, _ := support.GetStorageBucketAccessKeyId() + secretKey, _ := support.GetStorageBucketSecretKey() + bucket, bucketOK := support.GetStorageBucketName() + prefix, _ := support.GetStorageBucketMnistDir() + if !endpointOK { + endpoint = "" + } + if !bucketOK { + bucket = "" + } + // Create RWX PVC for shared dataset and pass the claim name to the notebook + storageClass, err := support.GetRWXStorageClass(test) + test.Expect(err).NotTo(HaveOccurred(), "Failed to find an RWX supporting StorageClass") + rwxPvc := support.CreatePersistentVolumeClaim( + test, + namespace.Name, + "20Gi", + support.AccessModes(corev1.ReadWriteMany), + support.StorageClassName(storageClass.Name), + ) + + shellCmd := fmt.Sprintf( + "set -e; "+ + "export OPENSHIFT_API_URL='%s'; export NOTEBOOK_TOKEN='%s'; "+ + "export NOTEBOOK_NAMESPACE='%s'; "+ + "export SHARED_PVC_NAME='%s'; "+ + "export AWS_DEFAULT_ENDPOINT='%s'; export AWS_ACCESS_KEY_ID='%s'; "+ + "export AWS_SECRET_ACCESS_KEY='%s'; export AWS_STORAGE_BUCKET='%s'; "+ + "export AWS_STORAGE_BUCKET_MNIST_DIR='%s'; "+ + "python -m pip install --quiet --no-cache-dir papermill boto3==1.34.162 && "+ + "if python -m papermill -k python3 /opt/app-root/notebooks/%s /opt/app-root/src/out.ipynb --log-output; "+ + "then echo 'NOTEBOOK_STATUS: SUCCESS'; else echo 'NOTEBOOK_STATUS: FAILURE'; fi; sleep infinity", + support.GetOpenShiftApiUrl(test), userToken, namespace.Name, rwxPvc.Name, + endpoint, accessKey, secretKey, bucket, prefix, + notebookName, + ) command := []string{"/bin/sh", "-c", shellCmd} - // Create Notebook CR (with default 10Gi PVC) - pvc := support.CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", support.AccessModes(corev1.ReadWriteOnce)) - common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, pvc, common.ContainerSizeSmall) + // Create Notebook CR using the RWX PVC + common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, rwxPvc, common.ContainerSizeSmall) // Cleanup defer func() { @@ -79,8 +108,9 @@ func RunFashionMnistCpuDistributedTraining(t *testing.T) { // Wait for the Notebook Pod and get pod/container names podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name) - // Poll marker file to check if the notebook execution completed successfully - if err := trainerutils.PollNotebookCompletionMarker(test, namespace.Name, podName, containerName, marker, support.TestTimeoutDouble); err != nil { + // Poll logs to check if the notebook execution completed successfully + if err := trainerutils.PollNotebookLogsForStatus(test, namespace.Name, podName, containerName, support.TestTimeoutDouble); err != nil { test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE") } + } diff --git a/tests/trainer/utils/utils_cluster_prep.go b/tests/trainer/utils/utils_cluster_prep.go index 4fda01d60..b4cb02ea5 100644 --- a/tests/trainer/utils/utils_cluster_prep.go +++ b/tests/trainer/utils/utils_cluster_prep.go @@ -17,68 +17,23 @@ limitations under the License. package trainer import ( - "fmt" - "os/exec" "testing" - . "github.com/onsi/gomega" - corev1 "k8s.io/api/core/v1" - rbacv1 "k8s.io/api/rbac/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" . "github.com/opendatahub-io/distributed-workloads/tests/common/support" ) -// EnsureTrainerClusterReady verifies cluster dependencies required by Kubeflow Trainer tests. -func EnsureTrainerClusterReady(t *testing.T, test Test) { - t.Helper() - // JobSet CRD present - // TODO: Remove once trainer is part of installation - if out, err := exec.Command("kubectl", "get", "crd", "jobsets.jobset.x-k8s.io").CombinedOutput(); err != nil { - t.Fatalf("JobSet CRD missing: %v\n%s", err, string(out)) - } - // Trainer controller deployment available - // TODO: Remove once trainer is part of installation - if out, err := exec.Command("kubectl", "-n", "opendatahub", "wait", "--for=condition=available", "--timeout=180s", "deploy/kubeflow-trainer-controller-manager").CombinedOutput(); err != nil { - t.Fatalf("Trainer controller not available: %v\n%s", err, string(out)) - } - // Required ClusterTrainingRuntimes present - runtimes, err := test.Client().Trainer().TrainerV1alpha1().ClusterTrainingRuntimes().List(test.Ctx(), metav1.ListOptions{}) - test.Expect(err).NotTo(HaveOccurred(), "Failed to list ClusterTrainingRuntimes") - found := map[string]bool{} - for _, rt := range runtimes.Items { - found[rt.Name] = true - } - // TODO: Extend / tweak with universal image runtime once available - for _, name := range []string{"torch-cuda-241", "torch-cuda-251", "torch-rocm-241", "torch-rocm-251"} { - test.Expect(found[name]).To(BeTrue(), fmt.Sprintf("Expected ClusterTrainingRuntime '%s' not found", name)) - } -} - -// EnsureNotebookRBAC sets up the Notebook ServiceAccount and RBAC so that notebooks can -// read ClusterTrainingRuntimes (cluster-scoped), and create/read TrainJobs and pod logs in the namespace. -func EnsureNotebookRBAC(t *testing.T, test Test, namespace string) { +// EnsureNotebookServiceAccount ensures the Notebook ServiceAccount exists in the target namespace. +// This avoids webhook/controller failures when creating the Notebook CR. +func EnsureNotebookServiceAccount(t *testing.T, test Test, namespace string) { t.Helper() - - // Ensure ServiceAccount exists saName := "jupyter-nb-kube-3aadmin" sa := &corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: saName, Namespace: namespace}} - _, _ = test.Client().Core().CoreV1().ServiceAccounts(namespace).Create(test.Ctx(), sa, metav1.CreateOptions{}) - // Get current SA (created or existing) - saObj, err := test.Client().Core().CoreV1().ServiceAccounts(namespace).Get(test.Ctx(), saName, metav1.GetOptions{}) - test.Expect(err).NotTo(HaveOccurred()) - - // Cluster-scoped read for ClusterTrainingRuntimes - ctrRead := CreateClusterRole(test, []rbacv1.PolicyRule{ - {APIGroups: []string{"trainer.kubeflow.org"}, Resources: []string{"clustertrainingruntimes"}, Verbs: []string{"get", "list", "watch"}}, - }) - CreateClusterRoleBinding(test, saObj, ctrRead) - - // Namespace Role for TrainJobs and pods/log access - role := CreateRole(test, namespace, []rbacv1.PolicyRule{ - {APIGroups: []string{"trainer.kubeflow.org"}, Resources: []string{"trainjobs", "trainjobs/status"}, Verbs: []string{"get", "list", "watch", "create", "update", "patch", "delete"}}, - {APIGroups: []string{""}, Resources: []string{"pods", "pods/log"}, Verbs: []string{"get", "list", "watch"}}, - }) - CreateRoleBinding(test, namespace, saObj, role) + _, err := test.Client().Core().CoreV1().ServiceAccounts(namespace).Create(test.Ctx(), sa, metav1.CreateOptions{}) + if err != nil && !apierrors.IsAlreadyExists(err) { + t.Fatalf("Failed to create ServiceAccount %s/%s: %v", namespace, saName, err) + } } diff --git a/tests/trainer/utils/utils_notebook.go b/tests/trainer/utils/utils_notebook.go index 6d76b1ece..c2eb9d7dd 100644 --- a/tests/trainer/utils/utils_notebook.go +++ b/tests/trainer/utils/utils_notebook.go @@ -31,20 +31,6 @@ import ( . "github.com/opendatahub-io/distributed-workloads/tests/common/support" ) -// BuildPapermillShellCmd builds a shell command to execute a notebook with papermill and write a SUCCESS/FAILURE marker. -// extraPipPackages, if provided, are installed alongside papermill. -func BuildPapermillShellCmd(notebookName string, marker string, extraPipPackages []string) string { - pipLine := "pip install --quiet --no-cache-dir papermill" - if len(extraPipPackages) > 0 { - pipLine = pipLine + " " + strings.Join(extraPipPackages, " ") - } - return fmt.Sprintf( - "set -e; %s; if papermill -k python3 /opt/app-root/notebooks/%s /opt/app-root/src/out.ipynb --log-output; "+ - "then echo 'SUCCESS' > %s; else echo 'FAILURE' > %s; fi; sleep infinity", - pipLine, notebookName, marker, marker, - ) -} - // CreateNotebookFromBytes creates a ConfigMap with the notebook and a Notebook CR to run it. func CreateNotebookFromBytes(test Test, namespace *corev1.Namespace, userToken string, notebookName string, notebookBytes []byte, command []string, numGpus int, containerSize common.ContainerSize) *corev1.ConfigMap { cm := CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: notebookBytes}) @@ -63,18 +49,19 @@ func WaitForNotebookPodRunning(test Test, namespace string) (string, string) { return pods[0].Name, pods[0].Spec.Containers[0].Name } -// PollNotebookCompletionMarker polls the given marker file inside the notebook pod until SUCCESS/FAILURE or timeout. -func PollNotebookCompletionMarker(test Test, namespace, podName, containerName, marker string, timeout time.Duration) error { +// PollNotebookLogsForStatus polls the notebook container logs until a definitive NOTEBOOK_STATUS line appears or timeout. +func PollNotebookLogsForStatus(test Test, namespace, podName, containerName string, timeout time.Duration) error { var finalErr error test.Eventually(func() bool { - out, err := exec.Command("kubectl", "-n", namespace, "exec", podName, "-c", containerName, "--", "cat", marker).CombinedOutput() + out, err := exec.Command("kubectl", "-n", namespace, "logs", podName, "-c", containerName, "--tail=2000").CombinedOutput() if err != nil { return false } - switch strings.TrimSpace(string(out)) { - case "SUCCESS": + logs := string(out) + switch { + case strings.Contains(logs, "NOTEBOOK_STATUS: SUCCESS"): return true - case "FAILURE": + case strings.Contains(logs, "NOTEBOOK_STATUS: FAILURE"): finalErr = fmt.Errorf("Notebook execution failed") return true default: From 8ae0dee4fb04f28ec73fc8c98342cd9222c887b5 Mon Sep 17 00:00:00 2001 From: MStokluska Date: Fri, 21 Nov 2025 11:36:07 +0100 Subject: [PATCH 4/5] address comments --- tests/trainer/resources/mnist.ipynb | 11 +++-------- tests/trainer/utils/utils_notebook.go | 23 ++++++++++++++++------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/trainer/resources/mnist.ipynb b/tests/trainer/resources/mnist.ipynb index c44fe2953..032da559b 100644 --- a/tests/trainer/resources/mnist.ipynb +++ b/tests/trainer/resources/mnist.ipynb @@ -376,7 +376,7 @@ " print(\"[notebook] Internet download completed successfully\")\n", " except Exception as e:\n", " print(f\"[notebook] Internet download failed: {e}\")\n", - " print(\"[notebook] WARNING: Dataset may be incomplete!\")\n", + " raise RuntimeError(\"Internet download failed; aborting test\")\n", "else:\n", " print(\"[notebook] Dataset files already present, skipping download\")\n" ] @@ -430,13 +430,8 @@ "\n", "try:\n", " torch_runtime = client.get_runtime(\"torch-distributed\")\n", - "except Exception:\n", - " torch_runtime = next(\n", - " (r for r in client.list_runtimes() if r.name == \"torch-distributed\"),\n", - " None,\n", - " )\n", - " if torch_runtime is None:\n", - " raise RuntimeError(\"Runtime 'torch-distributed' not found\")" + "except Exception as e:\n", + " raise RuntimeError(\"Runtime 'torch-distributed' not found or not accessible\") from e" ] }, { diff --git a/tests/trainer/utils/utils_notebook.go b/tests/trainer/utils/utils_notebook.go index c2eb9d7dd..566d90e80 100644 --- a/tests/trainer/utils/utils_notebook.go +++ b/tests/trainer/utils/utils_notebook.go @@ -18,7 +18,6 @@ package trainer import ( "fmt" - "os/exec" "strings" "time" @@ -52,21 +51,31 @@ func WaitForNotebookPodRunning(test Test, namespace string) (string, string) { // PollNotebookLogsForStatus polls the notebook container logs until a definitive NOTEBOOK_STATUS line appears or timeout. func PollNotebookLogsForStatus(test Test, namespace, podName, containerName string, timeout time.Duration) error { var finalErr error + + // Tail last N lines similar to the previous kubectl --tail + var tail int64 = 2000 + getLogs := PodLog(test, namespace, podName, corev1.PodLogOptions{ + Container: containerName, + TailLines: &tail, + }) + + // Track failure signal while polling + sawFailure := false test.Eventually(func() bool { - out, err := exec.Command("kubectl", "-n", namespace, "logs", podName, "-c", containerName, "--tail=2000").CombinedOutput() - if err != nil { - return false - } - logs := string(out) + logs := getLogs(test) switch { case strings.Contains(logs, "NOTEBOOK_STATUS: SUCCESS"): return true case strings.Contains(logs, "NOTEBOOK_STATUS: FAILURE"): - finalErr = fmt.Errorf("Notebook execution failed") + sawFailure = true return true default: return false } }, timeout).Should(BeTrue(), "Notebook did not reach definitive state") + + if sawFailure { + finalErr = fmt.Errorf("Notebook execution failed") + } return finalErr } From 098deb683f659d8651ca19b506e1113e132f0700 Mon Sep 17 00:00:00 2001 From: MStokluska Date: Mon, 24 Nov 2025 18:26:14 +0100 Subject: [PATCH 5/5] address comments and make training CPU only --- tests/trainer/kubeflow_sdk_test.go | 2 +- tests/trainer/resources/mnist.ipynb | 16 +++++++--------- tests/trainer/sdk_tests/fashion_mnist_tests.go | 5 ++--- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/trainer/kubeflow_sdk_test.go b/tests/trainer/kubeflow_sdk_test.go index f1de0753a..a43a66768 100644 --- a/tests/trainer/kubeflow_sdk_test.go +++ b/tests/trainer/kubeflow_sdk_test.go @@ -23,7 +23,7 @@ import ( sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests" ) -func TestKubeflowSDK_Sanity(t *testing.T) { +func TestKubeflowSdkSanity(t *testing.T) { Tags(t, Sanity) sdktests.RunFashionMnistCpuDistributedTraining(t) } diff --git a/tests/trainer/resources/mnist.ipynb b/tests/trainer/resources/mnist.ipynb index 032da559b..7c81c8027 100644 --- a/tests/trainer/resources/mnist.ipynb +++ b/tests/trainer/resources/mnist.ipynb @@ -44,12 +44,13 @@ " x = self.fc2(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", - " # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n", - " device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n", - " print(f\"Using Device: {device}, Backend: {backend}\")\n", + " # Force CPU-only for this test to avoid accidental NCCL/GPU usage\n", + " backend = \"gloo\"\n", + " device = torch.device(\"cpu\")\n", + " print(f\"Using Device: cpu, Backend: {backend}\")\n", "\n", " # Setup PyTorch distributed.\n", - " local_rank = int(os.getenv(\"PET_NODE_RANK\", 0))\n", + " local_rank = int(os.getenv(\"LOCAL_RANK\") or os.getenv(\"PET_NODE_RANK\") or 0)\n", " dist.init_process_group(backend=backend)\n", " print(\n", " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", @@ -60,19 +61,16 @@ " )\n", "\n", " # Create the model and load it into the device.\n", - " device = torch.device(f\"{device}:{local_rank}\")\n", " model = nn.parallel.DistributedDataParallel(Net().to(device))\n", " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", "\n", - " \n", " # Prefer shared PVC if present; else fallback to internet download (rank 0 only)\n", " from urllib.parse import urlparse\n", " import gzip, shutil\n", - " \n", + "\n", " pvc_root = \"/mnt/shared\"\n", " pvc_raw = os.path.join(pvc_root, \"FashionMNIST\", \"raw\")\n", "\n", - "\n", " use_pvc = os.path.isdir(pvc_raw) and any(os.scandir(pvc_raw))\n", "\n", " if not use_pvc:\n", @@ -126,7 +124,7 @@ "\n", " # Iterate over mini-batches from the training set\n", " for batch_idx, (inputs, labels) in enumerate(train_loader):\n", - " # Copy the data to the GPU device if available\n", + " # Move the data to the selected device\n", " inputs, labels = inputs.to(device), labels.to(device)\n", " # Forward pass\n", " outputs = model(inputs)\n", diff --git a/tests/trainer/sdk_tests/fashion_mnist_tests.go b/tests/trainer/sdk_tests/fashion_mnist_tests.go index 0cfef91ce..123ff39c6 100644 --- a/tests/trainer/sdk_tests/fashion_mnist_tests.go +++ b/tests/trainer/sdk_tests/fashion_mnist_tests.go @@ -109,8 +109,7 @@ func RunFashionMnistCpuDistributedTraining(t *testing.T) { podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name) // Poll logs to check if the notebook execution completed successfully - if err := trainerutils.PollNotebookLogsForStatus(test, namespace.Name, podName, containerName, support.TestTimeoutDouble); err != nil { - test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE") - } + err = trainerutils.PollNotebookLogsForStatus(test, namespace.Name, podName, containerName, support.TestTimeoutDouble) + test.Expect(err).ShouldNot(HaveOccurred(), "Notebook execution reported FAILURE") }