From 68132c38de855be80f0c661675a91b936011390b Mon Sep 17 00:00:00 2001 From: Fiona-Waters Date: Fri, 15 May 2026 15:47:08 +0100 Subject: [PATCH 1/3] Add E2E for GRPO/Art training hub implementation Signed-off-by: Fiona-Waters --- tests/common/support/environment.go | 6 + tests/trainer/kubeflow_sdk_test.go | 6 + tests/trainer/resources/grpo.ipynb | 437 ++++++++++++++++++ .../sdk_tests/grpo_traininghub_tests.go | 136 ++++++ 4 files changed, 585 insertions(+) create mode 100644 tests/trainer/resources/grpo.ipynb create mode 100644 tests/trainer/sdk_tests/grpo_traininghub_tests.go diff --git a/tests/common/support/environment.go b/tests/common/support/environment.go index fd441923c..04aed10a1 100644 --- a/tests/common/support/environment.go +++ b/tests/common/support/environment.go @@ -61,6 +61,7 @@ const ( storageBucketOsftDir = "AWS_STORAGE_BUCKET_OSFT_DIR" storageBucketSftDir = "AWS_STORAGE_BUCKET_SFT_DIR" storageBucketLoraDir = "AWS_STORAGE_BUCKET_LORA_DIR" + storageBucketGrpoDir = "AWS_STORAGE_BUCKET_GRPO_DIR" // Name of existing namespace to be used for test testNamespaceNameEnvVar = "TEST_NAMESPACE_NAME" @@ -224,6 +225,11 @@ func GetStorageBucketSftDir() (string, bool) { return storage_bucket_sft_dir, exists } +func GetStorageBucketGrpoDir() (string, bool) { + storage_bucket_grpo_dir, exists := os.LookupEnv(storageBucketGrpoDir) + return storage_bucket_grpo_dir, exists +} + func GetPipIndexURL() string { return lookupEnvOrDefault(pipIndexURL, "https://pypi.python.org/simple") } diff --git a/tests/trainer/kubeflow_sdk_test.go b/tests/trainer/kubeflow_sdk_test.go index ae940552f..db88ac931 100644 --- a/tests/trainer/kubeflow_sdk_test.go +++ b/tests/trainer/kubeflow_sdk_test.go @@ -64,6 +64,12 @@ func TestLoraTrainingHubSingleNodeSingleGPU(t *testing.T) { sdktests.RunLoraTrainingHubMultiGpuDistributedTraining(t, 1) } +// TestGrpoTrainingHubSingleNodeSingleGPU tests GRPO (RL) training on a single node with a single GPU +func TestGrpoTrainingHubSingleNodeSingleGPU(t *testing.T) { + Tags(t, KftoCuda, Gpu(support.NVIDIA)) + sdktests.RunGrpoTrainingHubMultiGpuDistributedTraining(t, 1) +} + // Multi-node, multi-GPU tests (2 nodes, 1 GPU each) // TestOsftTrainingHubMultiNodeMultiGPU tests OSFT training using TrainingHubTrainer diff --git a/tests/trainer/resources/grpo.ipynb b/tests/trainer/resources/grpo.ipynb new file mode 100644 index 000000000..393594f05 --- /dev/null +++ b/tests/trainer/resources/grpo.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install datasets --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "import sys\n", + "import time\n", + "from io import StringIO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from dotenv import load_dotenv\n", + " load_dotenv()\n", + "except ImportError:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from kubernetes import client as k8s, config as k8s_config\n", + "\n", + "api_server = os.getenv(\"OPENSHIFT_API_URL\")\n", + "token = os.getenv(\"NOTEBOOK_USER_TOKEN\")\n", + "if not api_server or not token:\n", + " raise RuntimeError(\"OPENSHIFT_API_URL and NOTEBOOK_USER_TOKEN environment variables are required\")\n", + "PVC_NAME = os.getenv(\"SHARED_PVC_NAME\", \"\")\n", + "\n", + "if not PVC_NAME:\n", + " raise RuntimeError(\"SHARED_PVC_NAME environment variable is required\")\n", + "\n", + "configuration = k8s.Configuration()\n", + "configuration.host = api_server\n", + "configuration.verify_ssl = False\n", + "configuration.api_key = {\"authorization\": f\"Bearer {token}\"}\n", + "api_client = k8s.ApiClient(configuration)\n", + "\n", + "PVC_MOUNT_PATH = \"/opt/app-root/src\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import gzip\n", + "import shutil\n", + "import time\n", + "import socket\n", + "import json\n", + "\n", + "try:\n", + " import s3fs\n", + " HAS_S3FS = True\n", + "except ImportError:\n", + " HAS_S3FS = False\n", + "\n", + "socket.setdefaulttimeout(10)\n", + "\n", + "PVC_NOTEBOOK_PATH = \"/opt/app-root/src/\"\n", + "DATASET_ROOT_NOTEBOOK = PVC_NOTEBOOK_PATH\n", + "MODEL_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, \"Qwen\", \"Qwen3-4B\")\n", + "os.makedirs(MODEL_DIR, exist_ok=True)\n", + "\n", + "s3_endpoint = os.getenv(\"AWS_DEFAULT_ENDPOINT\", \"\")\n", + "s3_access_key = os.getenv(\"AWS_ACCESS_KEY_ID\", \"\")\n", + "s3_secret_key = os.getenv(\"AWS_SECRET_ACCESS_KEY\", \"\")\n", + "s3_bucket = os.getenv(\"AWS_STORAGE_BUCKET\", \"\")\n", + "s3_prefix = os.getenv(\"AWS_STORAGE_BUCKET_GRPO_DIR\", \"\")\n", + "\n", + "data_download_successful = False\n", + "\n", + "if HAS_S3FS and s3_endpoint and s3_bucket:\n", + " try:\n", + " endpoint_url = (\n", + " s3_endpoint\n", + " if s3_endpoint.startswith(\"http\")\n", + " else f\"https://{s3_endpoint}\"\n", + " )\n", + " prefix = (s3_prefix or \"\").strip(\"/\")\n", + "\n", + " print(\n", + " f\"[notebook] S3 configured: \"\n", + " f\"endpoint={endpoint_url}, bucket={s3_bucket}, prefix={prefix or ''}\"\n", + " )\n", + "\n", + " fs = s3fs.S3FileSystem(\n", + " key=s3_access_key,\n", + " secret=s3_secret_key,\n", + " endpoint_url=endpoint_url,\n", + " use_ssl=endpoint_url.startswith(\"https\"),\n", + " config_kwargs={\"signature_version\": \"s3v4\"},\n", + " client_kwargs={\"verify\": False},\n", + " )\n", + "\n", + " remote_path = f\"{s3_bucket}/{prefix}\" if prefix else s3_bucket\n", + " pulled_any = False\n", + " file_count = 0\n", + "\n", + " print(f\"[notebook] Starting S3 download from prefix: {prefix}\")\n", + " for remote_file in fs.find(remote_path):\n", + " file_count += 1\n", + "\n", + " if remote_file.endswith(\"/\"):\n", + " continue\n", + "\n", + " rel = remote_file[len(remote_path):].lstrip(\"/\")\n", + " if not rel:\n", + " continue\n", + " print(f\"[notebook] Processing rel={rel}\")\n", + "\n", + " if \"qwen\" in rel.lower() or (prefix and any(rel.endswith(ext) for ext in [\".bin\", \".json\", \".model\", \".safetensors\", \".txt\"])):\n", + " dst = os.path.join(MODEL_DIR, rel.split(\"Qwen3-4B/\")[-1] if \"Qwen3-4B\" in rel else os.path.basename(rel))\n", + " print(f\"[notebook] Routing to model dir: {dst}\")\n", + " else:\n", + " dst = os.path.join(DATASET_ROOT_NOTEBOOK, rel)\n", + " print(f\"[notebook] Routing to default dir: {dst}\")\n", + "\n", + " os.makedirs(os.path.dirname(dst), exist_ok=True)\n", + "\n", + " if not os.path.exists(dst):\n", + " print(f\"[notebook] Downloading s3://{remote_file} -> {dst}\")\n", + " t0 = time.time()\n", + " fs.get(remote_file, dst)\n", + " print(f\"[notebook] DONE in {time.time() - t0:.2f}s\")\n", + " pulled_any = True\n", + " else:\n", + " print(f\"[notebook] Skipping existing file {dst}\")\n", + " pulled_any = True\n", + "\n", + " if dst.endswith(\".gz\") and os.path.exists(dst):\n", + " out_path = os.path.splitext(dst)[0]\n", + " if not os.path.exists(out_path):\n", + " print(f\"[notebook] Decompressing {dst} -> {out_path}\")\n", + " try:\n", + " with gzip.open(dst, \"rb\") as f_in, open(out_path, \"wb\") as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " except Exception as e:\n", + " print(f\"[notebook] Failed to decompress {dst}: {e}\")\n", + " else:\n", + " try:\n", + " os.remove(dst)\n", + " except Exception:\n", + " pass\n", + "\n", + " if pulled_any:\n", + " if os.path.exists(MODEL_DIR) and os.listdir(MODEL_DIR):\n", + " print(f\"[notebook] S3 download successful. Processed {file_count} files\")\n", + " data_download_successful = True\n", + " else:\n", + " print(f\"[notebook] S3 downloaded {file_count} files but model not found, will try HuggingFace fallback\")\n", + " else:\n", + " print(f\"[notebook] S3 download found no files to download\")\n", + "\n", + " except Exception as e:\n", + " print(f\"[notebook] S3 fetch failed: {e}\")\n", + " import traceback\n", + " traceback.print_exc()\n", + " print(\"[notebook] Will attempt HuggingFace fallback...\")\n", + "else:\n", + " if not HAS_S3FS:\n", + " print(\"[notebook] S3 not available: s3fs not installed\")\n", + " else:\n", + " print(\"[notebook] S3 not configured (missing endpoint or bucket env vars)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import snapshot_download\n", + "\n", + "hf_token = os.getenv(\"HUGGINGFACE_HUB_TOKEN\")\n", + "\n", + "if os.path.exists(MODEL_DIR) and os.listdir(MODEL_DIR):\n", + " print(f\"Using local model from S3: {MODEL_DIR}\")\n", + "else:\n", + " print(\"[notebook] Model not found in S3, downloading from HuggingFace...\")\n", + " snapshot_download(\n", + " repo_id=\"Qwen/Qwen3-4B\",\n", + " local_dir=MODEL_DIR,\n", + " token=hf_token,\n", + " resume_download=True,\n", + " local_dir_use_symlinks=False,\n", + " )\n", + " print(f\"Model downloaded to: {MODEL_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LOCAL_MODEL_PATH = \"/opt/app-root/src/Qwen/Qwen3-4B\"\n", + "\n", + "# GRPO uses Agent-Ark/Toucan-1.5M (HuggingFace dataset downloaded at training time).\n", + "# For the E2E test we use minimal iterations to keep wall-clock time reasonable.\n", + "DATA_PATH = \"Agent-Ark/Toucan-1.5M\"\n", + "DATA_CONFIG = \"Qwen3\"\n", + "\n", + "params = {\n", + " # Model and data\n", + " 'model_path': LOCAL_MODEL_PATH,\n", + " 'data_path': DATA_PATH,\n", + " 'data_config': DATA_CONFIG,\n", + " 'ckpt_output_dir': '/opt/app-root/src/grpo-output',\n", + " 'backend': 'art',\n", + "\n", + " # GRPO hyperparameters — small values for fast E2E validation\n", + " 'num_iterations': 2,\n", + " 'group_size': 4,\n", + " 'prompt_batch_size': 10,\n", + " 'n_train': 50,\n", + " 'learning_rate': 1e-5,\n", + "\n", + " # LoRA\n", + " 'lora_r': 16,\n", + " 'lora_alpha': 8,\n", + "\n", + " # vLLM GPU memory sharing\n", + " 'gpu_memory_utilization': 0.45,\n", + "\n", + " # Single node (overridden by NNODES env var if set)\n", + " 'nnodes': int(os.getenv(\"NNODES\", \"1\")),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from kubeflow.trainer import TrainerClient\n", + "from kubeflow.trainer.rhai import TrainingHubAlgorithms\n", + "from kubeflow.trainer.rhai import TrainingHubTrainer\n", + "from kubeflow.common.types import KubernetesBackendConfig\n", + "\n", + "backend_cfg = KubernetesBackendConfig(client_configuration=api_client.configuration)\n", + "client = TrainerClient(backend_cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "training_runtime_name = os.getenv(\"TRAINING_RUNTIME\")\n", + "\n", + "if not training_runtime_name:\n", + " raise RuntimeError(\"TRAINING_RUNTIME environment variable is required\")\n", + "\n", + "th_runtime = None\n", + "for runtime in client.list_runtimes():\n", + " if runtime.name == training_runtime_name:\n", + " th_runtime = runtime\n", + " print(\"Found runtime: \" + str(th_runtime))\n", + " break\n", + "\n", + "if th_runtime is None:\n", + " raise RuntimeError(f\"Required runtime '{training_runtime_name}' not found\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from kubeflow.trainer.options.kubernetes import (\n", + " PodTemplateOverrides,\n", + " PodTemplateOverride,\n", + " PodSpecOverride,\n", + " ContainerOverride,\n", + ")\n", + "\n", + "cache_root = \"/opt/app-root/src/.cache/huggingface\"\n", + "triton_cache = \"/tmp/.triton\"\n", + "\n", + "job_name = client.train(\n", + " trainer=TrainingHubTrainer(\n", + " algorithm=TrainingHubAlgorithms.LORA_GRPO,\n", + " func_args=params,\n", + " env={\n", + " \"HF_HOME\": cache_root,\n", + " \"TRITON_CACHE_DIR\": triton_cache,\n", + " \"XDG_CACHE_HOME\": \"/opt/app-root/src/.cache\",\n", + " \"NCCL_DEBUG\": \"INFO\",\n", + " \"TRANSFORMERS_ATTN_BACKEND\": \"sdpa\",\n", + " },\n", + " resources_per_node={\n", + " \"cpu\": 8,\n", + " \"memory\": \"64Gi\",\n", + " \"nvidia.com/gpu\": 1\n", + " },\n", + " ),\n", + " options=[\n", + " PodTemplateOverrides(\n", + " PodTemplateOverride(\n", + " target_jobs=[\"node\"],\n", + " spec=PodSpecOverride(\n", + " volumes=[\n", + " {\"name\": \"work\", \"persistentVolumeClaim\": {\"claimName\": PVC_NAME}},\n", + " {\"name\": \"dshm\", \"emptyDir\": {\"medium\": \"Memory\"}},\n", + " ],\n", + " containers=[\n", + " ContainerOverride(\n", + " name=\"node\",\n", + " volume_mounts=[\n", + " {\"name\": \"work\", \"mountPath\": \"/opt/app-root/src\", \"readOnly\": False},\n", + " {\"name\": \"dshm\", \"mountPath\": \"/dev/shm\"},\n", + " ],\n", + " )\n", + " ],\n", + " ),\n", + " )\n", + " )\n", + " ],\n", + " runtime=th_runtime,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Wait for Running, then wait for Complete or Failed.\n", + "# GRPO takes longer than SFT/LoRA due to vLLM rollout + RL training loop.\n", + "client.wait_for_job_status(name=job_name, status={\"Running\"}, timeout=600)\n", + "client.wait_for_job_status(name=job_name, status={\"Complete\", \"Failed\"}, timeout=3600)\n", + "\n", + "job = client.get_job(name=job_name)\n", + "pod_logs = client.get_job_logs(name=job_name, follow=False)\n", + "\n", + "logs = []\n", + "for log_line in pod_logs:\n", + " logs.extend(str(log_line).splitlines())\n", + "\n", + "log_text = \"\\n\".join(logs)\n", + "\n", + "print(f\"Training job final status: {job.status}\")\n", + "\n", + "if job.status == \"Failed\":\n", + " print(f\"ERROR: Training job '{job_name}' has Failed status\")\n", + " print(\"Last 30 lines of logs:\")\n", + " for line in logs[-30:]:\n", + " print(line)\n", + " raise RuntimeError(f\"Training job '{job_name}' failed\")\n", + "\n", + "if \"[PY] LORA_GRPO training complete. Result=\" not in log_text:\n", + " print(f\"ERROR: Training completion message not found in logs\")\n", + " print(\"Last 50 lines of logs:\")\n", + " for line in logs[-50:]:\n", + " print(line)\n", + " raise RuntimeError(f\"Training did not complete successfully - missing completion message\")\n", + "\n", + "print(f\"Training job '{job_name}' completed successfully\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for c in client.get_job(name=job_name).steps:\n", + " print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "logs = client.get_job_logs(name=job_name, follow=False)\n", + "\n", + "logs = list(logs)\n", + "log_text = \"\\n\".join(str(line) for line in logs)\n", + "print(log_text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client.delete_job(job_name)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/trainer/sdk_tests/grpo_traininghub_tests.go b/tests/trainer/sdk_tests/grpo_traininghub_tests.go new file mode 100644 index 000000000..d4144c49b --- /dev/null +++ b/tests/trainer/sdk_tests/grpo_traininghub_tests.go @@ -0,0 +1,136 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sdk_tests + +import ( + "fmt" + "os" + "testing" + + . "github.com/onsi/gomega" + + corev1 "k8s.io/api/core/v1" + + common "github.com/opendatahub-io/distributed-workloads/tests/common" + support "github.com/opendatahub-io/distributed-workloads/tests/common/support" + trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils" +) + +const ( + grpoNotebookName = "grpo.ipynb" + grpoNotebookPath = "resources/" + grpoNotebookName +) + +// Multi-GPU - Distributed Training with LORA_GRPO and TrainingHubTrainer +func RunGrpoTrainingHubMultiGpuDistributedTraining(t *testing.T, nnodes int) { + test := support.With(t) + + // Create a new test namespace + namespace := test.NewTestNamespace() + + // Ensure Notebook ServiceAccount exists (no extra RBAC) + trainerutils.EnsureNotebookServiceAccount(t, test, namespace.Name) + + // RBACs setup + userName := common.GetNotebookUserName(test) + userToken := common.GenerateNotebookUserToken(test) + support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin") + trainerutils.CreateUserClusterRoleBindingForTrainerRuntimes(test, userName) + + // Create ConfigMap with notebook and install script + localPath := grpoNotebookPath + nb, err := os.ReadFile(localPath) + test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath)) + + installScript, err := os.ReadFile(installScriptPath) + test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read install script: %s", installScriptPath)) + + cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{ + grpoNotebookName: nb, + installKubeflowScript: installScript, + }) + + // Build command with parameters and pinned deps, and print definitive status line to logs + endpoint, endpointOK := support.GetStorageBucketDefaultEndpoint() + accessKey, _ := support.GetStorageBucketAccessKeyId() + secretKey, _ := support.GetStorageBucketSecretKey() + bucket, bucketOK := support.GetStorageBucketName() + prefix, _ := support.GetStorageBucketGrpoDir() + if !endpointOK { + endpoint = "" + } + if !bucketOK { + bucket = "" + } + + // Create RWX PVC for shared dataset and pass the claim name to the notebook + storageClass, err := support.GetRWXStorageClass(test) + test.Expect(err).NotTo(HaveOccurred(), "Failed to find an RWX supporting StorageClass") + rwxPvc := support.CreatePersistentVolumeClaim( + test, + namespace.Name, + "20Gi", + support.AccessModes(corev1.ReadWriteMany), + support.StorageClassName(storageClass.Name), + ) + + sdkInstallExports := buildKubeflowInstallExports() + shellCmd := fmt.Sprintf( + "set -e; "+ + "export IPYTHONDIR='/tmp/.ipython'; "+ + "export OPENSHIFT_API_URL=%s; export NOTEBOOK_USER_TOKEN=%s; "+ + "export NOTEBOOK_NAMESPACE=%s; "+ + "export SHARED_PVC_NAME=%s; "+ + "export AWS_DEFAULT_ENDPOINT=%s; export AWS_ACCESS_KEY_ID=%s; "+ + "export AWS_SECRET_ACCESS_KEY=%s; "+ + "export AWS_STORAGE_BUCKET=%s; "+ + "export AWS_STORAGE_BUCKET_GRPO_DIR=%s; "+ + "export TRAINING_RUNTIME=%s; "+ + "export NNODES='%d'; "+ + "export GPU_TYPE='nvidia'; "+ + "%s"+ + "python -m pip install --quiet --no-cache-dir --break-system-packages papermill && "+ + "python /opt/app-root/notebooks/%s && "+ + "if python -m papermill -k python3 /opt/app-root/notebooks/%s /opt/app-root/src/out.ipynb --log-output; "+ + "then echo 'NOTEBOOK_STATUS: SUCCESS'; else echo 'NOTEBOOK_STATUS: FAILURE'; fi; sleep infinity", + shellQuote(support.GetOpenShiftApiUrl(test)), shellQuote(userToken), shellQuote(namespace.Name), shellQuote(rwxPvc.Name), + shellQuote(endpoint), shellQuote(accessKey), shellQuote(secretKey), shellQuote(bucket), shellQuote(prefix), + shellQuote(trainerutils.DefaultTrainingHubRuntimeCUDA), + nnodes, + sdkInstallExports, + installKubeflowScript, + grpoNotebookName, + ) + command := []string{"/bin/sh", "-c", shellCmd} + + // GRPO requires more memory than SFT/LoRA due to vLLM running alongside training + common.CreateNotebook(test, namespace, userToken, command, cm.Name, grpoNotebookName, 0, rwxPvc, common.ContainerSizeMedium, common.GetRecommendedNotebookImageFromImageStream(test, common.NotebookImageStreamTrainingHubCUDA)) + + // Cleanup - use longer timeout for GPU tests due to large runtime images + defer func() { + common.DeleteNotebook(test, namespace) + test.Eventually(common.Notebooks(test, namespace), support.TestTimeoutGpuProvisioning).Should(HaveLen(0)) + }() + + // Wait for the Notebook Pod and get pod/container names + podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name) + + // Poll logs to check if the notebook execution completed successfully + // GRPO training takes longer than SFT/LoRA due to generation + RL loop + err = trainerutils.PollNotebookLogsForStatus(test, namespace.Name, podName, containerName, support.TestTimeoutDouble) + test.Expect(err).ShouldNot(HaveOccurred(), "Notebook execution reported FAILURE") +} From ec8a9e29cee828601ea9dce2d3fe07db98bc0f34 Mon Sep 17 00:00:00 2001 From: Slowlybomb Date: Mon, 15 Jun 2026 14:51:39 +0100 Subject: [PATCH 2/3] Add GRPO training test and enhance GRPO notebook - Introduced TestGrpoTrainingHubMultiNodeMultiGPU to validate GRPO training with multiple nodes and GPUs. - Updated grpo.ipynb to clarify SDK installation and added checkpoint verification and metrics logging checks for improved robustness. Tested Signed-off-by: Slowlybomb --- tests/trainer/resources/grpo.ipynb | 221 ++++++++++++------ .../sdk_tests/grpo_traininghub_tests.go | 2 +- 2 files changed, 152 insertions(+), 71 deletions(-) diff --git a/tests/trainer/resources/grpo.ipynb b/tests/trainer/resources/grpo.ipynb index 393594f05..48f3f178d 100644 --- a/tests/trainer/resources/grpo.ipynb +++ b/tests/trainer/resources/grpo.ipynb @@ -2,44 +2,39 @@ "cells": [ { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "%pip install datasets --quiet" - ] + "# kubeflow SDK is installed by test harness via install_kubeflow.py" + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "import logging\n", "import os\n", - "import sys\n", - "import time\n", - "from io import StringIO" - ] + "import time" + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "try:\n", " from dotenv import load_dotenv\n", " load_dotenv()\n", "except ImportError:\n", " pass" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from kubernetes import client as k8s, config as k8s_config\n", "\n", @@ -59,18 +54,16 @@ "api_client = k8s.ApiClient(configuration)\n", "\n", "PVC_MOUNT_PATH = \"/opt/app-root/src\"" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "import os\n", "import gzip\n", "import shutil\n", - "import time\n", "import socket\n", "import json\n", "\n", @@ -187,13 +180,13 @@ " print(\"[notebook] S3 not available: s3fs not installed\")\n", " else:\n", " print(\"[notebook] S3 not configured (missing endpoint or bucket env vars)\")" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "\n", @@ -207,17 +200,15 @@ " repo_id=\"Qwen/Qwen3-4B\",\n", " local_dir=MODEL_DIR,\n", " token=hf_token,\n", - " resume_download=True,\n", - " local_dir_use_symlinks=False,\n", " )\n", " print(f\"Model downloaded to: {MODEL_DIR}\")" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "LOCAL_MODEL_PATH = \"/opt/app-root/src/Qwen/Qwen3-4B\"\n", "\n", @@ -251,13 +242,13 @@ " # Single node (overridden by NNODES env var if set)\n", " 'nnodes': int(os.getenv(\"NNODES\", \"1\")),\n", "}" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from kubeflow.trainer import TrainerClient\n", "from kubeflow.trainer.rhai import TrainingHubAlgorithms\n", @@ -266,13 +257,13 @@ "\n", "backend_cfg = KubernetesBackendConfig(client_configuration=api_client.configuration)\n", "client = TrainerClient(backend_cfg)" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "training_runtime_name = os.getenv(\"TRAINING_RUNTIME\")\n", "\n", @@ -288,13 +279,13 @@ "\n", "if th_runtime is None:\n", " raise RuntimeError(f\"Required runtime '{training_runtime_name}' not found\")" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from kubeflow.trainer.options.kubernetes import (\n", " PodTemplateOverrides,\n", @@ -347,78 +338,168 @@ " ],\n", " runtime=th_runtime,\n", ")" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "# Wait for Running, then wait for Complete or Failed.\n", + "# Wait for Running (also accept Complete/Failed in case the job finishes quickly).\n", "# GRPO takes longer than SFT/LoRA due to vLLM rollout + RL training loop.\n", - "client.wait_for_job_status(name=job_name, status={\"Running\"}, timeout=600)\n", - "client.wait_for_job_status(name=job_name, status={\"Complete\", \"Failed\"}, timeout=3600)\n", + "client.wait_for_job_status(name=job_name, status={\"Running\", \"Complete\", \"Failed\"}, timeout=600)\n", + "\n", + "# The ART/GRPO runtime may not exit cleanly, so the TrainJob can stay Running\n", + "# even after training completes. Poll logs for the completion marker instead\n", + "# of relying solely on job status.\n", + "completion_markers = [\n", + " \"[PY] LORA_GRPO training complete. Result=\",\n", + " \"train: 100%\",\n", + "]\n", + "\n", + "poll_timeout = 900\n", + "poll_interval = 30\n", + "elapsed = 0\n", + "completed = False\n", + "\n", + "while elapsed < poll_timeout:\n", + " job = client.get_job(name=job_name)\n", + " if job.status == \"Failed\":\n", + " pod_logs = client.get_job_logs(name=job_name, follow=False)\n", + " logs = []\n", + " for log_line in pod_logs:\n", + " logs.extend(str(log_line).splitlines())\n", + " print(f\"ERROR: Training job '{job_name}' has Failed status\")\n", + " print(\"Last 30 lines of logs:\")\n", + " for line in logs[-30:]:\n", + " print(line)\n", + " raise RuntimeError(f\"Training job '{job_name}' failed\")\n", + "\n", + " if job.status == \"Complete\":\n", + " print(f\"Training job reached Complete status after {elapsed}s\")\n", + " completed = True\n", + " break\n", + "\n", + " pod_logs = client.get_job_logs(name=job_name, follow=False)\n", + " logs = []\n", + " for log_line in pod_logs:\n", + " logs.extend(str(log_line).splitlines())\n", + " log_text = \"\\n\".join(logs)\n", + "\n", + " if any(marker in log_text for marker in completion_markers):\n", + " print(f\"Training completion detected in logs after {elapsed}s\")\n", + " completed = True\n", + " break\n", + "\n", + " print(f\"Waiting for training completion... ({elapsed}s/{poll_timeout}s, status={job.status})\")\n", + " time.sleep(poll_interval)\n", + " elapsed += poll_interval\n", + "\n", + "if not completed:\n", + " print(f\"ERROR: Training did not complete within {poll_timeout}s\")\n", + " print(f\"Looked for any of: {completion_markers}\")\n", + " print(\"Last 50 lines of logs:\")\n", + " for line in logs[-50:]:\n", + " print(line)\n", + " raise RuntimeError(f\"Training did not complete successfully within {poll_timeout}s\")\n", "\n", "job = client.get_job(name=job_name)\n", "pod_logs = client.get_job_logs(name=job_name, follow=False)\n", - "\n", "logs = []\n", "for log_line in pod_logs:\n", " logs.extend(str(log_line).splitlines())\n", - "\n", "log_text = \"\\n\".join(logs)\n", "\n", - "print(f\"Training job final status: {job.status}\")\n", + "print(f\"Training job '{job_name}' completed successfully (status={job.status})\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import re\n", "\n", - "if job.status == \"Failed\":\n", - " print(f\"ERROR: Training job '{job_name}' has Failed status\")\n", - " print(\"Last 30 lines of logs:\")\n", - " for line in logs[-30:]:\n", - " print(line)\n", - " raise RuntimeError(f\"Training job '{job_name}' failed\")\n", + "ckpt_dir = params['ckpt_output_dir']\n", + "print(f\"Checking checkpoint output directory: {ckpt_dir}\")\n", "\n", - "if \"[PY] LORA_GRPO training complete. Result=\" not in log_text:\n", - " print(f\"ERROR: Training completion message not found in logs\")\n", - " print(\"Last 50 lines of logs:\")\n", - " for line in logs[-50:]:\n", - " print(line)\n", - " raise RuntimeError(f\"Training did not complete successfully - missing completion message\")\n", + "expected_files = [\"adapter_config.json\", \"adapter_model.safetensors\"]\n", + "found_files = []\n", + "for root, dirs, files in os.walk(ckpt_dir):\n", + " for f in files:\n", + " rel = os.path.relpath(os.path.join(root, f), ckpt_dir)\n", + " found_files.append(rel)\n", + " print(f\" {rel}\")\n", + "\n", + "if not found_files:\n", + " raise RuntimeError(f\"Checkpoint directory '{ckpt_dir}' is empty — no LoRA adapter produced\")\n", + "\n", + "for expected in expected_files:\n", + " if not any(f.endswith(expected) for f in found_files):\n", + " raise RuntimeError(f\"Expected checkpoint file '{expected}' not found in {ckpt_dir}\")\n", "\n", - "print(f\"Training job '{job_name}' completed successfully\")" - ] + "print(f\"Checkpoint verification passed — found {len(found_files)} files including LoRA adapter\")" + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", + "metadata": {}, + "source": [ + "metric_pattern = re.compile(\n", + " r\"loss\\S*\\s*[:=]\\s*[-\\d.]+\"\n", + ")\n", + "metric_lines = [line for line in logs if metric_pattern.search(str(line))]\n", + "\n", + "if not metric_lines:\n", + " print(\"WARNING: No training metric lines found in logs\")\n", + " print(\"Last 20 lines of logs for debugging:\")\n", + " for line in logs[-20:]:\n", + " print(line)\n", + " raise RuntimeError(\"No training metrics found in logs — expected loss values\")\n", + "\n", + "print(f\"Found {len(metric_lines)} metric log entries\")\n", + "print(f\"First metric entry: {metric_lines[0]}\")\n", + "print(f\"Last metric entry: {metric_lines[-1]}\")\n", + "print(\"Metrics verification passed\")" + ], "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", "metadata": {}, - "outputs": [], "source": [ "for c in client.get_job(name=job_name).steps:\n", " print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "logs = client.get_job_logs(name=job_name, follow=False)\n", "\n", "logs = list(logs)\n", "log_text = \"\\n\".join(str(line) for line in logs)\n", "print(log_text)" - ] + ], + "execution_count": null, + "outputs": [] }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "client.delete_job(job_name)" - ] + ], + "execution_count": null, + "outputs": [] } ], "metadata": { @@ -434,4 +515,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/tests/trainer/sdk_tests/grpo_traininghub_tests.go b/tests/trainer/sdk_tests/grpo_traininghub_tests.go index d4144c49b..f7cf5628c 100644 --- a/tests/trainer/sdk_tests/grpo_traininghub_tests.go +++ b/tests/trainer/sdk_tests/grpo_traininghub_tests.go @@ -1,5 +1,5 @@ /* -Copyright 2025. +Copyright 2026. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From f07776ac93c178aa8d7487d122378a46f8d44c5a Mon Sep 17 00:00:00 2001 From: Slowlybomb Date: Thu, 18 Jun 2026 12:22:11 +0100 Subject: [PATCH 3/3] feat: add DspClient scaffold and pipeline e2e test structure RHOAIENG-43403: Spike to investigate e2e testing for reusable KFP pipelines. - Fix DSP env var typo (dspRoutURL -> dspRouteURL) in environment.go - Implement DspClient HTTP wrapper for KFP v2beta1 REST API (upload pipeline, create run, poll status, get details, delete) - Add unit tests for DspClient using httptest mock servers - Scaffold tests/pipelines/ package with SFT test skeleton Signed-off-by: Slowlybomb --- tests/common/environment.go | 24 +++ tests/common/support/dsp_client.go | 228 ++++++++++++++++++++++++ tests/common/support/dsp_client_test.go | 209 ++++++++++++++++++++++ tests/pipelines/pipelines_sft_test.go | 86 +++++++++ tests/pipelines/resources/.gitkeep | 1 + tests/pipelines/support.go | 40 +++++ 6 files changed, 588 insertions(+) create mode 100644 tests/common/support/dsp_client.go create mode 100644 tests/common/support/dsp_client_test.go create mode 100644 tests/pipelines/pipelines_sft_test.go create mode 100644 tests/pipelines/resources/.gitkeep create mode 100644 tests/pipelines/support.go diff --git a/tests/common/environment.go b/tests/common/environment.go index 52262cd05..78776f3ea 100644 --- a/tests/common/environment.go +++ b/tests/common/environment.go @@ -41,6 +41,10 @@ const ( testTierEnvVar = "TEST_TIER" // The environment variable for HuggingFace token to download models which require authentication huggingfaceTokenEnvVar = "HF_TOKEN" + // Data Science Pipelines (KFP) connectivity + dspRouteURL = "DSP_ROUTE_URL" + dspBearerToken = "DSP_BEARER_TOKEN" + dspNamespace = "DSP_NAMESPACE" ) const ( @@ -206,3 +210,23 @@ func lookupEnvOrDefault(key, value string) string { } return value } + +func GetDspRouteURL(t Test) string { + url, ok := os.LookupEnv(dspRouteURL) + if !ok { + t.T().Fatalf("Required env var %s not set. Set it to the DSP API route URL, e.g. https://ds-pipeline-dspa-.apps..", dspRouteURL) + } + return url +} + +func GetDspBearerToken(t Test) string { + token, ok := os.LookupEnv(dspBearerToken) + if !ok { + t.T().Fatalf("Required env var %s not set. Set it to an OpenShift service account token with access to the DSP namespace.", dspBearerToken) + } + return token +} + +func GetDspNamespace() (string, bool) { + return os.LookupEnv(dspNamespace) +} diff --git a/tests/common/support/dsp_client.go b/tests/common/support/dsp_client.go new file mode 100644 index 000000000..f374ca8ad --- /dev/null +++ b/tests/common/support/dsp_client.go @@ -0,0 +1,228 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package support + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "strings" + "time" +) + +const defaultPollInterval = 30 * time.Second + +type DspClient struct { + baseURL string + token string + httpClient *http.Client + PollInterval time.Duration +} + +func NewDspClient(routeURL, bearerToken string) *DspClient { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + Proxy: http.ProxyFromEnvironment, + } + return &DspClient{ + baseURL: strings.TrimRight(routeURL, "/"), + token: bearerToken, + httpClient: &http.Client{Transport: tr, Timeout: 30 * time.Second}, + PollInterval: defaultPollInterval, + } +} + +func (c *DspClient) do(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+c.token) + return c.httpClient.Do(req) +} + +func (c *DspClient) UploadPipeline(pipelineYAML []byte, name string) (string, error) { + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + fw, err := w.CreateFormFile("uploadfile", name+".yaml") + if err != nil { + return "", fmt.Errorf("create form file: %w", err) + } + if _, err = fw.Write(pipelineYAML); err != nil { + return "", fmt.Errorf("write pipeline YAML: %w", err) + } + _ = w.WriteField("name", name) + w.Close() + + url := c.baseURL + "/apis/v2beta1/pipelines/upload" + req, err := http.NewRequest(http.MethodPost, url, &buf) + if err != nil { + return "", fmt.Errorf("new request: %w", err) + } + req.Header.Set("Content-Type", w.FormDataContentType()) + + resp, err := c.do(req) + if err != nil { + return "", fmt.Errorf("upload pipeline: %w", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("upload pipeline: HTTP %d: %s", resp.StatusCode, body) + } + + var result struct { + PipelineID string `json:"pipeline_id"` + } + if err = json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("parse upload response: %w — body: %s", err, body) + } + return result.PipelineID, nil +} + +func (c *DspClient) CreateRun(pipelineID, experimentID, runName string, params map[string]interface{}) (string, error) { + payload := map[string]interface{}{ + "display_name": runName, + "pipeline_version_reference": map[string]string{ + "pipeline_id": pipelineID, + }, + "runtime_config": map[string]interface{}{ + "parameters": params, + }, + } + if experimentID != "" { + payload["experiment_id"] = experimentID + } + body, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshal run payload: %w", err) + } + + url := c.baseURL + "/apis/v2beta1/runs" + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("new request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.do(req) + if err != nil { + return "", fmt.Errorf("create run: %w", err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("create run: HTTP %d: %s", resp.StatusCode, respBody) + } + + var result struct { + RunID string `json:"run_id"` + } + if err = json.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("parse create run response: %w — body: %s", err, respBody) + } + return result.RunID, nil +} + +// WaitForRunCompletion polls the run until it reaches a terminal state. +// Returns nil on SUCCEEDED; returns an error on FAILED, SKIPPED, or timeout. +func (c *DspClient) WaitForRunCompletion(t Test, runID string, timeout time.Duration) error { + t.T().Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + state, err := c.getRunState(runID) + if err != nil { + return err + } + t.T().Logf("DSP run %s state: %s", runID, state) + switch state { + case "SUCCEEDED": + return nil + case "FAILED", "SKIPPED", "CANCELED": + return fmt.Errorf("run %s reached terminal state: %s", runID, state) + } + time.Sleep(c.PollInterval) + } + return fmt.Errorf("run %s did not complete within %s", runID, timeout) +} + +func (c *DspClient) getRunState(runID string) (string, error) { + url := c.baseURL + "/apis/v2beta1/runs/" + runID + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("new request: %w", err) + } + resp, err := c.do(req) + if err != nil { + return "", fmt.Errorf("get run state: %w", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("get run: HTTP %d: %s", resp.StatusCode, body) + } + + var result struct { + State string `json:"state"` + } + if err = json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("parse run response: %w", err) + } + return result.State, nil +} + +// GetRunDetails returns the full JSON response for a run, useful for post-run assertions. +func (c *DspClient) GetRunDetails(runID string) (map[string]interface{}, error) { + url := c.baseURL + "/apis/v2beta1/runs/" + runID + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + resp, err := c.do(req) + if err != nil { + return nil, fmt.Errorf("get run details: %w", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("get run details: HTTP %d: %s", resp.StatusCode, body) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parse run details: %w", err) + } + return result, nil +} + +func (c *DspClient) DeletePipeline(pipelineID string) error { + url := c.baseURL + "/apis/v2beta1/pipelines/" + pipelineID + req, err := http.NewRequest(http.MethodDelete, url, nil) + if err != nil { + return fmt.Errorf("new request: %w", err) + } + resp, err := c.do(req) + if err != nil { + return fmt.Errorf("delete pipeline: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("delete pipeline: HTTP %d: %s", resp.StatusCode, body) + } + return nil +} diff --git a/tests/common/support/dsp_client_test.go b/tests/common/support/dsp_client_test.go new file mode 100644 index 000000000..5633ad3ad --- /dev/null +++ b/tests/common/support/dsp_client_test.go @@ -0,0 +1,209 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package support + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/onsi/gomega" +) + +func TestNewDspClient(t *testing.T) { + g := gomega.NewGomegaWithT(t) + c := NewDspClient("https://example.com", "token123") + g.Expect(c).NotTo(gomega.BeNil()) + g.Expect(c.baseURL).To(gomega.Equal("https://example.com")) + g.Expect(c.token).To(gomega.Equal("token123")) +} + +func TestNewDspClientTrimsTrailingSlash(t *testing.T) { + g := gomega.NewGomegaWithT(t) + c := NewDspClient("https://example.com/", "tok") + g.Expect(c.baseURL).To(gomega.Equal("https://example.com")) +} + +func TestUploadPipelineReturnsID(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.Expect(r.Method).To(gomega.Equal(http.MethodPost)) + g.Expect(r.URL.Path).To(gomega.Equal("/apis/v2beta1/pipelines/upload")) + g.Expect(r.Header.Get("Authorization")).To(gomega.Equal("Bearer tok")) + g.Expect(r.Header.Get("Content-Type")).To(gomega.ContainSubstring("multipart/form-data")) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"pipeline_id":"pipe-abc","display_name":"test-sft"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + id, err := c.UploadPipeline([]byte("apiVersion: v1"), "test-sft") + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(id).To(gomega.Equal("pipe-abc")) +} + +func TestUploadPipelineHTTPError(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"boom"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + _, err := c.UploadPipeline([]byte("yaml"), "test") + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("HTTP 500")) +} + +func TestCreateRunReturnsID(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.Expect(r.Method).To(gomega.Equal(http.MethodPost)) + g.Expect(r.URL.Path).To(gomega.Equal("/apis/v2beta1/runs")) + g.Expect(r.Header.Get("Content-Type")).To(gomega.Equal("application/json")) + + body, _ := io.ReadAll(r.Body) + var payload map[string]interface{} + g.Expect(json.Unmarshal(body, &payload)).To(gomega.Succeed()) + g.Expect(payload["display_name"]).To(gomega.Equal("e2e-sft-run")) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"run_id":"run-123"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + params := map[string]interface{}{"model": "Qwen/Qwen2.5-1.5B-Instruct"} + id, err := c.CreateRun("pipe-abc", "", "e2e-sft-run", params) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(id).To(gomega.Equal("run-123")) +} + +func TestCreateRunWithExperimentID(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var payload map[string]interface{} + g.Expect(json.Unmarshal(body, &payload)).To(gomega.Succeed()) + g.Expect(payload["experiment_id"]).To(gomega.Equal("exp-1")) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"run_id":"run-456"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + id, err := c.CreateRun("pipe-abc", "exp-1", "run", nil) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(id).To(gomega.Equal("run-456")) +} + +func TestWaitForRunCompletionSucceeded(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + callCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.Expect(r.URL.Path).To(gomega.Equal("/apis/v2beta1/runs/run-1")) + callCount++ + state := "RUNNING" + if callCount >= 2 { + state = "SUCCEEDED" + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"run_id":"run-1","state":"` + state + `"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + c.PollInterval = 10 * time.Millisecond + test := With(t) + err := c.WaitForRunCompletion(test, "run-1", 5*time.Second) + g.Expect(err).NotTo(gomega.HaveOccurred()) +} + +func TestWaitForRunCompletionFailed(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"run_id":"run-2","state":"FAILED"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + c.PollInterval = 10 * time.Millisecond + test := With(t) + err := c.WaitForRunCompletion(test, "run-2", 5*time.Second) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("FAILED")) +} + +func TestGetRunDetails(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.Expect(r.URL.Path).To(gomega.Equal("/apis/v2beta1/runs/run-1")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"run_id":"run-1","state":"SUCCEEDED","pipeline_spec":{"root":{}}}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + details, err := c.GetRunDetails("run-1") + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(details["state"]).To(gomega.Equal("SUCCEEDED")) + g.Expect(details["run_id"]).To(gomega.Equal("run-1")) +} + +func TestDeletePipeline(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.Expect(r.Method).To(gomega.Equal(http.MethodDelete)) + g.Expect(r.URL.Path).To(gomega.Equal("/apis/v2beta1/pipelines/pipe-abc")) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + err := c.DeletePipeline("pipe-abc") + g.Expect(err).NotTo(gomega.HaveOccurred()) +} + +func TestDeletePipelineHTTPError(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"not found"}`)) + })) + defer srv.Close() + + c := NewDspClient(srv.URL, "tok") + err := c.DeletePipeline("nonexistent") + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("HTTP 404")) +} diff --git a/tests/pipelines/pipelines_sft_test.go b/tests/pipelines/pipelines_sft_test.go new file mode 100644 index 000000000..dc20d576a --- /dev/null +++ b/tests/pipelines/pipelines_sft_test.go @@ -0,0 +1,86 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pipelines + +import ( + "fmt" + "testing" + + . "github.com/onsi/gomega" + + . "github.com/opendatahub-io/distributed-workloads/tests/common" + "github.com/opendatahub-io/distributed-workloads/tests/common/support" +) + +const ( + sftPipelineYAMLPath = "resources/sft_pipeline.yaml" +) + +func TestSftPipelineRun(t *testing.T) { + Tags(t, Tier1, Gpu(support.NVIDIA)) + test := support.With(t) + + dspURL := GetDspRouteURL(test) + dspToken := GetDspBearerToken(test) + client := support.NewDspClient(dspURL, dspToken) + + pipelineYAML := readFile(test, sftPipelineYAMLPath) + + pipelineID, err := client.UploadPipeline(pipelineYAML, "e2e-sft-pipeline") + test.Expect(err).NotTo(HaveOccurred(), "failed to upload SFT pipeline") + + t.Cleanup(func() { + if err := client.DeletePipeline(pipelineID); err != nil { + t.Logf("warning: failed to delete pipeline %s: %v", pipelineID, err) + } + }) + + dataURI := buildSftDataURI(test) + + params := map[string]interface{}{ + "phase_01_dataset_man_data_uri": dataURI, + "phase_01_dataset_opt_subset": "100", + "phase_02_train_man_model": "Qwen/Qwen2.5-1.5B-Instruct", + "phase_02_train_man_epochs": "1", + "phase_02_train_man_workers": "1", + "phase_02_train_man_gpu": "1", + "phase_02_train_opt_runtime": "training-hub", + "phase_03_eval_man_tasks": "[]", + "phase_04_registry_man_address": "", + } + + runID, err := client.CreateRun(pipelineID, "", "e2e-sft-run", params) + test.Expect(err).NotTo(HaveOccurred(), "failed to create SFT pipeline run") + + t.Logf("SFT pipeline run created: %s", runID) + + err = client.WaitForRunCompletion(test, runID, PipelineRunTimeout) + test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("SFT pipeline run %s did not succeed", runID)) +} + +func buildSftDataURI(test support.Test) string { + endpoint, endpointOK := support.GetStorageBucketDefaultEndpoint() + bucket, bucketOK := support.GetStorageBucketName() + prefix, prefixOK := support.GetStorageBucketSftDir() + + if endpointOK && bucketOK && prefixOK && endpoint != "" && bucket != "" && prefix != "" { + return fmt.Sprintf("s3://%s/%s", bucket, prefix) + } + + test.T().Log("S3 not fully configured; falling back to HuggingFace dataset URI") + return "hf://ibm/merlinite-9b-lab-processed" +} diff --git a/tests/pipelines/resources/.gitkeep b/tests/pipelines/resources/.gitkeep new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/pipelines/resources/.gitkeep @@ -0,0 +1 @@ + diff --git a/tests/pipelines/support.go b/tests/pipelines/support.go new file mode 100644 index 000000000..62a4af95c --- /dev/null +++ b/tests/pipelines/support.go @@ -0,0 +1,40 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pipelines + +import ( + "embed" + "time" + + "github.com/onsi/gomega" + + "github.com/opendatahub-io/distributed-workloads/tests/common/support" +) + +// PipelineRunTimeout is how long we wait for a KFP pipeline run to reach SUCCEEDED. +// Fine-tuning pipelines with GPU typically take 30-90 minutes. +const PipelineRunTimeout = 2 * time.Hour + +//go:embed resources/* +var files embed.FS + +func readFile(t support.Test, fileName string) []byte { + t.T().Helper() + file, err := files.ReadFile(fileName) + t.Expect(err).NotTo(gomega.HaveOccurred()) + return file +}