diff --git a/.flake8 b/.flake8 index 67fad892a..0fba9bd8c 100644 --- a/.flake8 +++ b/.flake8 @@ -33,6 +33,9 @@ fcn_exclude_functions = add_to_assignees, validate_inference_output, # TODO: function should be fixed to get rid of this group + writelines + temp_file + sleep enable-extensions = FCN, diff --git a/.github/workflows/scripts/pr_workflow.py b/.github/workflows/scripts/pr_workflow.py index 3b8b0ec37..0a6b266e9 100644 --- a/.github/workflows/scripts/pr_workflow.py +++ b/.github/workflows/scripts/pr_workflow.py @@ -2,15 +2,9 @@ import re import sys -from github.PullRequest import PullRequest -from github.Repository import Repository -from github.MainClass import Github -from github.GithubException import UnknownObjectException -from github.Organization import Organization -from github.Team import Team - from constants import ( ALL_LABELS_DICT, + APPROVED, CANCEL_ACTION, CHANGED_REQUESTED_BY_LABEL_PREFIX, COMMENTED_BY_LABEL_PREFIX, @@ -22,8 +16,13 @@ SUPPORTED_LABELS, VERIFIED_LABEL_STR, WELCOME_COMMENT, - APPROVED, ) +from github.GithubException import UnknownObjectException +from github.MainClass import Github +from github.Organization import Organization +from github.PullRequest import PullRequest +from github.Repository import Repository +from github.Team import Team from simple_logger.logger import get_logger LOGGER = get_logger(name="pr_labeler") @@ -35,7 +34,7 @@ class SupportedActions: pr_size_action_name: str = "add-pr-size-label" welcome_comment_action_name: str = "add-welcome-comment-set-assignee" build_push_pr_image_action_name: str = "push-container-on-comment" - supported_actions: set[str] = { + supported_actions: set[str] = { # noqa: RUF012 pr_size_action_name, add_remove_labels_action_name, welcome_comment_action_name, @@ -48,7 +47,7 @@ def __init__(self) -> None: self.gh_client: Github self.repo_name = os.environ["GITHUB_REPOSITORY"] - self.pr_number = int(os.getenv("GITHUB_PR_NUMBER", 0)) + self.pr_number = int(os.getenv("GITHUB_PR_NUMBER", "0")) self.action = os.getenv("ACTION") self.event_action = os.getenv("GITHUB_EVENT_ACTION") self.event_name = os.getenv("GITHUB_EVENT_NAME") @@ -110,7 +109,7 @@ def verify_allowed_user(self) -> bool: # check if the user is a member of opendatahub-tests-contributors membership = team.get_team_membership(member=self.user_login) LOGGER.info(f"User {self.user_login} is a member of the test contributor team. {membership}") - return True + return True # noqa: TRY300 except UnknownObjectException: LOGGER.error(f"User {self.user_login} is not allowed for this action. Exiting.") return False @@ -123,9 +122,7 @@ def verify_labeler_config(self) -> None: if not self.user_login: sys.exit("`GITHUB_USER_LOGIN` is not set") - if ( - self.event_name == "issue_comment" or self.event_name == "pull_request_review" - ) and not self.comment_body: + if (self.event_name in {"issue_comment", "pull_request_review"}) and not self.comment_body: LOGGER.info("No comment, nothing to do. Exiting.") sys.exit(0) @@ -133,13 +130,11 @@ def run_pr_label_action(self) -> None: if self.action == self.SupportedActions.pr_size_action_name: self.set_pr_size() - if self.action == self.SupportedActions.build_push_pr_image_action_name: - if not self.verify_allowed_user(): - sys.exit(1) + if self.action == self.SupportedActions.build_push_pr_image_action_name and not self.verify_allowed_user(): + sys.exit(1) - if self.action == self.SupportedActions.add_remove_labels_action_name: - if self.verify_allowed_user(): - self.add_remove_pr_labels() + if self.action == self.SupportedActions.add_remove_labels_action_name and self.verify_allowed_user(): + self.add_remove_pr_labels() if self.action == self.SupportedActions.welcome_comment_action_name: self.add_welcome_comment_set_assignee() @@ -187,10 +182,9 @@ def set_label_in_repository(self, label: str) -> None: LOGGER.info(f"repo labels: {repo_labels}") try: - if _repo_label := self.repo.get_label(name=label): - if _repo_label.color != label_color: - LOGGER.info(f"Edit repository label: {label}, color: {label_color}") - _repo_label.edit(name=_repo_label.name, color=label_color) + if (_repo_label := self.repo.get_label(name=label)) and _repo_label.color != label_color: + LOGGER.info(f"Edit repository label: {label}, color: {label_color}") + _repo_label.edit(name=_repo_label.name, color=label_color) except UnknownObjectException: LOGGER.info(f"Add repository label: {label}, color: {label_color}") @@ -249,16 +243,15 @@ def add_remove_pr_labels(self) -> None: return - elif self.event_name == "pull_request_review": + elif ( + self.event_name == "pull_request_review" + or self.event_name == "workflow_run" + and self.event_action == "submitted" + ): self.pull_request_review_label_actions() return - # We will only reach here if the PR was created from a fork - elif self.event_name == "workflow_run" and self.event_action == "submitted": - self.pull_request_review_label_actions() - return - LOGGER.warning("`add_remove_pr_label` called without a supported event") def pull_request_review_label_actions( @@ -317,7 +310,7 @@ def issue_comment_label_actions( if not action[CANCEL_ACTION] or self.event_action == "deleted": self.approve_pr() - label_in_pr = any([label == _label.lower() for _label in self.pr_labels]) + label_in_pr = any(label == _label.lower() for _label in self.pr_labels) LOGGER.info(f"Processing label: {label}, action: {action}") if action[CANCEL_ACTION] or self.event_action == "deleted": diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 185bc336d..5f2b4af79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: exclude: .*/__snapshots__/.*|.*-input\.json$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.1 + rev: v0.15.2 hooks: - id: ruff - id: ruff-format diff --git a/conftest.py b/conftest.py index c5595c120..3c84b4154 100644 --- a/conftest.py +++ b/conftest.py @@ -1,43 +1,43 @@ +import datetime import logging import os import pathlib import shutil -import datetime import traceback +from typing import Any import pytest import shortuuid -from _pytest.runner import CallInfo +from _pytest.nodes import Node from _pytest.reports import TestReport +from _pytest.runner import CallInfo +from _pytest.terminal import TerminalReporter +from kubernetes.dynamic import DynamicClient +from ocp_resources.cluster_service_version import ClusterServiceVersion +from ocp_resources.resource import get_client from pytest import ( - Parser, - Session, - FixtureRequest, - FixtureDef, - Item, Collector, - Config, CollectReport, + Config, + FixtureDef, + FixtureRequest, + Item, + Parser, + Session, ) -from _pytest.nodes import Node -from _pytest.terminal import TerminalReporter -from typing import Optional, Any from pytest_testconfig import config as py_config -from utilities.constants import KServeDeploymentType, MODEL_REGISTRY_CUSTOM_NAMESPACE +from utilities.constants import MODEL_REGISTRY_CUSTOM_NAMESPACE, KServeDeploymentType from utilities.database import Database +from utilities.infra import get_data_science_cluster, get_dsci_applications_namespace, get_operator_distribution from utilities.logger import separator, setup_logging from utilities.must_gather_collector import ( - set_must_gather_collector_directory, - set_must_gather_collector_values, - get_must_gather_collector_dir, collect_rhoai_must_gather, get_base_dir, + get_must_gather_collector_dir, + set_must_gather_collector_directory, + set_must_gather_collector_values, ) -from kubernetes.dynamic import DynamicClient -from utilities.infra import get_operator_distribution, get_dsci_applications_namespace, get_data_science_cluster -from ocp_resources.resource import get_client -from ocp_resources.cluster_service_version import ClusterServiceVersion LOGGER = logging.getLogger(name=__name__) BASIC_LOGGER = logging.getLogger(name="basic") @@ -229,7 +229,7 @@ def _add_upgrade_test(_item: Item, _upgrade_deployment_modes: list[str]) -> bool if not _upgrade_deployment_modes: return True - return any([keyword for keyword in _item.keywords if keyword in _upgrade_deployment_modes]) + return any(keyword for keyword in _item.keywords if keyword in _upgrade_deployment_modes) pre_upgrade_tests: list[Item] = [] post_upgrade_tests: list[Item] = [] @@ -310,7 +310,7 @@ def pytest_sessionstart(session: Session) -> None: value = log_cli_override.split("=", 1)[1].lower() enable_console_value = value not in ("false", "0", "no", "off") - except Exception as e: + except Exception as e: # noqa: BLE001 # If there's any issue with option detection, fall back to default behavior LOGGER.error(f"Error detecting log_cli option: {e}") enable_console_value = True @@ -397,9 +397,9 @@ def pytest_runtest_setup(item: Item) -> None: db = item.config.option.must_gather_db db.insert_test_start_time( test_name=f"{item.fspath}::{item.name}", - start_time=int(datetime.datetime.now().timestamp()), + start_time=int(datetime.datetime.now().timestamp()), # noqa: DTZ005 ) - except Exception as db_exception: + except Exception as db_exception: # noqa: BLE001 LOGGER.error(f"Database error: {db_exception}. Must-gather collection may not be accurate") if KServeDeploymentType.RAW_DEPLOYMENT.lower() in item.keywords: @@ -460,7 +460,7 @@ def pytest_sessionfinish(session: Session, exitstatus: int) -> None: LOGGER.info(f"Deleting pytest base dir {session.config.option.basetemp}") shutil.rmtree(path=session.config.option.basetemp, ignore_errors=True) - reporter: Optional[TerminalReporter] = session.config.pluginmanager.get_plugin("terminalreporter") + reporter: TerminalReporter | None = session.config.pluginmanager.get_plugin("terminalreporter") if reporter: reporter.summary_stats() @@ -468,7 +468,7 @@ def pytest_sessionfinish(session: Session, exitstatus: int) -> None: def calculate_must_gather_timer(test_start_time: int) -> int: default_duration = 300 if test_start_time > 0: - duration = int(datetime.datetime.now().timestamp()) - test_start_time + duration = int(datetime.datetime.now().timestamp()) - test_start_time # noqa: DTZ005 return duration if duration > 60 else default_duration else: LOGGER.warning(f"Could not get start time of test. Collecting must-gather for last {default_duration}s") @@ -492,7 +492,7 @@ def pytest_exception_interact(node: Item | Collector, call: CallInfo[Any], repor try: db = node.config.option.must_gather_db test_start_time = db.get_test_start_time(test_name=test_name) - except Exception as db_exception: + except Exception as db_exception: # noqa: BLE001 test_start_time = 0 LOGGER.warning(f"Error: {db_exception} in accessing database.") @@ -503,7 +503,7 @@ def pytest_exception_interact(node: Item | Collector, call: CallInfo[Any], repor target_dir=os.path.join(get_must_gather_collector_dir(), "pytest_exception_interact"), ) - except Exception as current_exception: + except Exception as current_exception: # noqa: BLE001 LOGGER.warning(f"Failed to collect logs: {test_name}: {current_exception} {traceback.format_exc()}") diff --git a/pyproject.toml b/pyproject.toml index d153506b9..78f480b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,9 @@ fix = true output-format = "grouped" extend-exclude = ["utilities/manifests"] +[tool.ruff.lint] +external = ["E501"] + [tool.ruff.format] exclude = [".git", ".venv", ".mypy_cache", ".tox", "__pycache__", "utilities/manifests"] diff --git a/tests/cluster_health/test_operator_health.py b/tests/cluster_health/test_operator_health.py index f0a740edd..a34dd2c68 100644 --- a/tests/cluster_health/test_operator_health.py +++ b/tests/cluster_health/test_operator_health.py @@ -2,11 +2,12 @@ from kubernetes.dynamic import DynamicClient from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.dsc_initialization import DSCInitialization -from utilities.general import wait_for_pods_running -from utilities.infra import wait_for_dsci_status_ready, wait_for_dsc_status_ready from pytest_testconfig import config as py_config from simple_logger.logger import get_logger +from utilities.general import wait_for_pods_running +from utilities.infra import wait_for_dsc_status_ready, wait_for_dsci_status_ready + LOGGER = get_logger(name=__name__) diff --git a/tests/conftest.py b/tests/conftest.py index 6ba16b474..6dd24014a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,70 +1,71 @@ import base64 import binascii +import json import os import shutil from ast import literal_eval -from typing import Any, Callable, Generator +from collections.abc import Callable, Generator +from typing import Any + import pytest -from ocp_resources.route import Route -from semver import Version import shortuuid import yaml from _pytest._py.path import LocalPath from _pytest.legacypath import TempdirFactory from _pytest.tmpdir import TempPathFactory +from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError - +from ocp_resources.authentication_config_openshift_io import Authentication from ocp_resources.cluster_service_version import ClusterServiceVersion from ocp_resources.cluster_version import ClusterVersion from ocp_resources.config_map import ConfigMap +from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.deployment import Deployment from ocp_resources.dsc_initialization import DSCInitialization from ocp_resources.mariadb_operator import MariadbOperator +from ocp_resources.namespace import Namespace from ocp_resources.node import Node from ocp_resources.pod import Pod +from ocp_resources.resource import get_client +from ocp_resources.route import Route from ocp_resources.secret import Secret from ocp_resources.service import Service from ocp_resources.subscription import Subscription from ocp_utilities.monitoring import Prometheus +from ocp_utilities.operators import install_operator, uninstall_operator from pyhelper_utils.shell import run_command -from pytest import FixtureRequest, Config -from kubernetes.dynamic import DynamicClient -from ocp_resources.data_science_cluster import DataScienceCluster -from ocp_resources.namespace import Namespace -from ocp_resources.resource import get_client +from pytest import Config, FixtureRequest from pytest_testconfig import config as py_config +from semver import Version from simple_logger.logger import get_logger -import json -from ocp_utilities.operators import uninstall_operator, install_operator from utilities.certificates_utils import create_ca_bundle_file -from utilities.data_science_cluster_utils import update_components_in_dsc -from utilities.exceptions import ClusterLoginError -from utilities.infra import ( - verify_cluster_sanity, - create_ns, - login_with_user_password, - get_openshift_token, - download_oc_console_cli, - get_cluster_authentication, -) from utilities.constants import ( + OPENSHIFT_OPERATORS, AcceleratorType, DscComponents, Labels, MinIo, + OCIRegistry, Protocols, Timeout, - OPENSHIFT_OPERATORS, - OCIRegistry, ) -from utilities.infra import update_configmap_data +from utilities.data_science_cluster_utils import update_components_in_dsc +from utilities.exceptions import ClusterLoginError +from utilities.infra import ( + create_ns, + download_oc_console_cli, + get_cluster_authentication, + get_openshift_token, + login_with_user_password, + update_configmap_data, + verify_cluster_sanity, +) from utilities.logger import RedactedString from utilities.mariadb_utils import wait_for_mariadb_operator_deployments from utilities.minio import create_minio_data_connection_secret -from utilities.operator_utils import get_csv_related_images, get_cluster_service_version -from ocp_resources.authentication_config_openshift_io import Authentication -from utilities.user_utils import get_oidc_tokens, get_byoidc_issuer_url +from utilities.operator_utils import get_cluster_service_version, get_csv_related_images +from utilities.user_utils import get_byoidc_issuer_url, get_oidc_tokens LOGGER = get_logger(name=__name__) @@ -83,7 +84,7 @@ def admin_client() -> DynamicClient: @pytest.fixture(scope="session") -def tests_tmp_dir(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> Generator[None, None, None]: +def tests_tmp_dir(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> Generator[None]: base_path = os.path.join(request.config.option.basetemp, "tests") tests_tmp_path = tmp_path_factory.mktemp(basename=base_path) py_config["tmp_base_dir"] = str(tests_tmp_path) @@ -100,7 +101,7 @@ def current_client_token(admin_client: DynamicClient) -> str: def teardown_resources(pytestconfig: pytest.Config) -> bool: delete_resources = True - if pytestconfig.option.pre_upgrade: + if pytestconfig.option.pre_upgrade: # noqa: SIM102 if delete_resources := pytestconfig.option.delete_pre_upgrade_resources: LOGGER.warning("Upgrade resources will be deleted") @@ -164,7 +165,7 @@ def registry_pull_secret(pytestconfig: Config) -> list[str]: try: for secret in registry_pull_secret: base64.b64decode(s=secret, validate=True) - return registry_pull_secret + return registry_pull_secret # noqa: TRY300 except binascii.Error: raise ValueError("Registry pull secret is not a valid base64 encoded string") @@ -249,8 +250,8 @@ def modelcar_yaml_config(pytestconfig: pytest.Config) -> dict[str, Any] | None: try: modelcar_yaml = yaml.safe_load(file) if not isinstance(modelcar_yaml, dict): - raise ValueError("modelcar.yaml should contain a dictionary.") - return modelcar_yaml + raise ValueError("modelcar.yaml should contain a dictionary.") # noqa: TRY004 + return modelcar_yaml # noqa: TRY300 except yaml.YAMLError as e: raise ValueError(f"Error parsing modelcar.yaml: {e}") from e @@ -338,7 +339,7 @@ def use_unprivileged_client(pytestconfig: pytest.Config) -> bool: return literal_eval(_use_unprivileged_client) else: - raise ValueError( + raise ValueError( # noqa: TRY004 "use_unprivileged_client is not defined.\n" "Either pass with `--use-unprivileged-client` or " "set in `use_unprivileged_client` in `tests/global_config.py`" @@ -448,7 +449,9 @@ def unprivileged_client( # get the current context and modify the referenced user in place current_context_name = kubeconfig_content["current-context"] - current_context = [c for c in kubeconfig_content["contexts"] if c["name"] == current_context_name][0] + current_context = next((c for c in kubeconfig_content["contexts"] if c["name"] == current_context_name), None) + if current_context is None: + raise ValueError(f"Context '{current_context_name}' not found in kubeconfig") current_context["context"]["user"] = non_admin_user_password[0] unprivileged_client = get_client( @@ -681,8 +684,7 @@ def prometheus(admin_client: DynamicClient) -> Prometheus: @pytest.fixture(scope="session") def related_images_refs(admin_client: DynamicClient) -> set[str]: related_images = get_csv_related_images(admin_client=admin_client) - related_images_refs = {img["image"] for img in related_images} - return related_images_refs + return {img["image"] for img in related_images} @pytest.fixture(scope="session") diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 687e85f13..cb2bff402 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -1,6 +1,7 @@ -from typing import Callable, Dict -import pytest import os +from collections.abc import Callable + +import pytest from _pytest.fixtures import FixtureRequest S3_AUTO_CREATE_BUCKET = os.getenv("LLS_FILES_S3_AUTO_CREATE_BUCKET", "true") @@ -9,7 +10,7 @@ @pytest.fixture(scope="class") def files_provider_config_factory( request: FixtureRequest, -) -> Callable[[str], list[Dict[str, str]]]: +) -> Callable[[str], list[dict[str, str]]]: """ Factory fixture for configuring external files providers and returning their configuration. @@ -44,7 +45,7 @@ def test_with_s3(files_provider_config_factory): # env_vars contains S3_BUCKET_NAME, S3_BUCKET_ENDPOINT_URL, etc. """ - def _factory(provider_name: str) -> list[Dict[str, str]]: + def _factory(provider_name: str) -> list[dict[str, str]]: env_vars: list[dict[str, str]] = [] if provider_name == "local" or provider_name is None: diff --git a/tests/fixtures/guardrails.py b/tests/fixtures/guardrails.py index 3fbc6b1cd..bd2fb96cb 100644 --- a/tests/fixtures/guardrails.py +++ b/tests/fixtures/guardrails.py @@ -1,4 +1,5 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest @@ -11,7 +12,7 @@ from ocp_resources.resource import ResourceEditor from ocp_resources.route import Route -from utilities.constants import Labels, Annotations +from utilities.constants import Annotations, Labels from utilities.guardrails import check_guardrails_health_endpoint GUARDRAILS_ORCHESTRATOR_NAME: str = "guardrails-orchestrator" @@ -129,11 +130,15 @@ def guardrails_orchestrator_pod( model_namespace: Namespace, guardrails_orchestrator: GuardrailsOrchestrator, ) -> Pod: - return list( - Pod.get( - namespace=model_namespace.name, label_selector=f"app.kubernetes.io/instance={GUARDRAILS_ORCHESTRATOR_NAME}" + pods = Pod.get( + namespace=model_namespace.name, label_selector=f"app.kubernetes.io/instance={GUARDRAILS_ORCHESTRATOR_NAME}" + ) + pod = next(iter(pods), None) + if pod is None: + raise RuntimeError( + f"No guardrails orchestrator pod found with label app.kubernetes.io/instance={GUARDRAILS_ORCHESTRATOR_NAME}" ) - )[0] + return pod @pytest.fixture(scope="class") diff --git a/tests/fixtures/inference.py b/tests/fixtures/inference.py index e86cd044f..eac5b29ff 100644 --- a/tests/fixtures/inference.py +++ b/tests/fixtures/inference.py @@ -1,4 +1,5 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -13,15 +14,14 @@ from ocp_resources.serving_runtime import ServingRuntime from pytest_testconfig import py_config from simple_logger.logger import get_logger +from timeout_sampler import retry from utilities.constants import ( - RuntimeTemplates, - KServeDeploymentType, QWEN_MODEL_NAME, + KServeDeploymentType, LLMdInferenceSimConfig, + RuntimeTemplates, ) -from timeout_sampler import retry - from utilities.inference_utils import create_isvc from utilities.infra import get_data_science_cluster, wait_for_dsc_status_ready from utilities.serving_runtime import ServingRuntimeFromTemplate @@ -221,7 +221,7 @@ def kserve_controller_manager_deployment(admin_client: DynamicClient) -> Generat @pytest.fixture(scope="class") def patched_dsc_kserve_headed( admin_client, kserve_controller_manager_deployment: Deployment -) -> Generator[DataScienceCluster, None, None]: +) -> Generator[DataScienceCluster]: """Configure KServe Services to work in Headed mode i.e. using the Service port instead of the Pod port""" def _kserve_status(dsc_resource: DataScienceCluster) -> str: @@ -231,10 +231,10 @@ def _kserve_status(dsc_resource: DataScienceCluster) -> str: @retry(wait_timeout=30, sleep=1) def _wait_for_kserve_upgrade(dsc_resource: DataScienceCluster): - return not _kserve_status(dsc_resource) == "True" + return _kserve_status(dsc_resource) != "True" dsc = get_data_science_cluster(client=admin_client) - if not dsc.instance.spec.components.kserve.rawDeploymentServiceConfig == "Headed": + if dsc.instance.spec.components.kserve.rawDeploymentServiceConfig != "Headed": with ResourceEditor( patches={dsc: {"spec": {"components": {"kserve": {"rawDeploymentServiceConfig": "Headed"}}}}} ): diff --git a/tests/fixtures/trustyai.py b/tests/fixtures/trustyai.py index 4b0d2ed2b..add6d9e44 100644 --- a/tests/fixtures/trustyai.py +++ b/tests/fixtures/trustyai.py @@ -1,10 +1,9 @@ +from collections.abc import Generator + import pytest from kubernetes.dynamic import DynamicClient from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.deployment import Deployment - -from typing import Generator - from ocp_resources.resource import ResourceEditor from pytest_testconfig import py_config @@ -25,7 +24,7 @@ def trustyai_operator_deployment(admin_client: DynamicClient) -> Deployment: @pytest.fixture(scope="class") def patched_dsc_lmeval_allow_all( admin_client, trustyai_operator_deployment: Deployment -) -> Generator[DataScienceCluster, None, None]: +) -> Generator[DataScienceCluster]: """Enable LMEval PermitOnline and PermitCodeExecution flags in the Datascience cluster.""" dsc = get_data_science_cluster(client=admin_client) with ResourceEditor( diff --git a/tests/fixtures/vector_io.py b/tests/fixtures/vector_io.py index 62a38b678..9ccbaaf3e 100644 --- a/tests/fixtures/vector_io.py +++ b/tests/fixtures/vector_io.py @@ -1,14 +1,15 @@ -from typing import Generator, Any, Callable, Dict -import pytest import os import secrets +from collections.abc import Callable, Generator +from typing import Any + +import pytest from _pytest.fixtures import FixtureRequest from kubernetes.dynamic import DynamicClient from ocp_resources.deployment import Deployment from ocp_resources.namespace import Namespace -from ocp_resources.service import Service from ocp_resources.secret import Secret - +from ocp_resources.service import Service MILVUS_IMAGE = os.getenv( "LLS_VECTOR_IO_MILVUS_IMAGE", @@ -47,7 +48,7 @@ @pytest.fixture(scope="class") def vector_io_provider_deployment_config_factory( request: FixtureRequest, -) -> Callable[[str], list[Dict[str, Any]]]: +) -> Callable[[str], list[dict[str, Any]]]: """ Factory fixture for deploying vector I/O providers and returning their configuration. @@ -94,7 +95,7 @@ def test_with_milvus(vector_io_provider_deployment_config_factory): # env_vars contains MILVUS_ENDPOINT, MILVUS_TOKEN, etc. """ - def _factory(provider_name: str) -> list[Dict[str, Any]]: + def _factory(provider_name: str) -> list[dict[str, Any]]: env_vars: list[dict[str, Any]] = [] if provider_name is None or provider_name == "milvus": @@ -265,7 +266,7 @@ def milvus_service( yield service -def get_milvus_deployment_template() -> Dict[str, Any]: +def get_milvus_deployment_template() -> dict[str, Any]: """Return the Kubernetes deployment template for Milvus standalone.""" return { "metadata": {"labels": {"app": "milvus-standalone"}}, @@ -300,7 +301,7 @@ def get_milvus_deployment_template() -> Dict[str, Any]: } -def get_etcd_deployment_template() -> Dict[str, Any]: +def get_etcd_deployment_template() -> dict[str, Any]: """Return the Kubernetes deployment template for etcd.""" return { "metadata": {"labels": {"app": "etcd"}}, @@ -390,7 +391,7 @@ def pgvector_service( yield service -def get_pgvector_deployment_template() -> Dict[str, Any]: +def get_pgvector_deployment_template() -> dict[str, Any]: """Return a Kubernetes deployment for PGVector""" return { "metadata": {"labels": {"app": "pgvector"}}, @@ -490,7 +491,7 @@ def qdrant_service( yield service -def get_qdrant_deployment_template() -> Dict[str, Any]: +def get_qdrant_deployment_template() -> dict[str, Any]: """Return a Kubernetes deployment for Qdrant""" return { "metadata": {"labels": {"app": "qdrant"}}, diff --git a/tests/global_config.py b/tests/global_config.py index 41037e5b0..8e6a0935e 100644 --- a/tests/global_config.py +++ b/tests/global_config.py @@ -1,6 +1,6 @@ from utilities.constants import RHOAI_OPERATOR_NAMESPACE -global config # type:ignore[unused-ignore] +global config # type:ignore[unused-ignore] # noqa: PLW0604 dsc_name: str = "default-dsc" must_gather_base_dir: str = "must-gather-base-dir" diff --git a/tests/llama_stack/conftest.py b/tests/llama_stack/conftest.py index 6fea48812..9284c1562 100644 --- a/tests/llama_stack/conftest.py +++ b/tests/llama_stack/conftest.py @@ -1,8 +1,8 @@ -from typing import Generator, Any, Dict, Callable import os +from collections.abc import Callable, Generator +from typing import Any + import httpx -from ocp_resources.route import Route -from ocp_resources.resource import ResourceEditor import pytest from _pytest.fixtures import FixtureRequest from kubernetes.dynamic import DynamicClient @@ -10,26 +10,28 @@ from llama_stack_client.types.vector_store import VectorStore from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.deployment import Deployment - -from utilities.resources.llama_stack_distribution import LlamaStackDistribution from ocp_resources.namespace import Namespace +from ocp_resources.resource import ResourceEditor +from ocp_resources.route import Route +from ocp_resources.secret import Secret +from ocp_resources.service import Service from semver import Version from simple_logger.logger import get_logger -from utilities.general import generate_random_name + +from tests.llama_stack.constants import ( + LLS_OPENSHIFT_MINIMAL_VERSION, + ModelInfo, +) from tests.llama_stack.utils import ( create_llama_stack_distribution, - wait_for_llama_stack_client_ready, vector_store_create_file_from_url, + wait_for_llama_stack_client_ready, wait_for_unique_llama_stack_pod, ) -from utilities.constants import DscComponents, Annotations +from utilities.constants import Annotations, DscComponents from utilities.data_science_cluster_utils import update_components_in_dsc -from tests.llama_stack.constants import ( - LLS_OPENSHIFT_MINIMAL_VERSION, - ModelInfo, -) -from ocp_resources.service import Service -from ocp_resources.secret import Secret +from utilities.general import generate_random_name +from utilities.resources.llama_stack_distribution import LlamaStackDistribution LOGGER = get_logger(name=__name__) @@ -139,9 +141,9 @@ def enabled_llama_stack_operator(dsc_resource: DataScienceCluster) -> Generator[ @pytest.fixture(scope="class") def llama_stack_server_config( request: FixtureRequest, - vector_io_provider_deployment_config_factory: Callable[[str], list[Dict[str, str]]], - files_provider_config_factory: Callable[[str], list[Dict[str, str]]], -) -> Dict[str, Any]: + vector_io_provider_deployment_config_factory: Callable[[str], list[dict[str, str]]], + files_provider_config_factory: Callable[[str], list[dict[str, str]]], +) -> dict[str, Any]: """ Generate server configuration for LlamaStack distribution deployment and deploy vector I/O provider resources. @@ -326,7 +328,7 @@ def test_with_remote_milvus(llama_stack_server_config): env_vars_vector_io = vector_io_provider_deployment_config_factory(provider_name=vector_io_provider) env_vars.extend(env_vars_vector_io) - server_config: Dict[str, Any] = { + server_config: dict[str, Any] = { "containerSpec": { "resources": { "requests": {"cpu": "1", "memory": "3Gi"}, @@ -382,7 +384,7 @@ def unprivileged_llama_stack_distribution( unprivileged_model_namespace: Namespace, enabled_llama_stack_operator: DataScienceCluster, request: FixtureRequest, - llama_stack_server_config: Dict[str, Any], + llama_stack_server_config: dict[str, Any], ci_s3_bucket_name: str, ci_s3_bucket_endpoint: str, ci_s3_bucket_region: str, @@ -391,7 +393,7 @@ def unprivileged_llama_stack_distribution( unprivileged_llama_stack_distribution_secret: Secret, unprivileged_postgres_deployment: Deployment, unprivileged_postgres_service: Service, -) -> Generator[LlamaStackDistribution, None, None]: +) -> Generator[LlamaStackDistribution]: # Distribution name needs a random substring due to bug RHAIENG-999 / RHAIENG-1139 distribution_name = generate_random_name(prefix="llama-stack-distribution") with create_llama_stack_distribution( @@ -417,10 +419,10 @@ def unprivileged_llama_stack_distribution( access_key_id=aws_access_key_id, secret_access_key=aws_secret_access_key, ) - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Failed to clean up S3 files: {e}") - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Failed to clean up S3 files: {e}") @@ -430,7 +432,7 @@ def llama_stack_distribution( model_namespace: Namespace, enabled_llama_stack_operator: DataScienceCluster, request: FixtureRequest, - llama_stack_server_config: Dict[str, Any], + llama_stack_server_config: dict[str, Any], ci_s3_bucket_name: str, ci_s3_bucket_endpoint: str, ci_s3_bucket_region: str, @@ -439,7 +441,7 @@ def llama_stack_distribution( llama_stack_distribution_secret: Secret, postgres_deployment: Deployment, postgres_service: Service, -) -> Generator[LlamaStackDistribution, None, None]: +) -> Generator[LlamaStackDistribution]: # Distribution name needs a random substring due to bug RHAIENG-999 / RHAIENG-1139 with create_llama_stack_distribution( client=admin_client, @@ -464,10 +466,10 @@ def llama_stack_distribution( access_key_id=aws_access_key_id, secret_access_key=aws_secret_access_key, ) - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Failed to clean up S3 files: {e}") - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Failed to clean up S3 files: {e}") @@ -573,14 +575,15 @@ def _create_llama_stack_test_route( Generator[Route, Any, Any]: Route resource with TLS edge termination """ route_name = generate_random_name(prefix="llama-stack", length=12) - with Route( - client=client, - namespace=namespace.name, - name=route_name, - service=f"{deployment.name}-service", - wait_for_resource=True, - ) as route: - with ResourceEditor( + with ( + Route( + client=client, + namespace=namespace.name, + name=route_name, + service=f"{deployment.name}-service", + wait_for_resource=True, + ) as route, + ResourceEditor( patches={ route: { "spec": { @@ -594,9 +597,10 @@ def _create_llama_stack_test_route( }, } } - ): - route.wait(timeout=60) - yield route + ), + ): + route.wait(timeout=60) + yield route @pytest.fixture(scope="class") @@ -736,7 +740,7 @@ def vector_store( unprivileged_llama_stack_client: LlamaStackClient, llama_stack_models: ModelInfo, request: FixtureRequest, -) -> Generator[VectorStore, None, None]: +) -> Generator[VectorStore]: """ Creates a vector store for testing and automatically cleans it up. @@ -769,14 +773,14 @@ def vector_store( try: unprivileged_llama_stack_client.vector_stores.delete(vector_store_id=vector_store.id) LOGGER.info(f"Deleted vector store {vector_store.id}") - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Failed to delete vector store {vector_store.id}: {e}") @pytest.fixture(scope="class") def vector_store_with_example_docs( unprivileged_llama_stack_client: LlamaStackClient, vector_store: VectorStore -) -> Generator[VectorStore, None, None]: +) -> Generator[VectorStore]: """ Creates a vector store with the IBM fourth-quarter 2025 earnings report uploaded. @@ -888,7 +892,7 @@ def postgres_deployment( yield deployment -def get_postgres_deployment_template() -> Dict[str, Any]: +def get_postgres_deployment_template() -> dict[str, Any]: """Return a Kubernetes deployment for PostgreSQL""" return { "metadata": {"labels": {"app": "postgres"}}, diff --git a/tests/llama_stack/constants.py b/tests/llama_stack/constants.py index 8a97a9009..311a12cc5 100644 --- a/tests/llama_stack/constants.py +++ b/tests/llama_stack/constants.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from enum import Enum -from typing import List, NamedTuple, TypedDict +from typing import NamedTuple, TypedDict + +import semver from llama_stack_client.types import Model from semver import VersionInfo -import semver class LlamaStackProviders: @@ -35,16 +36,16 @@ class ModelInfo(NamedTuple): class TurnExpectation(TypedDict): question: str - expected_keywords: List[str] + expected_keywords: list[str] description: str class TurnResult(TypedDict): question: str description: str - expected_keywords: List[str] - found_keywords: List[str] - missing_keywords: List[str] + expected_keywords: list[str] + found_keywords: list[str] + missing_keywords: list[str] response_content: str response_length: int event_count: int @@ -63,7 +64,7 @@ class ValidationSummary(TypedDict): class ValidationResult(TypedDict): success: bool - results: List[TurnResult] + results: list[TurnResult] summary: ValidationSummary @@ -72,11 +73,11 @@ class TorchTuneTestExpectation: """Test expectation for TorchTune documentation questions.""" question: str - expected_keywords: List[str] + expected_keywords: list[str] description: str -TORCHTUNE_TEST_EXPECTATIONS: List[TorchTuneTestExpectation] = [ +TORCHTUNE_TEST_EXPECTATIONS: list[TorchTuneTestExpectation] = [ TorchTuneTestExpectation( question="what is torchtune", expected_keywords=["torchtune", "pytorch", "fine-tuning", "training", "model"], diff --git a/tests/llama_stack/eval/conftest.py b/tests/llama_stack/eval/conftest.py index 859d70d18..aa932d801 100644 --- a/tests/llama_stack/eval/conftest.py +++ b/tests/llama_stack/eval/conftest.py @@ -1,4 +1,5 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -35,7 +36,7 @@ def dataset_pvc(admin_client, model_namespace) -> Generator[PersistentVolumeClai @pytest.fixture(scope="class") def dataset_upload( admin_client: DynamicClient, model_namespace: Namespace, dataset_pvc: PersistentVolumeClaim -) -> Generator[dict[str, Any], None, None]: +) -> Generator[dict[str, Any]]: """ Copies dataset files from an image into the PVC at the location expected by LM-Eval """ diff --git a/tests/llama_stack/eval/test_lmeval_provider.py b/tests/llama_stack/eval/test_lmeval_provider.py index 4f183d7da..ed3a72ac7 100644 --- a/tests/llama_stack/eval/test_lmeval_provider.py +++ b/tests/llama_stack/eval/test_lmeval_provider.py @@ -2,8 +2,7 @@ from tests.llama_stack.constants import LlamaStackProviders from tests.llama_stack.eval.utils import wait_for_eval_job_completion -from utilities.constants import MinIo, QWEN_MODEL_NAME - +from utilities.constants import QWEN_MODEL_NAME, MinIo TRUSTYAI_LMEVAL_ARCEASY = f"{LlamaStackProviders.Eval.TRUSTYAI_LMEVAL}::arc_easy" TRUSTYAI_LMEVAL_CUSTOM = f"{LlamaStackProviders.Eval.TRUSTYAI_LMEVAL}::dk-bench" diff --git a/tests/llama_stack/eval/test_ragas_provider.py b/tests/llama_stack/eval/test_ragas_provider.py index b74a26ce9..54e1249cd 100644 --- a/tests/llama_stack/eval/test_ragas_provider.py +++ b/tests/llama_stack/eval/test_ragas_provider.py @@ -4,7 +4,7 @@ from tests.llama_stack.constants import LlamaStackProviders from tests.llama_stack.eval.utils import wait_for_eval_job_completion -from utilities.constants import MinIo, QWEN_MODEL_NAME +from utilities.constants import QWEN_MODEL_NAME, MinIo RAGAS_DATASET_ID: str = "ragas_dataset" RAGAS_INLINE_BENCHMARK_ID = "ragas_benchmark_inline" @@ -53,7 +53,7 @@ def test_ragas_inline_register_dataset(self, minio_pod, minio_data_connection, l "description": "Sample RAG evaluation dataset for Ragas demo", "size": len(RAGAS_TEST_DATASET), "format": "ragas", - "created_at": datetime.now().isoformat(), + "created_at": datetime.now().isoformat(), # noqa: DTZ005 }, ) @@ -130,7 +130,7 @@ def test_ragas_remote_register_dataset(self, minio_pod, minio_data_connection, l "description": "Sample RAG evaluation dataset for Ragas demo", "size": len(RAGAS_TEST_DATASET), "format": "ragas", - "created_at": datetime.now().isoformat(), + "created_at": datetime.now().isoformat(), # noqa: DTZ005 }, ) diff --git a/tests/llama_stack/inference/test_completions.py b/tests/llama_stack/inference/test_completions.py index 269d085b1..c4a1c6685 100644 --- a/tests/llama_stack/inference/test_completions.py +++ b/tests/llama_stack/inference/test_completions.py @@ -1,6 +1,7 @@ import pytest -from simple_logger.logger import get_logger from llama_stack_client import LlamaStackClient +from simple_logger.logger import get_logger + from tests.llama_stack.constants import ModelInfo LOGGER = get_logger(name=__name__) diff --git a/tests/llama_stack/operator/conftest.py b/tests/llama_stack/operator/conftest.py index 83d807a31..3cda34e6e 100644 --- a/tests/llama_stack/operator/conftest.py +++ b/tests/llama_stack/operator/conftest.py @@ -1,8 +1,11 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any + import pytest from kubernetes.dynamic import DynamicClient -from ocp_resources.pod import Pod from ocp_resources.namespace import Namespace +from ocp_resources.pod import Pod + from tests.llama_stack.constants import LLS_CORE_POD_FILTER from utilities.general import wait_for_pods_by_labels diff --git a/tests/llama_stack/operator/test_llama_stack_distribution.py b/tests/llama_stack/operator/test_llama_stack_distribution.py index 92e9e8e68..652663294 100644 --- a/tests/llama_stack/operator/test_llama_stack_distribution.py +++ b/tests/llama_stack/operator/test_llama_stack_distribution.py @@ -1,6 +1,8 @@ -from typing import Self, Set -from ocp_resources.pod import Pod +from typing import Self + import pytest +from ocp_resources.pod import Pod + from utilities.general import validate_container_images @@ -29,7 +31,7 @@ class TestLlamaStackDistribution: def test_llamastackdistribution_verify_images( self: Self, llama_stack_distribution_pods: Pod, - related_images_refs: Set[str], + related_images_refs: set[str], ) -> None: """ Verify that LlamaStackDistribution container images meet the requirements: diff --git a/tests/llama_stack/responses/test_responses.py b/tests/llama_stack/responses/test_responses.py index d3bf63a31..4e3c69d47 100644 --- a/tests/llama_stack/responses/test_responses.py +++ b/tests/llama_stack/responses/test_responses.py @@ -1,5 +1,6 @@ import pytest from llama_stack_client import LlamaStackClient + from tests.llama_stack.constants import ModelInfo diff --git a/tests/llama_stack/safety/conftest.py b/tests/llama_stack/safety/conftest.py index 0af5acbf4..5933b3191 100644 --- a/tests/llama_stack/safety/conftest.py +++ b/tests/llama_stack/safety/conftest.py @@ -1,7 +1,8 @@ import os import subprocess from base64 import b64encode -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -24,6 +25,7 @@ def guardrails_orchestrator_ssl_cert(guardrails_orchestrator_route: Route): capture_output=True, text=True, timeout=30, + check=False, ) if result.returncode != 0 and "CONNECTED" not in result.stdout: @@ -46,9 +48,9 @@ def guardrails_orchestrator_ssl_cert(guardrails_orchestrator_route: Route): with open(filepath, "w") as f: f.write("\n".join(cert_lines)) - return filepath + return filepath # noqa: TRY300 - except Exception as e: + except Exception as e: # noqa: BLE001 raise RuntimeError(f"Could not get certificate from {hostname}: {e}") @@ -57,7 +59,7 @@ def guardrails_orchestrator_ssl_cert_secret( admin_client: DynamicClient, model_namespace: Namespace, guardrails_orchestrator_ssl_cert: str, # ← Add dependency and use correct cert -) -> Generator[Secret, Any, None]: +) -> Generator[Secret, Any]: with open(guardrails_orchestrator_ssl_cert, "r") as f: cert_content = f.read() diff --git a/tests/llama_stack/safety/test_trustyai_fms_provider.py b/tests/llama_stack/safety/test_trustyai_fms_provider.py index 2484085ab..240aef87a 100644 --- a/tests/llama_stack/safety/test_trustyai_fms_provider.py +++ b/tests/llama_stack/safety/test_trustyai_fms_provider.py @@ -3,8 +3,7 @@ from simple_logger.logger import get_logger from tests.llama_stack.constants import LlamaStackProviders - -from utilities.constants import MinIo, CHAT_GENERATION_CONFIG, BUILTIN_DETECTOR_CONFIG, QWEN_MODEL_NAME +from utilities.constants import BUILTIN_DETECTOR_CONFIG, CHAT_GENERATION_CONFIG, QWEN_MODEL_NAME, MinIo LOGGER = get_logger(name=__name__) SECURE_SHIELD_ID: str = "secure_shield" diff --git a/tests/llama_stack/utils.py b/tests/llama_stack/utils.py index 330fedff2..6a293530b 100644 --- a/tests/llama_stack/utils.py +++ b/tests/llama_stack/utils.py @@ -1,32 +1,27 @@ +import os +import tempfile +from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, cast +from typing import Any, cast +import requests from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError -from llama_stack_client import LlamaStackClient, APIConnectionError, InternalServerError +from llama_stack_client import APIConnectionError, InternalServerError, LlamaStackClient from llama_stack_client.types.vector_store import VectorStore - -from utilities.resources.llama_stack_distribution import LlamaStackDistribution from ocp_resources.pod import Pod from simple_logger.logger import get_logger from timeout_sampler import retry -from utilities.exceptions import UnexpectedResourceCountError - - from tests.llama_stack.constants import ( + LLS_CORE_POD_FILTER, TORCHTUNE_TEST_EXPECTATIONS, - TurnExpectation, ModelInfo, + TurnExpectation, ValidationResult, - LLS_CORE_POD_FILTER, ) - -import os -import tempfile - -import requests - +from utilities.exceptions import UnexpectedResourceCountError +from utilities.resources.llama_stack_distribution import LlamaStackDistribution LOGGER = get_logger(name=__name__) @@ -37,7 +32,7 @@ def create_llama_stack_distribution( name: str, namespace: str, replicas: int, - server: Dict[str, Any], + server: dict[str, Any], teardown: bool = True, ) -> Generator[LlamaStackDistribution, Any, Any]: """ @@ -47,7 +42,7 @@ def create_llama_stack_distribution( # Starting with RHOAI 3.3, pods in the 'openshift-ingress' namespace must be allowed # to access the llama-stack-service. This is required for the llama_stack_test_route # to function properly. - network: Dict[str, Any] = { + network: dict[str, Any] = { "allowedFrom": { "namespaces": ["openshift-ingress"], }, @@ -108,19 +103,19 @@ def wait_for_llama_stack_client_ready(client: LlamaStackClient) -> bool: f"vector_stores:{len(vector_stores.data)} " f"files:{len(files.data)})" ) - return True + return True # noqa: TRY300 except (APIConnectionError, InternalServerError) as error: LOGGER.debug(f"Llama Stack server not ready yet: {error}") - LOGGER.debug(f"Base URL: {client.base_url}, Error type: {type(error)}, Error details: {str(error)}") + LOGGER.debug(f"Base URL: {client.base_url}, Error type: {type(error)}, Error details: {error!s}") return False - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Unexpected error checking Llama Stack readiness: {e}") return False -def get_torchtune_test_expectations() -> List[TurnExpectation]: +def get_torchtune_test_expectations() -> list[TurnExpectation]: """ Helper function to get the test expectations for TorchTune documentation questions. @@ -171,7 +166,7 @@ def _response_fn(*, question: str) -> str: def validate_api_responses( response_fn: Callable[..., str], - test_cases: List[TurnExpectation], + test_cases: list[TurnExpectation], min_keywords_required: int = 1, ) -> ValidationResult: """ diff --git a/tests/llama_stack/vector_io/test_vector_stores.py b/tests/llama_stack/vector_io/test_vector_stores.py index b2c5929d0..152b03003 100644 --- a/tests/llama_stack/vector_io/test_vector_stores.py +++ b/tests/llama_stack/vector_io/test_vector_stores.py @@ -2,11 +2,12 @@ from llama_stack_client import LlamaStackClient from llama_stack_client.types.vector_store import VectorStore from simple_logger.logger import get_logger + from tests.llama_stack.constants import ModelInfo from tests.llama_stack.utils import ( - validate_api_responses, create_response_function, get_torchtune_test_expectations, + validate_api_responses, ) LOGGER = get_logger(name=__name__) diff --git a/tests/model_explainability/conftest.py b/tests/model_explainability/conftest.py index d36bd7e23..5c0e048f1 100644 --- a/tests/model_explainability/conftest.py +++ b/tests/model_explainability/conftest.py @@ -1,4 +1,5 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient diff --git a/tests/model_explainability/guardrails/conftest.py b/tests/model_explainability/guardrails/conftest.py index 6d6ae992e..c09d21909 100644 --- a/tests/model_explainability/guardrails/conftest.py +++ b/tests/model_explainability/guardrails/conftest.py @@ -1,4 +1,5 @@ -from typing import Generator, Any, List +from collections.abc import Generator +from typing import Any import portforward import pytest @@ -29,10 +30,10 @@ from utilities.certificates_utils import create_ca_bundle_file from utilities.constants import ( KServeDeploymentType, - Timeout, RuntimeTemplates, + Timeout, ) -from utilities.inference_utils import create_isvc, LOGGER +from utilities.inference_utils import LOGGER, create_isvc from utilities.operator_utils import get_cluster_service_version from utilities.serving_runtime import ServingRuntimeFromTemplate @@ -155,7 +156,7 @@ def hap_detector_route( @pytest.fixture(scope="class") -def installed_tempo_operator(admin_client: DynamicClient, model_namespace: Namespace) -> Generator[None, Any, None]: +def installed_tempo_operator(admin_client: DynamicClient, model_namespace: Namespace) -> Generator[None, Any]: """ Installs the Tempo operator and waits for its deployment. """ @@ -206,7 +207,7 @@ def tempo_stack( admin_client: DynamicClient, model_namespace: Namespace, minio_secret_otel: Secret, -) -> Generator[Any, Any, None]: +) -> Generator[Any, Any]: """ Create a TempoStack CR in the test namespace, configured to use MinIO backend. """ @@ -265,7 +266,7 @@ def tempo_stack( @pytest.fixture(scope="class") -def installed_opentelemetry_operator(admin_client: DynamicClient) -> Generator[None, Any, None]: +def installed_opentelemetry_operator(admin_client: DynamicClient) -> Generator[None, Any]: """ Installs the Red Hat OpenTelemetry Operator and waits for its deployment. """ @@ -401,7 +402,7 @@ def wait_for_pods_by_label( timeout: Maximum wait time in seconds """ - def _get_pods() -> List[Pod]: + def _get_pods() -> list[Pod]: return [ pod for pod in Pod.get( diff --git a/tests/model_explainability/guardrails/test_guardrails.py b/tests/model_explainability/guardrails/test_guardrails.py index 06d8f1e2d..71b781554 100644 --- a/tests/model_explainability/guardrails/test_guardrails.py +++ b/tests/model_explainability/guardrails/test_guardrails.py @@ -6,30 +6,30 @@ from tests.model_explainability.guardrails.constants import ( AUTOCONFIG_DETECTOR_LABEL, - PII_INPUT_DETECTION_PROMPT, - PII_OUTPUT_DETECTION_PROMPT, - PROMPT_INJECTION_INPUT_DETECTION_PROMPT, + AUTOCONFIG_GATEWAY_ENDPOINT, + CHAT_COMPLETIONS_DETECTION_ENDPOINT, + HAP_DETECTOR, HAP_INPUT_DETECTION_PROMPT, - PII_ENDPOINT, HARMLESS_PROMPT, + PII_ENDPOINT, + PII_INPUT_DETECTION_PROMPT, + PII_OUTPUT_DETECTION_PROMPT, PROMPT_INJECTION_DETECTOR, - HAP_DETECTOR, - CHAT_COMPLETIONS_DETECTION_ENDPOINT, + PROMPT_INJECTION_INPUT_DETECTION_PROMPT, STANDALONE_DETECTION_ENDPOINT, - AUTOCONFIG_GATEWAY_ENDPOINT, ) from tests.model_explainability.guardrails.utils import ( create_detector_config, - verify_health_info_response, - send_and_verify_unsuitable_input_detection, - send_and_verify_unsuitable_output_detection, send_and_verify_negative_detection, send_and_verify_standalone_detection, + send_and_verify_unsuitable_input_detection, + send_and_verify_unsuitable_output_detection, + verify_health_info_response, ) from tests.model_explainability.utils import validate_tai_component_images from utilities.constants import ( - LLM_D_CHAT_GENERATION_CONFIG, BUILTIN_DETECTOR_CONFIG, + LLM_D_CHAT_GENERATION_CONFIG, LLMdInferenceSimConfig, Timeout, ) diff --git a/tests/model_explainability/guardrails/upgrade/test_guardrails_upgrade.py b/tests/model_explainability/guardrails/upgrade/test_guardrails_upgrade.py index 750cfaeec..395632076 100644 --- a/tests/model_explainability/guardrails/upgrade/test_guardrails_upgrade.py +++ b/tests/model_explainability/guardrails/upgrade/test_guardrails_upgrade.py @@ -1,18 +1,19 @@ import pytest import yaml + from tests.model_explainability.guardrails.constants import ( - PII_INPUT_DETECTION_PROMPT, - PII_OUTPUT_DETECTION_PROMPT, HARMLESS_PROMPT, PII_ENDPOINT, + PII_INPUT_DETECTION_PROMPT, + PII_OUTPUT_DETECTION_PROMPT, ) from tests.model_explainability.guardrails.utils import ( - verify_health_info_response, + send_and_verify_negative_detection, send_and_verify_unsuitable_input_detection, send_and_verify_unsuitable_output_detection, - send_and_verify_negative_detection, + verify_health_info_response, ) -from utilities.constants import LLM_D_CHAT_GENERATION_CONFIG, BUILTIN_DETECTOR_CONFIG, LLMdInferenceSimConfig +from utilities.constants import BUILTIN_DETECTOR_CONFIG, LLM_D_CHAT_GENERATION_CONFIG, LLMdInferenceSimConfig from utilities.plugins.constant import OpenAIEnpoints diff --git a/tests/model_explainability/guardrails/utils.py b/tests/model_explainability/guardrails/utils.py index 752d1b3c5..64c16c2ff 100644 --- a/tests/model_explainability/guardrails/utils.py +++ b/tests/model_explainability/guardrails/utils.py @@ -1,21 +1,20 @@ import http import json +from typing import Any import requests from requests import Response from simple_logger.logger import get_logger -from typing import Dict, Any, List, Optional - from timeout_sampler import retry +from tests.model_explainability.guardrails.constants import GuardrailsDetectionPrompt from utilities.exceptions import UnexpectedValueError from utilities.guardrails import get_auth_headers -from tests.model_explainability.guardrails.constants import GuardrailsDetectionPrompt LOGGER = get_logger(name=__name__) -def get_chat_detections_payload(content: str, model: str, detectors: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def get_chat_detections_payload(content: str, model: str, detectors: dict[str, Any] | None = None) -> dict[str, Any]: """ Constructs a chat detections payload for a given content string. @@ -29,7 +28,7 @@ def get_chat_detections_payload(content: str, model: str, detectors: Optional[Di A dictionary representing the chat detections payload. """ - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "model": model, "messages": [ {"role": "user", "content": content}, @@ -53,17 +52,17 @@ def verify_and_parse_response(response: Response) -> Any: return response_json -def assert_no_errors(errors: List[str], failure_message_prefix: str) -> None: +def assert_no_errors(errors: list[str], failure_message_prefix: str) -> None: assert not errors, f"{failure_message_prefix}:\n" + "\n".join(f"- {error}" for error in errors) def verify_detection( - detections_list: List[Dict[str, Any]], + detections_list: list[dict[str, Any]], detector_id: str, detection_name: str, detection_type: str, - expected_detection_text: Optional[str] = None, -) -> List[str]: + expected_detection_text: str | None = None, +) -> list[str]: """ Helper to verify detection results. @@ -235,7 +234,7 @@ def verify_negative_detection_response(response: Response) -> None: assert_no_errors(errors=errors, failure_message_prefix="Negative detection verification failed") -def create_detector_config(*detector_names: str) -> Dict[str, Dict[str, Any]]: +def create_detector_config(*detector_names: str) -> dict[str, dict[str, Any]]: detectors_dict = {name: {} for name in detector_names} return { "input": detectors_dict.copy(), @@ -261,7 +260,7 @@ def _send_guardrails_orchestrator_post_request( url: str, token: str, ca_bundle_file: str, - payload: Dict[str, Any], + payload: dict[str, Any], ) -> requests.Response: response = requests.post( url=url, @@ -282,7 +281,7 @@ def send_chat_detections_request( ca_bundle_file: str, content: str, model: str, - detectors: Dict[str, Any] = None, + detectors: dict[str, Any] | None = None, ) -> requests.Response: payload = get_chat_detections_payload(content=content, model=model, detectors=detectors) return _send_guardrails_orchestrator_post_request( @@ -297,7 +296,7 @@ def send_and_verify_unsuitable_input_detection( ca_bundle_file: str, prompt: GuardrailsDetectionPrompt, model: str, - detectors: Dict[str, Any] = None, + detectors: dict[str, Any] | None = None, ): """Send a prompt to the GuardrailsOrchestrator and verify that it triggers an unsuitable input detection""" response = send_chat_detections_request( @@ -321,7 +320,7 @@ def send_and_verify_unsuitable_output_detection( ca_bundle_file: str, prompt: GuardrailsDetectionPrompt, model: str, - detectors: Dict[str, Any] = None, + detectors: dict[str, Any] | None = None, ): """Send a prompt to the GuardrailsOrchestrator and verify that it triggers an unsuitable output detection""" @@ -345,7 +344,7 @@ def send_and_verify_negative_detection( ca_bundle_file: str, content: str, model: str, - detectors: Dict[str, Any] = None, + detectors: dict[str, Any] | None = None, ): """Send a prompt to the GuardrailsOrchestrator and verify that it doesn't trigger any detection""" diff --git a/tests/model_explainability/lm_eval/conftest.py b/tests/model_explainability/lm_eval/conftest.py index 8a25a2a6c..56c5bf5b5 100644 --- a/tests/model_explainability/lm_eval/conftest.py +++ b/tests/model_explainability/lm_eval/conftest.py @@ -1,5 +1,6 @@ import json -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -17,11 +18,11 @@ from tests.model_explainability.lm_eval.constants import ( ARC_EASY_DATASET_IMAGE, FLAN_T5_IMAGE, - LMEVAL_OCI_TAG, LMEVAL_OCI_REPO, + LMEVAL_OCI_TAG, ) from tests.model_explainability.lm_eval.utils import get_lmevaljob_pod -from utilities.constants import Labels, MinIo, Protocols, Timeout, ApiGroups +from utilities.constants import ApiGroups, Labels, MinIo, Protocols, Timeout from utilities.exceptions import MissingParameter from utilities.general import b64_encoded_string @@ -37,7 +38,7 @@ def lmevaljob_hf( model_namespace: Namespace, patched_dsc_lmeval_allow_all: DataScienceCluster, lmeval_hf_access_token: Secret, -) -> Generator[LMEvalJob, None, None]: +) -> Generator[LMEvalJob]: with LMEvalJob( client=admin_client, name=LMEVALJOB_NAME, @@ -217,7 +218,7 @@ def lmevaljob_vllm_emulator( {"name": "model", "value": "emulatedModel"}, { "name": "base_url", - "value": f"http://{vllm_emulator_service.name}:{str(VLLM_EMULATOR_PORT)}/v1/completions", + "value": f"http://{vllm_emulator_service.name}:{VLLM_EMULATOR_PORT!s}/v1/completions", }, {"name": "num_concurrent", "value": "1"}, {"name": "max_retries", "value": "3"}, diff --git a/tests/model_explainability/lm_eval/test_lm_eval.py b/tests/model_explainability/lm_eval/test_lm_eval.py index fa29c7d4c..7f180520c 100644 --- a/tests/model_explainability/lm_eval/test_lm_eval.py +++ b/tests/model_explainability/lm_eval/test_lm_eval.py @@ -1,29 +1,26 @@ import pytest -from typing import List - from kubernetes.dynamic import DynamicClient from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod +from simple_logger.logger import get_logger -from tests.model_explainability.lm_eval.constants import LMEVAL_OCI_REPO, LMEVAL_OCI_TAG from tests.model_explainability.lm_eval.constants import ( - LLMAAJ_TASK_DATA, - CUSTOM_UNITXT_TASK_DATA, ARC_EASY_DATASET_IMAGE, + CUSTOM_UNITXT_TASK_DATA, + LLMAAJ_TASK_DATA, + LMEVAL_OCI_REPO, + LMEVAL_OCI_TAG, ) -from tests.model_explainability.utils import validate_tai_component_images - from tests.model_explainability.lm_eval.utils import get_lmeval_tasks, validate_lmeval_job_pod_and_logs +from tests.model_explainability.utils import validate_tai_component_images from utilities.constants import OCIRegistry from utilities.registry_utils import pull_manifest_from_oci_registry -from simple_logger.logger import get_logger - LMEVALJOB_COMPLETE_STATE: str = "Complete" -TIER1_LMEVAL_TASKS: List[str] = get_lmeval_tasks(min_downloads=10000) +TIER1_LMEVAL_TASKS: list[str] = get_lmeval_tasks(min_downloads=10000) -TIER2_LMEVAL_TASKS: List[str] = list( +TIER2_LMEVAL_TASKS: list[str] = list( set(get_lmeval_tasks(min_downloads=0.70, max_downloads=10000)) - set(TIER1_LMEVAL_TASKS) ) diff --git a/tests/model_explainability/lm_eval/utils.py b/tests/model_explainability/lm_eval/utils.py index 2634a7041..33c587734 100644 --- a/tests/model_explainability/lm_eval/utils.py +++ b/tests/model_explainability/lm_eval/utils.py @@ -1,16 +1,14 @@ -from typing import List import re -from pyhelper_utils.general import tts + +import pandas as pd from kubernetes.dynamic import DynamicClient from ocp_resources.lm_eval_job import LMEvalJob from ocp_resources.pod import Pod - -from utilities.constants import Timeout +from pyhelper_utils.general import tts from simple_logger.logger import get_logger from timeout_sampler import TimeoutExpiredError -import pandas as pd - +from utilities.constants import Timeout from utilities.exceptions import PodLogMissMatchError, UnexpectedFailureError LOGGER = get_logger(name=__name__) @@ -39,7 +37,7 @@ def get_lmevaljob_pod(client: DynamicClient, lmevaljob: LMEvalJob, timeout: int return lmeval_pod -def get_lmeval_tasks(min_downloads: int | float, max_downloads: int | float | None = None) -> List[str]: +def get_lmeval_tasks(min_downloads: float, max_downloads: float | None = None) -> list[str]: """ Gets the list of supported LM-Eval tasks that have above a certain number of minimum downloads on HuggingFace. diff --git a/tests/model_explainability/trustyai_service/conftest.py b/tests/model_explainability/trustyai_service/conftest.py index 3b401d171..81b97e5d0 100644 --- a/tests/model_explainability/trustyai_service/conftest.py +++ b/tests/model_explainability/trustyai_service/conftest.py @@ -1,5 +1,6 @@ import json -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest import yaml @@ -13,6 +14,7 @@ from ocp_resources.mariadb_operator import MariadbOperator from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod +from ocp_resources.resource import ResourceEditor from ocp_resources.role import Role from ocp_resources.role_binding import RoleBinding from ocp_resources.secret import Secret @@ -23,44 +25,43 @@ from pytest_testconfig import py_config from tests.model_explainability.trustyai_service.constants import ( - TAI_DATA_CONFIG, - TAI_METRICS_CONFIG, - TAI_PVC_STORAGE_CONFIG, + GAUSSIAN_CREDIT_MODEL, + GAUSSIAN_CREDIT_MODEL_RESOURCES, + GAUSSIAN_CREDIT_MODEL_STORAGE_PATH, + ISVC_GETTER, KSERVE_MLSERVER, + KSERVE_MLSERVER_ANNOTATIONS, KSERVE_MLSERVER_CONTAINERS, KSERVE_MLSERVER_SUPPORTED_MODEL_FORMATS, - KSERVE_MLSERVER_ANNOTATIONS, - GAUSSIAN_CREDIT_MODEL_RESOURCES, - XGBOOST, + TAI_DATA_CONFIG, TAI_DB_STORAGE_CONFIG, - ISVC_GETTER, - GAUSSIAN_CREDIT_MODEL_STORAGE_PATH, - GAUSSIAN_CREDIT_MODEL, + TAI_METRICS_CONFIG, + TAI_PVC_STORAGE_CONFIG, + XGBOOST, ) from tests.model_explainability.trustyai_service.trustyai_service_utils import ( wait_for_isvc_deployment_registered_by_trustyai_service, ) from tests.model_explainability.trustyai_service.utils import ( - create_trustyai_service, - wait_for_mariadb_pods, create_isvc_getter_role, create_isvc_getter_role_binding, create_isvc_getter_service_account, create_isvc_getter_token_secret, + create_trustyai_service, + wait_for_mariadb_pods, ) -from utilities.logger import RedactedString -from utilities.operator_utils import get_cluster_service_version from utilities.constants import ( - KServeDeploymentType, - Labels, - OPENSHIFT_OPERATORS, MARIADB, + OPENSHIFT_OPERATORS, TRUSTYAI_SERVICE_NAME, Annotations, + KServeDeploymentType, + Labels, ) from utilities.inference_utils import create_isvc -from ocp_resources.resource import ResourceEditor from utilities.infra import create_inference_token, get_kserve_storage_initialize_image, update_configmap_data +from utilities.logger import RedactedString +from utilities.operator_utils import get_cluster_service_version DB_CREDENTIALS_SECRET_NAME: str = "db-credentials" DB_NAME: str = "trustyai_db" @@ -229,7 +230,7 @@ def mariadb( @pytest.fixture(scope="class") def trustyai_db_ca_secret( admin_client: DynamicClient, model_namespace: Namespace, mariadb: MariaDB -) -> Generator[Secret, Any, None]: +) -> Generator[Secret, Any]: mariadb_ca_secret = Secret( client=admin_client, name=f"{mariadb.name}-ca", namespace=model_namespace.name, ensure_exists=True ) diff --git a/tests/model_explainability/trustyai_service/constants.py b/tests/model_explainability/trustyai_service/constants.py index eac1ad327..33492d9bc 100644 --- a/tests/model_explainability/trustyai_service/constants.py +++ b/tests/model_explainability/trustyai_service/constants.py @@ -1,12 +1,12 @@ -from typing import Dict, Any, List +from typing import Any -from utilities.constants import Ports, ApiGroups +from utilities.constants import ApiGroups, Ports DRIFT_BASE_DATA_PATH: str = "./tests/model_explainability/trustyai_service/drift/model_data" -TAI_DATA_CONFIG: Dict[str, str] = {"filename": "data.csv", "format": "CSV"} -TAI_METRICS_CONFIG: Dict[str, str] = {"schedule": "5s"} -TAI_PVC_STORAGE_CONFIG: Dict[str, str] = {"format": "PVC", "folder": "/inputs", "size": "1Gi"} -TAI_DB_STORAGE_CONFIG: Dict[str, str] = { +TAI_DATA_CONFIG: dict[str, str] = {"filename": "data.csv", "format": "CSV"} +TAI_METRICS_CONFIG: dict[str, str] = {"schedule": "5s"} +TAI_PVC_STORAGE_CONFIG: dict[str, str] = {"format": "PVC", "folder": "/inputs", "size": "1Gi"} +TAI_DB_STORAGE_CONFIG: dict[str, str] = { "format": "DATABASE", "size": "1Gi", "databaseConfigurations": "db-credentials", @@ -21,13 +21,13 @@ GAUSSIAN_CREDIT_MODEL: str = "gaussian-credit-model" GAUSSIAN_CREDIT_MODEL_STORAGE_PATH: str = f"{SKLEARN}/{GAUSSIAN_CREDIT_MODEL.replace('-', '_')}/1" -GAUSSIAN_CREDIT_MODEL_RESOURCES: Dict[str, Dict[str, str]] = { +GAUSSIAN_CREDIT_MODEL_RESOURCES: dict[str, dict[str, str]] = { "requests": {"cpu": "1", "memory": "2Gi"}, "limits": {"cpu": "1", "memory": "2Gi"}, } KSERVE_MLSERVER: str = f"kserve-{MLSERVER}" -KSERVE_MLSERVER_SUPPORTED_MODEL_FORMATS: List[Dict[str, Any]] = [ +KSERVE_MLSERVER_SUPPORTED_MODEL_FORMATS: list[dict[str, Any]] = [ {"name": "sklearn", "version": "0", "autoSelect": True, "priority": 2}, {"name": "sklearn", "version": "1", "autoSelect": True, "priority": 2}, {"name": "xgboost", "version": "1", "autoSelect": True, "priority": 2}, @@ -37,7 +37,7 @@ {"name": "mlflow", "version": "1", "autoSelect": True, "priority": 1}, {"name": "mlflow", "version": "2", "autoSelect": True, "priority": 1}, ] -KSERVE_MLSERVER_CONTAINERS: List[Dict[str, Any]] = [ +KSERVE_MLSERVER_CONTAINERS: list[dict[str, Any]] = [ { "name": "kserve-container", "image": "quay.io/trustyai_testing/mlserver" @@ -51,7 +51,7 @@ "resources": {"requests": {"cpu": "1", "memory": "2Gi"}, "limits": {"cpu": "1", "memory": "2Gi"}}, } ] -KSERVE_MLSERVER_ANNOTATIONS: Dict[str, str] = { +KSERVE_MLSERVER_ANNOTATIONS: dict[str, str] = { f"{ApiGroups.OPENDATAHUB_IO}/accelerator-name": "", f"{ApiGroups.OPENDATAHUB_IO}/template-display-name": "KServe MLServer", "prometheus.kserve.io/path": "/metrics", diff --git a/tests/model_explainability/trustyai_service/drift/test_drift.py b/tests/model_explainability/trustyai_service/drift/test_drift.py index 03dda2bbe..783b69e84 100644 --- a/tests/model_explainability/trustyai_service/drift/test_drift.py +++ b/tests/model_explainability/trustyai_service/drift/test_drift.py @@ -4,16 +4,16 @@ from tests.model_explainability.trustyai_service.constants import DRIFT_BASE_DATA_PATH from tests.model_explainability.trustyai_service.trustyai_service_utils import ( + TrustyAIServiceMetrics, send_inferences_and_verify_trustyai_service_registered, - verify_upload_data_to_trustyai_service, + verify_trustyai_service_metric_delete_request, verify_trustyai_service_metric_request, - TrustyAIServiceMetrics, verify_trustyai_service_metric_scheduling_request, - verify_trustyai_service_metric_delete_request, + verify_upload_data_to_trustyai_service, ) from utilities.constants import MinIo from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG -from utilities.monitoring import validate_metrics_field, get_metric_label +from utilities.monitoring import get_metric_label, validate_metrics_field DRIFT_METRICS = [ TrustyAIServiceMetrics.Drift.MEANSHIFT, diff --git a/tests/model_explainability/trustyai_service/fairness/conftest.py b/tests/model_explainability/trustyai_service/fairness/conftest.py index 83861202e..5fa8fb7f6 100644 --- a/tests/model_explainability/trustyai_service/fairness/conftest.py +++ b/tests/model_explainability/trustyai_service/fairness/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -13,7 +14,7 @@ from tests.model_explainability.trustyai_service.trustyai_service_utils import ( wait_for_isvc_deployment_registered_by_trustyai_service, ) -from utilities.constants import MinIo, ModelFormat, KServeDeploymentType, RuntimeTemplates +from utilities.constants import KServeDeploymentType, MinIo, ModelFormat, RuntimeTemplates from utilities.inference_utils import create_isvc from utilities.serving_runtime import ServingRuntimeFromTemplate diff --git a/tests/model_explainability/trustyai_service/fairness/test_fairness.py b/tests/model_explainability/trustyai_service/fairness/test_fairness.py index 67ec052ce..c0fab936d 100644 --- a/tests/model_explainability/trustyai_service/fairness/test_fairness.py +++ b/tests/model_explainability/trustyai_service/fairness/test_fairness.py @@ -5,16 +5,16 @@ from ocp_resources.inference_service import InferenceService from tests.model_explainability.trustyai_service.trustyai_service_utils import ( - send_inferences_and_verify_trustyai_service_registered, - verify_trustyai_service_name_mappings, - verify_trustyai_service_metric_request, TrustyAIServiceMetrics, + send_inferences_and_verify_trustyai_service_registered, verify_trustyai_service_metric_delete_request, + verify_trustyai_service_metric_request, verify_trustyai_service_metric_scheduling_request, + verify_trustyai_service_name_mappings, ) from utilities.constants import MinIo from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG -from utilities.monitoring import validate_metrics_field, get_metric_label +from utilities.monitoring import get_metric_label, validate_metrics_field BASE_DATA_PATH: str = "./tests/model_explainability/trustyai_service/fairness/model_data" IS_MALE_IDENTIFYING: str = "Is Male-Identifying?" diff --git a/tests/model_explainability/trustyai_service/service/conftest.py b/tests/model_explainability/trustyai_service/service/conftest.py index 0d8de462d..02a0a8289 100644 --- a/tests/model_explainability/trustyai_service/service/conftest.py +++ b/tests/model_explainability/trustyai_service/service/conftest.py @@ -1,4 +1,5 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -8,16 +9,14 @@ from ocp_resources.secret import Secret from ocp_resources.trustyai_service import TrustyAIService - from tests.model_explainability.trustyai_service.constants import ( - TAI_METRICS_CONFIG, TAI_DB_STORAGE_CONFIG, + TAI_METRICS_CONFIG, ) -from utilities.constants import TRUSTYAI_SERVICE_NAME from tests.model_explainability.trustyai_service.utils import ( create_trustyai_service, ) - +from utilities.constants import TRUSTYAI_SERVICE_NAME INVALID_TLS_CERTIFICATE: str = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUJnRENDQVNlZ0F3SUJBZ0lRRGtTcXVuUWRzRmZwdi8zSm\ 5TS2ZoVEFLQmdncWhrak9QUVFEQWpBVk1STXcKRVFZRFZRUURFd3B0WVhKcFlXUmlMV05oTUI0WERUSTFNRFF4TkRFME1EUXhOMW9YRFRJNE1EUXhNekUx\ @@ -37,7 +36,7 @@ def trustyai_service_with_invalid_db_cert( user_workload_monitoring_config: ConfigMap, mariadb: MariaDB, trustyai_invalid_db_ca_secret: None, -) -> Generator[TrustyAIService, None, None]: +) -> Generator[TrustyAIService]: """Create a TrustyAIService deployment with an invalid database certificate set as secret. Yields: @@ -56,7 +55,7 @@ def trustyai_service_with_invalid_db_cert( @pytest.fixture(scope="class") def trustyai_invalid_db_ca_secret( admin_client: DynamicClient, model_namespace: Namespace, mariadb: MariaDB -) -> Generator[Secret, Any, None]: +) -> Generator[Secret, Any]: with Secret( client=admin_client, name=f"{TRUSTYAI_SERVICE_NAME}-db-ca", diff --git a/tests/model_explainability/trustyai_service/service/multi_ns/conftest.py b/tests/model_explainability/trustyai_service/service/multi_ns/conftest.py index 60c3b793d..e2a0200c2 100644 --- a/tests/model_explainability/trustyai_service/service/multi_ns/conftest.py +++ b/tests/model_explainability/trustyai_service/service/multi_ns/conftest.py @@ -1,54 +1,56 @@ import copy +from collections.abc import Generator from contextlib import ExitStack -from typing import Generator, Any, List +from typing import Any + import pytest +from kubernetes.dynamic import DynamicClient +from kubernetes.dynamic.exceptions import ResourceNotFoundError +from ocp_resources.cluster_service_version import ClusterServiceVersion +from ocp_resources.config_map import ConfigMap +from ocp_resources.inference_service import InferenceService +from ocp_resources.maria_db import MariaDB from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod +from ocp_resources.role import Role +from ocp_resources.role_binding import RoleBinding from ocp_resources.secret import Secret from ocp_resources.service import Service +from ocp_resources.service_account import ServiceAccount from ocp_resources.serving_runtime import ServingRuntime -from ocp_resources.inference_service import InferenceService from ocp_resources.trustyai_service import TrustyAIService -from ocp_resources.service_account import ServiceAccount -from ocp_resources.role import Role -from ocp_resources.role_binding import RoleBinding -from utilities.constants import OPENSHIFT_OPERATORS, MARIADB, TRUSTYAI_SERVICE_NAME -from ocp_resources.maria_db import MariaDB -from ocp_resources.config_map import ConfigMap + from tests.model_explainability.trustyai_service.constants import ( - TAI_METRICS_CONFIG, - TAI_DATA_CONFIG, - TAI_PVC_STORAGE_CONFIG, - GAUSSIAN_CREDIT_MODEL_STORAGE_PATH, + GAUSSIAN_CREDIT_MODEL, GAUSSIAN_CREDIT_MODEL_RESOURCES, + GAUSSIAN_CREDIT_MODEL_STORAGE_PATH, + ISVC_GETTER, KSERVE_MLSERVER, + KSERVE_MLSERVER_ANNOTATIONS, KSERVE_MLSERVER_CONTAINERS, KSERVE_MLSERVER_SUPPORTED_MODEL_FORMATS, - KSERVE_MLSERVER_ANNOTATIONS, - XGBOOST, - ISVC_GETTER, - GAUSSIAN_CREDIT_MODEL, + TAI_DATA_CONFIG, TAI_DB_STORAGE_CONFIG, + TAI_METRICS_CONFIG, + TAI_PVC_STORAGE_CONFIG, + XGBOOST, ) from tests.model_explainability.trustyai_service.trustyai_service_utils import ( wait_for_isvc_deployment_registered_by_trustyai_service, ) from tests.model_explainability.trustyai_service.utils import ( - create_trustyai_service, - create_isvc_getter_service_account, create_isvc_getter_role, create_isvc_getter_role_binding, + create_isvc_getter_service_account, create_isvc_getter_token_secret, + create_trustyai_service, wait_for_mariadb_pods, ) -from utilities.constants import KServeDeploymentType +from utilities.constants import MARIADB, OPENSHIFT_OPERATORS, TRUSTYAI_SERVICE_NAME, KServeDeploymentType from utilities.inference_utils import create_isvc from utilities.infra import create_inference_token, create_ns from utilities.minio import create_minio_data_connection_secret from utilities.operator_utils import get_cluster_service_version -from ocp_resources.cluster_service_version import ClusterServiceVersion -from kubernetes.dynamic.exceptions import ResourceNotFoundError -from kubernetes.dynamic import DynamicClient DB_CREDENTIALS_SECRET_NAME: str = "db-credentials" DB_NAME: str = "trustyai_db" @@ -57,7 +59,7 @@ @pytest.fixture(scope="class") -def model_namespaces(request, admin_client) -> Generator[List[Namespace], Any, None]: +def model_namespaces(request, admin_client) -> Generator[list[Namespace], Any]: with ExitStack() as stack: namespaces = [ stack.enter_context(create_ns(admin_client=admin_client, name=param["name"])) for param in request.param @@ -68,7 +70,7 @@ def model_namespaces(request, admin_client) -> Generator[List[Namespace], Any, N @pytest.fixture(scope="class") def minio_data_connection_multi_ns( request, admin_client, model_namespaces, minio_service -) -> Generator[List[Secret], Any, None]: +) -> Generator[list[Secret], Any]: with ExitStack() as stack: secrets = [ stack.enter_context( @@ -87,7 +89,7 @@ def minio_data_connection_multi_ns( @pytest.fixture(scope="class") def trustyai_service_with_pvc_storage_multi_ns( admin_client, model_namespaces, cluster_monitoring_config, user_workload_monitoring_config -) -> Generator[List[TrustyAIService], Any, None]: +) -> Generator[list[TrustyAIService], Any]: with ExitStack() as stack: services = [ stack.enter_context( @@ -109,8 +111,8 @@ def trustyai_service_with_pvc_storage_multi_ns( @pytest.fixture(scope="class") def kserve_logger_ca_bundle_multi_ns( - admin_client: DynamicClient, model_namespaces: List[Namespace] -) -> Generator[List[ConfigMap], Any, None]: + admin_client: DynamicClient, model_namespaces: list[Namespace] +) -> Generator[list[ConfigMap], Any]: """Create CA certificate ConfigMaps required for KServeRaw logger in each namespace.""" with ExitStack() as stack: ca_bundles = [ @@ -130,7 +132,7 @@ def kserve_logger_ca_bundle_multi_ns( @pytest.fixture(scope="class") -def mlserver_runtime_multi_ns(admin_client, model_namespaces) -> Generator[List[ServingRuntime], Any, None]: +def mlserver_runtime_multi_ns(admin_client, model_namespaces) -> Generator[list[ServingRuntime], Any]: with ExitStack() as stack: runtimes = [ stack.enter_context( @@ -154,14 +156,14 @@ def mlserver_runtime_multi_ns(admin_client, model_namespaces) -> Generator[List[ @pytest.fixture(scope="class") def gaussian_credit_model_multi_ns( admin_client: DynamicClient, - model_namespaces: List[Namespace], + model_namespaces: list[Namespace], minio_pod: Pod, minio_service: Service, - minio_data_connection_multi_ns: List[Secret], - mlserver_runtime_multi_ns: List[ServingRuntime], + minio_data_connection_multi_ns: list[Secret], + mlserver_runtime_multi_ns: list[ServingRuntime], kserve_raw_config: ConfigMap, - kserve_logger_ca_bundle_multi_ns: List[ConfigMap], -) -> Generator[List[InferenceService], Any, None]: + kserve_logger_ca_bundle_multi_ns: list[ConfigMap], +) -> Generator[list[InferenceService], Any]: with ExitStack() as stack: models = [] for ns, secret, runtime in zip(model_namespaces, minio_data_connection_multi_ns, mlserver_runtime_multi_ns): @@ -179,7 +181,7 @@ def gaussian_credit_model_multi_ns( wait_for_predictor_pods=False, resources=GAUSSIAN_CREDIT_MODEL_RESOURCES, ) - isvc = stack.enter_context(isvc_context) # noqa: FCN001 + isvc = stack.enter_context(cm=isvc_context) wait_for_isvc_deployment_registered_by_trustyai_service( client=admin_client, @@ -193,7 +195,7 @@ def gaussian_credit_model_multi_ns( @pytest.fixture(scope="class") -def isvc_getter_service_account_multi_ns(admin_client, model_namespaces) -> Generator[List[ServiceAccount], None, None]: +def isvc_getter_service_account_multi_ns(admin_client, model_namespaces) -> Generator[list[ServiceAccount]]: with ExitStack() as stack: sas = [ stack.enter_context(create_isvc_getter_service_account(admin_client, ns, ISVC_GETTER)) @@ -203,7 +205,7 @@ def isvc_getter_service_account_multi_ns(admin_client, model_namespaces) -> Gene @pytest.fixture(scope="class") -def isvc_getter_role_multi_ns(admin_client, model_namespaces) -> Generator[List[Role], None, None]: +def isvc_getter_role_multi_ns(admin_client, model_namespaces) -> Generator[list[Role]]: with ExitStack() as stack: roles = [ stack.enter_context(create_isvc_getter_role(admin_client, ns, f"isvc-getter-{ns.name}")) @@ -218,7 +220,7 @@ def isvc_getter_role_binding_multi_ns( model_namespaces, isvc_getter_role_multi_ns, isvc_getter_service_account_multi_ns, -) -> Generator[List[RoleBinding], None, None]: +) -> Generator[list[RoleBinding]]: with ExitStack() as stack: bindings = [ stack.enter_context( @@ -241,7 +243,7 @@ def isvc_getter_token_secret_multi_ns( model_namespaces, isvc_getter_service_account_multi_ns, isvc_getter_role_binding_multi_ns, -) -> Generator[List[Secret], None, None]: +) -> Generator[list[Secret]]: with ExitStack() as stack: secrets = [ stack.enter_context( @@ -261,7 +263,7 @@ def isvc_getter_token_secret_multi_ns( def isvc_getter_token_multi_ns( isvc_getter_service_account_multi_ns, isvc_getter_token_secret_multi_ns, -) -> List[str]: +) -> list[str]: return [create_inference_token(model_service_account=sa) for sa in isvc_getter_service_account_multi_ns] @@ -273,7 +275,7 @@ def trustyai_service_with_db_storage_multi_ns( user_workload_monitoring_config, mariadb_multi_ns, trustyai_db_ca_secret_multi_ns: None, -) -> Generator[List[TrustyAIService], Any, None]: +) -> Generator[list[TrustyAIService], Any]: with ExitStack() as stack: services = [ stack.enter_context( @@ -295,9 +297,9 @@ def trustyai_service_with_db_storage_multi_ns( @pytest.fixture(scope="class") def trustyai_db_ca_secret_multi_ns( admin_client, - model_namespaces: List[Namespace], - mariadb_multi_ns: List, -) -> Generator[List[Secret], None, None]: + model_namespaces: list[Namespace], + mariadb_multi_ns: list, +) -> Generator[list[Secret]]: """ Creates one trustyai-db-ca secret per namespace, using the corresponding MariaDB CA cert. """ @@ -313,8 +315,8 @@ def trustyai_db_ca_secret_multi_ns( ) ca_cert = mariadb_ca_secret.instance.data["ca.crt"] - secret = stack.enter_context( # noqa: FCN001 - Secret( + secret = stack.enter_context( + cm=Secret( client=admin_client, name=f"{TRUSTYAI_SERVICE_NAME}-db-ca", namespace=ns.name, @@ -327,16 +329,14 @@ def trustyai_db_ca_secret_multi_ns( @pytest.fixture(scope="class") -def db_credentials_secret_multi_ns( - admin_client, model_namespaces: List[Namespace] -) -> Generator[List[Secret], None, None]: +def db_credentials_secret_multi_ns(admin_client, model_namespaces: list[Namespace]) -> Generator[list[Secret]]: """Creates DB credentials Secret in each model namespace.""" with ExitStack() as stack: secrets = [] for ns in model_namespaces: - secret = stack.enter_context( # noqa: FCN001 - Secret( + secret = stack.enter_context( + cm=Secret( client=admin_client, name=DB_CREDENTIALS_SECRET_NAME, namespace=ns.name, @@ -359,10 +359,10 @@ def db_credentials_secret_multi_ns( @pytest.fixture(scope="class") def mariadb_multi_ns( admin_client: DynamicClient, - model_namespaces: List[Namespace], - db_credentials_secret_multi_ns: List[Secret], + model_namespaces: list[Namespace], + db_credentials_secret_multi_ns: list[Secret], mariadb_operator_cr, -) -> Generator[List[MariaDB], Any, Any]: +) -> Generator[list[MariaDB], Any, Any]: mariadb_csv: ClusterServiceVersion = get_cluster_service_version( client=admin_client, prefix=MARIADB, namespace=OPENSHIFT_OPERATORS ) @@ -372,7 +372,7 @@ def mariadb_multi_ns( if not mariadb_dict_template: raise ResourceNotFoundError(f"No MariaDB dict found in alm_examples for CSV {mariadb_csv.name}") - mariadb_instances: List[MariaDB] = [] + mariadb_instances: list[MariaDB] = [] with ExitStack() as stack: for ns, secret in zip(model_namespaces, db_credentials_secret_multi_ns): @@ -395,9 +395,7 @@ def mariadb_multi_ns( mariadb_dict["spec"]["rootPasswordSecretKeyRef"] = password_secret_key_ref mariadb_dict["spec"]["passwordSecretKeyRef"] = password_secret_key_ref - mariadb_instance = stack.enter_context( # noqa: FCN001 - MariaDB(kind_dict=mariadb_dict) - ) + mariadb_instance = stack.enter_context(cm=MariaDB(kind_dict=mariadb_dict)) wait_for_mariadb_pods(client=admin_client, mariadb=mariadb_instance) mariadb_instances.append(mariadb_instance) yield mariadb_instances diff --git a/tests/model_explainability/trustyai_service/service/multi_ns/test_trustyai_service_multi_ns.py b/tests/model_explainability/trustyai_service/service/multi_ns/test_trustyai_service_multi_ns.py index c03fd89f2..1d7cda9d2 100644 --- a/tests/model_explainability/trustyai_service/service/multi_ns/test_trustyai_service_multi_ns.py +++ b/tests/model_explainability/trustyai_service/service/multi_ns/test_trustyai_service_multi_ns.py @@ -1,11 +1,12 @@ import pytest + from tests.model_explainability.trustyai_service.constants import DRIFT_BASE_DATA_PATH from tests.model_explainability.trustyai_service.trustyai_service_utils import ( - send_inferences_and_verify_trustyai_service_registered, - verify_upload_data_to_trustyai_service, TrustyAIServiceMetrics, - verify_trustyai_service_metric_scheduling_request, + send_inferences_and_verify_trustyai_service_registered, verify_trustyai_service_metric_delete_request, + verify_trustyai_service_metric_scheduling_request, + verify_upload_data_to_trustyai_service, ) from utilities.constants import MinIo from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG diff --git a/tests/model_explainability/trustyai_service/service/test_trustyai_service.py b/tests/model_explainability/trustyai_service/service/test_trustyai_service.py index 17eb62ae8..9cf045ebb 100644 --- a/tests/model_explainability/trustyai_service/service/test_trustyai_service.py +++ b/tests/model_explainability/trustyai_service/service/test_trustyai_service.py @@ -6,19 +6,19 @@ DRIFT_BASE_DATA_PATH, TRUSTYAI_DB_MIGRATION_PATCH, ) +from tests.model_explainability.trustyai_service.service.utils import ( + patch_trustyai_service_cr, + wait_for_trustyai_db_migration_complete_log, +) from tests.model_explainability.trustyai_service.trustyai_service_utils import ( - verify_upload_data_to_trustyai_service, TrustyAIServiceMetrics, verify_trustyai_service_metric_scheduling_request, + verify_upload_data_to_trustyai_service, ) from tests.model_explainability.trustyai_service.utils import ( validate_trustyai_service_db_conn_failure, validate_trustyai_service_images, ) -from tests.model_explainability.trustyai_service.service.utils import ( - wait_for_trustyai_db_migration_complete_log, - patch_trustyai_service_cr, -) from utilities.constants import MinIo diff --git a/tests/model_explainability/trustyai_service/service/utils.py b/tests/model_explainability/trustyai_service/service/utils.py index 03d36404f..82392b82e 100644 --- a/tests/model_explainability/trustyai_service/service/utils.py +++ b/tests/model_explainability/trustyai_service/service/utils.py @@ -7,18 +7,21 @@ from ocp_resources.trustyai_service import TrustyAIService from timeout_sampler import retry -from utilities.constants import Timeout, TRUSTYAI_SERVICE_NAME +from utilities.constants import TRUSTYAI_SERVICE_NAME, Timeout @retry(wait_timeout=Timeout.TIMEOUT_5MIN, sleep=5) def wait_for_trustyai_db_migration_complete_log(client: DynamicClient, trustyai_service: TrustyAIService) -> bool: - trustyai_pod = list( - Pod.get( - client=client, - namespace=trustyai_service.namespace, - label_selector=f"app.kubernetes.io/instance={trustyai_service.name}", - ) - )[0] + pods = Pod.get( + client=client, + namespace=trustyai_service.namespace, + label_selector=f"app.kubernetes.io/instance={trustyai_service.name}", + ) + trustyai_pod = next(iter(pods), None) + if trustyai_pod is None: + raise RuntimeError( + f"No TrustyAI pod found for service {trustyai_service.name} in namespace {trustyai_service.namespace}" + ) # noqa: E501 return bool( re.search( r".+INFO.+Migration complete, the PVC is now safe to remove\.", diff --git a/tests/model_explainability/trustyai_service/trustyai_service_utils.py b/tests/model_explainability/trustyai_service/trustyai_service_utils.py index 76d4da8d7..e957af46a 100644 --- a/tests/model_explainability/trustyai_service/trustyai_service_utils.py +++ b/tests/model_explainability/trustyai_service/trustyai_service_utils.py @@ -14,7 +14,7 @@ from timeout_sampler import TimeoutSampler from utilities.certificates_utils import create_ca_bundle_file -from utilities.constants import Protocols, Timeout, TRUSTYAI_SERVICE_NAME +from utilities.constants import TRUSTYAI_SERVICE_NAME, Protocols, Timeout from utilities.exceptions import MetricValidationError from utilities.general import create_isvc_label_selector_str from utilities.inference_utils import Inference, UserInference @@ -25,8 +25,6 @@ class NoMetricsFoundError(ValueError): """Raised when no metrics are available for the requested operation.""" - pass - class TrustyAIServiceMetrics: class Fairness: @@ -264,7 +262,7 @@ def get_num_observations_from_trustyai_service( raise KeyError("Observations data not found in model metadata") except Exception as e: - LOGGER.error(f"Failed to parse response: {str(e)}") + LOGGER.error(f"Failed to parse response: {e!s}") raise @@ -432,9 +430,11 @@ def verify_trustyai_service_response( # Validate required non-empty fields if required_fields: - for field in required_fields: - if field in response_data and response_data[field] == "": - errors.append(f"{field.capitalize()} is empty") + errors.extend([ + f"{field.capitalize()} is empty" + for field in required_fields + if field in response_data and response_data[field] == "" + ]) # Validate expected values if expected_values: diff --git a/tests/model_explainability/trustyai_service/upgrade/test_trustyai_service_upgrade.py b/tests/model_explainability/trustyai_service/upgrade/test_trustyai_service_upgrade.py index 35394031d..f6641fee4 100644 --- a/tests/model_explainability/trustyai_service/upgrade/test_trustyai_service_upgrade.py +++ b/tests/model_explainability/trustyai_service/upgrade/test_trustyai_service_upgrade.py @@ -1,20 +1,20 @@ import pytest +from timeout_sampler import retry +from tests.model_explainability.trustyai_service.constants import DRIFT_BASE_DATA_PATH, TRUSTYAI_DB_MIGRATION_PATCH from tests.model_explainability.trustyai_service.service.utils import ( - wait_for_trustyai_db_migration_complete_log, patch_trustyai_service_cr, + wait_for_trustyai_db_migration_complete_log, ) -from tests.model_explainability.trustyai_service.constants import DRIFT_BASE_DATA_PATH, TRUSTYAI_DB_MIGRATION_PATCH from tests.model_explainability.trustyai_service.trustyai_service_utils import ( + TrustyAIServiceMetrics, send_inferences_and_verify_trustyai_service_registered, - verify_upload_data_to_trustyai_service, verify_trustyai_service_metric_delete_request, - TrustyAIServiceMetrics, verify_trustyai_service_metric_scheduling_request, + verify_upload_data_to_trustyai_service, ) from utilities.constants import MinIo from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG -from timeout_sampler import retry @pytest.mark.parametrize( diff --git a/tests/model_explainability/trustyai_service/utils.py b/tests/model_explainability/trustyai_service/utils.py index 5357da01c..7bbb2eae5 100644 --- a/tests/model_explainability/trustyai_service/utils.py +++ b/tests/model_explainability/trustyai_service/utils.py @@ -1,12 +1,13 @@ -from contextlib import contextmanager -from typing import Generator, Any, Optional import re +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any from kubernetes.dynamic import DynamicClient from ocp_resources.config_map import ConfigMap from ocp_resources.deployment import Deployment -from ocp_resources.mariadb_operator import MariadbOperator from ocp_resources.maria_db import MariaDB +from ocp_resources.mariadb_operator import MariadbOperator from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod from ocp_resources.role import Role @@ -15,12 +16,11 @@ from ocp_resources.service_account import ServiceAccount from ocp_resources.trustyai_service import TrustyAIService from simple_logger.logger import get_logger -from timeout_sampler import TimeoutSampler -from utilities.constants import Timeout, TRUSTYAI_SERVICE_NAME -from timeout_sampler import retry +from timeout_sampler import TimeoutSampler, retry +from utilities.constants import TRUSTYAI_SERVICE_NAME, Timeout from utilities.exceptions import TooManyPodsError, UnexpectedFailureError -from utilities.general import wait_for_pods_by_labels, validate_container_images +from utilities.general import validate_container_images, wait_for_pods_by_labels LOGGER = get_logger(name=__name__) @@ -40,7 +40,7 @@ def wait_for_mariadb_operator_deployments(mariadb_operator: MariadbOperator, cli def wait_for_mariadb_pods(client: DynamicClient, mariadb: MariaDB, timeout: int = Timeout.TIMEOUT_15MIN) -> None: def _get_mariadb_pods() -> list[Pod]: - _pods = [ + return [ _pod for _pod in Pod.get( client=client, @@ -48,7 +48,6 @@ def _get_mariadb_pods() -> list[Pod]: label_selector=f"app.kubernetes.io/instance={mariadb.name}", ) ] - return _pods sampler = TimeoutSampler(wait_timeout=timeout, sleep=1, func=lambda: bool(_get_mariadb_pods())) @@ -67,10 +66,10 @@ def _get_mariadb_pods() -> list[Pod]: @retry( wait_timeout=Timeout.TIMEOUT_2MIN, sleep=5, - exceptions_dict={TooManyPodsError: list(), UnexpectedFailureError: list()}, + exceptions_dict={TooManyPodsError: [], UnexpectedFailureError: []}, ) def validate_trustyai_service_db_conn_failure( - client: DynamicClient, namespace: Namespace, label_selector: Optional[str] + client: DynamicClient, namespace: Namespace, label_selector: str | None ) -> bool: """Validate if invalid DB Certificate leads to pod crash loop. @@ -122,7 +121,7 @@ def create_trustyai_service( storage: dict[str, str], metrics: dict[str, str], name: str = TRUSTYAI_SERVICE_NAME, - data: Optional[dict[str, str]] = None, + data: dict[str, str] | None = None, wait_for_replicas: bool = True, teardown: bool = True, ) -> Generator[TrustyAIService, Any, Any]: @@ -276,11 +275,11 @@ def validate_trustyai_service_images( Raises: AssertionError: If any of the related images references are not present or invalid. """ - tai_image_refs = set( - v - for k, v in trustyai_operator_configmap.instance.data.items() - if k in ["kube-rbac-proxy", "trustyaiServiceImage"] - ) + tai_image_refs = { + value + for key, value in trustyai_operator_configmap.instance.data.items() + if key in ["kube-rbac-proxy", "trustyaiServiceImage"] + } trustyai_service_pod = wait_for_pods_by_labels( admin_client=client, namespace=model_namespace.name, label_selector=label_selector, expected_num_pods=1 )[0] diff --git a/tests/model_explainability/utils.py b/tests/model_explainability/utils.py index cb2585a66..903fd2561 100644 --- a/tests/model_explainability/utils.py +++ b/tests/model_explainability/utils.py @@ -1,4 +1,5 @@ import re + from ocp_resources.config_map import ConfigMap from ocp_resources.pod import Pod diff --git a/tests/model_registry/component_health/conftest.py b/tests/model_registry/component_health/conftest.py index 3e28d7a19..edd03325c 100644 --- a/tests/model_registry/component_health/conftest.py +++ b/tests/model_registry/component_health/conftest.py @@ -1,4 +1,5 @@ import shlex + import pytest from ocp_utilities.monitoring import Prometheus from pyhelper_utils.shell import run_command diff --git a/tests/model_registry/component_health/test_mr_health_check.py b/tests/model_registry/component_health/test_mr_health_check.py index 3bbb6de21..2424d3225 100644 --- a/tests/model_registry/component_health/test_mr_health_check.py +++ b/tests/model_registry/component_health/test_mr_health_check.py @@ -1,13 +1,11 @@ import pytest -from utilities.constants import DscComponents - +from kubernetes.dynamic import DynamicClient from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.namespace import Namespace - -from simple_logger.logger import get_logger from pytest_testconfig import config as py_config -from kubernetes.dynamic import DynamicClient +from simple_logger.logger import get_logger +from utilities.constants import DscComponents from utilities.general import wait_for_pods_running LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/component_health/test_mr_operator_health.py b/tests/model_registry/component_health/test_mr_operator_health.py index 8f832c8e5..d0748d8a3 100644 --- a/tests/model_registry/component_health/test_mr_operator_health.py +++ b/tests/model_registry/component_health/test_mr_operator_health.py @@ -1,7 +1,6 @@ -from simple_logger.logger import get_logger - import pytest from ocp_utilities.monitoring import Prometheus +from simple_logger.logger import get_logger LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/conftest.py b/tests/model_registry/conftest.py index 5d04472b7..168e86d22 100644 --- a/tests/model_registry/conftest.py +++ b/tests/model_registry/conftest.py @@ -1,52 +1,52 @@ -from contextlib import ExitStack -import pytest -from pytest import Config, FixtureRequest -from typing import Generator, Any import os +from collections.abc import Generator +from contextlib import ExitStack +from typing import Any +import pytest +from kubernetes.dynamic import DynamicClient +from ocp_resources.config_map import ConfigMap +from ocp_resources.data_science_cluster import DataScienceCluster +from ocp_resources.deployment import Deployment from ocp_resources.infrastructure import Infrastructure +from ocp_resources.namespace import Namespace from ocp_resources.oauth import OAuth +from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.pod import Pod +from ocp_resources.resource import ResourceEditor from ocp_resources.secret import Secret -from ocp_resources.namespace import Namespace -from ocp_resources.data_science_cluster import DataScienceCluster -from ocp_resources.config_map import ConfigMap from ocp_resources.service import Service -from ocp_resources.persistent_volume_claim import PersistentVolumeClaim -from ocp_resources.deployment import Deployment from ocp_resources.service_account import ServiceAccount -from ocp_resources.resource import ResourceEditor - -from simple_logger.logger import get_logger -from kubernetes.dynamic import DynamicClient +from pytest import Config, FixtureRequest from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger -from utilities.general import wait_for_oauth_openshift_deployment -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry +from tests.model_registry.constants import ( + DB_BASE_RESOURCES_NAME, + DB_RESOURCE_NAME, + KUBERBACPROXY_STR, + MR_INSTANCE_BASE_NAME, + MR_INSTANCE_NAME, + MR_OPERATOR_NAME, +) from tests.model_registry.utils import ( + generate_namespace_name, get_byoidc_user_credentials, - get_model_registry_objects, get_model_registry_metadata_resources, - wait_for_default_resource_cleanedup, - generate_namespace_name, + get_model_registry_objects, get_rest_headers, + wait_for_default_resource_cleanedup, ) - -from utilities.general import generate_random_name, wait_for_pods_running - -from tests.model_registry.constants import ( - MR_OPERATOR_NAME, - MR_INSTANCE_BASE_NAME, - DB_BASE_RESOURCES_NAME, - DB_RESOURCE_NAME, - MR_INSTANCE_NAME, - KUBERBACPROXY_STR, +from utilities.constants import DscComponents, Labels +from utilities.general import ( + generate_random_name, + wait_for_oauth_openshift_deployment, + wait_for_pods_by_labels, + wait_for_pods_running, ) -from utilities.constants import Labels -from utilities.constants import DscComponents -from utilities.general import wait_for_pods_by_labels -from utilities.infra import get_data_science_cluster, wait_for_dsc_status_ready, login_with_user_password -from utilities.user_utils import UserTestSession, wait_for_user_creation, create_htpasswd_file +from utilities.infra import get_data_science_cluster, login_with_user_password, wait_for_dsc_status_ready +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry +from utilities.user_utils import UserTestSession, create_htpasswd_file, wait_for_user_creation DEFAULT_TOKEN_DURATION = "10m" LOGGER = get_logger(name=__name__) @@ -150,7 +150,7 @@ def test_idp_user( api_server_url: str, is_byoidc: bool, admin_client: DynamicClient, -) -> Generator[UserTestSession | None, None, None]: +) -> Generator[UserTestSession | None]: """ Session-scoped fixture that creates a test IDP user and cleans it up after all tests. Returns a UserTestSession object that contains all necessary credentials and contexts. @@ -204,7 +204,7 @@ def api_server_url(admin_client: DynamicClient) -> str: @pytest.fixture(scope="module") def created_htpasswd_secret( is_byoidc: bool, admin_client: DynamicClient, original_user: str, user_credentials_rbac: dict[str, str] -) -> Generator[UserTestSession | None, None, None]: +) -> Generator[UserTestSession | None]: """ Session-scoped fixture that creates a test IDP user and cleans it up after all tests. Returns a UserTestSession object that contains all necessary credentials and contexts. @@ -235,7 +235,7 @@ def created_htpasswd_secret( @pytest.fixture(scope="module") def updated_oauth_config( is_byoidc: bool, admin_client: DynamicClient, original_user: str, user_credentials_rbac: dict[str, str] -) -> Generator[Any, None, None]: +) -> Generator[Any]: if is_byoidc: yield else: @@ -332,7 +332,7 @@ def model_registry_metadata_db_resources( pytestconfig: Config, teardown_resources: bool, model_registry_namespace: str, -) -> Generator[dict[Any, Any], None, None]: +) -> Generator[dict[Any, Any]]: num_resources = getattr(request, "param", {}).get("num_resources", 1) db_backend = getattr(request, "param", {}).get("db_name", "mysql") @@ -401,7 +401,7 @@ def model_registry_rest_headers(current_client_token: str) -> dict[str, str]: @pytest.fixture(scope="class") -def sa_namespace(request: pytest.FixtureRequest, admin_client: DynamicClient) -> Generator[Namespace, None, None]: +def sa_namespace(request: pytest.FixtureRequest, admin_client: DynamicClient) -> Generator[Namespace]: """ Creates a namespace """ @@ -415,9 +415,7 @@ def sa_namespace(request: pytest.FixtureRequest, admin_client: DynamicClient) -> @pytest.fixture() -def login_as_test_user( - is_byoidc: bool, api_server_url: str, original_user: str, test_idp_user -) -> Generator[None, None, None]: +def login_as_test_user(is_byoidc: bool, api_server_url: str, original_user: str, test_idp_user) -> Generator[None]: """ Fixture to log in as a test user and restore original user after test. @@ -451,7 +449,7 @@ def login_as_test_user( @pytest.fixture(scope="class") -def service_account(admin_client: DynamicClient, sa_namespace: Namespace) -> Generator[Any, None, None]: +def service_account(admin_client: DynamicClient, sa_namespace: Namespace) -> Generator[Any]: """ Creates a ServiceAccount. """ diff --git a/tests/model_registry/constants.py b/tests/model_registry/constants.py index 105a4f7d8..ed4597f8b 100644 --- a/tests/model_registry/constants.py +++ b/tests/model_registry/constants.py @@ -5,6 +5,7 @@ from ocp_resources.resource import Resource from ocp_resources.secret import Secret from ocp_resources.service import Service + from utilities.constants import ModelFormat diff --git a/tests/model_registry/image_validation/conftest.py b/tests/model_registry/image_validation/conftest.py index f62e217b6..8159e61d6 100644 --- a/tests/model_registry/image_validation/conftest.py +++ b/tests/model_registry/image_validation/conftest.py @@ -1,28 +1,28 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient - from ocp_resources.pod import Pod -from utilities.general import wait_for_pods_by_labels from pytest import FixtureRequest +from utilities.general import wait_for_pods_by_labels + @pytest.fixture(scope="class") def model_registry_instance_pods_by_label( request: FixtureRequest, admin_client: DynamicClient, model_registry_namespace: str ) -> Generator[list[Pod], Any, Any]: """Get the model registry instance pod.""" - pods = [] - for label in request.param["label_selectors"]: - pods.append( - wait_for_pods_by_labels( - admin_client=admin_client, - namespace=model_registry_namespace, - label_selector=label, - expected_num_pods=1, - )[0] - ) + pods = [ + wait_for_pods_by_labels( + admin_client=admin_client, + namespace=model_registry_namespace, + label_selector=label, + expected_num_pods=1, + )[0] + for label in request.param["label_selectors"] + ] yield pods diff --git a/tests/model_registry/image_validation/test_verify_rhoai_images.py b/tests/model_registry/image_validation/test_verify_rhoai_images.py index 2ac635f20..48f9afb41 100644 --- a/tests/model_registry/image_validation/test_verify_rhoai_images.py +++ b/tests/model_registry/image_validation/test_verify_rhoai_images.py @@ -6,16 +6,17 @@ 3. Images are listed in the CSV's relatedImages section """ +from typing import Self + import pytest -from typing import Self, Set -from simple_logger.logger import get_logger from kubernetes.dynamic import DynamicClient +from ocp_resources.pod import Pod +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger -from tests.model_registry.constants import MR_INSTANCE_NAME, MR_POSTGRES_DEPLOYMENT_NAME_STR, MR_OPERATOR_NAME +from tests.model_registry.constants import MR_INSTANCE_NAME, MR_OPERATOR_NAME, MR_POSTGRES_DEPLOYMENT_NAME_STR from tests.model_registry.image_validation.utils import validate_images from utilities.constants import Labels -from ocp_resources.pod import Pod -from pytest_testconfig import config as py_config LOGGER = get_logger(name=__name__) pytestmark = [pytest.mark.downstream_only, pytest.mark.skip_must_gather, pytest.mark.smoke] @@ -45,7 +46,7 @@ def test_verify_pod_images( self: Self, admin_client: DynamicClient, resource_pods: list[Pod], - related_images_refs: Set[str], + related_images_refs: set[str], ): validate_images(pods_to_validate=resource_pods, related_images_refs=related_images_refs) @@ -78,6 +79,6 @@ def test_verify_model_registry_pod_images( self: Self, admin_client: DynamicClient, model_registry_instance_pods_by_label: list[Pod], - related_images_refs: Set[str], + related_images_refs: set[str], ): validate_images(pods_to_validate=model_registry_instance_pods_by_label, related_images_refs=related_images_refs) diff --git a/tests/model_registry/image_validation/utils.py b/tests/model_registry/image_validation/utils.py index 1932d3199..d3192e0a3 100644 --- a/tests/model_registry/image_validation/utils.py +++ b/tests/model_registry/image_validation/utils.py @@ -1,13 +1,13 @@ -from typing import Set -from simple_logger.logger import get_logger import pytest from ocp_resources.pod import Pod +from simple_logger.logger import get_logger + from utilities.general import validate_container_images LOGGER = get_logger(name=__name__) -def validate_images(pods_to_validate: list[Pod], related_images_refs: Set[str]): +def validate_images(pods_to_validate: list[Pod], related_images_refs: set[str]): validation_errors = [] for pod in pods_to_validate: LOGGER.info(f"Validating {pod.name} in {pod.namespace}") diff --git a/tests/model_registry/model_catalog/catalog_config/conftest.py b/tests/model_registry/model_catalog/catalog_config/conftest.py index de50b38e8..bc54ddff4 100644 --- a/tests/model_registry/model_catalog/catalog_config/conftest.py +++ b/tests/model_registry/model_catalog/catalog_config/conftest.py @@ -1,22 +1,22 @@ -import pytest import re -from typing import Generator -from simple_logger.logger import get_logger +from collections.abc import Generator +import pytest from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import NotFoundError from ocp_resources.config_map import ConfigMap from ocp_resources.resource import ResourceEditor from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger from timeout_sampler import TimeoutSampler from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG -from tests.model_registry.model_catalog.constants import REDHAT_AI_CATALOG_ID, REDHAT_AI_CATALOG_NAME from tests.model_registry.model_catalog.catalog_config.utils import ( filter_models_by_pattern, modify_catalog_source, wait_for_catalog_source_restore, ) +from tests.model_registry.model_catalog.constants import REDHAT_AI_CATALOG_ID, REDHAT_AI_CATALOG_NAME from tests.model_registry.utils import ( get_model_catalog_pod, wait_for_model_catalog_api, @@ -129,7 +129,7 @@ def redhat_ai_models_with_filter( model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], catalog_pod_model_counts: dict[str, int], -) -> Generator[set[str], None, None]: +) -> Generator[set[str]]: """ Unified fixture for applying filters to redhat_ai catalog and yielding expected models. @@ -199,7 +199,7 @@ def disabled_redhat_ai_source( model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], catalog_pod_model_counts: dict[str, int], -) -> Generator[None, None, None]: +) -> Generator[None]: """ Fixture that disables the redhat_ai catalog source and yields control. diff --git a/tests/model_registry/model_catalog/catalog_config/test_custom_model_catalog.py b/tests/model_registry/model_catalog/catalog_config/test_custom_model_catalog.py index f018b690c..e73404fcd 100644 --- a/tests/model_registry/model_catalog/catalog_config/test_custom_model_catalog.py +++ b/tests/model_registry/model_catalog/catalog_config/test_custom_model_catalog.py @@ -1,24 +1,25 @@ +from typing import Self + +import pytest +from kubernetes.dynamic.exceptions import ResourceNotFoundError +from ocp_resources.config_map import ConfigMap +from simple_logger.logger import get_logger + +from tests.model_registry.constants import CUSTOM_CATALOG_ID1, SAMPLE_MODEL_NAME1 from tests.model_registry.model_catalog.constants import ( - EXPECTED_CUSTOM_CATALOG_VALUES, CUSTOM_CATALOG_ID2, - SAMPLE_MODEL_NAME2, - MULTIPLE_CUSTOM_CATALOG_VALUES, - SAMPLE_MODEL_NAME3, + EXPECTED_CUSTOM_CATALOG_VALUES, EXPECTED_HF_CATALOG_VALUES, EXPECTED_MULTIPLE_HF_CATALOG_VALUES, + MULTIPLE_CUSTOM_CATALOG_VALUES, + SAMPLE_MODEL_NAME2, + SAMPLE_MODEL_NAME3, ) -from tests.model_registry.constants import SAMPLE_MODEL_NAME1, CUSTOM_CATALOG_ID1 -from ocp_resources.config_map import ConfigMap -import pytest -from simple_logger.logger import get_logger -from typing import Self -from kubernetes.dynamic.exceptions import ResourceNotFoundError - from tests.model_registry.model_catalog.utils import get_hf_catalog_str from tests.model_registry.utils import ( execute_get_command, - get_sample_yaml_str, get_catalog_str, + get_sample_yaml_str, validate_model_catalog_sources, ) diff --git a/tests/model_registry/model_catalog/catalog_config/test_default_model_catalog.py b/tests/model_registry/model_catalog/catalog_config/test_default_model_catalog.py index 8871a09e9..a7a2f3ab6 100644 --- a/tests/model_registry/model_catalog/catalog_config/test_default_model_catalog.py +++ b/tests/model_registry/model_catalog/catalog_config/test_default_model_catalog.py @@ -1,30 +1,29 @@ -import pytest import random +from typing import Any, Self +import pytest import yaml -from kubernetes.dynamic import DynamicClient from dictdiffer import diff -from ocp_resources.deployment import Deployment -from simple_logger.logger import get_logger -from typing import Self, Any -from timeout_sampler import TimeoutSampler +from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError - -from ocp_resources.pod import Pod from ocp_resources.config_map import ConfigMap +from ocp_resources.deployment import Deployment +from ocp_resources.pod import Pod from ocp_resources.route import Route from ocp_resources.service import Service +from simple_logger.logger import get_logger +from timeout_sampler import TimeoutSampler -from tests.model_registry.constants import DEFAULT_MODEL_CATALOG_CM, DEFAULT_CUSTOM_MODEL_CATALOG -from tests.model_registry.model_catalog.constants import REDHAT_AI_CATALOG_ID, CATALOG_CONTAINER, DEFAULT_CATALOGS +from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM from tests.model_registry.model_catalog.catalog_config.utils import ( - validate_model_catalog_enabled, - validate_model_catalog_resource, - get_validate_default_model_catalog_source, extract_schema_fields, + get_validate_default_model_catalog_source, validate_default_catalog, + validate_model_catalog_enabled, + validate_model_catalog_resource, ) -from tests.model_registry.utils import get_rest_headers, get_model_catalog_pod, execute_get_command +from tests.model_registry.model_catalog.constants import CATALOG_CONTAINER, DEFAULT_CATALOGS, REDHAT_AI_CATALOG_ID +from tests.model_registry.utils import execute_get_command, get_model_catalog_pod, get_rest_headers from utilities.user_utils import UserTestSession LOGGER = get_logger(name=__name__) @@ -193,9 +192,7 @@ def test_model_catalog_default_catalog_sources( assert result items_to_validate = [] if pytestconfig.option.pre_upgrade or pytestconfig.option.post_upgrade: - for catalog in result: - if catalog["id"] in DEFAULT_CATALOGS.keys(): - items_to_validate.append(catalog) + items_to_validate.extend([catalog for catalog in result if catalog["id"] in DEFAULT_CATALOGS]) assert len(items_to_validate) + 1 == len(result) else: items_to_validate = result diff --git a/tests/model_registry/model_catalog/catalog_config/test_default_source_inclusion_exclusion_cleanup.py b/tests/model_registry/model_catalog/catalog_config/test_default_source_inclusion_exclusion_cleanup.py index ee14b237a..0efabb2db 100644 --- a/tests/model_registry/model_catalog/catalog_config/test_default_source_inclusion_exclusion_cleanup.py +++ b/tests/model_registry/model_catalog/catalog_config/test_default_source_inclusion_exclusion_cleanup.py @@ -1,21 +1,22 @@ import pytest +from kubernetes.dynamic.client import DynamicClient +from ocp_resources.resource import ResourceEditor from simple_logger.logger import get_logger from timeout_sampler import TimeoutExpiredError -from ocp_resources.resource import ResourceEditor -from kubernetes.dynamic.client import DynamicClient -from tests.model_registry.model_catalog.constants import ( - REDHAT_AI_CATALOG_ID, - REDHAT_AI_CATALOG_NAME, -) + from tests.model_registry.model_catalog.catalog_config.utils import ( - modify_catalog_source, + filter_models_by_pattern, get_models_from_database_by_source, - wait_for_model_set_match, + modify_catalog_source, validate_cleanup_logging, - filter_models_by_pattern, validate_filter_test_result, validate_source_disabling_result, wait_for_catalog_source_restore, + wait_for_model_set_match, +) +from tests.model_registry.model_catalog.constants import ( + REDHAT_AI_CATALOG_ID, + REDHAT_AI_CATALOG_NAME, ) from tests.model_registry.utils import wait_for_model_catalog_api diff --git a/tests/model_registry/model_catalog/catalog_config/test_model_catalog_negative.py b/tests/model_registry/model_catalog/catalog_config/test_model_catalog_negative.py index f9e382dc2..3a46d54f8 100644 --- a/tests/model_registry/model_catalog/catalog_config/test_model_catalog_negative.py +++ b/tests/model_registry/model_catalog/catalog_config/test_model_catalog_negative.py @@ -1,12 +1,13 @@ -import pytest -from simple_logger.logger import get_logger from typing import Self +import pytest from ocp_resources.config_map import ConfigMap from ocp_resources.resource import ResourceEditor +from simple_logger.logger import get_logger + from tests.model_registry.constants import DEFAULT_MODEL_CATALOG_CM -from tests.model_registry.model_catalog.constants import DEFAULT_CATALOGS from tests.model_registry.model_catalog.catalog_config.utils import validate_model_catalog_configmap_data +from tests.model_registry.model_catalog.constants import DEFAULT_CATALOGS from tests.model_registry.model_catalog.utils import assert_source_error_state_message LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/catalog_config/utils.py b/tests/model_registry/model_catalog/catalog_config/utils.py index 29f4d2f01..f45ab670e 100644 --- a/tests/model_registry/model_catalog/catalog_config/utils.py +++ b/tests/model_registry/model_catalog/catalog_config/utils.py @@ -1,29 +1,29 @@ -from typing import Any -import subprocess -import yaml import re +import subprocess +from typing import Any import pytest +import yaml from kubernetes.dynamic import DynamicClient +from ocp_resources.config_map import ConfigMap +from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from timeout_sampler import retry, TimeoutExpiredError +from timeout_sampler import TimeoutExpiredError, retry -from ocp_resources.pod import Pod -from ocp_resources.config_map import ConfigMap +from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG from tests.model_registry.model_catalog.constants import ( DEFAULT_CATALOGS, REDHAT_AI_CATALOG_ID, REDHAT_AI_CATALOG_NAME, ) -from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG -from tests.model_registry.utils import get_model_catalog_pod -from utilities.constants import Timeout +from tests.model_registry.model_catalog.db_constants import GET_MODELS_BY_SOURCE_ID_DB_QUERY from tests.model_registry.model_catalog.utils import ( - get_models_from_catalog_api, execute_database_query, + get_models_from_catalog_api, parse_psql_output, ) -from tests.model_registry.model_catalog.db_constants import GET_MODELS_BY_SOURCE_ID_DB_QUERY +from tests.model_registry.utils import get_model_catalog_pod +from utilities.constants import Timeout LOGGER = get_logger(name=__name__) @@ -266,9 +266,9 @@ def modify_catalog_source( admin_client: DynamicClient, namespace: str, source_id: str, - enabled: bool = None, - included_models: list[str] = None, - excluded_models: list[str] = None, + enabled: bool | None = None, + included_models: list[str] | None = None, + excluded_models: list[str] | None = None, ) -> dict[str, ConfigMap | dict[str, Any] | str]: """ Modify a catalog source with various configuration changes. diff --git a/tests/model_registry/model_catalog/conftest.py b/tests/model_registry/model_catalog/conftest.py index 3d274ee12..0dd1ea0d2 100644 --- a/tests/model_registry/model_catalog/conftest.py +++ b/tests/model_registry/model_catalog/conftest.py @@ -1,39 +1,40 @@ import random -from typing import Generator, Any -import requests +from collections.abc import Generator +from typing import Any -from simple_logger.logger import get_logger -import yaml import pytest +import requests +import yaml from kubernetes.dynamic import DynamicClient - from ocp_resources.config_map import ConfigMap from ocp_resources.resource import ResourceEditor from ocp_resources.route import Route from ocp_resources.service_account import ServiceAccount +from simple_logger.logger import get_logger + +from tests.model_registry.constants import ( + CUSTOM_CATALOG_ID1, + DEFAULT_CUSTOM_MODEL_CATALOG, +) +from tests.model_registry.model_catalog.catalog_config.utils import get_models_from_database_by_source from tests.model_registry.model_catalog.constants import ( - SAMPLE_MODEL_NAME3, - DEFAULT_CATALOG_FILE, CATALOG_CONTAINER, - REDHAT_AI_CATALOG_ID, + DEFAULT_CATALOG_FILE, DEFAULT_CATALOGS, + REDHAT_AI_CATALOG_ID, + SAMPLE_MODEL_NAME3, ) from tests.model_registry.model_catalog.utils import get_models_from_catalog_api -from tests.model_registry.constants import ( - CUSTOM_CATALOG_ID1, - DEFAULT_CUSTOM_MODEL_CATALOG, -) from tests.model_registry.utils import ( - get_rest_headers, - wait_for_model_catalog_pod_ready_after_deletion, - get_model_catalog_pod, - wait_for_model_catalog_api, execute_get_command, + get_model_catalog_pod, get_model_str, get_mr_user_token, + get_rest_headers, + wait_for_model_catalog_api, + wait_for_model_catalog_pod_ready_after_deletion, ) -from utilities.infra import get_openshift_token, create_inference_token, login_with_user_password -from tests.model_registry.model_catalog.catalog_config.utils import get_models_from_database_by_source +from utilities.infra import create_inference_token, get_openshift_token, login_with_user_password LOGGER = get_logger(name=__name__) @@ -45,7 +46,7 @@ def sparse_override_catalog_source( model_registry_namespace: str, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], -) -> Generator[dict, None, None]: +) -> Generator[dict]: """ Creates a sparse override for an existing default catalog source. @@ -120,7 +121,7 @@ def updated_catalog_config_map( admin_client: DynamicClient, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: if pytestconfig.option.post_upgrade or pytestconfig.option.pre_upgrade: yield catalog_config_map else: @@ -159,7 +160,7 @@ def update_configmap_data_add_model( admin_client: DynamicClient, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: patches = catalog_config_map.instance.to_dict() patches["data"][f"{CUSTOM_CATALOG_ID1.replace('_', '-')}.yaml"] += get_model_str(model=SAMPLE_MODEL_NAME3) with ResourceEditor(patches={catalog_config_map: patches}): @@ -184,7 +185,7 @@ def user_token_for_api_calls( user_credentials_rbac: dict[str, str], service_account: ServiceAccount, model_catalog_rest_url: list[str], -) -> Generator[str, None, None]: +) -> Generator[str]: param = getattr(request, "param", {}) user = param.get("user_type", "admin") LOGGER.info("User used: %s", user) @@ -359,7 +360,7 @@ def labels_configmap_patch( model_registry_namespace: str, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], -) -> Generator[dict[str, Any], None, None]: +) -> Generator[dict[str, Any]]: # Get the editable ConfigMap sources_cm = ConfigMap(name=DEFAULT_CUSTOM_MODEL_CATALOG, client=admin_client, namespace=model_registry_namespace) @@ -405,7 +406,7 @@ def updated_catalog_config_map_scope_function( admin_client: DynamicClient, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: patches = {"data": {"sources.yaml": request.param}} with ResourceEditor(patches={catalog_config_map: patches}): wait_for_model_catalog_pod_ready_after_deletion( diff --git a/tests/model_registry/model_catalog/constants.py b/tests/model_registry/model_catalog/constants.py index 620840aff..5504b9e0e 100644 --- a/tests/model_registry/model_catalog/constants.py +++ b/tests/model_registry/model_catalog/constants.py @@ -1,6 +1,6 @@ from typing import Any -from tests.model_registry.constants import SAMPLE_MODEL_NAME1, CUSTOM_CATALOG_ID1 +from tests.model_registry.constants import CUSTOM_CATALOG_ID1, SAMPLE_MODEL_NAME1 CUSTOM_CATALOG_ID2: str = "sample_custom_catalog2" diff --git a/tests/model_registry/model_catalog/huggingface/conftest.py b/tests/model_registry/model_catalog/huggingface/conftest.py index 8e7c037c6..bac3c8ec3 100644 --- a/tests/model_registry/model_catalog/huggingface/conftest.py +++ b/tests/model_registry/model_catalog/huggingface/conftest.py @@ -1,28 +1,38 @@ -import pytest +import base64 import time -from typing import Any, Generator +from collections.abc import Generator +from typing import Any + +import portforward +import pytest from huggingface_hub import HfApi -from simple_logger.logger import get_logger from kubernetes.dynamic import DynamicClient -from ocp_resources.inference_service import InferenceService from ocp_resources.config_map import ConfigMap -from tests.model_registry.model_catalog.constants import HF_CUSTOM_MODE -from tests.model_registry.model_catalog.utils import get_models_from_catalog_api +from ocp_resources.inference_service import InferenceService +from ocp_resources.namespace import Namespace +from ocp_resources.pod import Pod +from ocp_resources.secret import Secret +from ocp_resources.serving_runtime import ServingRuntime +from pytest_testconfig import py_config +from simple_logger.logger import get_logger +from tests.model_registry.model_catalog.constants import HF_CUSTOM_MODE from tests.model_registry.model_catalog.huggingface.utils import get_huggingface_model_from_api +from tests.model_registry.model_catalog.utils import get_models_from_catalog_api from utilities.infra import create_ns from utilities.operator_utils import get_cluster_service_version -from ocp_resources.serving_runtime import ServingRuntime -from ocp_resources.namespace import Namespace -from ocp_resources.secret import Secret -from ocp_resources.pod import Pod -from pytest_testconfig import py_config -import base64 -import portforward LOGGER = get_logger(name=__name__) +class OpenVINOImageNotFoundError(Exception): + """Exception raised when OpenVINO image is not found in RHOAI CSV relatedImages.""" + + +class PredictorPodNotFoundError(Exception): + """Exception raised when predictor pods are not found for an InferenceService.""" + + def get_openvino_image_from_rhoai_csv(admin_client: DynamicClient) -> str: """ Get the OpenVINO model server image from the RHOAI ClusterServiceVersion. @@ -47,7 +57,7 @@ def get_openvino_image_from_rhoai_csv(admin_client: DynamicClient) -> str: LOGGER.info(f"Found OpenVINO image from RHOAI CSV: {image_url}") return image_url - raise Exception("Could not find odh-openvino-model-server image in RHOAI CSV relatedImages") + raise OpenVINOImageNotFoundError("Could not find odh-openvino-model-server image in RHOAI CSV relatedImages") @pytest.fixture() @@ -385,7 +395,9 @@ def huggingface_predictor_pod( predictor_pods = [pod for pod in pods if "predictor" in pod.name] if not predictor_pods: - raise Exception(f"No predictor pods found for InferenceService {huggingface_inference_service.name}") + raise PredictorPodNotFoundError( + f"No predictor pods found for InferenceService {huggingface_inference_service.name}" + ) pod = predictor_pods[0] # Use the first predictor pod LOGGER.info(f"Found predictor pod: {pod.name} in namespace: {namespace}") @@ -397,7 +409,7 @@ def huggingface_model_portforward( hugging_face_deployment_ns: Namespace, huggingface_inference_service: InferenceService, huggingface_predictor_pod: Pod, -) -> Generator[str, Any, None]: +) -> Generator[str, Any]: """ Port-forwards the HuggingFace OpenVINO model server pod to access the model API locally. Equivalent CLI: diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_exclude_models.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_exclude_models.py index 30598f0a4..31abe7273 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_exclude_models.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_exclude_models.py @@ -1,9 +1,10 @@ -import pytest from typing import Self + +import pytest from ocp_resources.config_map import ConfigMap from simple_logger.logger import get_logger -from tests.model_registry.model_catalog.utils import get_models_from_catalog_api, get_hf_catalog_str +from tests.model_registry.model_catalog.utils import get_hf_catalog_str, get_models_from_catalog_api LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_deployment.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_deployment.py index 3beb11445..97cafbb94 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_deployment.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_deployment.py @@ -2,8 +2,8 @@ import requests from kubernetes.dynamic import DynamicClient from ocp_resources.inference_service import InferenceService -from ocp_resources.serving_runtime import ServingRuntime from ocp_resources.namespace import Namespace +from ocp_resources.serving_runtime import ServingRuntime from simple_logger.logger import get_logger from tests.model_registry.model_catalog.utils import get_hf_catalog_str diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_search.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_search.py index f32f661ea..c3901c715 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_search.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_search.py @@ -1,11 +1,11 @@ from typing import Self import pytest - from ocp_resources.config_map import ConfigMap -from tests.model_registry.model_catalog.utils import get_hf_catalog_str, get_models_from_catalog_api from simple_logger.logger import get_logger +from tests.model_registry.model_catalog.utils import get_hf_catalog_str, get_models_from_catalog_api + LOGGER = get_logger(name=__name__) pytestmark = [pytest.mark.skip_on_disconnected] diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_sorting.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_sorting.py index be912a82c..25a07952c 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_sorting.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_sorting.py @@ -1,7 +1,8 @@ -import pytest from typing import Self +import pytest from ocp_resources.config_map import ConfigMap + from tests.model_registry.model_catalog.sorting.utils import assert_model_sorting from tests.model_registry.model_catalog.utils import get_hf_catalog_str diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_type_classification.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_type_classification.py index e489a5c53..98800b653 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_type_classification.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_type_classification.py @@ -5,9 +5,9 @@ from typing import Self import pytest +from simple_logger.logger import get_logger from tests.model_registry.model_catalog.utils import get_hf_catalog_str, get_models_from_catalog_api -from simple_logger.logger import get_logger LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_validation.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_validation.py index 90b4d8b2e..f123f4ca8 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_model_validation.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_model_validation.py @@ -1,20 +1,23 @@ +from collections.abc import Generator +from typing import Self + import pytest -from typing import Self, Generator +from huggingface_hub import HfApi +from kubernetes.dynamic import DynamicClient from ocp_resources.config_map import ConfigMap from simple_logger.logger import get_logger + from tests.model_registry.model_catalog.constants import HF_MODELS, HF_SOURCE_ID -from tests.model_registry.model_catalog.utils import ( - get_hf_catalog_str, -) from tests.model_registry.model_catalog.huggingface.utils import ( assert_huggingface_values_matches_model_catalog_api_values, - wait_for_huggingface_retrival_match, + get_huggingface_model_from_api, wait_for_hugging_face_model_import, + wait_for_huggingface_retrival_match, wait_for_last_sync_update, - get_huggingface_model_from_api, ) -from huggingface_hub import HfApi -from kubernetes.dynamic import DynamicClient +from tests.model_registry.model_catalog.utils import ( + get_hf_catalog_str, +) LOGGER = get_logger(name=__name__) @@ -48,7 +51,7 @@ class TestLastSyncedMetadataValidation: ) def test_huggingface_last_synced_custom( self: Self, - updated_catalog_config_map_scope_function: Generator[ConfigMap, None, None], + updated_catalog_config_map_scope_function: Generator[ConfigMap], initial_last_synced_values: str, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], @@ -219,7 +222,7 @@ def test_hugging_face_models( self: Self, admin_client: DynamicClient, model_registry_namespace: str, - updated_catalog_config_map_scope_function: Generator[ConfigMap, None, None], + updated_catalog_config_map_scope_function: Generator[ConfigMap], model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], huggingface_api: bool, diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_negative.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_negative.py index b99fc0e55..e51d1b4d3 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_negative.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_negative.py @@ -1,8 +1,9 @@ -import pytest -from simple_logger.logger import get_logger from typing import Self +import pytest from ocp_resources.config_map import ConfigMap +from simple_logger.logger import get_logger + from tests.model_registry.model_catalog.utils import assert_source_error_state_message LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/huggingface/test_huggingface_source_error_validation.py b/tests/model_registry/model_catalog/huggingface/test_huggingface_source_error_validation.py index f883e7149..25811a601 100644 --- a/tests/model_registry/model_catalog/huggingface/test_huggingface_source_error_validation.py +++ b/tests/model_registry/model_catalog/huggingface/test_huggingface_source_error_validation.py @@ -1,13 +1,13 @@ -import pytest import re -from simple_logger.logger import get_logger from typing import Self +import pytest +from kubernetes.dynamic.exceptions import ResourceNotFoundError from ocp_resources.config_map import ConfigMap +from simple_logger.logger import get_logger + from tests.model_registry.model_catalog.huggingface.utils import assert_accessible_models_via_catalog_api from tests.model_registry.utils import execute_get_command -from kubernetes.dynamic.exceptions import ResourceNotFoundError - LOGGER = get_logger(name=__name__) INACCESSIBLE_MODELS: list[str] = [ diff --git a/tests/model_registry/model_catalog/huggingface/utils.py b/tests/model_registry/model_catalog/huggingface/utils.py index 3dbcd041d..790ceb9f7 100644 --- a/tests/model_registry/model_catalog/huggingface/utils.py +++ b/tests/model_registry/model_catalog/huggingface/utils.py @@ -1,14 +1,14 @@ import ast from typing import Any +from huggingface_hub import HfApi +from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger +from timeout_sampler import retry from tests.model_registry.model_catalog.constants import HF_SOURCE_ID from tests.model_registry.model_catalog.utils import get_models_from_catalog_api from tests.model_registry.utils import execute_get_command, get_model_catalog_pod -from huggingface_hub import HfApi -from timeout_sampler import retry -from kubernetes.dynamic import DynamicClient LOGGER = get_logger(name=__name__) @@ -57,11 +57,11 @@ def get_huggingface_nested_attributes(obj, attr_path) -> Any: if not hasattr(current_obj, attr): return None current_obj = getattr(current_obj, attr) - return current_obj + return current_obj # noqa: TRY300 except AttributeError as e: LOGGER.error(f"AttributeError getting '{attr_path}': {e}") return None - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.error(f"Unexpected error getting '{attr_path}': {e}") return None diff --git a/tests/model_registry/model_catalog/metadata/test_catalog_preview.py b/tests/model_registry/model_catalog/metadata/test_catalog_preview.py index 8eb04cac4..4e023a009 100644 --- a/tests/model_registry/model_catalog/metadata/test_catalog_preview.py +++ b/tests/model_registry/model_catalog/metadata/test_catalog_preview.py @@ -1,15 +1,17 @@ -import pytest from typing import Self + +import pytest import requests import yaml from simple_logger.logger import get_logger + +from tests.model_registry.model_catalog.constants import VALIDATED_CATALOG_FILE, VALIDATED_CATALOG_ID from tests.model_registry.model_catalog.metadata.utils import ( - execute_model_catalog_post_command, build_catalog_preview_config, + execute_model_catalog_post_command, validate_catalog_preview_counts, validate_catalog_preview_items, ) -from tests.model_registry.model_catalog.constants import VALIDATED_CATALOG_ID, VALIDATED_CATALOG_FILE LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/metadata/test_custom_properties.py b/tests/model_registry/model_catalog/metadata/test_custom_properties.py index b612904ce..21197af74 100644 --- a/tests/model_registry/model_catalog/metadata/test_custom_properties.py +++ b/tests/model_registry/model_catalog/metadata/test_custom_properties.py @@ -1,16 +1,16 @@ -import pytest from typing import Any +import pytest from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger -from tests.model_registry.model_catalog.constants import VALIDATED_CATALOG_ID, REDHAT_AI_CATALOG_ID +from tests.model_registry.model_catalog.constants import REDHAT_AI_CATALOG_ID, VALIDATED_CATALOG_ID from tests.model_registry.model_catalog.metadata.utils import ( extract_custom_property_values, - validate_custom_properties_match_metadata, get_metadata_from_catalog_pod, + validate_custom_properties_match_metadata, ) -from tests.model_registry.utils import get_model_catalog_pod, execute_get_command +from tests.model_registry.utils import execute_get_command, get_model_catalog_pod LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/metadata/test_filter_options_endpoint.py b/tests/model_registry/model_catalog/metadata/test_filter_options_endpoint.py index fa0eb6d2f..36fc6cf5b 100644 --- a/tests/model_registry/model_catalog/metadata/test_filter_options_endpoint.py +++ b/tests/model_registry/model_catalog/metadata/test_filter_options_endpoint.py @@ -1,6 +1,14 @@ -import pytest from typing import Self + +import pytest +from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger + +from tests.model_registry.model_catalog.db_constants import ( + API_COMPUTED_FILTER_FIELDS, + API_EXCLUDED_FILTER_FIELDS, + FILTER_OPTIONS_DB_QUERY, +) from tests.model_registry.model_catalog.metadata.utils import ( compare_filter_options_with_database, ) @@ -8,14 +16,8 @@ execute_database_query, parse_psql_output, ) -from tests.model_registry.model_catalog.db_constants import ( - FILTER_OPTIONS_DB_QUERY, - API_EXCLUDED_FILTER_FIELDS, - API_COMPUTED_FILTER_FIELDS, -) -from tests.model_registry.utils import get_rest_headers, execute_get_command +from tests.model_registry.utils import execute_get_command, get_rest_headers from utilities.user_utils import UserTestSession -from kubernetes.dynamic import DynamicClient LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/metadata/test_labels_endpoint.py b/tests/model_registry/model_catalog/metadata/test_labels_endpoint.py index c7c4512d7..fbe9972b9 100644 --- a/tests/model_registry/model_catalog/metadata/test_labels_endpoint.py +++ b/tests/model_registry/model_catalog/metadata/test_labels_endpoint.py @@ -1,17 +1,16 @@ -import pytest from typing import Any + +import pytest from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger - - -from utilities.infra import get_openshift_token from timeout_sampler import TimeoutSampler from tests.model_registry.model_catalog.metadata.utils import ( - get_labels_from_configmaps, get_labels_from_api, + get_labels_from_configmaps, verify_labels_match, ) +from utilities.infra import get_openshift_token LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/metadata/test_sources_endpoint.py b/tests/model_registry/model_catalog/metadata/test_sources_endpoint.py index 3382e4817..9d190778d 100644 --- a/tests/model_registry/model_catalog/metadata/test_sources_endpoint.py +++ b/tests/model_registry/model_catalog/metadata/test_sources_endpoint.py @@ -1,5 +1,4 @@ import pytest - from simple_logger.logger import get_logger from tests.model_registry.model_catalog.constants import REDHAT_AI_CATALOG_ID diff --git a/tests/model_registry/model_catalog/metadata/utils.py b/tests/model_registry/model_catalog/metadata/utils.py index 9345be51c..f4694351d 100644 --- a/tests/model_registry/model_catalog/metadata/utils.py +++ b/tests/model_registry/model_catalog/metadata/utils.py @@ -1,13 +1,14 @@ -from typing import Any, List, Dict, Tuple import json -import requests -import yaml from fnmatch import fnmatch +from typing import Any -from simple_logger.logger import get_logger -from ocp_resources.pod import Pod -from ocp_resources.config_map import ConfigMap +import requests +import yaml from kubernetes.dynamic import DynamicClient +from ocp_resources.config_map import ConfigMap +from ocp_resources.pod import Pod +from simple_logger.logger import get_logger + from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM from tests.model_registry.utils import execute_get_command, get_rest_headers @@ -293,7 +294,7 @@ def get_metadata_from_catalog_pod(model_catalog_pod: Pod, model_name: str) -> di metadata_json = model_catalog_pod.execute(command=["cat", metadata_path], container=CATALOG_CONTAINER) metadata = json.loads(metadata_json) LOGGER.info(f"Successfully loaded metadata.json for model '{model_name}'") - return metadata + return metadata # noqa: TRY300 except Exception as e: LOGGER.error(f"Failed to read metadata.json for model '{model_name}': {e}") raise @@ -301,7 +302,7 @@ def get_metadata_from_catalog_pod(model_catalog_pod: Pod, model_name: str) -> di def compare_filter_options_with_database( api_filters: dict[str, Any], db_properties: dict[str, list[str]], excluded_fields: set[str] -) -> Tuple[bool, List[str]]: +) -> tuple[bool, list[str]]: """ Compare API filter options response with database query results. @@ -323,7 +324,8 @@ def compare_filter_options_with_database( LOGGER.info(f"Database returned {len(db_properties)} total properties") LOGGER.info( - f"After applying API filtering, expecting {len(expected_properties)} properties: {list(expected_properties.keys())}" # noqa: E501 + f"After applying API filtering, expecting {len(expected_properties)}" + f" properties: {list(expected_properties.keys())}" ) # Check for missing/extra properties @@ -412,7 +414,7 @@ def compare_filter_options_with_database( return is_valid, comparison_errors -def get_labels_from_configmaps(admin_client: DynamicClient, namespace: str) -> List[Dict[str, Any]]: +def get_labels_from_configmaps(admin_client: DynamicClient, namespace: str) -> list[dict[str, Any]]: """ Get all labels from both model catalog ConfigMaps. @@ -440,7 +442,7 @@ def get_labels_from_configmaps(admin_client: DynamicClient, namespace: str) -> L return labels -def get_labels_from_api(model_catalog_rest_url: str, user_token: str) -> List[Dict[str, Any]]: +def get_labels_from_api(model_catalog_rest_url: str, user_token: str) -> list[dict[str, Any]]: """ Get labels from the API endpoint. @@ -458,7 +460,7 @@ def get_labels_from_api(model_catalog_rest_url: str, user_token: str) -> List[Di return response["items"] -def verify_labels_match(expected_labels: List[Dict[str, Any]], api_labels: List[Dict[str, Any]]) -> None: +def verify_labels_match(expected_labels: list[dict[str, Any]], api_labels: list[dict[str, Any]]) -> None: """ Verify that all expected labels are present in the API response. diff --git a/tests/model_registry/model_catalog/rbac/test_catalog_rbac.py b/tests/model_registry/model_catalog/rbac/test_catalog_rbac.py index 4c13f827a..c914b6f83 100644 --- a/tests/model_registry/model_catalog/rbac/test_catalog_rbac.py +++ b/tests/model_registry/model_catalog/rbac/test_catalog_rbac.py @@ -3,12 +3,11 @@ """ import pytest -from simple_logger.logger import get_logger - -from kubernetes.dynamic import DynamicClient from kubernetes.client.rest import ApiException +from kubernetes.dynamic import DynamicClient from ocp_resources.config_map import ConfigMap from ocp_resources.resource import get_client +from simple_logger.logger import get_logger from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM diff --git a/tests/model_registry/model_catalog/search/test_model_artifact_search.py b/tests/model_registry/model_catalog/search/test_model_artifact_search.py index 3afbcc4e9..8f967be81 100644 --- a/tests/model_registry/model_catalog/search/test_model_artifact_search.py +++ b/tests/model_registry/model_catalog/search/test_model_artifact_search.py @@ -1,20 +1,21 @@ -import pytest -from typing import Self, Any import random +from typing import Any, Self + +import pytest from dictdiffer import diff +from simple_logger.logger import get_logger +from tests.model_registry.model_catalog.constants import ( + METRICS_ARTIFACT_TYPE, + MODEL_ARTIFACT_TYPE, + VALIDATED_CATALOG_ID, +) from tests.model_registry.model_catalog.search.utils import ( fetch_all_artifacts_with_dynamic_paging, validate_model_artifacts_match_criteria_and, validate_model_artifacts_match_criteria_or, validate_recommendations_subset, ) -from tests.model_registry.model_catalog.constants import ( - VALIDATED_CATALOG_ID, - MODEL_ARTIFACT_TYPE, - METRICS_ARTIFACT_TYPE, -) -from simple_logger.logger import get_logger LOGGER = get_logger(name=__name__) pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")] diff --git a/tests/model_registry/model_catalog/search/test_model_search.py b/tests/model_registry/model_catalog/search/test_model_search.py index 188a87063..16db2e673 100644 --- a/tests/model_registry/model_catalog/search/test_model_search.py +++ b/tests/model_registry/model_catalog/search/test_model_search.py @@ -1,26 +1,28 @@ +from typing import Any, Self + import pytest from dictdiffer import diff +from kubernetes.dynamic import DynamicClient +from kubernetes.dynamic.exceptions import ResourceNotFoundError from simple_logger.logger import get_logger -from typing import Self, Any + from tests.model_registry.model_catalog.constants import ( REDHAT_AI_CATALOG_ID, - VALIDATED_CATALOG_ID, REDHAT_AI_CATALOG_NAME, REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME, + VALIDATED_CATALOG_ID, ) -from tests.model_registry.model_catalog.utils import get_models_from_catalog_api from tests.model_registry.model_catalog.search.utils import ( fetch_all_artifacts_with_dynamic_paging, - validate_model_contains_search_term, - validate_search_results_against_database, validate_filter_query_results_against_database, - validate_performance_data_files_on_pod, validate_model_artifacts_match_criteria_and, validate_model_artifacts_match_criteria_or, + validate_model_contains_search_term, + validate_performance_data_files_on_pod, + validate_search_results_against_database, ) +from tests.model_registry.model_catalog.utils import get_models_from_catalog_api from tests.model_registry.utils import get_model_catalog_pod -from kubernetes.dynamic import DynamicClient -from kubernetes.dynamic.exceptions import ResourceNotFoundError LOGGER = get_logger(name=__name__) pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")] @@ -135,7 +137,12 @@ class TestSearchModelCatalogQParameter: "search_term", [ pytest.param( - "The Llama 4 collection of models are natively multimodal AI models that enable text and multimodal experiences. These models leverage a mixture-of-experts architecture to offer industry-leading performance in text and image understanding. These Llama 4 models mark the beginning of a new era for the Llama ecosystem. We are launching two efficient models in the Llama 4 series, Llama 4 Scout, a 17 billion parameter model with 16 experts, and Llama 4 Maverick, a 17 billion parameter model with 128 experts.", # noqa: E501 + "The Llama 4 collection of models are natively multimodal AI models that enable text and multimodal " + "experiences. These models leverage a mixture-of-experts architecture to offer industry-leading " + "performance in text and image understanding. These Llama 4 models mark the beginning of a new era " + "for the Llama ecosystem. We are launching two efficient models in the Llama 4 series, Llama 4 " + "Scout, a 17 billion parameter model with 16 experts, and Llama 4 Maverick, a 17 billion parameter " + "model with 128 experts.", id="long_description", ), ], @@ -213,8 +220,8 @@ def test_q_parameter_with_source_label_filter( ) # Combined filter results should be a subset of search-only results - search_only_model_ids = set(m.get("id") for m in search_only_response.get("items", [])) - combined_model_ids = set(m.get("id") for m in models) + search_only_model_ids = {m.get("id") for m in search_only_response.get("items", [])} + combined_model_ids = {m.get("id") for m in models} assert combined_model_ids.issubset(search_only_model_ids), ( f"Combined filter results should be a subset of search-only results. " diff --git a/tests/model_registry/model_catalog/search/utils.py b/tests/model_registry/model_catalog/search/utils.py index b51c87d66..fc60f352b 100644 --- a/tests/model_registry/model_catalog/search/utils.py +++ b/tests/model_registry/model_catalog/search/utils.py @@ -1,26 +1,27 @@ """Utility functions for model catalog search tests.""" from typing import Any -from simple_logger.logger import get_logger + +from kubernetes.dynamic import DynamicClient from ocp_resources.pod import Pod +from simple_logger.logger import get_logger from tests.model_registry.model_catalog.constants import ( + CATALOG_CONTAINER, + PERFORMANCE_DATA_DIR, + REDHAT_AI_CATALOG_ID, REDHAT_AI_CATALOG_NAME, REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME, - REDHAT_AI_CATALOG_ID, VALIDATED_CATALOG_ID, - CATALOG_CONTAINER, - PERFORMANCE_DATA_DIR, ) from tests.model_registry.model_catalog.db_constants import ( + FILTER_MODELS_BY_LICENSE_AND_LANGUAGE_DB_QUERY, + FILTER_MODELS_BY_LICENSE_DB_QUERY, SEARCH_MODELS_DB_QUERY, SEARCH_MODELS_WITH_SOURCE_ID_DB_QUERY, - FILTER_MODELS_BY_LICENSE_DB_QUERY, - FILTER_MODELS_BY_LICENSE_AND_LANGUAGE_DB_QUERY, ) from tests.model_registry.model_catalog.utils import execute_database_query, parse_psql_output from tests.model_registry.utils import execute_get_command -from kubernetes.dynamic import DynamicClient LOGGER = get_logger(name=__name__) @@ -86,7 +87,8 @@ def get_models_matching_search_from_database( catalog_id = VALIDATED_CATALOG_ID else: raise ValueError( - f"Unknown source_label: '{source_label}'. Supported labels: {REDHAT_AI_CATALOG_NAME}, {REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME}" # noqa: E501 + f"Unknown source_label: '{source_label}'. " + f"Supported labels: {REDHAT_AI_CATALOG_NAME}, {REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME}" ) # Use the extended query with source_id filtering from db_constants @@ -165,7 +167,7 @@ def _compare_api_and_database_results( # Get actual results from API api_models = api_response.get("items", []) - actual_model_ids = set(model.get("id") for model in api_models if model.get("id")) + actual_model_ids = {model.get("id") for model in api_models if model.get("id")} LOGGER.info(f"API returned {len(actual_model_ids)} models for {description}") # Compare results diff --git a/tests/model_registry/model_catalog/sorting/test_model_artifacts_sorting.py b/tests/model_registry/model_catalog/sorting/test_model_artifacts_sorting.py index d91300d32..40833f989 100644 --- a/tests/model_registry/model_catalog/sorting/test_model_artifacts_sorting.py +++ b/tests/model_registry/model_catalog/sorting/test_model_artifacts_sorting.py @@ -1,13 +1,15 @@ -import pytest -from typing import Self import random +from typing import Self + +import pytest from simple_logger.logger import get_logger + +from tests.model_registry.model_catalog.constants import VALIDATED_CATALOG_ID from tests.model_registry.model_catalog.sorting.utils import ( get_artifacts_with_sorting, validate_items_sorted_correctly, verify_custom_properties_sorted, ) -from tests.model_registry.model_catalog.constants import VALIDATED_CATALOG_ID LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/sorting/test_model_sorting.py b/tests/model_registry/model_catalog/sorting/test_model_sorting.py index 6fc371519..38be27a6e 100644 --- a/tests/model_registry/model_catalog/sorting/test_model_sorting.py +++ b/tests/model_registry/model_catalog/sorting/test_model_sorting.py @@ -1,16 +1,18 @@ -import pytest from typing import Self + +import pytest +from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger + from tests.model_registry.model_catalog.constants import ( REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME, VALIDATED_CATALOG_ID, ) -from tests.model_registry.model_catalog.utils import get_models_from_catalog_api from tests.model_registry.model_catalog.sorting.utils import ( - validate_accuracy_sorting_against_database, get_model_latencies, + validate_accuracy_sorting_against_database, ) -from kubernetes.dynamic import DynamicClient +from tests.model_registry.model_catalog.utils import get_models_from_catalog_api LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_catalog/sorting/test_sorting_functionality.py b/tests/model_registry/model_catalog/sorting/test_sorting_functionality.py index 2f11b4477..8fc78e76c 100644 --- a/tests/model_registry/model_catalog/sorting/test_sorting_functionality.py +++ b/tests/model_registry/model_catalog/sorting/test_sorting_functionality.py @@ -1,6 +1,8 @@ -import pytest from typing import Self + +import pytest from simple_logger.logger import get_logger + from tests.model_registry.model_catalog.sorting.utils import ( get_sources_with_sorting, validate_items_sorted_correctly, diff --git a/tests/model_registry/model_catalog/sorting/utils.py b/tests/model_registry/model_catalog/sorting/utils.py index 4010ad0c0..a240ec53b 100644 --- a/tests/model_registry/model_catalog/sorting/utils.py +++ b/tests/model_registry/model_catalog/sorting/utils.py @@ -1,17 +1,18 @@ from typing import Any +from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger -from tests.model_registry.utils import execute_get_command -from tests.model_registry.model_catalog.utils import ( - execute_database_query, - parse_psql_output, - get_models_from_catalog_api, -) + from tests.model_registry.model_catalog.db_constants import ( GET_MODELS_BY_ACCURACY_DB_QUERY, GET_MODELS_BY_ACCURACY_WITH_TASK_FILTER_DB_QUERY, ) -from kubernetes.dynamic import DynamicClient +from tests.model_registry.model_catalog.utils import ( + execute_database_query, + get_models_from_catalog_api, + parse_psql_output, +) +from tests.model_registry.utils import execute_get_command LOGGER = get_logger(name=__name__) @@ -403,7 +404,7 @@ def _verify_models_with_accuracy_sorted( return False else: # Only validate presence, not order - actual_names = set([name for _, name in models]) + actual_names = {name for _, name in models} expected_names = set(expected_models) if actual_names != expected_names: LOGGER.error("Models with accuracy do not match expected models from database") diff --git a/tests/model_registry/model_catalog/upgrade/test_model_catalog_upgrade.py b/tests/model_registry/model_catalog/upgrade/test_model_catalog_upgrade.py index 9c80b6317..e9eae1a31 100644 --- a/tests/model_registry/model_catalog/upgrade/test_model_catalog_upgrade.py +++ b/tests/model_registry/model_catalog/upgrade/test_model_catalog_upgrade.py @@ -1,17 +1,19 @@ +from collections.abc import Generator +from typing import Self + import pytest import yaml -from simple_logger.logger import get_logger -from typing import Self, Generator from kubernetes.dynamic import DynamicClient - from ocp_resources.config_map import ConfigMap from ocp_resources.resource import ResourceEditor -from tests.model_registry.constants import SAMPLE_MODEL_NAME1, CUSTOM_CATALOG_ID1 +from simple_logger.logger import get_logger + +from tests.model_registry.constants import CUSTOM_CATALOG_ID1, SAMPLE_MODEL_NAME1 from tests.model_registry.utils import ( - wait_for_model_catalog_pod_ready_after_deletion, - wait_for_model_catalog_api, get_catalog_str, get_sample_yaml_str, + wait_for_model_catalog_api, + wait_for_model_catalog_pod_ready_after_deletion, ) LOGGER = get_logger(name=__name__) @@ -52,7 +54,7 @@ def post_upgrade_config_map_update( catalog_config_map: ConfigMap, admin_client: DynamicClient, model_registry_namespace: str, -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: """Fixture for updating catalog config map after post upgrade testing is done""" yield catalog_config_map # Only teardown is needed diff --git a/tests/model_registry/model_catalog/utils.py b/tests/model_registry/model_catalog/utils.py index 9a635f6c5..762177cda 100644 --- a/tests/model_registry/model_catalog/utils.py +++ b/tests/model_registry/model_catalog/utils.py @@ -1,11 +1,11 @@ from typing import Any +from kubernetes.dynamic import DynamicClient +from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from ocp_resources.pod import Pod from tests.model_registry.model_catalog.constants import HF_MODELS from tests.model_registry.utils import execute_get_command -from kubernetes.dynamic import DynamicClient LOGGER = get_logger(name=__name__) @@ -195,7 +195,7 @@ def get_models_from_catalog_api( return execute_get_command(url=base_url, headers=model_registry_rest_headers, params=params) -def get_hf_catalog_str(ids: list[str], excluded_models: list[str] = None) -> str: +def get_hf_catalog_str(ids: list[str], excluded_models: list[str] | None = None) -> str: """ Generate a HuggingFace catalog configuration string in YAML format. Similar to get_catalog_str() but for HuggingFace catalogs. diff --git a/tests/model_registry/model_registry/async_job/conftest.py b/tests/model_registry/model_registry/async_job/conftest.py index c5e9e8ac6..8caaf70dc 100644 --- a/tests/model_registry/model_registry/async_job/conftest.py +++ b/tests/model_registry/model_registry/async_job/conftest.py @@ -1,38 +1,35 @@ -from typing import Any, Generator import json +from collections.abc import Generator +from typing import Any import pytest +import shortuuid from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError +from model_registry import ModelRegistry as ModelRegistryClient +from model_registry.types import RegisteredModel +from ocp_resources.config_map import ConfigMap from ocp_resources.job import Job +from ocp_resources.role_binding import RoleBinding +from ocp_resources.secret import Secret +from ocp_resources.service import Service +from ocp_resources.service_account import ServiceAccount +from pytest import FixtureRequest +from pytest_testconfig import py_config from tests.model_registry.model_registry.async_job.constants import ( ASYNC_JOB_ANNOTATIONS, ASYNC_JOB_LABELS, ASYNC_UPLOAD_JOB_NAME, MODEL_SYNC_CONFIG, + REPO_NAME, VOLUME_MOUNTS, ) - -import shortuuid -from pytest import FixtureRequest -from pytest_testconfig import py_config - -from ocp_resources.role_binding import RoleBinding -from ocp_resources.secret import Secret -from ocp_resources.service import Service -from ocp_resources.service_account import ServiceAccount -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry -from ocp_resources.config_map import ConfigMap -from model_registry.types import RegisteredModel -from model_registry import ModelRegistry as ModelRegistryClient - -from utilities.constants import OCIRegistry, MinIo, Protocols, Labels, ApiGroups -from utilities.general import b64_encoded_string from tests.model_registry.model_registry.async_job.utils import upload_test_model_to_minio_from_image -from tests.model_registry.utils import get_mr_service_by_label, get_endpoint_from_mr_service -from tests.model_registry.model_registry.async_job.constants import REPO_NAME -from utilities.general import get_s3_secret_dict +from tests.model_registry.utils import get_endpoint_from_mr_service, get_mr_service_by_label +from utilities.constants import ApiGroups, Labels, MinIo, OCIRegistry, Protocols +from utilities.general import b64_encoded_string, get_s3_secret_dict +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry @pytest.fixture(scope="class") @@ -136,7 +133,8 @@ def async_upload_image(admin_client: DynamicClient) -> str: if not config_map.exists: raise ResourceNotFoundError( - f"ConfigMap 'model-registry-operator-parameters' not found in namespace '{py_config['applications_namespace']}'" # noqa: E501 + f"ConfigMap 'model-registry-operator-parameters' not found in" + f" namespace '{py_config['applications_namespace']}'" ) try: @@ -276,7 +274,7 @@ def create_test_data_in_minio_from_image( @pytest.fixture(scope="class") def registered_model_from_image( request: FixtureRequest, model_registry_client: list[ModelRegistryClient] -) -> Generator[RegisteredModel, None, None]: +) -> Generator[RegisteredModel]: """Create a registered model for testing with KSERVE_MINIO_IMAGE data""" yield model_registry_client[0].register_model( name=request.param.get("model_name"), diff --git a/tests/model_registry/model_registry/async_job/constants.py b/tests/model_registry/model_registry/async_job/constants.py index a9650588b..e248e21c4 100644 --- a/tests/model_registry/model_registry/async_job/constants.py +++ b/tests/model_registry/model_registry/async_job/constants.py @@ -9,10 +9,8 @@ "component": "model-registry-job", "modelregistry.opendatahub.io/job-type": "async-upload", } - -ASYNC_JOB_ANNOTATIONS = { - "modelregistry.opendatahub.io/description": "Asynchronous job for uploading models to Model Registry and converting them to ModelCar format" # noqa: E501 -} +ASYNC_STR: str = "Asynchronous job for uploading models to Model Registry and converting them to ModelCar format" +ASYNC_JOB_ANNOTATIONS = {"modelregistry.opendatahub.io/description": ASYNC_STR} # Model sync parameters (from sample YAML) MODEL_SYNC_CONFIG = { diff --git a/tests/model_registry/model_registry/async_job/test_async_upload_e2e.py b/tests/model_registry/model_registry/async_job/test_async_upload_e2e.py index 4c09d7ffb..41dc19d82 100644 --- a/tests/model_registry/model_registry/async_job/test_async_upload_e2e.py +++ b/tests/model_registry/model_registry/async_job/test_async_upload_e2e.py @@ -1,23 +1,26 @@ -from typing import Self -import time import json +import time +from typing import Self import pytest from kubernetes.dynamic import DynamicClient -from ocp_resources.job import Job +from model_registry import ModelRegistry as ModelRegistryClient from model_registry.types import ArtifactState, RegisteredModelState +from ocp_resources.job import Job +from simple_logger.logger import get_logger + +from tests.model_registry.constants import MODEL_DICT from tests.model_registry.model_registry.async_job.constants import ( ASYNC_UPLOAD_JOB_NAME, + MODEL_SYNC_CONFIG, + REPO_NAME, + TAG, ) from tests.model_registry.model_registry.async_job.utils import ( get_latest_job_pod, ) -from tests.model_registry.constants import MODEL_DICT from utilities.constants import MinIo, OCIRegistry from utilities.registry_utils import pull_manifest_from_oci_registry -from model_registry import ModelRegistry as ModelRegistryClient -from simple_logger.logger import get_logger -from tests.model_registry.model_registry.async_job.constants import MODEL_SYNC_CONFIG, REPO_NAME, TAG LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_registry/async_job/utils.py b/tests/model_registry/model_registry/async_job/utils.py index 1c0b42010..392e07c97 100644 --- a/tests/model_registry/model_registry/async_job/utils.py +++ b/tests/model_registry/model_registry/async_job/utils.py @@ -2,10 +2,10 @@ from ocp_resources.job import Job from ocp_resources.pod import Pod from ocp_resources.service import Service -from utilities.constants import MinIo from simple_logger.logger import get_logger from timeout_sampler import TimeoutExpiredError +from utilities.constants import MinIo from utilities.general import collect_pod_information LOGGER = get_logger(name=__name__) @@ -49,7 +49,7 @@ def upload_test_model_to_minio_from_image( object_key: S3 object key path model_image: Container image containing the model """ - + mc_url = f"http://{minio_service.name}.{minio_service.namespace}.svc.cluster.local:{MinIo.Metadata.DEFAULT_PORT} " with Pod( client=admin_client, name="test-model-uploader-from-image", @@ -86,7 +86,7 @@ def upload_test_model_to_minio_from_image( f"echo 'Model file details:' && ls -la /upload-data/model.onnx && " f"echo 'Model file content preview:' && head -c 100 /upload-data/model.onnx && echo && " f"export MC_CONFIG_DIR=/upload-data/.mc && " - f"mc alias set testminio http://{minio_service.name}.{minio_service.namespace}.svc.cluster.local:{MinIo.Metadata.DEFAULT_PORT} " # noqa: E501 + f"mc alias set testminio {mc_url}" f"{MinIo.Credentials.ACCESS_KEY_VALUE} {MinIo.Credentials.SECRET_KEY_VALUE} && " f"mc mb --ignore-existing testminio/{MinIo.Buckets.MODELMESH_EXAMPLE_MODELS} && " f"mc cp /upload-data/model.onnx testminio/{MinIo.Buckets.MODELMESH_EXAMPLE_MODELS}/{object_key} && " @@ -115,7 +115,7 @@ def upload_test_model_to_minio_from_image( try: upload_logs = upload_pod.log() LOGGER.info(f"Upload logs: {upload_logs}") - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Could not retrieve upload logs: {e}") LOGGER.info( diff --git a/tests/model_registry/model_registry/conftest.py b/tests/model_registry/model_registry/conftest.py index e4ea83cdf..c06cec252 100644 --- a/tests/model_registry/model_registry/conftest.py +++ b/tests/model_registry/model_registry/conftest.py @@ -1,34 +1,32 @@ +import shlex import subprocess +from collections.abc import Generator +from typing import Any import pytest -import shlex -from pyhelper_utils.shell import run_command -from typing import Generator, Any, List, Dict - -from ocp_resources.pod import Pod +from kubernetes.dynamic import DynamicClient +from kubernetes.dynamic.exceptions import ResourceNotFoundError +from model_registry import ModelRegistry as ModelRegistryClient +from model_registry.types import RegisteredModel +from ocp_resources.deployment import Deployment from ocp_resources.namespace import Namespace -from ocp_resources.service_account import ServiceAccount +from ocp_resources.pod import Pod from ocp_resources.role import Role from ocp_resources.role_binding import RoleBinding -from ocp_resources.deployment import Deployment -from kubernetes.dynamic.exceptions import ResourceNotFoundError - +from ocp_resources.service_account import ServiceAccount +from pyhelper_utils.shell import run_command from pytest import FixtureRequest from simple_logger.logger import get_logger -from kubernetes.dynamic import DynamicClient -from model_registry.types import RegisteredModel - from tests.model_registry.constants import ( - MR_INSTANCE_NAME, MODEL_REGISTRY_POD_FILTER, + MR_INSTANCE_NAME, ) -from utilities.constants import Protocols from tests.model_registry.utils import ( get_endpoint_from_mr_service, get_mr_service_by_label, ) -from model_registry import ModelRegistry as ModelRegistryClient +from utilities.constants import Protocols from utilities.general import wait_for_pods_by_labels LOGGER = get_logger(name=__name__) @@ -83,7 +81,7 @@ def model_registry_client( @pytest.fixture(scope="class") def registered_model( request: FixtureRequest, model_registry_client: list[ModelRegistryClient] -) -> Generator[RegisteredModel, None, None]: +) -> Generator[RegisteredModel]: yield model_registry_client[0].register_model( name=request.param.get("model_name"), uri=request.param.get("model_uri"), @@ -131,19 +129,19 @@ def sa_token(service_account: ServiceAccount) -> str: try: cmd = f"oc create token {sa_name} -n {namespace} --duration={DEFAULT_TOKEN_DURATION}" LOGGER.debug(f"Executing command: {cmd}") - res, out, err = run_command(command=shlex.split(cmd), verify_stderr=False, check=True, timeout=30) + _, out, _ = run_command(command=shlex.split(cmd), verify_stderr=False, check=True, timeout=30) token = out.strip() if not token: raise ValueError("Retrieved token is empty after successful command execution.") LOGGER.info(f"Successfully retrieved token for SA '{sa_name}'") - return token + return token # noqa: TRY300 except Exception as e: # Catch all exceptions from the try block error_type = type(e).__name__ log_message = ( f"Failed during token retrieval for SA '{sa_name}' in namespace '{namespace}'. " - f"Error Type: {error_type}, Message: {str(e)}" + f"Error Type: {error_type}, Message: {e!s}" ) if isinstance(e, subprocess.CalledProcessError): # Add specific details for CalledProcessError @@ -160,7 +158,7 @@ def sa_token(service_account: ServiceAccount) -> str: command_not_found = e.filename if hasattr(e, "filename") and e.filename else shlex.split(cmd)[0] log_message += f". Command '{command_not_found}' not found. Is it installed and in PATH?" - LOGGER.error(log_message, exc_info=True) # exc_info=True adds stack trace to the log + LOGGER.error(log_message) raise @@ -169,14 +167,14 @@ def mr_access_role( admin_client: DynamicClient, model_registry_namespace: str, sa_namespace: Namespace, -) -> Generator[Role, None, None]: +) -> Generator[Role]: """ Creates the MR Access Role using direct constructor parameters and a context manager. """ role_name = f"registry-user-{MR_INSTANCE_NAME}-{sa_namespace.name[:8]}" LOGGER.info(f"Defining Role: {role_name} in namespace {model_registry_namespace}") - role_rules: List[Dict[str, Any]] = [ + role_rules: list[dict[str, Any]] = [ { "apiGroups": [""], "resources": ["services"], @@ -206,7 +204,7 @@ def mr_access_role_binding( model_registry_namespace: str, mr_access_role: Role, sa_namespace: Namespace, -) -> Generator[RoleBinding, None, None]: +) -> Generator[RoleBinding]: """ Creates the MR Access RoleBinding using direct constructor parameters and a context manager. """ diff --git a/tests/model_registry/model_registry/negative_tests/conftest.py b/tests/model_registry/model_registry/negative_tests/conftest.py index fbf1331d7..64b122ade 100644 --- a/tests/model_registry/model_registry/negative_tests/conftest.py +++ b/tests/model_registry/model_registry/negative_tests/conftest.py @@ -1,34 +1,32 @@ -import pytest -from typing import Generator, Any +from collections.abc import Generator +from typing import Any +import pytest from _pytest.config import Config -from pytest_testconfig import config as py_config - +from kubernetes.dynamic import DynamicClient from ocp_resources.data_science_cluster import DataScienceCluster +from ocp_resources.deployment import Deployment +from ocp_resources.namespace import Namespace +from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.pod import Pod from ocp_resources.secret import Secret -from ocp_resources.namespace import Namespace from ocp_resources.service import Service -from ocp_resources.persistent_volume_claim import PersistentVolumeClaim -from ocp_resources.deployment import Deployment - -from kubernetes.dynamic import DynamicClient - +from pytest_testconfig import config as py_config from tests.model_registry.constants import ( - MODEL_REGISTRY_DB_SECRET_STR_DATA, - MODEL_REGISTRY_DB_SECRET_ANNOTATIONS, DB_RESOURCE_NAME, + MODEL_REGISTRY_DB_SECRET_ANNOTATIONS, + MODEL_REGISTRY_DB_SECRET_STR_DATA, MR_INSTANCE_NAME, ) -from tests.model_registry.utils import get_model_registry_deployment_template_dict, get_model_registry_db_label_dict -from utilities.constants import MODEL_REGISTRY_CUSTOM_NAMESPACE -from utilities.general import wait_for_pods_by_labels -from utilities.infra import create_ns from tests.model_registry.model_registry.negative_tests.utils import ( - execute_mysql_command, create_mysql_credentials_file, + execute_mysql_command, ) +from tests.model_registry.utils import get_model_registry_db_label_dict, get_model_registry_deployment_template_dict +from utilities.constants import MODEL_REGISTRY_CUSTOM_NAMESPACE +from utilities.general import wait_for_pods_by_labels +from utilities.infra import create_ns DB_RESOURCES_NAME_NEGATIVE = "db-model-registry-negative" diff --git a/tests/model_registry/model_registry/negative_tests/test_db_migration.py b/tests/model_registry/model_registry/negative_tests/test_db_migration.py index 606270130..cbd90a678 100644 --- a/tests/model_registry/model_registry/negative_tests/test_db_migration.py +++ b/tests/model_registry/model_registry/negative_tests/test_db_migration.py @@ -1,13 +1,14 @@ -import pytest from typing import Self -from simple_logger.logger import get_logger -from pytest_testconfig import config as py_config +import pytest +from kubernetes.dynamic.client import DynamicClient from ocp_resources.pod import Pod +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger + from tests.model_registry.constants import MR_INSTANCE_NAME -from kubernetes.dynamic.client import DynamicClient -from utilities.general import wait_for_container_status from tests.model_registry.utils import wait_for_new_running_mr_pod +from utilities.general import wait_for_container_status LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_registry/negative_tests/test_model_registry_creation_negative.py b/tests/model_registry/model_registry/negative_tests/test_model_registry_creation_negative.py index 60481d2e2..a7942faac 100644 --- a/tests/model_registry/model_registry/negative_tests/test_model_registry_creation_negative.py +++ b/tests/model_registry/model_registry/negative_tests/test_model_registry_creation_negative.py @@ -1,21 +1,22 @@ -import pytest from typing import Self -from simple_logger.logger import get_logger + +import pytest +from kubernetes.dynamic import DynamicClient +from kubernetes.dynamic.exceptions import ForbiddenError from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.deployment import Deployment -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry -from pytest_testconfig import config as py_config from ocp_resources.namespace import Namespace from ocp_resources.secret import Secret -from utilities.constants import Annotations +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger + from tests.model_registry.constants import ( - MR_OPERATOR_NAME, - MR_INSTANCE_NAME, DB_RESOURCE_NAME, + MR_INSTANCE_NAME, + MR_OPERATOR_NAME, ) -from kubernetes.dynamic.exceptions import ForbiddenError -from kubernetes.dynamic import DynamicClient - +from utilities.constants import Annotations +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry LOGGER = get_logger(name=__name__) @@ -47,7 +48,7 @@ def test_registering_model_negative( "skipDBCreation": False, "username": model_registry_db_secret_negative_test.string_data["database-user"], } - with pytest.raises( + with pytest.raises( # noqa: SIM117 ForbiddenError, # UnprocessibleEntityError match=f"namespace must be {py_config['model_registry_namespace']}", ): diff --git a/tests/model_registry/model_registry/negative_tests/utils.py b/tests/model_registry/model_registry/negative_tests/utils.py index dfcfd786e..65dc67fda 100644 --- a/tests/model_registry/model_registry/negative_tests/utils.py +++ b/tests/model_registry/model_registry/negative_tests/utils.py @@ -1,6 +1,8 @@ +import base64 + from ocp_resources.pod import Pod + from tests.model_registry.constants import MODEL_REGISTRY_DB_SECRET_STR_DATA -import base64 def create_mysql_credentials_file(model_registry_db_instance_pod: Pod) -> None: diff --git a/tests/model_registry/model_registry/python_client/signing/conftest.py b/tests/model_registry/model_registry/python_client/signing/conftest.py index 765a7bf4b..5126a82df 100644 --- a/tests/model_registry/model_registry/python_client/signing/conftest.py +++ b/tests/model_registry/model_registry/python_client/signing/conftest.py @@ -1,33 +1,36 @@ """Fixtures for Model Registry Python Client Signing Tests.""" -from typing import Any, Generator import json -import requests +from collections.abc import Generator +from typing import Any + import pytest +import requests from kubernetes.dynamic import DynamicClient -from ocp_resources.subscription import Subscription -from ocp_resources.namespace import Namespace -from ocp_resources.deployment import Deployment from ocp_resources.config_map import ConfigMap +from ocp_resources.deployment import Deployment +from ocp_resources.namespace import Namespace +from ocp_resources.subscription import Subscription from ocp_utilities.operators import install_operator, uninstall_operator from pytest_testconfig import config as py_config from simple_logger.logger import get_logger -from timeout_sampler import TimeoutSampler, TimeoutExpiredError -from utilities.constants import Timeout, OPENSHIFT_OPERATORS -from utilities.infra import get_openshift_token -from utilities.resources.securesign import Securesign +from timeout_sampler import TimeoutExpiredError, TimeoutSampler + from tests.model_registry.model_registry.python_client.signing.constants import ( - SECURESIGN_NAMESPACE, - SECURESIGN_NAME, SECURESIGN_API_VERSION, + SECURESIGN_NAME, + SECURESIGN_NAMESPACE, TAS_CONNECTION_TYPE_NAME, ) from tests.model_registry.model_registry.python_client.signing.utils import ( + create_connection_type_field, get_organization_config, - is_securesign_ready, get_tas_service_urls, - create_connection_type_field, + is_securesign_ready, ) +from utilities.constants import OPENSHIFT_OPERATORS, Timeout +from utilities.infra import get_openshift_token +from utilities.resources.securesign import Securesign LOGGER = get_logger(name=__name__) @@ -61,7 +64,7 @@ def oidc_issuer_url(admin_client: DynamicClient, api_server_url: str) -> str: @pytest.fixture(scope="class") -def installed_tas_operator(admin_client: DynamicClient) -> Generator[None, Any, None]: +def installed_tas_operator(admin_client: DynamicClient) -> Generator[None, Any]: """Install Red Hat Trusted Artifact Signer (RHTAS/TAS) operator if not already installed. This fixture checks if TAS operator subscription exists in openshift-operators @@ -124,7 +127,7 @@ def installed_tas_operator(admin_client: DynamicClient) -> Generator[None, Any, @pytest.fixture(scope="class") def securesign_instance( admin_client: DynamicClient, installed_tas_operator: None, oidc_issuer_url: str -) -> Generator[Securesign, Any, None]: +) -> Generator[Securesign, Any]: """Create a Securesign instance with all Sigstore components in the trusted-artifact-signer namespace with the following components enabled: @@ -216,9 +219,7 @@ def securesign_instance( @pytest.fixture(scope="class") -def tas_connection_type( - admin_client: DynamicClient, securesign_instance: Securesign -) -> Generator[ConfigMap, Any, None]: +def tas_connection_type(admin_client: DynamicClient, securesign_instance: Securesign) -> Generator[ConfigMap, Any]: """Create ODH Connection Type ConfigMap for TAS (Trusted Artifact Signer). Provides TAS service endpoints for programmatic access to signing services. diff --git a/tests/model_registry/model_registry/python_client/signing/test_signing_infrastructure.py b/tests/model_registry/model_registry/python_client/signing/test_signing_infrastructure.py index dbeb1bb07..3ffabff4d 100644 --- a/tests/model_registry/model_registry/python_client/signing/test_signing_infrastructure.py +++ b/tests/model_registry/model_registry/python_client/signing/test_signing_infrastructure.py @@ -1,10 +1,12 @@ """Tests for TAS signing infrastructure setup and readiness.""" -from typing import Self import json +from typing import Self + import pytest from ocp_resources.config_map import ConfigMap from simple_logger.logger import get_logger + from utilities.resources.securesign import Securesign LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_registry/python_client/signing/utils.py b/tests/model_registry/model_registry/python_client/signing/utils.py index ca6ffccdf..8674102b5 100644 --- a/tests/model_registry/model_registry/python_client/signing/utils.py +++ b/tests/model_registry/model_registry/python_client/signing/utils.py @@ -1,8 +1,8 @@ """Utility functions for Model Registry Python Client Signing Tests.""" from tests.model_registry.model_registry.python_client.signing.constants import ( - SECURESIGN_ORGANIZATION_NAME, SECURESIGN_ORGANIZATION_EMAIL, + SECURESIGN_ORGANIZATION_NAME, ) diff --git a/tests/model_registry/model_registry/python_client/test_model_registry_creation.py b/tests/model_registry/model_registry/python_client/test_model_registry_creation.py index 6c5e4a2d4..1dc49e06b 100644 --- a/tests/model_registry/model_registry/python_client/test_model_registry_creation.py +++ b/tests/model_registry/model_registry/python_client/test_model_registry_creation.py @@ -1,19 +1,19 @@ +from typing import Any, Self + import pytest -from typing import Self, Any -from simple_logger.logger import get_logger +from model_registry import ModelRegistry as ModelRegistryClient +from model_registry.types import RegisteredModel # ocp_resources imports from ocp_resources.pod import Pod +from simple_logger.logger import get_logger +from tests.model_registry.constants import MODEL_DICT, MODEL_NAME from tests.model_registry.utils import ( execute_model_registry_get_command, - validate_no_grpc_container, validate_mlmd_removal_in_model_registry_pod_log, + validate_no_grpc_container, ) -from tests.model_registry.constants import MODEL_NAME, MODEL_DICT -from model_registry import ModelRegistry as ModelRegistryClient -from model_registry.types import RegisteredModel - LOGGER = get_logger(name=__name__) @@ -71,11 +71,12 @@ def test_model_registry_operator_env( model_registry_namespace: str, model_registry_operator_pod: Pod, ): - namespace_env = [] - for container in model_registry_operator_pod.instance.spec.containers: - for env in container.env: - if env.name == "REGISTRIES_NAMESPACE" and env.value == model_registry_namespace: - namespace_env.append({container.name: env}) + namespace_env = [ + {container.name: env} + for container in model_registry_operator_pod.instance.spec.containers + for env in container.env + if env.name == "REGISTRIES_NAMESPACE" and env.value == model_registry_namespace + ] if not namespace_env: pytest.fail("Missing environment variable REGISTRIES_NAMESPACE") diff --git a/tests/model_registry/model_registry/rbac/conftest.py b/tests/model_registry/model_registry/rbac/conftest.py index 96ab5ee20..a170ee2de 100644 --- a/tests/model_registry/model_registry/rbac/conftest.py +++ b/tests/model_registry/model_registry/rbac/conftest.py @@ -1,30 +1,28 @@ +from collections.abc import Generator from contextlib import ExitStack +from typing import Any import pytest -from typing import Generator, List, Any - from _pytest.fixtures import FixtureRequest -from simple_logger.logger import get_logger - +from kubernetes.dynamic import DynamicClient from ocp_resources.deployment import Deployment -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry +from ocp_resources.group import Group from ocp_resources.persistent_volume_claim import PersistentVolumeClaim +from ocp_resources.resource import ResourceEditor +from ocp_resources.role import Role +from ocp_resources.role_binding import RoleBinding from ocp_resources.secret import Secret from ocp_resources.service import Service -from ocp_resources.role_binding import RoleBinding -from ocp_resources.role import Role -from ocp_resources.group import Group - -from ocp_resources.resource import ResourceEditor -from kubernetes.dynamic import DynamicClient +from simple_logger.logger import get_logger -from tests.model_registry.model_registry.rbac.utils import create_role_binding -from utilities.user_utils import UserTestSession -from tests.model_registry.model_registry.rbac.group_utils import create_group from tests.model_registry.constants import ( - MR_INSTANCE_NAME, KUBERBACPROXY_STR, + MR_INSTANCE_NAME, ) +from tests.model_registry.model_registry.rbac.group_utils import create_group +from tests.model_registry.model_registry.rbac.utils import create_role_binding +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry +from utilities.user_utils import UserTestSession LOGGER = get_logger(name=__name__) @@ -33,7 +31,7 @@ def add_user_to_group( admin_client: DynamicClient, test_idp_user: UserTestSession, -) -> Generator[str, None, None]: +) -> Generator[str]: """ Fixture to create a group and add a test user to it. Uses create_group context manager to ensure proper cleanup. @@ -59,7 +57,7 @@ def model_registry_group_with_user( is_byoidc: bool, admin_client: DynamicClient, test_idp_user: UserTestSession, -) -> Generator[Group, None, None]: +) -> Generator[Group]: """ Fixture to manage a test user in a specified group. Adds the user to the group before the test, then removes them after. @@ -102,7 +100,7 @@ def created_role_binding_group( mr_access_role: Role, test_idp_user: UserTestSession, add_user_to_group: str, -) -> Generator[RoleBinding, None, None]: +) -> Generator[RoleBinding]: yield from create_role_binding( admin_client=admin_client, model_registry_namespace=model_registry_namespace, @@ -120,7 +118,7 @@ def created_role_binding_user( model_registry_namespace: str, mr_access_role: Role, user_credentials_rbac: dict[str, str], -) -> Generator[RoleBinding, None, None]: +) -> Generator[RoleBinding]: # Determine the username to use without mutating the shared fixture username = "mr-non-admin" if is_byoidc else user_credentials_rbac["username"] LOGGER.info(f"Using user {username}") @@ -140,7 +138,7 @@ def created_role_binding_user( @pytest.fixture(scope="class") def db_secret_parametrized( request: FixtureRequest, admin_client: DynamicClient, teardown_resources: bool -) -> Generator[List[Secret], Any, Any]: +) -> Generator[list[Secret], Any, Any]: """Create DB Secret parametrized""" with ExitStack() as stack: secrets = [ @@ -159,7 +157,7 @@ def db_secret_parametrized( @pytest.fixture(scope="class") def db_pvc_parametrized( request: FixtureRequest, admin_client: DynamicClient, teardown_resources: bool -) -> Generator[List[PersistentVolumeClaim], Any, Any]: +) -> Generator[list[PersistentVolumeClaim], Any, Any]: """Create DB PVC parametrized""" with ExitStack() as stack: pvc = [ @@ -178,7 +176,7 @@ def db_pvc_parametrized( @pytest.fixture(scope="class") def db_service_parametrized( request: FixtureRequest, admin_client: DynamicClient, teardown_resources: bool -) -> Generator[List[Service], Any, Any]: +) -> Generator[list[Service], Any, Any]: """Create DB Service parametrized""" with ExitStack() as stack: services = [ @@ -197,7 +195,7 @@ def db_service_parametrized( @pytest.fixture(scope="class") def db_deployment_parametrized( request: FixtureRequest, admin_client: DynamicClient, teardown_resources: bool -) -> Generator[List[Deployment], Any, Any]: +) -> Generator[list[Deployment], Any, Any]: """Create DB Deployment parametrized""" with ExitStack() as stack: deployments = [ @@ -220,7 +218,7 @@ def db_deployment_parametrized( @pytest.fixture(scope="class") def model_registry_instance_parametrized( request: FixtureRequest, admin_client: DynamicClient, teardown_resources: bool -) -> Generator[List[ModelRegistry], Any, Any]: +) -> Generator[list[ModelRegistry], Any, Any]: """Create Model Registry instance parametrized""" with ExitStack() as stack: model_registry_instances = [] diff --git a/tests/model_registry/model_registry/rbac/group_utils.py b/tests/model_registry/model_registry/rbac/group_utils.py index 1143c6aa4..4987fe0ae 100644 --- a/tests/model_registry/model_registry/rbac/group_utils.py +++ b/tests/model_registry/model_registry/rbac/group_utils.py @@ -1,8 +1,9 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import Generator -from simple_logger.logger import get_logger + from kubernetes.dynamic import DynamicClient from ocp_resources.group import Group +from simple_logger.logger import get_logger LOGGER = get_logger(name=__name__) @@ -13,7 +14,7 @@ def create_group( group_name: str, users: list[str] | None = None, wait_for_resource: bool = True, -) -> Generator[str, None, None]: +) -> Generator[str]: """ Factory function to create an OpenShift group with optional users. Uses context manager to ensure proper cleanup. diff --git a/tests/model_registry/model_registry/rbac/multiple_instance_utils.py b/tests/model_registry/model_registry/rbac/multiple_instance_utils.py index cf0eae96b..438872405 100644 --- a/tests/model_registry/model_registry/rbac/multiple_instance_utils.py +++ b/tests/model_registry/model_registry/rbac/multiple_instance_utils.py @@ -2,22 +2,22 @@ from pytest_testconfig import config as py_config from tests.model_registry.constants import ( - MODEL_REGISTRY_DB_SECRET_STR_DATA, DB_BASE_RESOURCES_NAME, - NUM_MR_INSTANCES, MODEL_REGISTRY_DB_SECRET_ANNOTATIONS, + MODEL_REGISTRY_DB_SECRET_STR_DATA, MR_INSTANCE_BASE_NAME, + NUM_MR_INSTANCES, ) from tests.model_registry.utils import ( + get_external_db_config, get_model_registry_db_label_dict, get_model_registry_deployment_template_dict, get_mr_standard_labels, - get_external_db_config, ) ns_name = py_config["model_registry_namespace"] -resource_names = [f"{DB_BASE_RESOURCES_NAME}{index}" for index in range(0, NUM_MR_INSTANCES)] +resource_names = [f"{DB_BASE_RESOURCES_NAME}{index}" for index in range(NUM_MR_INSTANCES)] db_secret_params = [ { @@ -85,7 +85,7 @@ "wait_for_resource": True, "kube_rbac_proxy": {}, } - for index in range(0, NUM_MR_INSTANCES) + for index in range(NUM_MR_INSTANCES) ] # Add this complete set of parameters as a pytest.param tuple to the list. diff --git a/tests/model_registry/model_registry/rbac/test_mr_rbac.py b/tests/model_registry/model_registry/rbac/test_mr_rbac.py index 7ec654451..ac771d4f9 100644 --- a/tests/model_registry/model_registry/rbac/test_mr_rbac.py +++ b/tests/model_registry/model_registry/rbac/test_mr_rbac.py @@ -8,36 +8,37 @@ - Role and RoleBinding management """ -import pytest -from typing import Self, Generator, List -from simple_logger.logger import get_logger +from collections.abc import Generator +from typing import Self +import pytest +from kubernetes.dynamic import DynamicClient from model_registry import ModelRegistry as ModelRegistryClient -from timeout_sampler import TimeoutSampler - +from mr_openapi.exceptions import ForbiddenException from ocp_resources.data_science_cluster import DataScienceCluster +from ocp_resources.deployment import Deployment from ocp_resources.group import Group +from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.role_binding import RoleBinding from ocp_resources.secret import Secret -from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.service import Service -from ocp_resources.deployment import Deployment +from simple_logger.logger import get_logger +from timeout_sampler import TimeoutSampler + +from tests.model_registry.constants import NUM_MR_INSTANCES from tests.model_registry.model_registry.rbac.multiple_instance_utils import MR_MULTIPROJECT_TEST_SCENARIO_PARAMS from tests.model_registry.model_registry.rbac.utils import ( - build_mr_client_args, - assert_positive_mr_registry, assert_forbidden_access, + assert_positive_mr_registry, + build_mr_client_args, + grant_mr_access, + revoke_mr_access, ) -from tests.model_registry.constants import NUM_MR_INSTANCES +from tests.model_registry.utils import get_endpoint_from_mr_service, get_mr_service_by_label, get_mr_user_token +from utilities.constants import Protocols from utilities.infra import get_openshift_token -from mr_openapi.exceptions import ForbiddenException -from utilities.user_utils import UserTestSession -from kubernetes.dynamic import DynamicClient from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry -from tests.model_registry.utils import get_mr_service_by_label, get_endpoint_from_mr_service, get_mr_user_token -from tests.model_registry.model_registry.rbac.utils import grant_mr_access, revoke_mr_access -from utilities.constants import Protocols - +from utilities.user_utils import UserTestSession LOGGER = get_logger(name=__name__) pytestmark = [pytest.mark.usefixtures("original_user", "test_idp_user")] @@ -84,7 +85,7 @@ def test_user_added_to_group( test_idp_user: UserTestSession, user_credentials_rbac: dict[str, str], model_registry_group_with_user: Group, - login_as_test_user: Generator[UserTestSession, None, None], + login_as_test_user: Generator[UserTestSession], ): """ This test verifies that: @@ -188,12 +189,12 @@ def test_user_permission_multi_project_parametrized( admin_client: DynamicClient, updated_dsc_component_state_scope_session: DataScienceCluster, model_registry_namespace: str, - db_secret_parametrized: List[Secret], - db_pvc_parametrized: List[PersistentVolumeClaim], - db_service_parametrized: List[Service], - db_deployment_parametrized: List[Deployment], + db_secret_parametrized: list[Secret], + db_pvc_parametrized: list[PersistentVolumeClaim], + db_service_parametrized: list[Service], + db_deployment_parametrized: list[Deployment], user_credentials_rbac: dict[str, str], - model_registry_instance_parametrized: List[ModelRegistry], + model_registry_instance_parametrized: list[ModelRegistry], login_as_test_user: None, ): """ diff --git a/tests/model_registry/model_registry/rbac/test_mr_rbac_sa.py b/tests/model_registry/model_registry/rbac/test_mr_rbac_sa.py index e60d30059..9bb44bda8 100644 --- a/tests/model_registry/model_registry/rbac/test_mr_rbac_sa.py +++ b/tests/model_registry/model_registry/rbac/test_mr_rbac_sa.py @@ -1,14 +1,16 @@ # AI Disclaimer: Google Gemini 2.5 pro has been used to generate a majority of this code, with human review and editing. -import pytest from typing import Any, Self -from simple_logger.logger import get_logger + +import pytest from model_registry import ModelRegistry as ModelRegistryClient -from tests.model_registry.model_registry.rbac.utils import build_mr_client_args -from utilities.infra import create_inference_token from mr_openapi.exceptions import ForbiddenException, UnauthorizedException from ocp_resources.service_account import ServiceAccount +from simple_logger.logger import get_logger from timeout_sampler import TimeoutSampler, retry +from tests.model_registry.model_registry.rbac.utils import build_mr_client_args +from utilities.infra import create_inference_token + LOGGER = get_logger(name=__name__) @@ -94,7 +96,7 @@ def test_service_account_access_granted( assert mr_client_success is not None, "Client initialization failed after granting permissions" LOGGER.info("Client instantiated successfully after granting permissions.") except Exception as e: - LOGGER.error(f"Failed to access Model Registry after granting permissions: {e}", exc_info=True) + LOGGER.error(f"Failed to access Model Registry after granting permissions: {e}") raise LOGGER.info("--- RBAC Test Completed Successfully ---") diff --git a/tests/model_registry/model_registry/rbac/utils.py b/tests/model_registry/model_registry/rbac/utils.py index 7672c128e..8104d78cf 100644 --- a/tests/model_registry/model_registry/rbac/utils.py +++ b/tests/model_registry/model_registry/rbac/utils.py @@ -1,17 +1,20 @@ -from typing import Any, Dict, Generator, List +import logging +from collections.abc import Generator +from typing import Any + from kubernetes.dynamic import DynamicClient +from model_registry import ModelRegistry as ModelRegistryClient +from mr_openapi.exceptions import ForbiddenException from ocp_resources.role import Role from ocp_resources.role_binding import RoleBinding + from utilities.constants import Protocols -import logging -from model_registry import ModelRegistry as ModelRegistryClient from utilities.infra import get_openshift_token -from mr_openapi.exceptions import ForbiddenException LOGGER = logging.getLogger(__name__) -def build_mr_client_args(rest_endpoint: str, token: str, author: str = "rbac-test") -> Dict[str, Any]: +def build_mr_client_args(rest_endpoint: str, token: str, author: str = "rbac-test") -> dict[str, Any]: """ Builds arguments for ModelRegistryClient based on REST endpoint and token. @@ -71,7 +74,7 @@ def create_role_binding( name: str, subjects_kind: str, subjects_name: str, -) -> Generator[RoleBinding, None, None]: +) -> Generator[RoleBinding]: with RoleBinding( client=admin_client, namespace=model_registry_namespace, @@ -88,7 +91,7 @@ def grant_mr_access( admin_client: DynamicClient, user: str, mr_instance_name: str, model_registry_namespace: str ) -> tuple[Role, RoleBinding]: """Grant a user access to a Model Registry instance.""" - role_rules: List[Dict[str, Any]] = [ + role_rules: list[dict[str, Any]] = [ { "apiGroups": [""], "resources": ["services"], diff --git a/tests/model_registry/model_registry/rest_api/conftest.py b/tests/model_registry/model_registry/rest_api/conftest.py index e377d0caf..94518154c 100644 --- a/tests/model_registry/model_registry/rest_api/conftest.py +++ b/tests/model_registry/model_registry/rest_api/conftest.py @@ -1,43 +1,44 @@ -from typing import Any, Generator +import copy import os -from kubernetes.dynamic import DynamicClient +import tempfile +from collections.abc import Generator +from typing import Any + import pytest -import copy +from kubernetes.dynamic import DynamicClient +from ocp_resources.config_map import ConfigMap +from ocp_resources.deployment import Deployment +from ocp_resources.resource import ResourceEditor +from ocp_resources.secret import Secret +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger -from tests.model_registry.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI, MODEL_REGISTER_DATA +from tests.model_registry.constants import ( + CA_CONFIGMAP_NAME, + CA_FILE_PATH, + CA_MOUNT_PATH, + DB_RESOURCE_NAME, + KUBERBACPROXY_STR, + SECURE_MR_NAME, +) +from tests.model_registry.model_registry.rest_api.constants import MODEL_REGISTER_DATA, MODEL_REGISTRY_BASE_URI from tests.model_registry.model_registry.rest_api.utils import ( - register_model_rest_api, execute_model_registry_patch_command, + generate_ca_and_server_cert, get_mr_deployment, + register_model_rest_api, ) -from utilities.general import generate_random_name, wait_for_pods_running -from ocp_resources.deployment import Deployment from tests.model_registry.utils import ( - get_model_registry_deployment_template_dict, - apply_db_args_and_volume_mounts, add_db_certs_volumes_to_deployment, - get_mr_standard_labels, + apply_db_args_and_volume_mounts, get_external_db_config, + get_model_registry_deployment_template_dict, + get_mr_standard_labels, ) - -from tests.model_registry.constants import ( - DB_RESOURCE_NAME, - CA_MOUNT_PATH, - CA_FILE_PATH, - CA_CONFIGMAP_NAME, - SECURE_MR_NAME, - KUBERBACPROXY_STR, -) -from ocp_resources.resource import ResourceEditor -from ocp_resources.secret import Secret -from ocp_resources.config_map import ConfigMap -from simple_logger.logger import get_logger -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry -from pytest_testconfig import config as py_config +from utilities.certificates_utils import create_ca_bundle_with_router_cert, create_k8s_secret from utilities.exceptions import MissingParameter -import tempfile -from tests.model_registry.model_registry.rest_api.utils import generate_ca_and_server_cert -from utilities.certificates_utils import create_k8s_secret, create_ca_bundle_with_router_cert +from utilities.general import generate_random_name, wait_for_pods_running +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry LOGGER = get_logger(name=__name__) @@ -184,7 +185,7 @@ def deploy_secure_db_mr( external_db_template_with_ca: dict[str, Any], patch_external_deployment_with_ssl_ca: Deployment, db_backend_under_test: str, -) -> Generator[ModelRegistry, None, None]: +) -> Generator[ModelRegistry]: """ Deploy a secure database and Model Registry instance. """ @@ -240,7 +241,7 @@ def ca_configmap_for_test( admin_client: DynamicClient, model_registry_namespace: str, external_db_ssl_artifact_paths: dict[str, Any], -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: """ Creates a test-specific ConfigMap for the CA bundle, using the generated CA cert. @@ -256,7 +257,7 @@ def ca_configmap_for_test( ca_content = f.read() if not ca_content: LOGGER.info("CA content is empty") - raise Exception("CA content is empty") + raise MissingParameter("CA content is empty") cm_name = "db-ca-configmap" with ConfigMap( client=admin_client, @@ -314,7 +315,7 @@ def patch_external_deployment_with_ssl_ca( @pytest.fixture(scope="class") -def external_db_ssl_artifact_paths() -> Generator[dict[str, str], None, None]: +def external_db_ssl_artifact_paths() -> Generator[dict[str, str]]: """ Generates external database SSL certificate and key files in a temporary directory and provides their paths. @@ -329,7 +330,7 @@ def external_db_ssl_secrets( admin_client: DynamicClient, model_registry_namespace: str, external_db_ssl_artifact_paths: dict[str, str], -) -> Generator[dict[str, Secret], None, None]: +) -> Generator[dict[str, Secret]]: """ Creates Kubernetes secrets for external database SSL artifacts. """ @@ -369,7 +370,7 @@ def external_db_ssl_secrets( @pytest.fixture(scope="function") -def model_data_for_test() -> Generator[dict[str, Any], None, None]: +def model_data_for_test() -> Generator[dict[str, Any]]: """ Generates a model data for the test. diff --git a/tests/model_registry/model_registry/rest_api/test_model_registry_rest_api.py b/tests/model_registry/model_registry/rest_api/test_model_registry_rest_api.py index 68308261f..6052d0191 100644 --- a/tests/model_registry/model_registry/rest_api/test_model_registry_rest_api.py +++ b/tests/model_registry/model_registry/rest_api/test_model_registry_rest_api.py @@ -1,31 +1,31 @@ -from typing import Self, Any +from typing import Any, Self + import pytest from kubernetes.dynamic import DynamicClient from ocp_resources.deployment import Deployment -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.pod import Pod from ocp_resources.secret import Secret from ocp_resources.service import Service +from simple_logger.logger import get_logger +from tests.model_registry.constants import MR_POSTGRES_DB_OBJECT from tests.model_registry.model_registry.rest_api.constants import ( - MODEL_REGISTER, + CUSTOM_PROPERTY, MODEL_ARTIFACT, - MODEL_VERSION, - MODEL_REGISTER_DATA, MODEL_ARTIFACT_DESCRIPTION, MODEL_FORMAT_NAME, MODEL_FORMAT_VERSION, + MODEL_REGISTER, + MODEL_REGISTER_DATA, + MODEL_VERSION, MODEL_VERSION_DESCRIPTION, + REGISTERED_MODEL_DESCRIPTION, STATE_ARCHIVED, STATE_LIVE, - CUSTOM_PROPERTY, - REGISTERED_MODEL_DESCRIPTION, ) -from tests.model_registry.constants import MR_POSTGRES_DB_OBJECT from tests.model_registry.model_registry.rest_api.utils import validate_resource_attributes -from simple_logger.logger import get_logger - +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry LOGGER = get_logger(name=__name__) CONNECTION_STRING: str = "/var/run/postgresql:5432 - accepting connections" diff --git a/tests/model_registry/model_registry/rest_api/test_model_registry_secure_db.py b/tests/model_registry/model_registry/rest_api/test_model_registry_secure_db.py index f0f99cf54..6b3ebb3aa 100644 --- a/tests/model_registry/model_registry/rest_api/test_model_registry_secure_db.py +++ b/tests/model_registry/model_registry/rest_api/test_model_registry_secure_db.py @@ -1,14 +1,14 @@ +from typing import Any, Self + import pytest import requests -from typing import Self, Any -from tests.model_registry.model_registry.rest_api.utils import register_model_rest_api, validate_resource_attributes -from tests.model_registry.utils import get_mr_service_by_label, get_endpoint_from_mr_service from kubernetes.dynamic import DynamicClient -from utilities.constants import Protocols -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry - from simple_logger.logger import get_logger +from tests.model_registry.model_registry.rest_api.utils import register_model_rest_api, validate_resource_attributes +from tests.model_registry.utils import get_endpoint_from_mr_service, get_mr_service_by_label +from utilities.constants import Protocols +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/model_registry/rest_api/test_multiple_mr.py b/tests/model_registry/model_registry/rest_api/test_multiple_mr.py index 83df161e6..1c346daa7 100644 --- a/tests/model_registry/model_registry/rest_api/test_multiple_mr.py +++ b/tests/model_registry/model_registry/rest_api/test_multiple_mr.py @@ -1,25 +1,24 @@ from typing import Self + import pytest from kubernetes.dynamic import DynamicClient - from ocp_resources.config_map import ConfigMap -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry from ocp_resources.pod import Pod +from simple_logger.logger import get_logger from tests.model_registry.constants import ( - MR_INSTANCE_BASE_NAME, - NUM_RESOURCES, DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM, + MR_INSTANCE_BASE_NAME, + NUM_RESOURCES, ) from tests.model_registry.model_registry.rest_api.utils import ( - validate_resource_attributes, get_register_model_data, register_model_rest_api, + validate_resource_attributes, ) -from simple_logger.logger import get_logger - from tests.model_registry.utils import get_model_catalog_pod +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry LOGGER = get_logger(name=__name__) @@ -47,7 +46,7 @@ def test_validate_multiple_model_registry( model_registry_instance: list[ModelRegistry], model_registry_namespace: str, ): - for num in range(0, NUM_RESOURCES["num_resources"]): + for num in range(NUM_RESOURCES["num_resources"]): mr = ModelRegistry( client=admin_client, name=f"{MR_INSTANCE_BASE_NAME}{num}", @@ -63,11 +62,12 @@ def test_validate_one_model_catalog_configmap( """ Validate that when multiple MR exists on a cluster, only two model catalog configmaps are created """ - config_map_names: list[str] = [] expected_number_config_maps: int = 2 - for config_map in list(ConfigMap.get(namespace=model_registry_namespace, client=admin_client)): - if config_map.name.startswith(tuple([DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM])): - config_map_names.append(config_map.name) + config_map_names = [ + config_map.name + for config_map in list(ConfigMap.get(namespace=model_registry_namespace, client=admin_client)) + if config_map.name.startswith((DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM)) + ] assert len(config_map_names) == expected_number_config_maps, ( f"Expected {expected_number_config_maps} model catalog sources, found: {config_map_names}" ) @@ -90,7 +90,7 @@ def test_validate_register_models_multiple_registries( self: Self, model_registry_rest_url: list[str], model_registry_rest_headers: dict[str, str] ): data = get_register_model_data(num_models=NUM_RESOURCES["num_resources"]) - for num in range(0, NUM_RESOURCES["num_resources"]): + for num in range(NUM_RESOURCES["num_resources"]): result = register_model_rest_api( model_registry_rest_url=model_registry_rest_url[num], model_registry_rest_headers=model_registry_rest_headers, diff --git a/tests/model_registry/model_registry/rest_api/utils.py b/tests/model_registry/model_registry/rest_api/utils.py index c5a6a5ee8..b1193fc7f 100644 --- a/tests/model_registry/model_registry/rest_api/utils.py +++ b/tests/model_registry/model_registry/rest_api/utils.py @@ -1,22 +1,21 @@ import copy -from typing import Any, Dict -import requests import json import os +from typing import Any +import requests from kubernetes.dynamic import DynamicClient +from ocp_resources.deployment import Deployment +from pyhelper_utils.shell import run_command from simple_logger.logger import get_logger -from ocp_resources.deployment import Deployment from tests.model_registry.exceptions import ( ModelRegistryResourceNotCreated, ModelRegistryResourceNotUpdated, ) -from tests.model_registry.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI, MODEL_REGISTER_DATA -from pyhelper_utils.shell import run_command +from tests.model_registry.model_registry.rest_api.constants import MODEL_REGISTER_DATA, MODEL_REGISTRY_BASE_URI from utilities.exceptions import ResourceValueMismatch - LOGGER = get_logger(name=__name__) @@ -107,7 +106,7 @@ def validate_resource_attributes( errors: list[dict[str, list[Any]]] if errors := [ {key: [f"Expected value: {expected_params[key]}, actual value: {actual_resource_data.get(key)}"]} - for key in expected_params.keys() + for key in expected_params if (not actual_resource_data.get(key) or actual_resource_data[key] != expected_params[key]) ]: raise ResourceValueMismatch(f"Resource: {resource_name} has mismatched data: {errors}") @@ -119,7 +118,7 @@ def generate_ca_and_server_cert( db_service_hostname: str = "db-model-registry.rhoai-model-registries.svc.cluster.local", ca_name: str = "Test CA", server_cn: str = "mysql-server", -) -> Dict[str, str]: +) -> dict[str, str]: """ Generates a CA and server certificate/key for the MySQL server. @@ -255,9 +254,9 @@ def sign_db_server_cert_with_ca_with_openssl( def get_register_model_data(num_models: int) -> list[dict[str, Any]]: model_data = [] - for num_model in range(0, num_models): + for num_model in range(num_models): copy_data = copy.deepcopy(MODEL_REGISTER_DATA) - for key, value in copy_data.items(): + for value in copy_data.values(): value["name"] = f"{value['name']}{num_model}" value["description"] = f"{value['description']}{num_model}" model_data.append(copy_data) diff --git a/tests/model_registry/model_registry/upgrade/conftest.py b/tests/model_registry/model_registry/upgrade/conftest.py index 8dc492225..69ff5bb2a 100644 --- a/tests/model_registry/model_registry/upgrade/conftest.py +++ b/tests/model_registry/model_registry/upgrade/conftest.py @@ -1,22 +1,24 @@ +from collections.abc import Generator +from typing import Any + import pytest -from typing import Any, Generator -from pytest import FixtureRequest -from model_registry.types import RegisteredModel +from class_generator.parsers.explain_parser import ResourceNotFoundError from kubernetes.dynamic import DynamicClient -from pytest import Config from model_registry import ModelRegistry as ModelRegistryClient -from class_generator.parsers.explain_parser import ResourceNotFoundError -from tests.model_registry.constants import MR_INSTANCE_BASE_NAME, KUBERBACPROXY_STR -from utilities.constants import Protocols -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry +from model_registry.types import RegisteredModel +from pytest import Config, FixtureRequest from simple_logger.logger import get_logger + +from tests.model_registry.constants import KUBERBACPROXY_STR, MR_INSTANCE_BASE_NAME from tests.model_registry.utils import ( - wait_for_default_resource_cleanedup, - get_mr_standard_labels, - get_mr_service_by_label, get_endpoint_from_mr_service, + get_mr_service_by_label, + get_mr_standard_labels, + wait_for_default_resource_cleanedup, ) +from utilities.constants import Protocols from utilities.general import wait_for_pods_running +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry LOGGER = get_logger(name=__name__) MR_DEFAULT_DB_NAME: str = f"{MR_INSTANCE_BASE_NAME}1" @@ -29,7 +31,7 @@ def model_registry_instance_default_db( teardown_resources: bool, model_registry_metadata_db_resources: dict[Any, Any], model_registry_namespace: str, -) -> Generator[ModelRegistry, None, None]: +) -> Generator[ModelRegistry]: """ Create model registry instance specifically with default postgres database. """ @@ -106,7 +108,7 @@ def model_registry_client_default_db( @pytest.fixture(scope="class") def registered_model_default_db( request: FixtureRequest, model_registry_client_default_db: ModelRegistryClient -) -> Generator[RegisteredModel, None, None]: +) -> Generator[RegisteredModel]: yield model_registry_client_default_db.register_model( name=request.param.get("model_name"), uri=request.param.get("model_uri"), diff --git a/tests/model_registry/model_registry/upgrade/test_model_registry_upgrade.py b/tests/model_registry/model_registry/upgrade/test_model_registry_upgrade.py index b49efae3c..a2c228fec 100644 --- a/tests/model_registry/model_registry/upgrade/test_model_registry_upgrade.py +++ b/tests/model_registry/model_registry/upgrade/test_model_registry_upgrade.py @@ -1,15 +1,14 @@ -import pytest -from typing import Self, Any - +from typing import Any, Self -from tests.model_registry.constants import MODEL_NAME, MODEL_DICT -from model_registry.types import RegisteredModel +import pytest from model_registry import ModelRegistry as ModelRegistryClient +from model_registry.types import RegisteredModel +from simple_logger.logger import get_logger +from tests.model_registry.constants import MODEL_DICT, MODEL_NAME from tests.model_registry.model_registry.upgrade.utils import validate_upgrade_model_registration from utilities.constants import ModelFormat from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry -from simple_logger.logger import get_logger LOGGER = get_logger(name=__name__) MODEL_NAME_DEFAULT_DB: str = f"{MODEL_NAME}-default-db" diff --git a/tests/model_registry/model_registry/upgrade/utils.py b/tests/model_registry/model_registry/upgrade/utils.py index 3bfad1ed9..c65201716 100644 --- a/tests/model_registry/model_registry/upgrade/utils.py +++ b/tests/model_registry/model_registry/upgrade/utils.py @@ -1,8 +1,9 @@ import pytest -from tests.model_registry.utils import get_and_validate_registered_model from model_registry import ModelRegistry as ModelRegistryClient from model_registry.types import RegisteredModel +from tests.model_registry.utils import get_and_validate_registered_model + def validate_upgrade_model_registration( model_registry_client: ModelRegistryClient, model_name: str, registered_model: RegisteredModel = None diff --git a/tests/model_registry/scc/conftest.py b/tests/model_registry/scc/conftest.py index d6f71584c..1d8e9ea8c 100644 --- a/tests/model_registry/scc/conftest.py +++ b/tests/model_registry/scc/conftest.py @@ -1,15 +1,14 @@ import pytest from _pytest.fixtures import FixtureRequest - +from kubernetes.dynamic import DynamicClient +from ocp_resources.deployment import Deployment from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod -from ocp_resources.deployment import Deployment -from tests.model_registry.scc.utils import get_pod_by_deployment_name -from tests.model_registry.constants import MR_INSTANCE_NAME, MR_POSTGRES_DEPLOYMENT_NAME_STR - -from kubernetes.dynamic import DynamicClient from simple_logger.logger import get_logger +from tests.model_registry.constants import MR_INSTANCE_NAME, MR_POSTGRES_DEPLOYMENT_NAME_STR +from tests.model_registry.scc.utils import get_pod_by_deployment_name + LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/scc/test_model_catalog_scc.py b/tests/model_registry/scc/test_model_catalog_scc.py index c935ad4c6..dd05e87e1 100644 --- a/tests/model_registry/scc/test_model_catalog_scc.py +++ b/tests/model_registry/scc/test_model_catalog_scc.py @@ -1,16 +1,15 @@ -import pytest from typing import Self +import pytest +from ocp_resources.deployment import Deployment +from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from ocp_resources.pod import Pod -from ocp_resources.deployment import Deployment from tests.model_registry.scc.utils import ( validate_deployment_scc, validate_pod_scc, ) - LOGGER = get_logger(name=__name__) MODEL_CATALOG_STR = "model-catalog" diff --git a/tests/model_registry/scc/test_model_registry_scc.py b/tests/model_registry/scc/test_model_registry_scc.py index e173a72a2..e04a38c2d 100644 --- a/tests/model_registry/scc/test_model_registry_scc.py +++ b/tests/model_registry/scc/test_model_registry_scc.py @@ -1,15 +1,15 @@ -import pytest from typing import Self +import pytest +from ocp_resources.deployment import Deployment +from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from ocp_resources.pod import Pod -from ocp_resources.deployment import Deployment +from tests.model_registry.constants import MR_INSTANCE_NAME, MR_POSTGRES_DEPLOYMENT_NAME_STR from tests.model_registry.scc.utils import ( validate_deployment_scc, validate_pod_scc, ) -from tests.model_registry.constants import MR_INSTANCE_NAME, MR_POSTGRES_DEPLOYMENT_NAME_STR LOGGER = get_logger(name=__name__) diff --git a/tests/model_registry/scc/utils.py b/tests/model_registry/scc/utils.py index 89d3793d5..0e9da312d 100644 --- a/tests/model_registry/scc/utils.py +++ b/tests/model_registry/scc/utils.py @@ -1,11 +1,10 @@ from typing import Any -from simple_logger.logger import get_logger -from ocp_resources.pod import Pod +from kubernetes.dynamic import DynamicClient from ocp_resources.deployment import Deployment +from ocp_resources.pod import Pod from ocp_resources.resource import NamespacedResource -from kubernetes.dynamic import DynamicClient - +from simple_logger.logger import get_logger KEYS_TO_VALIDATE = ["runAsGroup", "runAsUser", "allowPrivilegeEscalation", "capabilities"] @@ -130,7 +129,7 @@ def validate_deployment_scc(deployment: Deployment) -> None: if not container_security_context: LOGGER.info(f"No container security context exists for {container.name}") else: - if not all([True for key in ["runAsGroup", "runAsUser"] if not container_security_context.get(key)]): + if not all(True for key in ["runAsGroup", "runAsUser"] if not container_security_context.get(key)): error.append({container.name: container.securityContext}) if error: diff --git a/tests/model_registry/utils.py b/tests/model_registry/utils.py index 31a3e3ce3..efa0c9722 100644 --- a/tests/model_registry/utils.py +++ b/tests/model_registry/utils.py @@ -1,38 +1,37 @@ import base64 import json import time -from typing import Any, List, Dict +from typing import Any import requests from kubernetes.dynamic import DynamicClient - +from kubernetes.dynamic.exceptions import ResourceNotFoundError +from model_registry import ModelRegistry as ModelRegistryClient +from model_registry.types import RegisteredModel from ocp_resources.config_map import ConfigMap from ocp_resources.deployment import Deployment from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.pod import Pod from ocp_resources.secret import Secret from ocp_resources.service import Service -from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry -from kubernetes.dynamic.exceptions import ResourceNotFoundError from simple_logger.logger import get_logger from timeout_sampler import retry + from tests.model_registry.constants import ( - MR_DB_IMAGE_DIGEST, - MODEL_REGISTRY_DB_SECRET_STR_DATA, - MODEL_REGISTRY_DB_SECRET_ANNOTATIONS, DB_BASE_RESOURCES_NAME, MARIADB_MY_CNF, - PORT_MAP, + MODEL_REGISTRY_DB_SECRET_ANNOTATIONS, + MODEL_REGISTRY_DB_SECRET_STR_DATA, MODEL_REGISTRY_POD_FILTER, + MR_DB_IMAGE_DIGEST, MR_POSTGRES_DB_OBJECT, + PORT_MAP, ) from tests.model_registry.exceptions import ModelRegistryResourceNotFoundError +from utilities.constants import Annotations, PodNotFound, Protocols, Timeout from utilities.exceptions import ProtocolNotSupportedError, TooManyServicesError -from utilities.constants import Protocols, Annotations, Timeout, PodNotFound -from model_registry import ModelRegistry as ModelRegistryClient -from model_registry.types import RegisteredModel - from utilities.general import wait_for_pods_running +from utilities.resources.model_registry_modelregistry_opendatahub_io import ModelRegistry from utilities.user_utils import get_byoidc_issuer_url ADDRESS_ANNOTATION_PREFIX: str = "routing.opendatahub.io/external-address-" @@ -48,8 +47,6 @@ class TransientUnauthorizedError(Exception): """Exception for transient 401 Unauthorized errors that should be retried.""" - pass - def get_mr_service_by_label(client: DynamicClient, namespace_name: str, mr_instance: ModelRegistry) -> Service: """ @@ -339,9 +336,8 @@ def wait_for_new_running_mr_pod( label_selector=MODEL_REGISTRY_POD_FILTER, ) ) - if pods and len(pods) == 1: - if pods[0].name != orig_pod_name and pods[0].status == Pod.Status.RUNNING: - return pods[0] + if pods and len(pods) == 1 and pods[0].name != orig_pod_name and pods[0].status == Pod.Status.RUNNING: + return pods[0] raise TimeoutError(f"Timeout waiting for pod {orig_pod_name} to be replaced") @@ -471,7 +467,7 @@ def get_and_validate_registered_model( model_registry_client: ModelRegistryClient, model_name: str, registered_model: RegisteredModel = None, -) -> List[str]: +) -> list[str]: """ Get and validate a registered model. """ @@ -488,12 +484,11 @@ def get_and_validate_registered_model( expected_attrs = { "name": model_name, } - errors = [ + return [ f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}" for attr, expected in expected_attrs.items() if getattr(model, attr) != expected ] - return errors def execute_model_registry_get_command(url: str, headers: dict[str, str], json_output: bool = True) -> dict[Any, Any]: @@ -536,7 +531,7 @@ def get_mr_service_objects( service_port_name = db_backend if db_backend == "postgres" else "mysql" service_uri = rf"{service_port_name}://{{.spec.clusterIP}}:{{.spec.ports[?(.name==\{service_port_name}\)].port}}" annotation = {"template.openshift.io/expose-uri": service_uri} - for num_service in range(0, num): + for num_service in range(num): name = f"{base_name}{num_service}" services.append( Service( @@ -574,7 +569,7 @@ def get_mr_configmap_objects( ) -> list[Service]: config_maps = [] if db_backend == "mariadb": - for num_config_map in range(0, num): + for num_config_map in range(num): name = f"{base_name}{num_config_map}" config_maps.append( ConfigMap( @@ -593,7 +588,7 @@ def get_mr_pvc_objects( base_name: str, namespace: str, client: DynamicClient, teardown_resources: bool, num: int ) -> list[PersistentVolumeClaim]: pvcs = [] - for num_pvc in range(0, num): + for num_pvc in range(num): name = f"{base_name}{num_pvc}" pvcs.append( PersistentVolumeClaim( @@ -613,7 +608,7 @@ def get_mr_secret_objects( base_name: str, namespace: str, client: DynamicClient, teardown_resources: bool, num: int ) -> list[Secret]: secrets = [] - for num_secret in range(0, num): + for num_secret in range(num): name = f"{base_name}{num_secret}" secrets.append( Secret( @@ -639,7 +634,7 @@ def get_mr_deployment_objects( ) -> list[Deployment]: deployments = [] - for num_deployment in range(0, num): + for num_deployment in range(num): name = f"{base_name}{num_deployment}" selectors = {"matchLabels": {"name": name}} if db_backend == "mariadb": @@ -688,7 +683,7 @@ def get_model_registry_objects( db_backend: str, ) -> list[Any]: model_registry_objects = [] - for num_mr in range(0, num): + for num_mr in range(num): name = f"{base_name}{num_mr}" db_value = None @@ -796,9 +791,8 @@ def validate_mlmd_removal_in_model_registry_pod_log( container_name = container["name"] LOGGER.info(f"Checking {container_name}") log = pod_object.log(container=container_name) - if "rest" in container_name: - if embedmd_message not in log: - errors.append(f"Missing {embedmd_message} in {container_name} log") + if "rest" in container_name and embedmd_message not in log: + errors.append(f"Missing {embedmd_message} in {container_name} log") if "MLMD" in log: errors.append(f"MLMD reference found in {container_name} log") assert not errors, f"Log validation failed with error(s): {errors}" @@ -942,8 +936,8 @@ def get_model_str(model: str) -> str: libraryName: transformers artifacts: - uri: https://huggingface.co/{model}/resolve/main/consolidated.safetensors - createTimeSinceEpoch: \"{str(current_time - 10000)}\" - lastUpdateTimeSinceEpoch: \"{str(current_time)}\" + createTimeSinceEpoch: \"{current_time - 10000!s}\" + lastUpdateTimeSinceEpoch: \"{current_time!s}\" """ @@ -976,29 +970,26 @@ def get_mr_user_token(admin_client: DynamicClient, user_credentials_rbac: dict[s "scope": "openid", } - try: - LOGGER.info(f"Requesting token for user {user_credentials_rbac['username']} in byoidc environment") - response = requests.post( - url=url, - headers=headers, - data=data, - allow_redirects=True, - timeout=30, - verify=True, # Set to False if you need to skip SSL verification - ) - response.raise_for_status() - json_response = response.json() + LOGGER.info(f"Requesting token for user {user_credentials_rbac['username']} in byoidc environment") + response = requests.post( + url=url, + headers=headers, + data=data, + allow_redirects=True, + timeout=30, + verify=True, # Set to False if you need to skip SSL verification + ) + response.raise_for_status() + json_response = response.json() - # Validate that we got an access token - if "id_token" not in json_response: - LOGGER.error("Warning: No id_token in response") - raise AssertionError(f"No id_token in response: {json_response}") - return json_response["id_token"] - except Exception as e: - raise e + # Validate that we got an access token + if "id_token" not in json_response: + LOGGER.error("Warning: No id_token in response") + raise AssertionError(f"No id_token in response: {json_response}") + return json_response["id_token"] -def get_byoidc_user_credentials(client: DynamicClient, username: str = None) -> Dict[str, str]: +def get_byoidc_user_credentials(client: DynamicClient, username: str | None = None) -> dict[str, str]: """ Get user credentials from byoidc-credentials secret. diff --git a/tests/model_serving/conftest.py b/tests/model_serving/conftest.py index a3ae9d570..c582ef164 100644 --- a/tests/model_serving/conftest.py +++ b/tests/model_serving/conftest.py @@ -1,11 +1,12 @@ -from typing import Generator, Any +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any import pytest from _pytest.fixtures import FixtureRequest from kubernetes.dynamic import DynamicClient from ocp_resources.namespace import Namespace from ocp_resources.secret import Secret -from contextlib import contextmanager @pytest.fixture(scope="session") diff --git a/tests/model_serving/model_runtime/mlserver/basic_model_deployment/test_mlserver_basic_model_deployment.py b/tests/model_serving/model_runtime/mlserver/basic_model_deployment/test_mlserver_basic_model_deployment.py index ae4a23aa5..be1fa14c4 100644 --- a/tests/model_serving/model_runtime/mlserver/basic_model_deployment/test_mlserver_basic_model_deployment.py +++ b/tests/model_serving/model_runtime/mlserver/basic_model_deployment/test_mlserver_basic_model_deployment.py @@ -8,24 +8,21 @@ from typing import Any import pytest - from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import ModelFormat, Protocols - from tests.model_serving.model_runtime.mlserver.constant import ( MODEL_CONFIGS, RAW_DEPLOYMENT_TYPE, ) - from tests.model_serving.model_runtime.mlserver.utils import ( - validate_inference_request, - get_model_storage_uri_dict, - get_model_namespace_dict, get_deployment_config_dict, + get_model_namespace_dict, + get_model_storage_uri_dict, get_test_case_id, + validate_inference_request, ) +from utilities.constants import ModelFormat, Protocols pytestmark = pytest.mark.usefixtures("valid_aws_config") diff --git a/tests/model_serving/model_runtime/mlserver/conftest.py b/tests/model_serving/model_runtime/mlserver/conftest.py index b982369b2..3f73c2d6c 100644 --- a/tests/model_serving/model_runtime/mlserver/conftest.py +++ b/tests/model_serving/model_runtime/mlserver/conftest.py @@ -8,24 +8,23 @@ - Providing test utilities like snapshots and pod resources """ -from typing import Any, Generator, cast import copy +from collections.abc import Generator +from typing import Any, cast import pytest -from syrupy.extensions.json import JSONSnapshotExtension - from kubernetes.dynamic import DynamicClient -from ocp_resources.namespace import Namespace -from ocp_resources.serving_runtime import ServingRuntime from ocp_resources.inference_service import InferenceService +from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod from ocp_resources.secret import Secret from ocp_resources.service_account import ServiceAccount +from ocp_resources.serving_runtime import ServingRuntime +from syrupy.extensions.json import JSONSnapshotExtension from tests.model_serving.model_runtime.mlserver.constant import ( PREDICT_RESOURCES, ) - from utilities.constants import ( KServeDeploymentType, Labels, @@ -43,7 +42,7 @@ def mlserver_serving_runtime( admin_client: DynamicClient, model_namespace: Namespace, mlserver_runtime_image: str, -) -> Generator[ServingRuntime, None, None]: +) -> Generator[ServingRuntime]: """ Provides a ServingRuntime resource for MLServer using the pre-installed template. @@ -75,7 +74,7 @@ def mlserver_inference_service( mlserver_serving_runtime: ServingRuntime, s3_models_storage_uri: str, mlserver_model_service_account: ServiceAccount, -) -> Generator[InferenceService, None, None]: +) -> Generator[InferenceService]: """ Creates and yields a configured InferenceService instance for MLServer testing. @@ -127,9 +126,7 @@ def mlserver_inference_service( @pytest.fixture(scope="class") -def mlserver_model_service_account( - admin_client: DynamicClient, kserve_s3_secret: Secret -) -> Generator[ServiceAccount, None, None]: +def mlserver_model_service_account(admin_client: DynamicClient, kserve_s3_secret: Secret) -> Generator[ServiceAccount]: """ Creates and yields a ServiceAccount linked to the provided S3 secret for MLServer models. diff --git a/tests/model_serving/model_runtime/mlserver/constant.py b/tests/model_serving/model_runtime/mlserver/constant.py index 6987fe5fa..13bd0f317 100644 --- a/tests/model_serving/model_runtime/mlserver/constant.py +++ b/tests/model_serving/model_runtime/mlserver/constant.py @@ -5,7 +5,7 @@ and input queries used across MLServer runtime tests for different frameworks. """ -from typing import Any, Union +from typing import Any from utilities.constants import KServeDeploymentType, ModelFormat @@ -21,7 +21,7 @@ class OutputType: RAW_DEPLOYMENT_TYPE: str = "raw" MODEL_PATH_PREFIX: str = "mlserver/model_repository" -PREDICT_RESOURCES: dict[str, Union[list[dict[str, Union[str, dict[str, str]]]], dict[str, dict[str, str]]]] = { +PREDICT_RESOURCES: dict[str, list[dict[str, str | dict[str, str]]] | dict[str, dict[str, str]]] = { "volumes": [ {"name": "shared-memory", "emptyDir": {"medium": "Memory", "sizeLimit": "2Gi"}}, {"name": "tmp", "emptyDir": {}}, diff --git a/tests/model_serving/model_runtime/mlserver/utils.py b/tests/model_serving/model_runtime/mlserver/utils.py index dc704e08f..927d2941c 100644 --- a/tests/model_serving/model_runtime/mlserver/utils.py +++ b/tests/model_serving/model_runtime/mlserver/utils.py @@ -18,8 +18,8 @@ BASE_RAW_DEPLOYMENT_CONFIG, LOCALHOST_URL, MODEL_PATH_PREFIX, - OutputType, RAW_DEPLOYMENT_TYPE, + OutputType, ) from utilities.constants import KServeDeploymentType, Ports, Protocols diff --git a/tests/model_serving/model_runtime/model_validation/conftest.py b/tests/model_serving/model_runtime/model_validation/conftest.py index 76e4f55bb..36b269db7 100644 --- a/tests/model_serving/model_runtime/model_validation/conftest.py +++ b/tests/model_serving/model_runtime/model_validation/conftest.py @@ -1,5 +1,6 @@ import json -from typing import Any, Generator, List +from collections.abc import Generator +from typing import Any import pytest import yaml @@ -10,29 +11,24 @@ from ocp_resources.secret import Secret from ocp_resources.serving_runtime import ServingRuntime from pytest import FixtureRequest -from utilities.infra import get_pods_by_isvc_label +from simple_logger.logger import get_logger from tests.model_serving.model_runtime.model_validation.constant import ( ACCELERATOR_IDENTIFIER, - TEMPLATE_MAP, + BASE_RAW_DEPLOYMENT_CONFIG, PREDICT_RESOURCES, PULL_SECRET_ACCESS_TYPE, -) -from tests.model_serving.model_runtime.model_validation.constant import ( - BASE_RAW_DEPLOYMENT_CONFIG, -) -from tests.model_serving.model_runtime.model_validation.constant import PULL_SECRET_NAME -from tests.model_serving.model_runtime.model_validation.constant import ( + PULL_SECRET_NAME, + TEMPLATE_MAP, TIMEOUT_20MIN, ) from tests.model_serving.model_runtime.model_validation.utils import safe_k8s_name from tests.model_serving.model_runtime.vllm.utils import validate_supported_quantization_schema from utilities.constants import KServeDeploymentType, Labels, RuntimeTemplates from utilities.inference_utils import create_isvc +from utilities.infra import get_pods_by_isvc_label from utilities.serving_runtime import ServingRuntimeFromTemplate -from simple_logger.logger import get_logger - LOGGER = get_logger(name=__name__) @@ -43,7 +39,7 @@ def model_car_serving_runtime( model_namespace: Namespace, supported_accelerator_type: str, vllm_runtime_image: str, -) -> Generator[ServingRuntime, None, None]: +) -> Generator[ServingRuntime]: accelerator_type = supported_accelerator_type.lower() template_name = TEMPLATE_MAP.get(accelerator_type, RuntimeTemplates.VLLM_CUDA) @@ -106,11 +102,7 @@ def vllm_model_car_inference_service( isvc_kwargs["volumes_mounts"] = PREDICT_RESOURCES["volume_mounts"] if arguments := deployment_config.get("runtime_argument"): - arguments = [ - arg - for arg in arguments - if not (arg.startswith("--tensor-parallel-size") or arg.startswith("--quantization")) - ] + arguments = [arg for arg in arguments if not arg.startswith(("--tensor-parallel-size", "--quantization"))] arguments.append(f"--tensor-parallel-size={gpu_count}") if quantization := request.param.get("quantization"): validate_supported_quantization_schema(q_type=quantization) @@ -195,7 +187,7 @@ def build_raw_params( return param, test_id -def build_pytest_markers(deployment_type: str, execution_mode: str) -> List[Any]: +def build_pytest_markers(deployment_type: str, execution_mode: str) -> list[Any]: """ Build a list of pytest markers based on deployment type, execution mode. @@ -206,7 +198,7 @@ def build_pytest_markers(deployment_type: str, execution_mode: str) -> List[Any] Returns: List[Any]: List of pytest.mark objects to attach to the test """ - markers: List[pytest.MarkDecorator] = [] + markers: list[pytest.MarkDecorator] = [] if deployment_type == KServeDeploymentType.RAW_DEPLOYMENT: markers.append(pytest.mark.rawdeployment) @@ -233,7 +225,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: default_serving_config = yaml_config.get("default", {}) if not isinstance(model_car_data, list): - raise ValueError("Invalid format for `model-car` in YAML. Expected a list of objects.") + raise TypeError("Invalid format for `model-car` in YAML. Expected a list of objects.") if not metafunc.cls: return diff --git a/tests/model_serving/model_runtime/model_validation/constant.py b/tests/model_serving/model_runtime/model_validation/constant.py index 01c9858a7..f42c20456 100644 --- a/tests/model_serving/model_runtime/model_validation/constant.py +++ b/tests/model_serving/model_runtime/model_validation/constant.py @@ -1,6 +1,6 @@ -from typing import Union, Any -from utilities.constants import AcceleratorType, Labels, RuntimeTemplates -from utilities.constants import KServeDeploymentType +from typing import Any + +from utilities.constants import AcceleratorType, KServeDeploymentType, Labels, RuntimeTemplates # Configurations ACCELERATOR_IDENTIFIER: dict[str, str] = { @@ -20,7 +20,7 @@ } -PREDICT_RESOURCES: dict[str, Union[list[dict[str, Union[str, dict[str, str]]]], dict[str, dict[str, str]]]] = { +PREDICT_RESOURCES: dict[str, list[dict[str, str | dict[str, str]]] | dict[str, dict[str, str]]] = { "volumes": [ {"name": "shared-memory", "emptyDir": {"medium": "Memory", "sizeLimit": "16Gi"}}, {"name": "tmp", "emptyDir": {}}, @@ -50,7 +50,7 @@ {"text": "Explain the significance of the Great Wall of China in history and its impact on modern tourism."}, {"text": "Discuss the ethical implications of using artificial intelligence in healthcare decision-making."}, { - "text": "Summarize the main events of the Apollo 11 moon landing and its importance in space exploration history." # noqa: E122, E501 + "text": "Summarize the main events of the Apollo 11 moon landing and its importance in space exploration history." # noqa: E501 }, ] @@ -79,7 +79,7 @@ {"text": "Explain the significance of the Great Wall of China in history and its impact on modern tourism."}, {"text": "Discuss the ethical implications of using artificial intelligence in healthcare decision-making."}, { - "text": "Summarize the main events of the Apollo 11 moon landing and its importance in space exploration history." # noqa: E122, E501 + "text": "Summarize the main events of the Apollo 11 moon landing and its importance in space exploration history." # noqa: E501 }, ] diff --git a/tests/model_serving/model_runtime/model_validation/test_modelvalidation.py b/tests/model_serving/model_runtime/model_validation/test_modelvalidation.py index a449f93cc..b13365070 100644 --- a/tests/model_serving/model_runtime/model_validation/test_modelvalidation.py +++ b/tests/model_serving/model_runtime/model_validation/test_modelvalidation.py @@ -1,12 +1,14 @@ from typing import Any + import pytest -from simple_logger.logger import get_logger from ocp_resources.inference_service import InferenceService +from ocp_resources.pod import Pod +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.model_validation.constant import COMPLETION_QUERY from tests.model_serving.model_runtime.utils import ( validate_raw_openai_inference_request, ) -from tests.model_serving.model_runtime.model_validation.constant import COMPLETION_QUERY -from ocp_resources.pod import Pod LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/model_validation/utils.py b/tests/model_serving/model_runtime/model_validation/utils.py index 3840494a0..c2c2fa92b 100644 --- a/tests/model_serving/model_runtime/model_validation/utils.py +++ b/tests/model_serving/model_runtime/model_validation/utils.py @@ -1,5 +1,6 @@ import re from typing import Any + from tests.model_serving.model_runtime.vllm.constant import VLLM_SUPPORTED_QUANTIZATION diff --git a/tests/model_serving/model_runtime/openvino/conftest.py b/tests/model_serving/model_runtime/openvino/conftest.py index d765e665c..fcb1c9f0c 100644 --- a/tests/model_serving/model_runtime/openvino/conftest.py +++ b/tests/model_serving/model_runtime/openvino/conftest.py @@ -8,23 +8,23 @@ - Providing test utilities like snapshots and pod resources """ -from typing import cast, Any, Generator import copy +from collections.abc import Generator +from typing import Any, cast import pytest -from syrupy.extensions.json import JSONSnapshotExtension - from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError -from ocp_resources.namespace import Namespace -from ocp_resources.serving_runtime import ServingRuntime from ocp_resources.inference_service import InferenceService +from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod from ocp_resources.secret import Secret from ocp_resources.service_account import ServiceAccount +from ocp_resources.serving_runtime import ServingRuntime +from simple_logger.logger import get_logger +from syrupy.extensions.json import JSONSnapshotExtension from tests.model_serving.model_runtime.openvino.constant import PREDICT_RESOURCES - from utilities.constants import ( KServeDeploymentType, Labels, @@ -34,9 +34,6 @@ from utilities.infra import get_pods_by_isvc_label from utilities.serving_runtime import ServingRuntimeFromTemplate -from simple_logger.logger import get_logger - - LOGGER = get_logger(name=__name__) @@ -45,7 +42,7 @@ def openvino_serving_runtime( request: pytest.FixtureRequest, admin_client: DynamicClient, model_namespace: Namespace, -) -> Generator[ServingRuntime, None, None]: +) -> Generator[ServingRuntime]: """ Provides a ServingRuntime resource for OpenVINO with the specified protocol and deployment type. diff --git a/tests/model_serving/model_runtime/openvino/constant.py b/tests/model_serving/model_runtime/openvino/constant.py index 877577c12..eef794272 100644 --- a/tests/model_serving/model_runtime/openvino/constant.py +++ b/tests/model_serving/model_runtime/openvino/constant.py @@ -5,14 +5,13 @@ and input queries used across OpenVINO runtime tests for different frameworks. """ -from typing import Any, Union - from pathlib import Path +from typing import Any from utilities.constants import ( KServeDeploymentType, - Protocols, ModelFormat, + Protocols, ) MODEL_PATH_PREFIX: str = "openvino/model_repository" @@ -25,7 +24,7 @@ REST_PROTOCOL_TYPE_DICT: dict[str, str] = {"protocol_type": Protocols.REST} -PREDICT_RESOURCES: dict[str, Union[list[dict[str, Union[str, dict[str, str]]]], dict[str, dict[str, str]]]] = { +PREDICT_RESOURCES: dict[str, list[dict[str, str | dict[str, str]]] | dict[str, dict[str, str]]] = { "volumes": [ {"name": "shared-memory", "emptyDir": {"medium": "Memory", "sizeLimit": "2Gi"}}, {"name": "tmp", "emptyDir": {}}, diff --git a/tests/model_serving/model_runtime/openvino/test_ovms_model_deployment.py b/tests/model_serving/model_runtime/openvino/test_ovms_model_deployment.py index 756da586e..71c9bf74a 100644 --- a/tests/model_serving/model_runtime/openvino/test_ovms_model_deployment.py +++ b/tests/model_serving/model_runtime/openvino/test_ovms_model_deployment.py @@ -8,27 +8,24 @@ from typing import Any import pytest -from simple_logger.logger import get_logger - from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod - -from utilities.constants import Protocols, ModelFormat +from simple_logger.logger import get_logger from tests.model_serving.model_runtime.openvino.constant import ( MODEL_CONFIGS, RAW_DEPLOYMENT_TYPE, REST_PROTOCOL_TYPE_DICT, ) - from tests.model_serving.model_runtime.openvino.utils import ( - validate_inference_request, - get_model_storage_uri_dict, - get_model_namespace_dict, get_deployment_config_dict, - get_test_case_id, get_input_query, + get_model_namespace_dict, + get_model_storage_uri_dict, + get_test_case_id, + validate_inference_request, ) +from utilities.constants import ModelFormat, Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/openvino/utils.py b/tests/model_serving/model_runtime/openvino/utils.py index 3acd2b22d..592b07974 100644 --- a/tests/model_serving/model_runtime/openvino/utils.py +++ b/tests/model_serving/model_runtime/openvino/utils.py @@ -8,20 +8,20 @@ - Validating responses against snapshots """ -from typing import Any, Dict - -import os import json +import os +from typing import Any + import portforward import requests from ocp_resources.inference_service import InferenceService from tests.model_serving.model_runtime.openvino.constant import ( - OPENVINO_REST_PORT, + BASE_RAW_DEPLOYMENT_CONFIG, LOCAL_HOST_URL, MODEL_PATH_PREFIX, + OPENVINO_REST_PORT, RAW_DEPLOYMENT_TYPE, - BASE_RAW_DEPLOYMENT_CONFIG, ) from utilities.constants import KServeDeploymentType @@ -198,7 +198,7 @@ def get_test_case_id(model_format_name: str, deployment_type: str, protocol_type return f"{model_format_name.strip()}-{deployment_type.strip()}-{protocol_type.strip()}-deployment" -def get_input_query(model_format_config: Dict[str, Any], protocol: str) -> Dict[str, Any]: +def get_input_query(model_format_config: dict[str, Any], protocol: str) -> dict[str, Any]: """ Get the input query for the given protocol from the model config. diff --git a/tests/model_serving/model_runtime/rhoai_upgrade/conftest.py b/tests/model_serving/model_runtime/rhoai_upgrade/conftest.py index 9cfe13bb4..35f8b9648 100644 --- a/tests/model_serving/model_runtime/rhoai_upgrade/conftest.py +++ b/tests/model_serving/model_runtime/rhoai_upgrade/conftest.py @@ -1,17 +1,16 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient -from pytest_testconfig import config as py_config - from ocp_resources.template import Template +from pytest_testconfig import config as py_config from tests.model_serving.model_runtime.rhoai_upgrade.constant import ( OVMS_SERVING_RUNTIME_TEMPLATE_DICT, - SERVING_RUNTIME_TEMPLATE_NAME, SERVING_RUNTIME_INSTANCE_NAME, + SERVING_RUNTIME_TEMPLATE_NAME, ) - from utilities.serving_runtime import ServingRuntimeFromTemplate diff --git a/tests/model_serving/model_runtime/rhoai_upgrade/constant.py b/tests/model_serving/model_runtime/rhoai_upgrade/constant.py index bf3b5ce78..20e0878c1 100644 --- a/tests/model_serving/model_runtime/rhoai_upgrade/constant.py +++ b/tests/model_serving/model_runtime/rhoai_upgrade/constant.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Union +from typing import Any SERVING_RUNTIME_TEMPLATE_NAME: str = "kserve-ovms-serving-runtime-template" SERVING_RUNTIME_INSTANCE_NAME: str = "kserve-ovms-serving-runtime-instance" @@ -7,34 +7,34 @@ "quay.io/modh/openvino_model_server@sha256:53b7fcf95de9b81e4c8652d0bf4e84e22d5b696827a5d951d863420c68b9cfe8" ) -OVMS_TEMPLATE_LABELS: Dict[str, str] = { +OVMS_TEMPLATE_LABELS: dict[str, str] = { "opendatahub.io/dashboard": "true", "opendatahub.io/ootb": "true", } -OVMS_TEMPLATE_ANNOTATIONS: Dict[str, str] = { +OVMS_TEMPLATE_ANNOTATIONS: dict[str, str] = { "tags": "kserve-ovms,servingruntime", "description": "OpenVino Model Serving Definition", "opendatahub.io/modelServingSupport": '["single"]', "opendatahub.io/apiProtocol": "REST", } -OVMS_RUNTIME_LABELS: Dict[str, str] = { +OVMS_RUNTIME_LABELS: dict[str, str] = { "opendatahub.io/dashboard": "true", } -OVMS_RUNTIME_ANNOTATIONS: Dict[str, str] = { +OVMS_RUNTIME_ANNOTATIONS: dict[str, str] = { "openshift.io/display-name": "OpenVINO Model Server", "opendatahub.io/recommended-accelerators": '["nvidia.com/gpu"]', "opendatahub.io/runtime-version": "v2025.1", } -OVMS_RUNTIME_PROMETHEUS_ANNOTATIONS: Dict[str, str] = { +OVMS_RUNTIME_PROMETHEUS_ANNOTATIONS: dict[str, str] = { "prometheus.io/port": "8888", "prometheus.io/path": "/metrics", } -OVMS_SUPPORTED_MODEL_FORMATS: List[Dict[str, Union[str, bool]]] = [ +OVMS_SUPPORTED_MODEL_FORMATS: list[dict[str, str | bool]] = [ {"name": "openvino_ir", "version": "opset13", "autoSelect": True}, {"name": "onnx", "version": "1"}, {"name": "tensorflow", "version": "1", "autoSelect": True}, @@ -43,7 +43,7 @@ {"name": "pytorch", "version": "2", "autoSelect": True}, ] -OVMS_CONTAINER_ARGS: List[str] = [ +OVMS_CONTAINER_ARGS: list[str] = [ "--model_name={{.Name}}", "--port=8001", "--rest_port=8888", @@ -55,7 +55,7 @@ "--metrics_enable", ] -OVMS_SERVING_RUNTIME_TEMPLATE_DICT: Dict[str, Any] = { +OVMS_SERVING_RUNTIME_TEMPLATE_DICT: dict[str, Any] = { "metadata": { "name": SERVING_RUNTIME_TEMPLATE_NAME, "labels": OVMS_TEMPLATE_LABELS, diff --git a/tests/model_serving/model_runtime/rhoai_upgrade/test_upgrade.py b/tests/model_serving/model_runtime/rhoai_upgrade/test_upgrade.py index 117190bda..dd0d0819e 100644 --- a/tests/model_serving/model_runtime/rhoai_upgrade/test_upgrade.py +++ b/tests/model_serving/model_runtime/rhoai_upgrade/test_upgrade.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from simple_logger.logger import get_logger diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/conftest.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/conftest.py index e5d81fc73..b8b3ddeab 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/conftest.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/conftest.py @@ -1,30 +1,30 @@ -from typing import cast, Any, Generator, List, Dict import copy -import pytest +from collections.abc import Generator from contextlib import contextmanager +from typing import Any, cast -from kubernetes.dynamic.exceptions import ResourceNotFoundError -from syrupy.extensions.json import JSONSnapshotExtension -from pytest_testconfig import config as py_config - +import pytest from kubernetes.dynamic import DynamicClient -from ocp_resources.namespace import Namespace -from ocp_resources.serving_runtime import ServingRuntime +from kubernetes.dynamic.exceptions import ResourceNotFoundError from ocp_resources.inference_service import InferenceService +from ocp_resources.namespace import Namespace from ocp_resources.pod import Pod from ocp_resources.secret import Secret -from ocp_resources.template import Template from ocp_resources.service_account import ServiceAccount +from ocp_resources.serving_runtime import ServingRuntime +from ocp_resources.template import Template +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger +from syrupy.extensions.json import JSONSnapshotExtension +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import ( + get_gpu_identifier, + get_template_name, +) from tests.model_serving.model_runtime.triton.constant import ( PREDICT_RESOURCES, RUNTIME_MAP, ) -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import ( - get_template_name, - get_gpu_identifier, -) - from utilities.constants import ( KServeDeploymentType, Protocols, @@ -33,15 +33,11 @@ from utilities.infra import get_pods_by_isvc_label from utilities.serving_runtime import ServingRuntimeFromTemplate -from simple_logger.logger import get_logger - LOGGER = get_logger(name=__name__) @pytest.fixture(scope="class") -def triton_grpc_serving_runtime_template( - admin_client: DynamicClient, triton_runtime_image: str -) -> Generator[Template, None, None]: +def triton_grpc_serving_runtime_template(admin_client: DynamicClient, triton_runtime_image: str) -> Generator[Template]: with create_triton_template( admin_client=admin_client, protocol=Protocols.GRPC, triton_runtime_image=triton_runtime_image ) as template: @@ -49,9 +45,7 @@ def triton_grpc_serving_runtime_template( @pytest.fixture(scope="class") -def triton_rest_serving_runtime_template( - admin_client: DynamicClient, triton_runtime_image: str -) -> Generator[Template, None, None]: +def triton_rest_serving_runtime_template(admin_client: DynamicClient, triton_runtime_image: str) -> Generator[Template]: with create_triton_template( admin_client=admin_client, protocol=Protocols.REST, triton_runtime_image=triton_runtime_image ) as template: @@ -101,7 +95,7 @@ def create_triton_serving_runtime(protocol: str, triton_runtime_image: str) -> d f"--{'allow-grpc' if protocol == Protocols.GRPC else 'allow-http'}=True", ] - kserve_container: List[Dict[str, Any]] = [ + kserve_container: list[dict[str, Any]] = [ { "name": "kserve-container", "image": triton_runtime_image, @@ -121,7 +115,7 @@ def create_triton_serving_runtime(protocol: str, triton_runtime_image: str) -> d } ] - supported_model_formats: List[Dict[str, Any]] = [ + supported_model_formats: list[dict[str, Any]] = [ {"name": "tensorrt", "version": "8", "autoSelect": True, "priority": 1}, {"name": "tensorflow", "version": "1", "autoSelect": True, "priority": 1}, {"name": "tensorflow", "version": "2", "autoSelect": True, "priority": 1}, @@ -158,7 +152,7 @@ def triton_serving_runtime( model_namespace: Namespace, protocol: str, supported_accelerator_type: str | None, -) -> Generator[ServingRuntime, None, None]: +) -> Generator[ServingRuntime]: template_name = get_template_name(protocol=protocol, accelerator_type=supported_accelerator_type) with ServingRuntimeFromTemplate( client=admin_client, @@ -220,9 +214,7 @@ def triton_inference_service( @pytest.fixture(scope="class") -def triton_model_service_account( - admin_client: DynamicClient, kserve_s3_secret: Secret -) -> Generator[ServiceAccount, None, None]: +def triton_model_service_account(admin_client: DynamicClient, kserve_s3_secret: Secret) -> Generator[ServiceAccount]: with ServiceAccount( client=admin_client, namespace=kserve_s3_secret.namespace, diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_dali_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_dali_model.py index d7bce46e6..d5366b5d6 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_dali_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_dali_model.py @@ -11,14 +11,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX_DALI, TRITON_GRPC_DALI_INPUT_PATH, TRITON_REST_DALI_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_fil_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_fil_model.py index e8827fc4b..97acda652 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_fil_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_fil_model.py @@ -11,14 +11,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX, TRITON_GRPC_FIL_INPUT_PATH, TRITON_REST_FIL_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_keras_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_keras_model.py index e53413ba1..fa2d86f86 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_keras_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_keras_model.py @@ -11,14 +11,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX_KERAS, TRITON_GRPC_KERAS_INPUT_PATH, TRITON_REST_KERAS_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_onnx_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_onnx_model.py index 68c84d74c..04899eb07 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_onnx_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_onnx_model.py @@ -11,14 +11,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX, TRITON_GRPC_ONNX_INPUT_PATH, TRITON_REST_ONNX_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_python_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_python_model.py index e7c317090..f5d692f5e 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_python_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_python_model.py @@ -11,14 +11,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX, TRITON_GRPC_PYTHON_INPUT_PATH, TRITON_REST_PYTHON_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_pytorch_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_pytorch_model.py index 1f195c244..bcc1fe44b 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_pytorch_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_pytorch_model.py @@ -11,14 +11,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX, TRITON_GRPC_PYTORCH_INPUT_PATH, TRITON_REST_PYTORCH_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_tensorflow_model.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_tensorflow_model.py index 91ae1ebf6..99be62df2 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/test_tensorflow_model.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/test_tensorflow_model.py @@ -13,14 +13,14 @@ from ocp_resources.pod import Pod from simple_logger.logger import get_logger -from utilities.constants import Protocols -from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import validate_inference_request, load_json +from tests.model_serving.model_runtime.triton.basic_model_deployment.utils import load_json, validate_inference_request from tests.model_serving.model_runtime.triton.constant import ( BASE_RAW_DEPLOYMENT_CONFIG, MODEL_PATH_PREFIX, TRITON_GRPC_TF_INPUT_PATH, TRITON_REST_TF_INPUT_PATH, ) +from utilities.constants import Protocols LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/triton/basic_model_deployment/utils.py b/tests/model_serving/model_runtime/triton/basic_model_deployment/utils.py index bf46cda5e..a420bb47a 100644 --- a/tests/model_serving/model_runtime/triton/basic_model_deployment/utils.py +++ b/tests/model_serving/model_runtime/triton/basic_model_deployment/utils.py @@ -18,15 +18,15 @@ import requests from ocp_resources.inference_service import InferenceService -from tests.model_serving.model_runtime.triton.constant import ACCELERATOR_IDENTIFIER, TEMPLATE_MAP from tests.model_serving.model_runtime.triton.constant import ( + ACCELERATOR_IDENTIFIER, LOCAL_HOST_URL, PROTO_FILE_PATH, - TRITON_REST_PORT, + TEMPLATE_MAP, TRITON_GRPC_PORT, + TRITON_REST_PORT, ) -from utilities.constants import KServeDeploymentType, Protocols -from utilities.constants import Labels, RuntimeTemplates +from utilities.constants import KServeDeploymentType, Labels, Protocols, RuntimeTemplates def send_rest_request(url: str, input_data: dict[str, Any]) -> Any: diff --git a/tests/model_serving/model_runtime/triton/constant.py b/tests/model_serving/model_runtime/triton/constant.py index f206dfb71..f6ee88ee2 100644 --- a/tests/model_serving/model_runtime/triton/constant.py +++ b/tests/model_serving/model_runtime/triton/constant.py @@ -1,11 +1,11 @@ import os -from typing import Any, Union +from typing import Any from utilities.constants import ( KServeDeploymentType, + Labels, Protocols, RuntimeTemplates, - Labels, Timeout, ) @@ -50,7 +50,7 @@ Protocols.GRPC: "triton-grpc-runtime", } -PREDICT_RESOURCES: dict[str, Union[list[dict[str, Union[str, dict[str, str]]]], dict[str, dict[str, str]]]] = { +PREDICT_RESOURCES: dict[str, list[dict[str, str | dict[str, str]]] | dict[str, dict[str, str]]] = { "volumes": [ {"name": "shared-memory", "emptyDir": {"medium": "Memory", "sizeLimit": "16Gi"}}, {"name": "tmp", "emptyDir": {}}, diff --git a/tests/model_serving/model_runtime/utils.py b/tests/model_serving/model_runtime/utils.py index c7d138ad9..e74d6b269 100644 --- a/tests/model_serving/model_runtime/utils.py +++ b/tests/model_serving/model_runtime/utils.py @@ -1,4 +1,7 @@ -from typing import Any, Iterable, Optional +import os +import subprocess +from collections.abc import Iterable +from typing import Any import portforward from ocp_resources.inference_service import InferenceService @@ -6,11 +9,11 @@ from tenacity import retry, stop_after_attempt, wait_exponential from tests.model_serving.model_runtime.model_validation.constant import ( + AUDIO_FILE_LOCAL_PATH, + AUDIO_FILE_URL, COMPLETION_QUERY, EMBEDDING_QUERY, OPENAI_ENDPOINT_NAME, - AUDIO_FILE_URL, - AUDIO_FILE_LOCAL_PATH, SPYRE_INFERENCE_SERVICE_PORT, ) from utilities.constants import Ports @@ -18,8 +21,6 @@ from utilities.plugins.constant import OpenAIEnpoints from utilities.plugins.openai_plugin import OpenAIClient from utilities.plugins.tgis_grpc_plugin import TGISGRPCPlugin -import subprocess -import os LOGGER = get_logger(name=__name__) @@ -99,10 +100,10 @@ def run_raw_inference( def run_embedding_inference( endpoint: str, model_name: str, - url: Optional[str] = None, - pod_name: Optional[str] = None, - isvc: Optional[InferenceService] = None, - port: Optional[int] = Ports.REST_PORT, + url: str | None = None, + pod_name: str | None = None, + isvc: InferenceService | None = None, + port: int | None = Ports.REST_PORT, embedding_query: list[dict[str, str]] = EMBEDDING_QUERY, ) -> tuple[Any, list[Any]]: LOGGER.info("Running embedding inference for model: %s on endpoint: %s", model_name, endpoint) @@ -149,10 +150,10 @@ def run_audio_inference( model_name: str, audio_file_path: str = AUDIO_FILE_LOCAL_PATH, audio_file_url: str = AUDIO_FILE_URL, - url: Optional[str] = None, - pod_name: Optional[str] = None, - isvc: Optional[InferenceService] = None, - port: Optional[int] = Ports.REST_PORT, + url: str | None = None, + pod_name: str | None = None, + isvc: InferenceService | None = None, + port: int | None = Ports.REST_PORT, ) -> tuple[Any, list[Any]]: LOGGER.info("Running audio inference for model: %s on endpoint: %s", model_name, endpoint) download_audio_file(audio_file_url=audio_file_url, destination_path=audio_file_path) @@ -263,7 +264,7 @@ def download_audio_file(audio_file_url: str = AUDIO_FILE_URL, destination_path: return cmd = ["curl", "-fSL", "-o", destination_path, audio_file_url] try: - subprocess.run(args=cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + subprocess.run(args=cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # noqa: UP022 LOGGER.info("Audio file downloaded successfully to %s", destination_path) except subprocess.CalledProcessError as e: stderr = e.stderr.decode() if e.stderr else str(e) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_elyza_japanese_llama_2_7b_instruct.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_elyza_japanese_llama_2_7b_instruct.py index 20538ad34..dcf49bb55 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_elyza_japanese_llama_2_7b_instruct.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_elyza_japanese_llama_2_7b_instruct.py @@ -1,18 +1,21 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY_JAPANESE, + COMPLETION_QUERY_JAPANESE, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY_JAPANESE, - CHAT_QUERY_JAPANESE, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_2b_instruct_preview_4k_r240917a.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_2b_instruct_preview_4k_r240917a.py index 4836765e7..0e0300e50 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_2b_instruct_preview_4k_r240917a.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_2b_instruct_preview_4k_r240917a.py @@ -1,7 +1,8 @@ import pytest from simple_logger.logger import get_logger -from utilities.constants import KServeDeploymentType, Ports + from tests.model_serving.model_runtime.vllm.utils import run_raw_inference +from utilities.constants import KServeDeploymentType, Ports LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_redhat_lab.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_redhat_lab.py index f35ce352c..57b3a6bc3 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_redhat_lab.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_redhat_lab.py @@ -1,18 +1,21 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY, + COMPLETION_QUERY, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY, - CHAT_QUERY, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_starter.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_starter.py index d9ae3891c..9bec1ecbb 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_starter.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_granite_7b_starter.py @@ -1,18 +1,21 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY, + COMPLETION_QUERY, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY, - CHAT_QUERY, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama31_8B_instruct.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama31_8B_instruct.py index 04d0cb5ef..be60e6481 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama31_8B_instruct.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama31_8B_instruct.py @@ -1,18 +1,22 @@ +# noqa: N999 +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY, + COMPLETION_QUERY, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY, - CHAT_QUERY, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama3_8B_instruct.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama3_8B_instruct.py index 9446f4179..2071ea424 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama3_8B_instruct.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama3_8B_instruct.py @@ -1,18 +1,22 @@ +# noqa: N999 +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY, + COMPLETION_QUERY, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY, - CHAT_QUERY, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama_2_13b_chat.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama_2_13b_chat.py index e5317de95..3f2ed716f 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama_2_13b_chat.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_llama_2_13b_chat.py @@ -1,18 +1,21 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY, + COMPLETION_QUERY, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY, - CHAT_QUERY, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_merlinite_7b_lab.py b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_merlinite_7b_lab.py index 30d75f124..1a351aebe 100644 --- a/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_merlinite_7b_lab.py +++ b/tests/model_serving/model_runtime/vllm/basic_model_deployment/test_merlinite_7b_lab.py @@ -1,18 +1,21 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ( + BASE_RAW_DEPLOYMENT_CONFIG, + CHAT_QUERY, + COMPLETION_QUERY, +) from tests.model_serving.model_runtime.vllm.utils import ( validate_raw_openai_inference_request, validate_raw_tgis_inference_request, ) -from tests.model_serving.model_runtime.vllm.constant import ( - COMPLETION_QUERY, - CHAT_QUERY, - BASE_RAW_DEPLOYMENT_CONFIG, -) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/conftest.py b/tests/model_serving/model_runtime/vllm/conftest.py index a6229bab6..8e92366c6 100644 --- a/tests/model_serving/model_runtime/vllm/conftest.py +++ b/tests/model_serving/model_runtime/vllm/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient @@ -14,11 +15,10 @@ from tests.model_serving.model_runtime.vllm.constant import ACCELERATOR_IDENTIFIER, PREDICT_RESOURCES, TEMPLATE_MAP from tests.model_serving.model_runtime.vllm.utils import ( kserve_s3_endpoint_secret, - validate_supported_quantization_schema, skip_if_not_deployment_mode, + validate_supported_quantization_schema, ) -from utilities.constants import KServeDeploymentType -from utilities.constants import Labels, RuntimeTemplates +from utilities.constants import KServeDeploymentType, Labels, RuntimeTemplates from utilities.inference_utils import create_isvc from utilities.infra import get_pods_by_isvc_label from utilities.serving_runtime import ServingRuntimeFromTemplate @@ -33,7 +33,7 @@ def serving_runtime( model_namespace: Namespace, supported_accelerator_type: str, vllm_runtime_image: str, -) -> Generator[ServingRuntime, None, None]: +) -> Generator[ServingRuntime]: accelerator_type = supported_accelerator_type.lower() template_name = TEMPLATE_MAP.get(accelerator_type, RuntimeTemplates.VLLM_CUDA) with ServingRuntimeFromTemplate( @@ -82,11 +82,7 @@ def vllm_inference_service( isvc_kwargs["volumes"] = PREDICT_RESOURCES["volumes"] isvc_kwargs["volumes_mounts"] = PREDICT_RESOURCES["volume_mounts"] if arguments := request.param.get("runtime_argument"): - arguments = [ - arg - for arg in arguments - if not (arg.startswith("--tensor-parallel-size") or arg.startswith("--quantization")) - ] + arguments = [arg for arg in arguments if not arg.startswith(("--tensor-parallel-size", "--quantization"))] arguments.append(f"--tensor-parallel-size={gpu_count}") if quantization := request.param.get("quantization"): validate_supported_quantization_schema(q_type=quantization) @@ -119,7 +115,7 @@ def kserve_endpoint_s3_secret( aws_secret_access_key: str, models_s3_bucket_region: str, models_s3_bucket_endpoint: str, -) -> Generator[Secret, None, None]: +) -> Generator[Secret]: with kserve_s3_endpoint_secret( admin_client=admin_client, name="models-bucket-secret", diff --git a/tests/model_serving/model_runtime/vllm/constant.py b/tests/model_serving/model_runtime/vllm/constant.py index a8efc7280..498a2428b 100644 --- a/tests/model_serving/model_runtime/vllm/constant.py +++ b/tests/model_serving/model_runtime/vllm/constant.py @@ -1,4 +1,5 @@ -from typing import Any, Union +from typing import Any + from utilities.constants import AcceleratorType, KServeDeploymentType, Labels, RuntimeTemplates OPENAI_ENDPOINT_NAME: str = "openai" @@ -18,7 +19,7 @@ AcceleratorType.GAUDI: RuntimeTemplates.VLLM_GAUDI, } -PREDICT_RESOURCES: dict[str, Union[list[dict[str, Union[str, dict[str, str]]]], dict[str, dict[str, str]]]] = { +PREDICT_RESOURCES: dict[str, list[dict[str, str | dict[str, str]]] | dict[str, dict[str, str]]] = { "volumes": [ {"name": "shared-memory", "emptyDir": {"medium": "Memory", "sizeLimit": "16Gi"}}, {"name": "tmp", "emptyDir": {}}, @@ -220,7 +221,7 @@ "text": "小説に登場させる魔法使いのキャラクターを考えています。主人公の師となるようなキャラクターの案を背景を含めて考えてください。" }, { - "text": "日本国内で観光に行きたいと思っています。東京、名古屋、大阪、京都、福岡の特徴を表にまとめてください。列名は「都道府県」「おすすめスポット」「おすすめグルメ」にしてください。" + "text": "日本国内で観光に行きたいと思っています。東京、名古屋、大阪、京都、福岡の特徴を表にまとめてください。列名は「都道府県」「おすすめスポット」「おすすめグルメ」にしてください。" # noqa: E501 }, ] @@ -239,7 +240,7 @@ }, { "role": "user", - "content": "ルービックキューブをセンター試験の会場で、休憩時間に回そうと思っています。このような行動をしたときに周囲の人たちが感じるであろう感情について、3パターン程度述べてください。", + "content": "ルービックキューブをセンター試験の会場で、休憩時間に回そうと思っています。このような行動をしたときに周囲の人たちが感じるであろう感情について、3パターン程度述べてください。", # noqa: E501 }, ], ] diff --git a/tests/model_serving/model_runtime/vllm/multimodal/test_granite_31_2b_vision.py b/tests/model_serving/model_runtime/vllm/multimodal/test_granite_31_2b_vision.py index 5be9d3b59..704dbe48b 100644 --- a/tests/model_serving/model_runtime/vllm/multimodal/test_granite_31_2b_vision.py +++ b/tests/model_serving/model_runtime/vllm/multimodal/test_granite_31_2b_vision.py @@ -1,19 +1,22 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import List, Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType, Ports +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import MULTI_IMAGE_QUERIES, OPENAI_ENDPOINT_NAME, THREE_IMAGE_QUERY from tests.model_serving.model_runtime.vllm.utils import ( run_raw_inference, validate_inference_output, ) -from tests.model_serving.model_runtime.vllm.constant import OPENAI_ENDPOINT_NAME, MULTI_IMAGE_QUERIES, THREE_IMAGE_QUERY +from utilities.constants import KServeDeploymentType, Ports LOGGER = get_logger(name=__name__) -SERVING_ARGUMENT: List[str] = ["--model=/mnt/models", "--uvicorn-log-level=debug", '--limit-mm-per-prompt={"image": 2}'] +SERVING_ARGUMENT: list[str] = ["--model=/mnt/models", "--uvicorn-log-level=debug", '--limit-mm-per-prompt={"image": 2}'] MODEL_PATH: str = "ibm-granite/granite-vision-3.1-2b-preview" diff --git a/tests/model_serving/model_runtime/vllm/quantization/test_openhermes-2_5_mistral-7b_awq.py b/tests/model_serving/model_runtime/vllm/quantization/test_openhermes-2_5_mistral-7b_awq.py index 66e9dc6d2..9262bef19 100644 --- a/tests/model_serving/model_runtime/vllm/quantization/test_openhermes-2_5_mistral-7b_awq.py +++ b/tests/model_serving/model_runtime/vllm/quantization/test_openhermes-2_5_mistral-7b_awq.py @@ -1,11 +1,13 @@ +# noqa: N999 import pytest from simple_logger.logger import get_logger -from utilities.constants import KServeDeploymentType, Ports + +from tests.model_serving.model_runtime.vllm.constant import VLLM_SUPPORTED_QUANTIZATION from tests.model_serving.model_runtime.vllm.utils import ( run_raw_inference, validate_inference_output, ) -from tests.model_serving.model_runtime.vllm.constant import VLLM_SUPPORTED_QUANTIZATION +from utilities.constants import KServeDeploymentType, Ports LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_draft.py b/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_draft.py index 67a10c6da..65c7c33dc 100644 --- a/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_draft.py +++ b/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_draft.py @@ -1,19 +1,22 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import List, Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType, Ports +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import OPENAI_ENDPOINT_NAME, TGIS_ENDPOINT_NAME from tests.model_serving.model_runtime.vllm.utils import ( run_raw_inference, validate_inference_output, ) -from tests.model_serving.model_runtime.vllm.constant import OPENAI_ENDPOINT_NAME, TGIS_ENDPOINT_NAME +from utilities.constants import KServeDeploymentType, Ports LOGGER = get_logger(name=__name__) TIMEOUT_20MIN: str = 20 * 60 -SERVING_ARGUMENT: List[str] = [ +SERVING_ARGUMENT: list[str] = [ "--model=/mnt/models/granite-7b-instruct", "--uvicorn-log-level=debug", "--dtype=float16", diff --git a/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_ngram.py b/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_ngram.py index f6c25df1f..ecdc90aef 100644 --- a/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_ngram.py +++ b/tests/model_serving/model_runtime/vllm/speculative_decoding/test_granite_7b_lab_ngram.py @@ -1,19 +1,22 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import List, Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType, Ports +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import OPENAI_ENDPOINT_NAME, TGIS_ENDPOINT_NAME from tests.model_serving.model_runtime.vllm.utils import ( run_raw_inference, validate_inference_output, ) -from tests.model_serving.model_runtime.vllm.constant import OPENAI_ENDPOINT_NAME, TGIS_ENDPOINT_NAME +from utilities.constants import KServeDeploymentType, Ports LOGGER = get_logger(name=__name__) -SERVING_ARGUMENT: List[str] = [ +SERVING_ARGUMENT: list[str] = [ "--model=/mnt/models", "--uvicorn-log-level=debug", "--dtype=float16", diff --git a/tests/model_serving/model_runtime/vllm/toolcalling/test_granite_3_2_8b_instruct_preview.py b/tests/model_serving/model_runtime/vllm/toolcalling/test_granite_3_2_8b_instruct_preview.py index 9c0d4ce3a..9e291f267 100644 --- a/tests/model_serving/model_runtime/vllm/toolcalling/test_granite_3_2_8b_instruct_preview.py +++ b/tests/model_serving/model_runtime/vllm/toolcalling/test_granite_3_2_8b_instruct_preview.py @@ -1,21 +1,24 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod -from utilities.constants import KServeDeploymentType -from tests.model_serving.model_runtime.vllm.utils import ( - validate_raw_openai_inference_request, - validate_raw_tgis_inference_request, -) +from simple_logger.logger import get_logger + from tests.model_serving.model_runtime.vllm.constant import ( - LIGHTSPEED_TOOL_QUERY, + COMPLETION_QUERY, LIGHTSPEED_TOOL, + LIGHTSPEED_TOOL_QUERY, + MATH_CHAT_QUERY, WEATHER_TOOL, WEATHER_TOOL_QUERY, - MATH_CHAT_QUERY, - COMPLETION_QUERY, ) +from tests.model_serving.model_runtime.vllm.utils import ( + validate_raw_openai_inference_request, + validate_raw_tgis_inference_request, +) +from utilities.constants import KServeDeploymentType LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_runtime/vllm/utils.py b/tests/model_serving/model_runtime/vllm/utils.py index 4efde5201..60cd5f402 100644 --- a/tests/model_serving/model_runtime/vllm/utils.py +++ b/tests/model_serving/model_runtime/vllm/utils.py @@ -1,24 +1,27 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import Generator, Any +from typing import Any + +import portforward +import pytest from kubernetes.dynamic import DynamicClient -from ocp_resources.secret import Secret from ocp_resources.inference_service import InferenceService +from ocp_resources.secret import Secret from simple_logger.logger import get_logger -from tests.model_serving.model_runtime.vllm.constant import CHAT_QUERY, COMPLETION_QUERY from tenacity import retry, stop_after_attempt, wait_exponential +from tests.model_serving.model_runtime.vllm.constant import ( + CHAT_QUERY, + COMPLETION_QUERY, + OPENAI_ENDPOINT_NAME, + TGIS_ENDPOINT_NAME, + VLLM_SUPPORTED_QUANTIZATION, +) from utilities.constants import Ports from utilities.exceptions import NotSupportedError from utilities.plugins.constant import OpenAIEnpoints from utilities.plugins.openai_plugin import OpenAIClient from utilities.plugins.tgis_grpc_plugin import TGISGRPCPlugin -from tests.model_serving.model_runtime.vllm.constant import VLLM_SUPPORTED_QUANTIZATION -from tests.model_serving.model_runtime.vllm.constant import ( - OPENAI_ENDPOINT_NAME, - TGIS_ENDPOINT_NAME, -) -import portforward -import pytest LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_server/conftest.py b/tests/model_serving/model_server/conftest.py index d5e0dad98..2cde06a35 100644 --- a/tests/model_serving/model_server/conftest.py +++ b/tests/model_serving/model_server/conftest.py @@ -1,11 +1,15 @@ +from collections.abc import Generator from contextlib import ExitStack -from typing import Any, Generator, Dict +from typing import Any import pytest import yaml from _pytest.fixtures import FixtureRequest from kubernetes.dynamic import DynamicClient +from kubernetes.dynamic.exceptions import ResourceNotFoundError +from ocp_resources.cluster_service_version import ClusterServiceVersion from ocp_resources.config_map import ConfigMap +from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.inference_service import InferenceService from ocp_resources.namespace import Namespace from ocp_resources.persistent_volume_claim import PersistentVolumeClaim @@ -14,31 +18,17 @@ from ocp_resources.service_account import ServiceAccount from ocp_resources.serving_runtime import ServingRuntime from ocp_resources.storage_class import StorageClass -from ocp_resources.data_science_cluster import DataScienceCluster -from ocp_resources.cluster_service_version import ClusterServiceVersion -from kubernetes.dynamic.exceptions import ResourceNotFoundError -from utilities.kueue_utils import ( - create_local_queue, - create_cluster_queue, - create_resource_flavor, - LocalQueue, - ClusterQueue, - ResourceFlavor, - wait_for_kueue_crds_available, -) from pytest_testconfig import config as py_config from simple_logger.logger import get_logger from utilities.constants import ( + DscComponents, KServeDeploymentType, Labels, + ModelAndFormat, ModelFormat, RuntimeTemplates, StorageClassName, - DscComponents, -) -from utilities.constants import ( - ModelAndFormat, ) from utilities.data_science_cluster_utils import ( get_dsc_ready_condition, @@ -49,6 +39,15 @@ s3_endpoint_secret, update_configmap_data, ) +from utilities.kueue_utils import ( + ClusterQueue, + LocalQueue, + ResourceFlavor, + create_cluster_queue, + create_local_queue, + create_resource_flavor, + wait_for_kueue_crds_available, +) from utilities.serving_runtime import ServingRuntimeFromTemplate LOGGER = get_logger(name=__name__) @@ -328,7 +327,7 @@ def ovms_raw_inference_service( @pytest.fixture(scope="class") def user_workload_monitoring_config_map( admin_client: DynamicClient, cluster_monitoring_config: ConfigMap -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: uwm_namespace = "openshift-user-workload-monitoring" data = { @@ -393,6 +392,7 @@ def gpu_model_car_inference_service( ) -> Generator[InferenceService, Any, Any]: """Create a GPU-accelerated model car inference service.""" from copy import deepcopy + from tests.model_serving.model_runtime.openvino.constant import PREDICT_RESOURCES deployment_mode = request.param.get("deployment-mode", KServeDeploymentType.RAW_DEPLOYMENT) @@ -436,7 +436,7 @@ def _is_kueue_operator_installed(admin_client: DynamicClient) -> bool: if csv.name.startswith("kueue") and csv.status == csv.Status.SUCCEEDED: LOGGER.info(f"Found Kueue operator CSV: {csv.name}") return True - return False + return False # noqa: TRY300 except ResourceNotFoundError: return False @@ -444,7 +444,7 @@ def _is_kueue_operator_installed(admin_client: DynamicClient) -> bool: @pytest.fixture(scope="session") def ensure_kueue_unmanaged_in_dsc( admin_client: DynamicClient, dsc_resource: DataScienceCluster -) -> Generator[None, Any, None]: +) -> Generator[None, Any]: """Set DSC Kueue to Unmanaged and wait for CRDs to be available.""" try: if not _is_kueue_operator_installed(admin_client): @@ -487,7 +487,7 @@ def kueue_resource_groups( flavor_name: str, cpu_quota: int, memory_quota: str, -) -> list[Dict[str, Any]]: +) -> list[dict[str, Any]]: return [ { "coveredResources": ["cpu", "memory"], @@ -509,7 +509,7 @@ def kueue_cluster_queue_from_template( request: FixtureRequest, admin_client: DynamicClient, ensure_kueue_unmanaged_in_dsc, -) -> Generator[ClusterQueue, Any, None]: +) -> Generator[ClusterQueue, Any]: if request.param.get("name") is None: raise ValueError("name is required") with create_cluster_queue( @@ -528,7 +528,7 @@ def kueue_resource_flavor_from_template( request: FixtureRequest, admin_client: DynamicClient, ensure_kueue_unmanaged_in_dsc, -) -> Generator[ResourceFlavor, Any, None]: +) -> Generator[ResourceFlavor, Any]: if request.param.get("name") is None: raise ValueError("name is required") with create_resource_flavor( @@ -544,7 +544,7 @@ def kueue_local_queue_from_template( unprivileged_model_namespace: Namespace, admin_client: DynamicClient, ensure_kueue_unmanaged_in_dsc, -) -> Generator[LocalQueue, Any, None]: +) -> Generator[LocalQueue, Any]: if request.param.get("name") is None: raise ValueError("name is required") if request.param.get("cluster_queue") is None: diff --git a/tests/model_serving/model_server/kserve/authentication/conftest.py b/tests/model_serving/model_server/kserve/authentication/conftest.py index ebe0ddabb..171acb32e 100644 --- a/tests/model_serving/model_server/kserve/authentication/conftest.py +++ b/tests/model_serving/model_server/kserve/authentication/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any from urllib.parse import urlparse import pytest @@ -7,29 +8,29 @@ from ocp_resources.inference_service import InferenceService from ocp_resources.namespace import Namespace from ocp_resources.resource import ResourceEditor -from ocp_resources.role_binding import RoleBinding from ocp_resources.role import Role +from ocp_resources.role_binding import RoleBinding from ocp_resources.secret import Secret from ocp_resources.service_account import ServiceAccount from ocp_resources.serving_runtime import ServingRuntime -from utilities.inference_utils import create_isvc -from utilities.infra import ( - create_isvc_view_role, - get_pods_by_isvc_label, - create_inference_token, -) from utilities.constants import ( + Annotations, KServeDeploymentType, ModelFormat, ModelName, Protocols, RuntimeTemplates, ) +from utilities.inference_utils import create_isvc +from utilities.infra import ( + create_inference_token, + create_isvc_view_role, + get_pods_by_isvc_label, +) from utilities.jira import is_jira_open from utilities.logger import RedactedString from utilities.serving_runtime import ServingRuntimeFromTemplate -from utilities.constants import Annotations # HTTP/REST model serving diff --git a/tests/model_serving/model_server/kserve/authentication/test_kserve_token_authentication_raw.py b/tests/model_serving/model_server/kserve/authentication/test_kserve_token_authentication_raw.py index 3bef0517e..c52ba1a33 100644 --- a/tests/model_serving/model_server/kserve/authentication/test_kserve_token_authentication_raw.py +++ b/tests/model_serving/model_server/kserve/authentication/test_kserve_token_authentication_raw.py @@ -2,8 +2,7 @@ from ocp_resources.resource import ResourceEditor from tests.model_serving.model_server.utils import verify_inference_response -from utilities.constants import Protocols -from utilities.constants import Annotations +from utilities.constants import Annotations, Protocols from utilities.inference_utils import Inference, UserInference from utilities.infra import check_pod_status_in_time, get_pods_by_isvc_label from utilities.jira import is_jira_open diff --git a/tests/model_serving/model_server/kserve/components/conftest.py b/tests/model_serving/model_server/kserve/components/conftest.py index a426a92ca..14a9c1c66 100644 --- a/tests/model_serving/model_server/kserve/components/conftest.py +++ b/tests/model_serving/model_server/kserve/components/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest diff --git a/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/conftest.py b/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/conftest.py index 049e09d1c..27e3f3e31 100644 --- a/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/conftest.py +++ b/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest diff --git a/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/test_kserve_dsc_default_deployment_mode.py b/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/test_kserve_dsc_default_deployment_mode.py index 77678bb28..155d92d81 100644 --- a/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/test_kserve_dsc_default_deployment_mode.py +++ b/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/test_kserve_dsc_default_deployment_mode.py @@ -35,7 +35,7 @@ {"name": "dsc-raw"}, RUNTIME_PARAMS, { - **{"name": f"{ModelFormat.OPENVINO}-{KServeDeploymentType.RAW_DEPLOYMENT.lower()}"}, + "name": f"{ModelFormat.OPENVINO}-{KServeDeploymentType.RAW_DEPLOYMENT.lower()}", **INFERENCE_SERVICE_PARAMS, }, ) diff --git a/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/utils.py b/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/utils.py index 75888eebc..c2f4eb393 100644 --- a/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/utils.py +++ b/tests/model_serving/model_server/kserve/components/kserve_dsc_deployment_mode/utils.py @@ -1,5 +1,6 @@ import json -from typing import Any, Generator +from collections.abc import Generator +from typing import Any from ocp_resources.config_map import ConfigMap from ocp_resources.data_science_cluster import DataScienceCluster diff --git a/tests/model_serving/model_server/kserve/inference_graph/conftest.py b/tests/model_serving/model_server/kserve/inference_graph/conftest.py index f94bf1d56..bba062e35 100644 --- a/tests/model_serving/model_server/kserve/inference_graph/conftest.py +++ b/tests/model_serving/model_server/kserve/inference_graph/conftest.py @@ -1,7 +1,8 @@ import logging import time +from collections.abc import Generator from secrets import token_hex -from typing import Generator, Any +from typing import Any import pytest from _pytest.fixtures import FixtureRequest @@ -18,11 +19,11 @@ from ocp_resources.serving_runtime import ServingRuntime from pytest_testconfig import config as py_config -from utilities.constants import ModelFormat, KServeDeploymentType, ModelStoragePath, Annotations, Labels +from utilities.constants import Annotations, KServeDeploymentType, Labels, ModelFormat, ModelStoragePath from utilities.inference_utils import create_isvc from utilities.infra import ( - create_inference_token, create_inference_graph_view_role, + create_inference_token, get_services_by_isvc_label, ) @@ -79,7 +80,7 @@ def kserve_raw_headless_service_config( else: logger.warning(msg="No KServe controller deployment found") logger.info(msg="Waiting for KServe controller to process configuration change...") - time.sleep(60) # noqa + time.sleep(60) yield dsc_resource @@ -226,13 +227,14 @@ def service_account_with_access( dog_breed_inference_graph: InferenceGraph, bare_service_account: ServiceAccount, ) -> Generator[ServiceAccount, Any, Any]: - with create_inference_graph_view_role( - client=admin_client, - name=f"{dog_breed_inference_graph.name}-view", - namespace=unprivileged_model_namespace.name, - resource_names=[dog_breed_inference_graph.name], - ) as role: - with RoleBinding( + with ( + create_inference_graph_view_role( + client=admin_client, + name=f"{dog_breed_inference_graph.name}-view", + namespace=unprivileged_model_namespace.name, + resource_names=[dog_breed_inference_graph.name], + ) as role, + RoleBinding( client=admin_client, namespace=unprivileged_model_namespace.name, name=f"{bare_service_account.name}-view", @@ -240,8 +242,9 @@ def service_account_with_access( role_ref_kind=role.kind, subjects_kind=bare_service_account.kind, subjects_name=bare_service_account.name, - ): - yield bare_service_account + ), + ): + yield bare_service_account @pytest.fixture diff --git a/tests/model_serving/model_server/kserve/inference_graph/test_inference_graph_raw.py b/tests/model_serving/model_server/kserve/inference_graph/test_inference_graph_raw.py index 71e616d60..69b04b88b 100644 --- a/tests/model_serving/model_server/kserve/inference_graph/test_inference_graph_raw.py +++ b/tests/model_serving/model_server/kserve/inference_graph/test_inference_graph_raw.py @@ -1,8 +1,8 @@ import pytest from tests.model_serving.model_server.utils import verify_inference_response +from utilities.constants import KServeDeploymentType, ModelInferenceRuntime, Protocols from utilities.inference_utils import Inference -from utilities.constants import ModelInferenceRuntime, Protocols, KServeDeploymentType from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG diff --git a/tests/model_serving/model_server/kserve/inference_service_configuration/conftest.py b/tests/model_serving/model_server/kserve/inference_service_configuration/conftest.py index 2a4a2be7b..320cb5722 100644 --- a/tests/model_serving/model_server/kserve/inference_service_configuration/conftest.py +++ b/tests/model_serving/model_server/kserve/inference_service_configuration/conftest.py @@ -1,23 +1,23 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient from ocp_resources.inference_service import InferenceService -from ocp_resources.pod import Pod from ocp_resources.namespace import Namespace +from ocp_resources.pod import Pod from ocp_resources.serving_runtime import ServingRuntime -from utilities.inference_utils import create_isvc -from utilities.constants import KServeDeploymentType - from tests.model_serving.model_server.kserve.inference_service_configuration.constants import ( ISVC_ENV_VARS, - UPDATED_PULL_SECRET, ORIGINAL_PULL_SECRET, + UPDATED_PULL_SECRET, ) from tests.model_serving.model_server.kserve.inference_service_configuration.utils import ( update_inference_service, ) +from utilities.constants import KServeDeploymentType +from utilities.inference_utils import create_isvc from utilities.infra import get_pods_by_isvc_label diff --git a/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_pull_secret_updates.py b/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_pull_secret_updates.py index 5e0822284..c75d93916 100644 --- a/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_pull_secret_updates.py +++ b/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_pull_secret_updates.py @@ -1,10 +1,10 @@ import pytest -from tests.model_serving.model_server.kserve.inference_service_configuration.utils import verify_pull_secret from tests.model_serving.model_server.kserve.inference_service_configuration.constants import ( ORIGINAL_PULL_SECRET, UPDATED_PULL_SECRET, ) +from tests.model_serving.model_server.kserve.inference_service_configuration.utils import verify_pull_secret from utilities.constants import ModelName, ModelStorage, RuntimeTemplates diff --git a/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_replicas_update.py b/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_replicas_update.py index 22cdd53c5..3001a7d14 100644 --- a/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_replicas_update.py +++ b/tests/model_serving/model_server/kserve/inference_service_configuration/test_isvc_replicas_update.py @@ -16,7 +16,6 @@ from utilities.infra import get_pods_by_isvc_label from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG - LOGGER = get_logger(name=__name__) pytestmark = [pytest.mark.sanity, pytest.mark.usefixtures("valid_aws_config")] diff --git a/tests/model_serving/model_server/kserve/inference_service_configuration/utils.py b/tests/model_serving/model_server/kserve/inference_service_configuration/utils.py index 93cd9ee60..144252e04 100644 --- a/tests/model_serving/model_server/kserve/inference_service_configuration/utils.py +++ b/tests/model_serving/model_server/kserve/inference_service_configuration/utils.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator +from typing import Any from kubernetes.dynamic import DynamicClient from ocp_resources.inference_service import InferenceService @@ -17,7 +18,7 @@ @contextmanager def update_inference_service( client: DynamicClient, isvc: InferenceService, isvc_updated_dict: dict[str, Any], wait_for_new_pods: bool = True -) -> Generator[InferenceService, Any, None]: +) -> Generator[InferenceService, Any]: """ Update InferenceService object. @@ -101,9 +102,12 @@ def wait_for_new_running_inference_pods( client=isvc.client, isvc=isvc, ): - if pods and len(pods) == expected_num_pods: - if all(pod.name not in oring_pods_names and pod.status == pod.Status.RUNNING for pod in pods): - return + if ( + pods + and len(pods) == expected_num_pods + and all(pod.name not in oring_pods_names and pod.status == pod.Status.RUNNING for pod in pods) + ): + return except TimeoutError: LOGGER.error(f"Timeout waiting for pods {oring_pods_names} to be replaced") diff --git a/tests/model_serving/model_server/kserve/keda/conftest.py b/tests/model_serving/model_server/kserve/keda/conftest.py index 983facd6f..9bc2303ee 100644 --- a/tests/model_serving/model_server/kserve/keda/conftest.py +++ b/tests/model_serving/model_server/kserve/keda/conftest.py @@ -1,5 +1,6 @@ -from typing import Any, Generator import threading +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest @@ -10,24 +11,24 @@ from ocp_resources.service_account import ServiceAccount from ocp_resources.serving_runtime import ServingRuntime from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.constant import ACCELERATOR_IDENTIFIER, PREDICT_RESOURCES, TEMPLATE_MAP from tests.model_serving.model_runtime.vllm.utils import ( - validate_supported_quantization_schema, kserve_s3_endpoint_secret, + validate_supported_quantization_schema, +) +from tests.model_serving.model_server.utils import ( + run_concurrent_load_for_keda_scaling, ) -from tests.model_serving.model_runtime.vllm.constant import ACCELERATOR_IDENTIFIER, PREDICT_RESOURCES, TEMPLATE_MAP -from utilities.manifests.vllm import VLLM_INFERENCE_CONFIG - from utilities.constants import ( + THANOS_QUERIER_ADDRESS, KServeDeploymentType, - RuntimeTemplates, Labels, ModelAndFormat, - THANOS_QUERIER_ADDRESS, -) -from tests.model_serving.model_server.utils import ( - run_concurrent_load_for_keda_scaling, + RuntimeTemplates, ) from utilities.inference_utils import create_isvc +from utilities.manifests.vllm import VLLM_INFERENCE_CONFIG from utilities.serving_runtime import ServingRuntimeFromTemplate LOGGER = get_logger(name=__name__) @@ -41,7 +42,7 @@ def keda_endpoint_s3_secret( aws_secret_access_key: str, models_s3_bucket_region: str, models_s3_bucket_endpoint: str, -) -> Generator[Secret, None, None]: +) -> Generator[Secret]: """Create S3 endpoint secret for KEDA GPU tests using model_namespace.""" with kserve_s3_endpoint_secret( admin_client=admin_client, @@ -117,7 +118,7 @@ def vllm_cuda_serving_runtime( model_namespace: Namespace, supported_accelerator_type: str, vllm_runtime_image: str, -) -> Generator[ServingRuntime, None, None]: +) -> Generator[ServingRuntime]: template_name = TEMPLATE_MAP.get(supported_accelerator_type.lower(), RuntimeTemplates.VLLM_CUDA) with ServingRuntimeFromTemplate( client=admin_client, @@ -167,11 +168,7 @@ def stressed_keda_vllm_inference_service( isvc_kwargs["volumes"] = PREDICT_RESOURCES["volumes"] isvc_kwargs["volumes_mounts"] = PREDICT_RESOURCES["volume_mounts"] if arguments := request.param.get("runtime_argument"): - arguments = [ - arg - for arg in arguments - if not (arg.startswith("--tensor-parallel-size") or arg.startswith("--quantization")) - ] + arguments = [arg for arg in arguments if not arg.startswith(("--tensor-parallel-size", "--quantization"))] arguments.append(f"--tensor-parallel-size={gpu_count}") if quantization := request.param.get("quantization"): validate_supported_quantization_schema(q_type=quantization) diff --git a/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_cpu.py b/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_cpu.py index 25366998f..e624fd936 100644 --- a/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_cpu.py +++ b/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_cpu.py @@ -1,22 +1,25 @@ +from collections.abc import Generator +from typing import Any + import pytest -from ocp_resources.resource import ResourceEditor -from simple_logger.logger import get_logger -from typing import Any, Generator from kubernetes.dynamic import DynamicClient -from ocp_resources.namespace import Namespace from ocp_resources.inference_service import InferenceService +from ocp_resources.namespace import Namespace +from ocp_resources.resource import ResourceEditor +from simple_logger.logger import get_logger + +from tests.model_serving.model_runtime.vllm.basic_model_deployment.test_granite_7b_starter import SERVING_ARGUMENT +from tests.model_serving.model_runtime.vllm.constant import BASE_RAW_DEPLOYMENT_CONFIG from tests.model_serving.model_server.utils import ( - verify_keda_scaledobject, - verify_final_pod_count, run_inference_multiple_times, + verify_final_pod_count, + verify_keda_scaledobject, ) -from tests.model_serving.model_runtime.vllm.constant import BASE_RAW_DEPLOYMENT_CONFIG -from tests.model_serving.model_runtime.vllm.basic_model_deployment.test_granite_7b_starter import SERVING_ARGUMENT -from utilities.constants import ModelFormat, ModelVersion, RunTimeConfigs, Protocols, Timeout -from utilities.monitoring import validate_metrics_field +from utilities.constants import ModelFormat, ModelVersion, Protocols, RunTimeConfigs, Timeout from utilities.inference_utils import Inference -from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG from utilities.jira import is_jira_open +from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG +from utilities.monitoring import validate_metrics_field LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_gpu.py b/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_gpu.py index d1f0bc45d..3ec6a0971 100644 --- a/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_gpu.py +++ b/tests/model_serving/model_server/kserve/keda/test_isvc_keda_scaling_gpu.py @@ -1,12 +1,15 @@ +from collections.abc import Generator +from typing import Any + import pytest -from simple_logger.logger import get_logger -from typing import Any, Generator from kubernetes.dynamic import DynamicClient -from ocp_resources.namespace import Namespace from ocp_resources.inference_service import InferenceService -from utilities.constants import KServeDeploymentType -from tests.model_serving.model_server.utils import verify_keda_scaledobject, verify_final_pod_count +from ocp_resources.namespace import Namespace +from simple_logger.logger import get_logger + from tests.model_serving.model_runtime.vllm.constant import BASE_RAW_DEPLOYMENT_CONFIG +from tests.model_serving.model_server.utils import verify_final_pod_count, verify_keda_scaledobject +from utilities.constants import KServeDeploymentType from utilities.monitoring import validate_metrics_field LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_server/kserve/kueue/conftest.py b/tests/model_serving/model_server/kserve/kueue/conftest.py index 1f113fdf5..1b067cb94 100644 --- a/tests/model_serving/model_server/kserve/kueue/conftest.py +++ b/tests/model_serving/model_server/kserve/kueue/conftest.py @@ -1,16 +1,17 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest from kubernetes.dynamic import DynamicClient -from utilities.constants import ModelAndFormat, KServeDeploymentType -from utilities.inference_utils import create_isvc -from utilities.serving_runtime import ServingRuntimeFromTemplate -from ocp_resources.secret import Secret from ocp_resources.inference_service import InferenceService -from ocp_resources.serving_runtime import ServingRuntime from ocp_resources.namespace import Namespace -from utilities.constants import RuntimeTemplates, ModelFormat +from ocp_resources.secret import Secret +from ocp_resources.serving_runtime import ServingRuntime + +from utilities.constants import KServeDeploymentType, ModelAndFormat, ModelFormat, RuntimeTemplates +from utilities.inference_utils import create_isvc +from utilities.serving_runtime import ServingRuntimeFromTemplate @pytest.fixture(scope="class") diff --git a/tests/model_serving/model_server/kserve/kueue/test_kueue_isvc_raw.py b/tests/model_serving/model_server/kserve/kueue/test_kueue_isvc_raw.py index 3fd8a157c..d08da0964 100644 --- a/tests/model_serving/model_server/kserve/kueue/test_kueue_isvc_raw.py +++ b/tests/model_serving/model_server/kserve/kueue/test_kueue_isvc_raw.py @@ -6,7 +6,8 @@ import pytest from ocp_resources.deployment import Deployment from timeout_sampler import TimeoutExpiredError, TimeoutSampler -from utilities.constants import RunTimeConfigs, KServeDeploymentType, ModelVersion, Labels + +from utilities.constants import KServeDeploymentType, Labels, ModelVersion, RunTimeConfigs from utilities.general import create_isvc_label_selector_str from utilities.kueue_utils import check_gated_pods_and_running_pods diff --git a/tests/model_serving/model_server/kserve/metrics/test_model_metrics.py b/tests/model_serving/model_server/kserve/metrics/test_model_metrics.py index e3e2f8c3f..45f245adc 100644 --- a/tests/model_serving/model_server/kserve/metrics/test_model_metrics.py +++ b/tests/model_serving/model_server/kserve/metrics/test_model_metrics.py @@ -1,4 +1,5 @@ import pytest +from timeout_sampler import TimeoutSampler from tests.model_serving.model_server.kserve.metrics.utils import validate_metrics_configuration from tests.model_serving.model_server.utils import ( @@ -12,7 +13,6 @@ Protocols, RuntimeTemplates, ) -from timeout_sampler import TimeoutSampler from utilities.inference_utils import Inference from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG from utilities.monitoring import get_metrics_value, validate_metrics_field diff --git a/tests/model_serving/model_server/kserve/model_car/test_gpu_ovms.py b/tests/model_serving/model_server/kserve/model_car/test_gpu_ovms.py index 528414598..233b6ec35 100644 --- a/tests/model_serving/model_server/kserve/model_car/test_gpu_ovms.py +++ b/tests/model_serving/model_server/kserve/model_car/test_gpu_ovms.py @@ -16,8 +16,8 @@ Protocols, RuntimeTemplates, ) -from utilities.infra import get_pods_by_isvc_label from utilities.inference_utils import Inference +from utilities.infra import get_pods_by_isvc_label from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG pytestmark = [ diff --git a/tests/model_serving/model_server/kserve/model_car/test_oci_image.py b/tests/model_serving/model_server/kserve/model_car/test_oci_image.py index c777e50ab..be9163df0 100644 --- a/tests/model_serving/model_server/kserve/model_car/test_oci_image.py +++ b/tests/model_serving/model_server/kserve/model_car/test_oci_image.py @@ -1,9 +1,9 @@ import pytest from tests.model_serving.model_server.utils import verify_inference_response -from utilities.infra import get_pods_by_isvc_label -from utilities.constants import ModelCarImage, ModelFormat, ModelName, Protocols, RuntimeTemplates, KServeDeploymentType +from utilities.constants import KServeDeploymentType, ModelCarImage, ModelFormat, ModelName, Protocols, RuntimeTemplates from utilities.inference_utils import Inference +from utilities.infra import get_pods_by_isvc_label from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG diff --git a/tests/model_serving/model_server/kserve/multi_node/conftest.py b/tests/model_serving/model_server/kserve/multi_node/conftest.py index 6eb6500b7..8765dadae 100644 --- a/tests/model_serving/model_server/kserve/multi_node/conftest.py +++ b/tests/model_serving/model_server/kserve/multi_node/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest @@ -17,7 +18,7 @@ from tests.model_serving.model_server.kserve.multi_node.utils import ( delete_multi_node_pod_by_role, ) -from utilities.constants import KServeDeploymentType, Labels, Protocols, Timeout, ModelCarImage +from utilities.constants import KServeDeploymentType, Labels, ModelCarImage, Protocols, Timeout from utilities.general import download_model_data from utilities.inference_utils import create_isvc from utilities.infra import ( @@ -30,7 +31,7 @@ @pytest.fixture(scope="session") def nvidia_gpu_nodes(nodes: list[Node]) -> list[Node]: - return [node for node in nodes if "nvidia.com/gpu.present" in node.labels.keys()] + return [node for node in nodes if "nvidia.com/gpu.present" in node.labels] @pytest.fixture(scope="session") @@ -233,7 +234,7 @@ def ray_tls_secret(unprivileged_client: DynamicClient, multi_node_inference_serv @pytest.fixture() def deleted_serving_runtime( multi_node_serving_runtime: ServingRuntime, -) -> Generator[None, Any, None]: +) -> Generator[None, Any]: multi_node_serving_runtime.clean_up() yield diff --git a/tests/model_serving/model_server/kserve/multi_node/test_oci_multi_node.py b/tests/model_serving/model_server/kserve/multi_node/test_oci_multi_node.py index a7a44db25..c8ac25978 100644 --- a/tests/model_serving/model_server/kserve/multi_node/test_oci_multi_node.py +++ b/tests/model_serving/model_server/kserve/multi_node/test_oci_multi_node.py @@ -2,8 +2,8 @@ from simple_logger.logger import get_logger from tests.model_serving.model_server.utils import verify_inference_response -from utilities.manifests.vllm import VLLM_INFERENCE_CONFIG from utilities.constants import Protocols +from utilities.manifests.vllm import VLLM_INFERENCE_CONFIG pytestmark = [ pytest.mark.rawdeployment, diff --git a/tests/model_serving/model_server/kserve/multi_node/utils.py b/tests/model_serving/model_server/kserve/multi_node/utils.py index 52c447f56..190ec865d 100644 --- a/tests/model_serving/model_server/kserve/multi_node/utils.py +++ b/tests/model_serving/model_server/kserve/multi_node/utils.py @@ -12,7 +12,6 @@ from utilities.constants import Timeout from utilities.infra import get_pods_by_isvc_label - LOGGER = get_logger(name=__name__) @@ -72,7 +71,7 @@ def verify_nvidia_gpu_status(pod: Pod) -> None: def delete_multi_node_pod_by_role(client: DynamicClient, isvc: InferenceService, role: str) -> None: - f""" + """ Delete multi node pod by role Worker pods have {WORKER_POD_ROLE} str in their name, head pod does not have an identifier in the name. @@ -89,10 +88,12 @@ def delete_multi_node_pod_by_role(client: DynamicClient, isvc: InferenceService, pods = get_pods_by_isvc_label(client=client, isvc=isvc) for pod in pods: - if role == WORKER_POD_ROLE and WORKER_POD_ROLE in pod.name: - pod.delete(wait=True) - - elif role == HEAD_POD_ROLE and WORKER_POD_ROLE not in pod.name: + if ( + role == WORKER_POD_ROLE + and WORKER_POD_ROLE in pod.name + or role == HEAD_POD_ROLE + and WORKER_POD_ROLE not in pod.name + ): pod.delete(wait=True) diff --git a/tests/model_serving/model_server/kserve/negative/conftest.py b/tests/model_serving/model_server/kserve/negative/conftest.py index 1aa8f6281..d5241f698 100644 --- a/tests/model_serving/model_server/kserve/negative/conftest.py +++ b/tests/model_serving/model_server/kserve/negative/conftest.py @@ -1,5 +1,5 @@ -from typing import Any, Generator - +from collections.abc import Generator +from typing import Any from urllib.parse import urlparse import pytest diff --git a/tests/model_serving/model_server/kserve/negative/test_invalid_inference_requests.py b/tests/model_serving/model_server/kserve/negative/test_invalid_inference_requests.py index fb79b32e8..00d721818 100644 --- a/tests/model_serving/model_server/kserve/negative/test_invalid_inference_requests.py +++ b/tests/model_serving/model_server/kserve/negative/test_invalid_inference_requests.py @@ -18,7 +18,6 @@ ) from utilities.infra import get_pods_by_isvc_label - pytestmark = pytest.mark.usefixtures("valid_aws_config") @@ -56,7 +55,7 @@ class TestUnsupportedContentType: - Model pod remains healthy (Running, no restarts) """ - VALID_INFERENCE_BODY: dict[str, Any] = { + VALID_INFERENCE_BODY: dict[str, Any] = { # noqa: RUF012 "inputs": [ { "name": "Input3", diff --git a/tests/model_serving/model_server/kserve/private_endpoint/conftest.py b/tests/model_serving/model_server/kserve/private_endpoint/conftest.py index ad0cccf71..292873a80 100644 --- a/tests/model_serving/model_server/kserve/private_endpoint/conftest.py +++ b/tests/model_serving/model_server/kserve/private_endpoint/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient diff --git a/tests/model_serving/model_server/kserve/private_endpoint/test_kserve_private_endpoint.py b/tests/model_serving/model_server/kserve/private_endpoint/test_kserve_private_endpoint.py index 162981a69..ce5aeef42 100644 --- a/tests/model_serving/model_server/kserve/private_endpoint/test_kserve_private_endpoint.py +++ b/tests/model_serving/model_server/kserve/private_endpoint/test_kserve_private_endpoint.py @@ -1,9 +1,10 @@ from typing import Self import pytest -from simple_logger.logger import get_logger from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod +from simple_logger.logger import get_logger + from tests.model_serving.model_server.kserve.private_endpoint.utils import curl_from_pod from utilities.constants import CurlOutput, ModelEndpoint, Protocols, RuntimeTemplates diff --git a/tests/model_serving/model_server/kserve/private_endpoint/utils.py b/tests/model_serving/model_server/kserve/private_endpoint/utils.py index e2f0f8ec3..67bd49761 100644 --- a/tests/model_serving/model_server/kserve/private_endpoint/utils.py +++ b/tests/model_serving/model_server/kserve/private_endpoint/utils.py @@ -1,11 +1,12 @@ import shlex -from typing import Any, Generator -from urllib.parse import urlparse +from collections.abc import Generator from contextlib import contextmanager +from typing import Any +from urllib.parse import urlparse -from ocp_resources.pod import Pod from kubernetes.dynamic.client import DynamicClient from ocp_resources.inference_service import InferenceService +from ocp_resources.pod import Pod from simple_logger.logger import get_logger from utilities.constants import Protocols diff --git a/tests/model_serving/model_server/kserve/raw_deployment/test_kserve_raw_routes_reconciliation.py b/tests/model_serving/model_server/kserve/raw_deployment/test_kserve_raw_routes_reconciliation.py index 394f52a53..5302fde6f 100644 --- a/tests/model_serving/model_server/kserve/raw_deployment/test_kserve_raw_routes_reconciliation.py +++ b/tests/model_serving/model_server/kserve/raw_deployment/test_kserve_raw_routes_reconciliation.py @@ -1,12 +1,11 @@ import pytest -from tests.model_serving.model_server.utils import verify_inference_response from tests.model_serving.model_server.kserve.raw_deployment.utils import assert_ingress_status_changed +from tests.model_serving.model_server.utils import verify_inference_response from utilities.constants import ModelFormat, ModelVersion, Protocols, RunTimeConfigs from utilities.inference_utils import Inference from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG - pytestmark = [pytest.mark.rawdeployment, pytest.mark.usefixtures("valid_aws_config")] diff --git a/tests/model_serving/model_server/kserve/raw_deployment/utils.py b/tests/model_serving/model_server/kserve/raw_deployment/utils.py index 84a7671af..4a3c2a72d 100644 --- a/tests/model_serving/model_server/kserve/raw_deployment/utils.py +++ b/tests/model_serving/model_server/kserve/raw_deployment/utils.py @@ -1,6 +1,7 @@ from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError from ocp_resources.inference_service import InferenceService + from utilities.constants import Timeout from utilities.infra import get_model_route diff --git a/tests/model_serving/model_server/kserve/routes/conftest.py b/tests/model_serving/model_server/kserve/routes/conftest.py index ea8f4de17..6824589b2 100644 --- a/tests/model_serving/model_server/kserve/routes/conftest.py +++ b/tests/model_serving/model_server/kserve/routes/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest @@ -48,11 +49,8 @@ def patched_s3_caikit_kserve_isvc_visibility_label( sleep=1, func=lambda: s3_models_inference_service.instance.status.url, ): - if sample: - if visibility == Labels.Kserve.EXPOSED and isvc_orig_url == sample: - break - - elif sample != isvc_orig_url: + if sample: # noqa: SIM102 + if visibility == Labels.Kserve.EXPOSED and isvc_orig_url == sample or sample != isvc_orig_url: break yield s3_models_inference_service diff --git a/tests/model_serving/model_server/kserve/routes/test_raw_deployment.py b/tests/model_serving/model_server/kserve/routes/test_raw_deployment.py index 52be8d9ba..cdd85220d 100644 --- a/tests/model_serving/model_server/kserve/routes/test_raw_deployment.py +++ b/tests/model_serving/model_server/kserve/routes/test_raw_deployment.py @@ -6,10 +6,10 @@ KServeDeploymentType, Labels, ModelFormat, + ModelInferenceRuntime, ModelStoragePath, OpenshiftRouteTimeout, Protocols, - ModelInferenceRuntime, RuntimeTemplates, ) from utilities.exceptions import ( diff --git a/tests/model_serving/model_server/kserve/stop_resume/conftest.py b/tests/model_serving/model_server/kserve/stop_resume/conftest.py index 8674d7c90..38576384f 100644 --- a/tests/model_serving/model_server/kserve/stop_resume/conftest.py +++ b/tests/model_serving/model_server/kserve/stop_resume/conftest.py @@ -1,8 +1,10 @@ -from typing import Generator, Any +from collections.abc import Generator +from typing import Any import pytest from ocp_resources.inference_service import InferenceService from ocp_resources.resource import ResourceEditor + from utilities.constants import Annotations diff --git a/tests/model_serving/model_server/kserve/stop_resume/test_raw_stop_resume_model.py b/tests/model_serving/model_server/kserve/stop_resume/test_raw_stop_resume_model.py index 5fc9e4965..d86cf16e2 100644 --- a/tests/model_serving/model_server/kserve/stop_resume/test_raw_stop_resume_model.py +++ b/tests/model_serving/model_server/kserve/stop_resume/test_raw_stop_resume_model.py @@ -1,5 +1,6 @@ import pytest +from tests.model_serving.model_server.kserve.stop_resume.utils import consistently_verify_no_pods_exist from tests.model_serving.model_server.utils import verify_inference_response from utilities.constants import ( ModelFormat, @@ -9,7 +10,6 @@ ) from utilities.inference_utils import Inference from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG -from tests.model_serving.model_server.kserve.stop_resume.utils import consistently_verify_no_pods_exist pytestmark = [pytest.mark.usefixtures("valid_aws_config")] diff --git a/tests/model_serving/model_server/kserve/stop_resume/utils.py b/tests/model_serving/model_server/kserve/stop_resume/utils.py index dcdcd67f0..7c0d2e0db 100644 --- a/tests/model_serving/model_server/kserve/stop_resume/utils.py +++ b/tests/model_serving/model_server/kserve/stop_resume/utils.py @@ -1,6 +1,7 @@ """Utilities for stop/resume model testing.""" import time + from kubernetes.dynamic.client import DynamicClient from ocp_resources.inference_service import InferenceService from timeout_sampler import TimeoutExpiredError @@ -32,7 +33,7 @@ def consistently_verify_no_pods_exist( # Nested timeout samplers can cause false negatives if the internal sampler has # a timeout that is greater than the external sampler. # So we iterate and sleep here instead. - time.sleep(interval) # noqa: FCN001 + time.sleep(interval) except TimeoutExpiredError: return False return True diff --git a/tests/model_serving/model_server/kserve/storage/constants.py b/tests/model_serving/model_server/kserve/storage/constants.py index 6b8760b8a..fa8402b30 100644 --- a/tests/model_serving/model_server/kserve/storage/constants.py +++ b/tests/model_serving/model_server/kserve/storage/constants.py @@ -2,7 +2,6 @@ from utilities.constants import ModelFormat - KSERVE_OVMS_SERVING_RUNTIME_PARAMS: dict[str, Any] = { "name": "ovms-runtime", "template-name": "kserve-ovms", diff --git a/tests/model_serving/model_server/kserve/storage/minio/conftest.py b/tests/model_serving/model_server/kserve/storage/minio/conftest.py index b07b60287..5708f0953 100644 --- a/tests/model_serving/model_server/kserve/storage/minio/conftest.py +++ b/tests/model_serving/model_server/kserve/storage/minio/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from _pytest.fixtures import FixtureRequest diff --git a/tests/model_serving/model_server/kserve/storage/pvc/conftest.py b/tests/model_serving/model_server/kserve/storage/pvc/conftest.py index b7e4d4739..d8316c88f 100644 --- a/tests/model_serving/model_server/kserve/storage/pvc/conftest.py +++ b/tests/model_serving/model_server/kserve/storage/pvc/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from kubernetes.dynamic import DynamicClient diff --git a/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_rwx.py b/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_rwx.py index bf621fab4..8dd0b80a6 100644 --- a/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_rwx.py +++ b/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_rwx.py @@ -1,5 +1,4 @@ import shlex -from utilities.constants import Containers, KServeDeploymentType, StorageClassName import pytest @@ -7,6 +6,7 @@ INFERENCE_SERVICE_PARAMS, KSERVE_OVMS_SERVING_RUNTIME_PARAMS, ) +from utilities.constants import Containers, KServeDeploymentType, StorageClassName POD_LS_SPLIT_COMMAND: list[str] = shlex.split("ls /mnt/models") diff --git a/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_write_access.py b/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_write_access.py index 710fa9762..0fe92ce82 100644 --- a/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_write_access.py +++ b/tests/model_serving/model_server/kserve/storage/pvc/test_kserve_pvc_write_access.py @@ -1,7 +1,7 @@ import shlex -from ocp_resources.pod import ExecOnPodError import pytest +from ocp_resources.pod import ExecOnPodError from tests.model_serving.model_server.kserve.storage.constants import ( INFERENCE_SERVICE_PARAMS, @@ -84,7 +84,6 @@ def test_isvc_read_only_annotation_false(self, unprivileged_client, patched_read indirect=True, ) def test_isvc_read_only_annotation_true(self, unprivileged_client, patched_read_only_isvc): - """ """ new_pod = get_pods_by_isvc_label( client=unprivileged_client, isvc=patched_read_only_isvc, diff --git a/tests/model_serving/model_server/llmd/conftest.py b/tests/model_serving/model_server/llmd/conftest.py index 24caea785..df26d2fe3 100644 --- a/tests/model_serving/model_server/llmd/conftest.py +++ b/tests/model_serving/model_server/llmd/conftest.py @@ -1,5 +1,5 @@ +from collections.abc import Generator from contextlib import ExitStack -from typing import Generator import pytest import yaml @@ -20,17 +20,17 @@ PREFIX_CACHE_HASH_SEED, ROUTER_SCHEDULER_CONFIG_ESTIMATED_PREFIX_CACHE, ) -from utilities.constants import Timeout, ResourceLimits -from utilities.infra import s3_endpoint_secret, create_inference_token -from utilities.logger import RedactedString -from utilities.llmd_utils import create_llmisvc, create_llmd_gateway +from utilities.constants import ResourceLimits, Timeout +from utilities.infra import create_inference_token, s3_endpoint_secret from utilities.llmd_constants import ( - ModelStorage, ContainerImages, - ModelNames, LLMDDefaults, LLMDGateway, + ModelNames, + ModelStorage, ) +from utilities.llmd_utils import create_llmd_gateway, create_llmisvc +from utilities.logger import RedactedString # ********************************* @@ -45,7 +45,7 @@ def llmd_s3_secret( models_s3_bucket_name: str, models_s3_bucket_region: str, models_s3_bucket_endpoint: str, -) -> Generator[Secret, None, None]: +) -> Generator[Secret]: """Create a Kubernetes secret with S3 credentials for LLMD model storage.""" with s3_endpoint_secret( client=admin_client, @@ -61,9 +61,7 @@ def llmd_s3_secret( @pytest.fixture(scope="class") -def llmd_s3_service_account( - admin_client: DynamicClient, llmd_s3_secret: Secret -) -> Generator[ServiceAccount, None, None]: +def llmd_s3_service_account(admin_client: DynamicClient, llmd_s3_secret: Secret) -> Generator[ServiceAccount]: """Create a service account linked to the S3 secret for LLMD pods.""" with ServiceAccount( client=admin_client, @@ -87,7 +85,7 @@ def gateway_namespace() -> str: def shared_llmd_gateway( admin_client: DynamicClient, gateway_namespace: str, -) -> Generator[Gateway, None, None]: +) -> Generator[Gateway]: """Shared LLMD gateway for all tests.""" with create_llmd_gateway( client=admin_client, @@ -236,7 +234,7 @@ def llmd_inference_service( request: FixtureRequest, admin_client: DynamicClient, unprivileged_model_namespace: Namespace, -) -> Generator[LLMInferenceService, None, None]: +) -> Generator[LLMInferenceService]: """Basic LLMInferenceService fixture for OCI storage with CPU runtime. This is the most commonly used fixture for basic LLMD tests. It uses @@ -330,7 +328,7 @@ def llmd_inference_service_s3( admin_client: DynamicClient, unprivileged_model_namespace: Namespace, llmd_s3_service_account: ServiceAccount, -) -> Generator[LLMInferenceService, None, None]: +) -> Generator[LLMInferenceService]: """Create an LLMInferenceService that loads models from S3 storage.""" if isinstance(request.param, str): name_suffix = request.param @@ -377,7 +375,7 @@ def llmd_inference_service_gpu( admin_client: DynamicClient, unprivileged_model_namespace: Namespace, llmd_s3_service_account: ServiceAccount, -) -> Generator[LLMInferenceService, None, None]: +) -> Generator[LLMInferenceService]: """Create an LLMInferenceService with GPU resources for accelerated inference.""" if isinstance(request.param, str): name_suffix = request.param @@ -457,7 +455,7 @@ def singlenode_estimated_prefix_cache( llmd_s3_secret: Secret, llmd_s3_service_account: ServiceAccount, llmd_gateway: Gateway, -) -> Generator[LLMInferenceService, None, None]: +) -> Generator[LLMInferenceService]: """LLMInferenceService fixture for single-node estimated prefix cache test.""" llmisvc_name = "singlenode-estimated-prefix-cache" diff --git a/tests/model_serving/model_server/llmd/kueue/test_kueue_llmisvc_raw.py b/tests/model_serving/model_server/llmd/kueue/test_kueue_llmisvc_raw.py index 19a5e78a0..c86f84ad2 100644 --- a/tests/model_serving/model_server/llmd/kueue/test_kueue_llmisvc_raw.py +++ b/tests/model_serving/model_server/llmd/kueue/test_kueue_llmisvc_raw.py @@ -1,16 +1,16 @@ import pytest -from timeout_sampler import TimeoutExpiredError, TimeoutSampler from ocp_resources.deployment import Deployment +from timeout_sampler import TimeoutExpiredError, TimeoutSampler from tests.model_serving.model_server.llmd.utils import ( verify_gateway_status, verify_llm_service_status, ) +from utilities.constants import Labels, Protocols +from utilities.exceptions import UnexpectedResourceCountError from utilities.kueue_utils import check_gated_pods_and_running_pods from utilities.llmd_utils import verify_inference_response_llmd from utilities.manifests.tinyllama import TINYLLAMA_INFERENCE_CONFIG -from utilities.constants import Protocols, Labels -from utilities.exceptions import UnexpectedResourceCountError pytestmark = [ pytest.mark.rawdeployment, diff --git a/tests/model_serving/model_server/llmd/test_llmd_auth.py b/tests/model_serving/model_server/llmd/test_llmd_auth.py index 958bec465..0b8124cae 100644 --- a/tests/model_serving/model_server/llmd/test_llmd_auth.py +++ b/tests/model_serving/model_server/llmd/test_llmd_auth.py @@ -1,8 +1,8 @@ import pytest from tests.model_serving.model_server.llmd.utils import ( - verify_llm_service_status, verify_gateway_status, + verify_llm_service_status, ) from utilities.constants import Protocols from utilities.llmd_utils import verify_inference_response_llmd diff --git a/tests/model_serving/model_server/llmd/utils.py b/tests/model_serving/model_server/llmd/utils.py index 7181cabb3..1ee8c96b3 100644 --- a/tests/model_serving/model_server/llmd/utils.py +++ b/tests/model_serving/model_server/llmd/utils.py @@ -15,14 +15,13 @@ from simple_logger.logger import get_logger from timeout_sampler import TimeoutSampler, retry +from tests.model_serving.model_server.llmd.constants import PREFIX_CACHE_BLOCK_SIZE from utilities.constants import Protocols from utilities.exceptions import PodContainersRestartError from utilities.llmd_utils import verify_inference_response_llmd from utilities.manifests.tinyllama import TINYLLAMA_INFERENCE_CONFIG from utilities.monitoring import get_metrics_value -from tests.model_serving.model_server.llmd.constants import PREFIX_CACHE_BLOCK_SIZE - LOGGER = get_logger(name=__name__) @@ -102,9 +101,10 @@ def verify_llmd_no_failed_pods( FailedPodsError: If any pods are in failed state TimeoutError: If pods don't become ready within timeout """ - from utilities.exceptions import FailedPodsError from ocp_resources.resource import Resource + from utilities.exceptions import FailedPodsError + LOGGER.info(f"Comprehensive health check for LLMInferenceService {llm_service.name}") container_wait_base_errors = ["InvalidImageName", "CrashLoopBackOff", "ImagePullBackOff", "ErrImagePull"] @@ -308,7 +308,7 @@ def send_prefix_cache_test_requests( authorized_user=True, ) successful_requests += 1 - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.error(f"Request {index + 1} failed: {e}") failed_requests += 1 diff --git a/tests/model_serving/model_server/maas_billing/conftest.py b/tests/model_serving/model_server/maas_billing/conftest.py index 346fd04ce..a0b45a034 100644 --- a/tests/model_serving/model_server/maas_billing/conftest.py +++ b/tests/model_serving/model_server/maas_billing/conftest.py @@ -1,55 +1,57 @@ -from typing import Generator, Dict, List, Any import base64 +from collections.abc import Generator +from typing import Any + import pytest import requests -from simple_logger.logger import get_logger -from utilities.plugins.constant import OpenAIEnpoints -from ocp_resources.service_account import ServiceAccount from kubernetes.dynamic import DynamicClient -from ocp_resources.namespace import Namespace -from ocp_resources.llm_inference_service import LLMInferenceService -from ocp_resources.deployment import Deployment -from timeout_sampler import TimeoutSampler -from utilities.llmd_utils import create_llmisvc -from utilities.llmd_constants import ModelStorage, ContainerImages -from ocp_resources.gateway_gateway_networking_k8s_io import Gateway from ocp_resources.config_map import ConfigMap from ocp_resources.data_science_cluster import DataScienceCluster -from pytest_testconfig import config as py_config -from utilities.constants import ( - MAAS_GATEWAY_NAMESPACE, - MAAS_RATE_LIMIT_POLICY_NAME, - MAAS_TOKEN_RATE_LIMIT_POLICY_NAME, -) -from pytest import FixtureRequest - +from ocp_resources.deployment import Deployment +from ocp_resources.gateway_gateway_networking_k8s_io import Gateway from ocp_resources.infrastructure import Infrastructure +from ocp_resources.llm_inference_service import LLMInferenceService +from ocp_resources.namespace import Namespace from ocp_resources.oauth import OAuth from ocp_resources.resource import ResourceEditor -from utilities.general import generate_random_name -from utilities.user_utils import UserTestSession, wait_for_user_creation, create_htpasswd_file -from utilities.infra import login_with_user_password, get_openshift_token, create_ns, s3_endpoint_secret -from utilities.general import wait_for_oauth_openshift_deployment from ocp_resources.secret import Secret -from tests.model_serving.model_server.maas_billing.utils import get_total_tokens -from utilities.constants import DscComponents, MAAS_GATEWAY_NAME -from utilities.resources.rate_limit_policy import RateLimitPolicy -from utilities.resources.token_rate_limit_policy import TokenRateLimitPolicy +from ocp_resources.service_account import ServiceAccount +from pytest import FixtureRequest +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger +from timeout_sampler import TimeoutSampler + from tests.model_serving.model_server.maas_billing.utils import ( - detect_scheme_via_llmisvc, - host_from_ingress_domain, - mint_token, - patch_llmisvc_with_maas_router, - create_maas_group, build_maas_headers, - get_maas_models_response, - verify_chat_completions, - maas_gateway_rate_limits_patched, + create_maas_group, + detect_scheme_via_llmisvc, endpoints_have_ready_addresses, gateway_probe_reaches_maas_api, + get_maas_models_response, + get_total_tokens, + host_from_ingress_domain, maas_gateway_listeners, + maas_gateway_rate_limits_patched, + mint_token, + patch_llmisvc_with_maas_router, revoke_token, + verify_chat_completions, ) +from utilities.constants import ( + MAAS_GATEWAY_NAME, + MAAS_GATEWAY_NAMESPACE, + MAAS_RATE_LIMIT_POLICY_NAME, + MAAS_TOKEN_RATE_LIMIT_POLICY_NAME, + DscComponents, +) +from utilities.general import generate_random_name, wait_for_oauth_openshift_deployment +from utilities.infra import create_ns, get_openshift_token, login_with_user_password, s3_endpoint_secret +from utilities.llmd_constants import ContainerImages, ModelStorage +from utilities.llmd_utils import create_llmisvc +from utilities.plugins.constant import OpenAIEnpoints +from utilities.resources.rate_limit_policy import RateLimitPolicy +from utilities.resources.token_rate_limit_policy import TokenRateLimitPolicy +from utilities.user_utils import UserTestSession, create_htpasswd_file, wait_for_user_creation LOGGER = get_logger(name=__name__) MODELS_INFO = OpenAIEnpoints.MODELS_INFO @@ -60,7 +62,7 @@ @pytest.fixture(scope="session") -def request_session_http() -> Generator[requests.Session, None, None]: +def request_session_http() -> Generator[requests.Session]: session = requests.Session() session.headers.update({"User-Agent": "odh-maas-billing-tests/1"}) session.verify = False @@ -214,7 +216,7 @@ def maas_user_credentials_both() -> dict[str, str]: @pytest.fixture(scope="session") def maas_htpasswd_files( maas_user_credentials_both: dict[str, str], -) -> Generator[tuple[str, str, str, str], None, None]: +) -> Generator[tuple[str, str, str, str]]: """ Create per-user htpasswd files for FREE and PREMIUM users and return their file paths + base64 contents. @@ -336,7 +338,7 @@ def maas_free_user_session( is_byoidc: bool, maas_rbac_idp_env: dict[str, str], admin_client: DynamicClient, -) -> Generator[UserTestSession, None, None]: +) -> Generator[UserTestSession]: if is_byoidc: pytest.skip("Working on OIDC support for tests that use htpasswd IDP for MaaS") else: @@ -383,7 +385,7 @@ def maas_premium_user_session( is_byoidc: bool, maas_rbac_idp_env: dict[str, str], admin_client: DynamicClient, -) -> Generator[UserTestSession, None, None]: +) -> Generator[UserTestSession]: if is_byoidc: pytest.skip("Working on OIDC support for tests that use htpasswd IDP for MaaS") else: @@ -427,7 +429,7 @@ def maas_premium_user_session( def maas_free_group( admin_client: DynamicClient, maas_free_user_session: UserTestSession, -) -> Generator[str, None, None]: +) -> Generator[str]: """Create a FREE-tier MaaS group and add the FREE test user to it.""" with create_maas_group( admin_client=admin_client, @@ -442,7 +444,7 @@ def maas_free_group( def maas_premium_group( admin_client: DynamicClient, maas_premium_user_session: UserTestSession, -) -> Generator[str, None, None]: +) -> Generator[str]: """Create a PREMIUM-tier MaaS group and add the PREMIUM test user to it.""" with create_maas_group( admin_client=admin_client, @@ -461,7 +463,7 @@ def ocp_token_for_actor( admin_client: DynamicClient, maas_free_user_session: UserTestSession, maas_premium_user_session: UserTestSession, -) -> Generator[str, None, None]: +) -> Generator[str]: """ Log in as the requested actor ('admin' / 'free' / 'premium') and yield the OpenShift token for that user. @@ -554,7 +556,7 @@ def maas_models_response_for_actor( @pytest.fixture(scope="class") def maas_models_for_actor( maas_models_response_for_actor: requests.Response, -) -> List[Dict]: +) -> list[dict]: models_list = maas_models_response_for_actor.json().get("data", []) assert models_list, "no models returned from /v1/models" @@ -578,9 +580,9 @@ def exercise_rate_limiter( scenario: dict, request_session_http: requests.Session, model_url: str, - maas_headers_for_actor: Dict[str, str], - maas_models_for_actor: List[Dict], -) -> List[int]: + maas_headers_for_actor: dict[str, str], + maas_models_for_actor: list[dict], +) -> list[int]: models_list = maas_models_for_actor @@ -588,7 +590,7 @@ def exercise_rate_limiter( max_tokens = scenario["max_tokens"] log_prefix = scenario["log_prefix"] - status_codes_list: List[int] = [] + status_codes_list: list[int] = [] for attempt_index in range(max_requests): LOGGER.info(f"{log_prefix}[{actor_label}]: attempt {attempt_index + 1}/{max_requests}") @@ -624,43 +626,45 @@ def maas_inference_service_tinyllama( maas_gateway_api: None, maas_request_ratelimit_policy: None, maas_token_ratelimit_policy: None, -) -> Generator[LLMInferenceService, None, None]: +) -> Generator[LLMInferenceService]: """ TinyLlama S3-backed LLMInferenceService wired through MaaS for tests. """ - with create_llmisvc( - client=admin_client, - name="llm-s3-tinyllama", - namespace=maas_unprivileged_model_namespace.name, - storage_uri=ModelStorage.TINYLLAMA_S3, - container_image=ContainerImages.VLLM_CPU, - container_resources={ - "limits": {"cpu": "2", "memory": "12Gi"}, - "requests": {"cpu": "1", "memory": "8Gi"}, - }, - service_account=maas_model_service_account.name, - wait=False, - timeout=900, - ) as llm_service: - with patch_llmisvc_with_maas_router( + with ( + create_llmisvc( + client=admin_client, + name="llm-s3-tinyllama", + namespace=maas_unprivileged_model_namespace.name, + storage_uri=ModelStorage.TINYLLAMA_S3, + container_image=ContainerImages.VLLM_CPU, + container_resources={ + "limits": {"cpu": "2", "memory": "12Gi"}, + "requests": {"cpu": "1", "memory": "8Gi"}, + }, + service_account=maas_model_service_account.name, + wait=False, + timeout=900, + ) as llm_service, + patch_llmisvc_with_maas_router( llm_service=llm_service, - ): - inst = llm_service.instance - storage_uri = inst.spec.model.uri - assert storage_uri == ModelStorage.TINYLLAMA_S3, f"Unexpected storage_uri on TinyLlama LLMI: {storage_uri}" - - llm_service.wait_for_condition( - condition="Ready", - status="True", - timeout=900, - ) + ), + ): + inst = llm_service.instance + storage_uri = inst.spec.model.uri + assert storage_uri == ModelStorage.TINYLLAMA_S3, f"Unexpected storage_uri on TinyLlama LLMI: {storage_uri}" - LOGGER.info( - f"MaaS: TinyLlama LLMI {llm_service.namespace}/{llm_service.name} " - f"Ready and patched (storage_uri={storage_uri})" - ) + llm_service.wait_for_condition( + condition="Ready", + status="True", + timeout=900, + ) + + LOGGER.info( + f"MaaS: TinyLlama LLMI {llm_service.namespace}/{llm_service.name} " + f"Ready and patched (storage_uri={storage_uri})" + ) - yield llm_service + yield llm_service @pytest.fixture(scope="class") @@ -683,7 +687,7 @@ def maas_gateway_rate_limits( maas_request_ratelimit_policy: None, maas_token_ratelimit_policy: None, maas_tier_mapping_cm, -) -> Generator[None, None, None]: +) -> Generator[None]: with maas_gateway_rate_limits_patched( admin_client=admin_client, namespace=MAAS_GATEWAY_NAMESPACE, @@ -704,7 +708,7 @@ def maas_controller_enabled_latest( maas_gateway_api: None, maas_request_ratelimit_policy: None, maas_token_ratelimit_policy: None, -) -> Generator[DataScienceCluster, None, None]: +) -> Generator[DataScienceCluster]: """ Ensure MaaS (KServe modelsAsService) is MANAGED for the session. Restore DSC to original state on teardown. @@ -810,7 +814,7 @@ def maas_api_gateway_reachable( def maas_gateway_api( admin_client: DynamicClient, maas_gateway_api_hostname: str, -) -> Generator[None, None, None]: +) -> Generator[None]: """ Ensure MaaS Gateway exists once per test session. @@ -861,7 +865,7 @@ def maas_request_ratelimit_policy( admin_client: DynamicClient, maas_gateway_api: None, maas_gateway_target_ref: dict, -) -> Generator[None, None, None]: +) -> Generator[None]: with RateLimitPolicy( client=admin_client, name=MAAS_RATE_LIMIT_POLICY_NAME, @@ -885,7 +889,7 @@ def maas_token_ratelimit_policy( admin_client: DynamicClient, maas_gateway_api: None, maas_gateway_target_ref: dict, -) -> Generator[None, None, None]: +) -> Generator[None]: with TokenRateLimitPolicy( client=admin_client, name=MAAS_TOKEN_RATE_LIMIT_POLICY_NAME, @@ -911,7 +915,7 @@ def ensure_working_maas_token_pre_revoke( maas_headers_for_actor, maas_models_response_for_actor, actor_label, -) -> List[dict]: +) -> list[dict]: models_list = maas_models_response_for_actor.json().get("data", []) verify_chat_completions( diff --git a/tests/model_serving/model_server/maas_billing/test_maas_endpoints.py b/tests/model_serving/model_server/maas_billing/test_maas_endpoints.py index 6131a59e1..8cb39af9e 100644 --- a/tests/model_serving/model_server/maas_billing/test_maas_endpoints.py +++ b/tests/model_serving/model_server/maas_billing/test_maas_endpoints.py @@ -1,6 +1,7 @@ -from simple_logger.logger import get_logger -import requests import pytest +import requests +from simple_logger.logger import get_logger + from tests.model_serving.model_server.maas_billing.utils import verify_chat_completions LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_server/maas_billing/test_maas_rbac_e2e.py b/tests/model_serving/model_server/maas_billing/test_maas_rbac_e2e.py index 7838caf3e..0d05d04a2 100644 --- a/tests/model_serving/model_server/maas_billing/test_maas_rbac_e2e.py +++ b/tests/model_serving/model_server/maas_billing/test_maas_rbac_e2e.py @@ -1,9 +1,10 @@ import pytest from simple_logger.logger import get_logger -from utilities.plugins.constant import OpenAIEnpoints + from tests.model_serving.model_server.maas_billing.utils import ( verify_chat_completions, ) +from utilities.plugins.constant import OpenAIEnpoints LOGGER = get_logger(name=__name__) diff --git a/tests/model_serving/model_server/maas_billing/test_maas_request_rate_limits.py b/tests/model_serving/model_server/maas_billing/test_maas_request_rate_limits.py index bf7a61df7..60336b660 100644 --- a/tests/model_serving/model_server/maas_billing/test_maas_request_rate_limits.py +++ b/tests/model_serving/model_server/maas_billing/test_maas_request_rate_limits.py @@ -1,6 +1,6 @@ -from typing import List import pytest from simple_logger.logger import get_logger + from tests.model_serving.model_server.maas_billing.utils import ( assert_mixed_200_and_429, ) @@ -50,7 +50,7 @@ def test_request_rate_limits( ocp_token_for_actor: str, actor_label: str, scenario: dict, - exercise_rate_limiter: List[int], + exercise_rate_limiter: list[int], ) -> None: _ = ocp_token_for_actor diff --git a/tests/model_serving/model_server/maas_billing/test_maas_token_rate_limits.py b/tests/model_serving/model_server/maas_billing/test_maas_token_rate_limits.py index 9e32e5648..b012f7a94 100644 --- a/tests/model_serving/model_server/maas_billing/test_maas_token_rate_limits.py +++ b/tests/model_serving/model_server/maas_billing/test_maas_token_rate_limits.py @@ -1,6 +1,6 @@ -from typing import List import pytest from simple_logger.logger import get_logger + from tests.model_serving.model_server.maas_billing.utils import ( assert_mixed_200_and_429, ) @@ -47,7 +47,7 @@ def test_token_rate_limits( ocp_token_for_actor: str, actor_label: str, scenario: dict, - exercise_rate_limiter: List[int], + exercise_rate_limiter: list[int], ) -> None: _ = ocp_token_for_actor diff --git a/tests/model_serving/model_server/maas_billing/utils.py b/tests/model_serving/model_server/maas_billing/utils.py index 947c9b970..30c99b2fa 100644 --- a/tests/model_serving/model_server/maas_billing/utils.py +++ b/tests/model_serving/model_server/maas_billing/utils.py @@ -1,29 +1,30 @@ -from typing import Any, Dict, Generator, List, Tuple import base64 -import requests +from collections.abc import Generator +from contextlib import contextmanager from json import JSONDecodeError +from typing import Any from urllib.parse import urlparse -from contextlib import contextmanager +import requests from kubernetes.dynamic import DynamicClient + +# from ocp_resources.gateway_gateway_networking_k8s_io import Gateway +from ocp_resources.endpoints import Endpoints from ocp_resources.group import Group from ocp_resources.ingress_config_openshift_io import Ingress as IngressConfig from ocp_resources.llm_inference_service import LLMInferenceService +from ocp_resources.resource import ResourceEditor from requests import Response from simple_logger.logger import get_logger -from utilities.llmd_utils import get_llm_inference_url -from utilities.plugins.constant import RestHeader, OpenAIEnpoints -from ocp_resources.resource import ResourceEditor -from utilities.resources.rate_limit_policy import RateLimitPolicy -from utilities.resources.token_rate_limit_policy import TokenRateLimitPolicy -# from ocp_resources.gateway_gateway_networking_k8s_io import Gateway -from ocp_resources.endpoints import Endpoints from utilities.constants import ( MAAS_GATEWAY_NAME, MAAS_GATEWAY_NAMESPACE, ) - +from utilities.llmd_utils import get_llm_inference_url +from utilities.plugins.constant import OpenAIEnpoints, RestHeader +from utilities.resources.rate_limit_policy import RateLimitPolicy +from utilities.resources.token_rate_limit_policy import TokenRateLimitPolicy LOGGER = get_logger(name=__name__) MODELS_INFO = OpenAIEnpoints.MODELS_INFO @@ -78,7 +79,7 @@ def detect_scheme_via_llmisvc(client, namespace: str = "llm") -> str: return "https" -def maas_auth_headers(token: str) -> Dict[str, str]: +def maas_auth_headers(token: str) -> dict[str, str]: """Authorization header only (used for /v1/tokens with OCP user token).""" return {"Authorization": f"Bearer {token}"} @@ -114,7 +115,7 @@ def create_maas_group( admin_client: DynamicClient, group_name: str, users: list[str] | None = None, -) -> Generator[Group, None, None]: +) -> Generator[Group]: """ Create an OpenShift Group with optional users and delete it on exit. """ @@ -159,7 +160,7 @@ def get_maas_models_response( @contextmanager def patch_llmisvc_with_maas_router( llm_service: LLMInferenceService, -) -> Generator[None, None, None]: +) -> Generator[None]: router_spec = { "gateway": {"refs": [{"name": MAAS_GATEWAY_NAME, "namespace": MAAS_GATEWAY_NAMESPACE}]}, "route": {}, @@ -253,7 +254,7 @@ def verify_chat_completions( def assert_mixed_200_and_429( *, actor_label: str, - status_codes_list: List[int], + status_codes_list: list[int], context: str, require_429: bool = True, ) -> None: @@ -286,7 +287,7 @@ def assert_mixed_200_and_429( ) -def maas_token_ratelimitpolicy_spec() -> Dict[str, Any]: +def maas_token_ratelimitpolicy_spec() -> dict[str, Any]: """ Deterministic TokenRateLimitPolicy limits for MaaS tests. @@ -312,7 +313,7 @@ def maas_token_ratelimitpolicy_spec() -> Dict[str, Any]: } -def maas_ratelimitpolicy_spec() -> Dict[str, Any]: +def maas_ratelimitpolicy_spec() -> dict[str, Any]: """ Deterministic RateLimitPolicy limits for MaaS tests. @@ -345,7 +346,7 @@ def maas_gateway_rate_limits_patched( namespace: str, token_policy_name: str, request_policy_name: str, -) -> Generator[None, None, None]: +) -> Generator[None]: """ Temporarily patch ONLY `spec.limits` of the Kuadrant TokenRateLimitPolicy and RateLimitPolicy for MaaS tests, and restore the original state afterwards. @@ -429,7 +430,7 @@ def get_total_tokens(resp: Response, *, fail_if_missing: bool = False) -> int | return None -def maas_gateway_listeners(hostname: str) -> List[Dict[str, Any]]: +def maas_gateway_listeners(hostname: str) -> list[dict[str, Any]]: return [ { "name": "http", @@ -475,7 +476,7 @@ def gateway_probe_reaches_maas_api( http_session: requests.Session, probe_url: str, request_timeout_seconds: int, -) -> Tuple[bool, int, str]: +) -> tuple[bool, int, str]: response = http_session.get(probe_url, timeout=request_timeout_seconds) status_code = response.status_code response_text = response.text diff --git a/tests/model_serving/model_server/upgrade/conftest.py b/tests/model_serving/model_server/upgrade/conftest.py index 0fe678460..a33b7fcbd 100644 --- a/tests/model_serving/model_server/upgrade/conftest.py +++ b/tests/model_serving/model_server/upgrade/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest import yaml @@ -34,7 +35,6 @@ from utilities.logger import RedactedString from utilities.serving_runtime import ServingRuntimeFromTemplate - LOGGER = get_logger(name=__name__) UPGRADE_NAMESPACE = "upgrade-model-server" @@ -511,7 +511,7 @@ def model_car_inference_service_fixture( def upgrade_user_workload_monitoring_config_map( admin_client: DynamicClient, cluster_monitoring_config: ConfigMap, -) -> Generator[ConfigMap, None, None]: +) -> Generator[ConfigMap]: """ Session-scoped user workload monitoring ConfigMap for upgrade tests. diff --git a/tests/model_serving/model_server/upgrade/test_upgrade.py b/tests/model_serving/model_server/upgrade/test_upgrade.py index e77bf53ae..bdc5bf6c2 100644 --- a/tests/model_serving/model_server/upgrade/test_upgrade.py +++ b/tests/model_serving/model_server/upgrade/test_upgrade.py @@ -9,7 +9,6 @@ from utilities.inference_utils import Inference from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG - pytestmark = [pytest.mark.rawdeployment, pytest.mark.usefixtures("valid_aws_config")] diff --git a/tests/model_serving/model_server/upgrade/test_upgrade_auth.py b/tests/model_serving/model_server/upgrade/test_upgrade_auth.py index d3d02bdf9..90ef18787 100644 --- a/tests/model_serving/model_server/upgrade/test_upgrade_auth.py +++ b/tests/model_serving/model_server/upgrade/test_upgrade_auth.py @@ -11,7 +11,6 @@ from utilities.inference_utils import Inference, UserInference from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG - pytestmark = [pytest.mark.rawdeployment, pytest.mark.usefixtures("valid_aws_config")] diff --git a/tests/model_serving/model_server/upgrade/test_upgrade_metrics.py b/tests/model_serving/model_server/upgrade/test_upgrade_metrics.py index d2d88df97..75cd0f251 100644 --- a/tests/model_serving/model_server/upgrade/test_upgrade_metrics.py +++ b/tests/model_serving/model_server/upgrade/test_upgrade_metrics.py @@ -16,7 +16,6 @@ from utilities.inference_utils import Inference from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG - pytestmark = [ pytest.mark.rawdeployment, pytest.mark.metrics, diff --git a/tests/model_serving/model_server/upgrade/test_upgrade_model_car.py b/tests/model_serving/model_server/upgrade/test_upgrade_model_car.py index 2910e1092..79c46cf35 100644 --- a/tests/model_serving/model_server/upgrade/test_upgrade_model_car.py +++ b/tests/model_serving/model_server/upgrade/test_upgrade_model_car.py @@ -12,7 +12,6 @@ from utilities.inference_utils import Inference from utilities.manifests.onnx import ONNX_INFERENCE_CONFIG - pytestmark = pytest.mark.rawdeployment diff --git a/tests/model_serving/model_server/upgrade/test_upgrade_private_endpoint.py b/tests/model_serving/model_server/upgrade/test_upgrade_private_endpoint.py index f2b939c85..118da72c7 100644 --- a/tests/model_serving/model_server/upgrade/test_upgrade_private_endpoint.py +++ b/tests/model_serving/model_server/upgrade/test_upgrade_private_endpoint.py @@ -13,7 +13,6 @@ from utilities.inference_utils import Inference from utilities.manifests.openvino import OPENVINO_KSERVE_INFERENCE_CONFIG - pytestmark = [pytest.mark.rawdeployment, pytest.mark.usefixtures("valid_aws_config")] diff --git a/tests/model_serving/model_server/utils.py b/tests/model_serving/model_server/utils.py index dd566ef3c..fc64f413a 100644 --- a/tests/model_serving/model_server/utils.py +++ b/tests/model_serving/model_server/utils.py @@ -2,25 +2,23 @@ import re from concurrent.futures import ThreadPoolExecutor, as_completed, wait from string import Template -from typing import Any, Optional +from typing import Any + from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError - from ocp_resources.inference_graph import InferenceGraph from ocp_resources.inference_service import InferenceService from ocp_resources.utils.constants import DEFAULT_CLUSTER_RETRY_EXCEPTIONS from simple_logger.logger import get_logger +from timeout_sampler import TimeoutExpiredError, TimeoutSampler, TimeoutWatch -from utilities.constants import KServeDeploymentType +from tests.model_serving.model_server.kserve.keda.utils import get_isvc_keda_scaledobject +from utilities.constants import KServeDeploymentType, Protocols, Timeout from utilities.exceptions import ( InferenceResponseError, ) -from utilities.constants import Timeout from utilities.inference_utils import UserInference from utilities.infra import get_pods_by_isvc_label -from tests.model_serving.model_server.kserve.keda.utils import get_isvc_keda_scaledobject -from utilities.constants import Protocols -from timeout_sampler import TimeoutWatch, TimeoutSampler, TimeoutExpiredError LOGGER = get_logger(name=__name__) @@ -30,13 +28,13 @@ def verify_inference_response( inference_config: dict[str, Any], inference_type: str, protocol: str, - model_name: Optional[str] = None, - inference_input: Optional[Any] = None, + model_name: str | None = None, + inference_input: Any | None = None, use_default_query: bool = False, - expected_response_text: Optional[str] = None, + expected_response_text: str | None = None, insecure: bool = False, - token: Optional[str] = None, - authorized_user: Optional[bool] = None, + token: str | None = None, + authorized_user: bool | None = None, ) -> None: """ Verify the inference response. @@ -148,7 +146,7 @@ def verify_inference_response( elif inference_type == inference.INFER or use_regex: formatted_res = json.dumps(res[inference.inference_response_text_key_name]).replace(" ", "") if use_regex: - assert re.search(expected_response_text, formatted_res), ( # type: ignore[arg-type] # noqa: E501 + assert re.search(expected_response_text, formatted_res), ( # type: ignore[arg-type] f"Expected: {expected_response_text} not found in: {formatted_res}" ) @@ -220,10 +218,7 @@ def run_inference_multiple_times( verify_inference_response(**infer_kwargs) if futures: - exceptions = [] - for result in as_completed(futures): - if _exception := result.exception(): - exceptions.append(_exception) + exceptions = [_exception for result in as_completed(futures) if (_exception := result.exception())] if exceptions: raise InferenceResponseError(f"Failed to run inference. Error: {exceptions}") diff --git a/tests/workbenches/conftest.py b/tests/workbenches/conftest.py index 89cf8b5ab..8b95c6dc8 100644 --- a/tests/workbenches/conftest.py +++ b/tests/workbenches/conftest.py @@ -1,24 +1,19 @@ -from typing import Generator +from collections.abc import Generator import pytest -from pytest_testconfig import config as py_config - -from simple_logger.logger import get_logger -from tests.workbenches.utils import get_username - from kubernetes.dynamic import DynamicClient - from ocp_resources.namespace import Namespace -from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.notebook import Notebook +from ocp_resources.persistent_volume_claim import PersistentVolumeClaim from ocp_resources.pod import Pod +from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger -from utilities.constants import Labels, Timeout +from tests.workbenches.utils import get_username from utilities import constants -from utilities.constants import INTERNAL_IMAGE_REGISTRY_PATH -from utilities.infra import get_product_version -from utilities.infra import check_internal_image_registry_available +from utilities.constants import INTERNAL_IMAGE_REGISTRY_PATH, Labels, Timeout from utilities.general import collect_pod_information +from utilities.infra import check_internal_image_registry_available, get_product_version LOGGER = get_logger(name=__name__) @@ -26,7 +21,7 @@ @pytest.fixture(scope="function") def users_persistent_volume_claim( request: pytest.FixtureRequest, unprivileged_model_namespace: Namespace, unprivileged_client: DynamicClient -) -> Generator[PersistentVolumeClaim, None, None]: +) -> Generator[PersistentVolumeClaim]: with PersistentVolumeClaim( client=unprivileged_client, name=request.param["name"], @@ -40,7 +35,7 @@ def users_persistent_volume_claim( @pytest.fixture(scope="function") -def minimal_image(admin_client: DynamicClient) -> Generator[str, None, None]: +def minimal_image(admin_client: DynamicClient) -> Generator[str]: """Provides a full image name of a minimal workbench image.""" image_name = "jupyter-minimal-notebook" if py_config.get("distribution") == "upstream" else "s2i-minimal-notebook" image_tag = py_config.get("workbench_image_tag") @@ -111,7 +106,7 @@ def default_notebook( request: pytest.FixtureRequest, admin_client: DynamicClient, notebook_image: str, -) -> Generator[Notebook, None, None]: +) -> Generator[Notebook]: """Returns a new Notebook CR for a given namespace, name, and image""" namespace = request.param["namespace"] name = request.param["name"] diff --git a/tests/workbenches/notebook-controller/test_custom_images.py b/tests/workbenches/notebook-controller/test_custom_images.py index feabe5b8e..86bc6b384 100644 --- a/tests/workbenches/notebook-controller/test_custom_images.py +++ b/tests/workbenches/notebook-controller/test_custom_images.py @@ -5,16 +5,14 @@ from time import time import pytest - -from ocp_resources.pod import Pod -from ocp_resources.pod import ExecOnPodError from ocp_resources.namespace import Namespace from ocp_resources.notebook import Notebook from ocp_resources.persistent_volume_claim import PersistentVolumeClaim +from ocp_resources.pod import ExecOnPodError, Pod +from simple_logger.logger import get_logger from utilities.constants import Timeout from utilities.general import collect_pod_information -from simple_logger.logger import get_logger LOGGER = get_logger(name=__name__) @@ -282,10 +280,10 @@ class TestCustomImageValidation: ) def test_custom_image_package_verification( self, - unprivileged_model_namespace: Namespace, # noqa: ARG002 - users_persistent_volume_claim: PersistentVolumeClaim, # noqa: ARG002 + unprivileged_model_namespace: Namespace, + users_persistent_volume_claim: PersistentVolumeClaim, default_notebook: Notebook, - notebook_image: str, # noqa: ARG002 + notebook_image: str, notebook_pod: Pod, packages_to_verify: list[str], ): diff --git a/tests/workbenches/notebook-controller/test_spawning.py b/tests/workbenches/notebook-controller/test_spawning.py index 288eab1b3..8c4933741 100644 --- a/tests/workbenches/notebook-controller/test_spawning.py +++ b/tests/workbenches/notebook-controller/test_spawning.py @@ -1,11 +1,9 @@ import pytest - from kubernetes.dynamic.client import DynamicClient - -from ocp_resources.pod import Pod from ocp_resources.namespace import Namespace from ocp_resources.notebook import Notebook from ocp_resources.persistent_volume_claim import PersistentVolumeClaim +from ocp_resources.pod import Pod class TestNotebook: diff --git a/utilities/certificates_utils.py b/utilities/certificates_utils.py index 24cf8b64b..9b9c7e14b 100644 --- a/utilities/certificates_utils.py +++ b/utilities/certificates_utils.py @@ -3,19 +3,18 @@ from functools import cache from kubernetes.dynamic import DynamicClient -from ocp_resources.secret import Secret from ocp_resources.config_map import ConfigMap +from ocp_resources.secret import Secret from pytest_testconfig import config as py_config from simple_logger.logger import get_logger from utilities.constants import ( ISTIO_CA_BUNDLE_FILENAME, - KServeDeploymentType, OPENSHIFT_CA_BUNDLE_FILENAME, + KServeDeploymentType, ) from utilities.infra import is_managed_cluster, is_self_managed_operator - LOGGER = get_logger(name=__name__) diff --git a/utilities/constants.py b/utilities/constants.py index 3116ff4f0..4681c81b4 100644 --- a/utilities/constants.py +++ b/utilities/constants.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict +from typing import Any from ocp_resources.resource import Resource @@ -105,7 +105,7 @@ class Protocols: GRPC: str = "grpc" REST: str = "rest" TCP: str = "TCP" - TCP_PROTOCOLS: set[str] = {HTTP, HTTPS} + TCP_PROTOCOLS: set[str] = {HTTP, HTTPS} # noqa: RUF012 ALL_SUPPORTED_PROTOCOLS: set[str] = TCP_PROTOCOLS.union({GRPC}) @@ -131,7 +131,7 @@ class AcceleratorType: GAUDI: str = "gaudi" SPYRE: str = "spyre" CPU_x86: str = "cpu_x86" - SUPPORTED_LISTS: list[str] = [NVIDIA, AMD, GAUDI, SPYRE, CPU_x86] + SUPPORTED_LISTS: list[str] = [NVIDIA, AMD, GAUDI, SPYRE, CPU_x86] # noqa: RUF012 class ApiGroups: @@ -185,7 +185,7 @@ class ConditionType: MODEL_MESH_SERVING_READY: str = "ModelMeshServingReady" LLAMA_STACK_OPERATOR_READY: str = "LlamaStackOperatorReady" - COMPONENT_MAPPING: dict[str, str] = { + COMPONENT_MAPPING: dict[str, str] = { # noqa: RUF012 MODELMESHSERVING: ConditionType.MODEL_MESH_SERVING_READY, KSERVE: ConditionType.KSERVE_READY, MODELREGISTRY: ConditionType.MODEL_REGISTRY_READY, @@ -282,7 +282,7 @@ class Containers: class RunTimeConfigs: - ONNX_OPSET13_RUNTIME_CONFIG: dict[str, Any] = { + ONNX_OPSET13_RUNTIME_CONFIG: dict[str, Any] = { # noqa: RUF012 "runtime-name": ModelInferenceRuntime.ONNX_RUNTIME, "model-format": {ModelFormat.ONNX: ModelVersion.OPSET13}, } @@ -291,7 +291,6 @@ class RunTimeConfigs: class ModelCarImage: MNIST_8_1: str = ( "oci://quay.io/mwaykole/test@sha256:cb7d25c43e52c755e85f5b59199346f30e03b7112ef38b74ed4597aec8748743" - # noqa: E501 ) GRANITE_8B_CODE_INSTRUCT: str = "oci://registry.redhat.io/rhelai1/modelcar-granite-8b-code-instruct:1.4" @@ -324,7 +323,7 @@ class Metadata: class PodConfig: REGISTRY_IMAGE: str = "ghcr.io/project-zot/zot:v2.1.8" - REGISTRY_BASE_CONFIG: dict[str, Any] = { + REGISTRY_BASE_CONFIG: dict[str, Any] = { # noqa: RUF012 "args": None, "labels": { "maistra.io/expose-route": "true", @@ -358,9 +357,8 @@ class Buckets: class PodConfig: KSERVE_MINIO_IMAGE: str = ( "quay.io/jooholee/model-minio@sha256:b9554be19a223830cf792d5de984ccc57fc140b954949f5ffc6560fab977ca7a" - # noqa: E501 ) - MINIO_BASE_LABELS_ANNOTATIONS: dict[str, Any] = { + MINIO_BASE_LABELS_ANNOTATIONS: dict[str, Any] = { # noqa: RUF012 "labels": { "maistra.io/expose-route": "true", }, @@ -369,32 +367,32 @@ class PodConfig: }, } - MINIO_BASE_CONFIG: dict[str, Any] = { + MINIO_BASE_CONFIG: dict[str, Any] = { # noqa: RUF012 "args": ["server", "/data1"], **MINIO_BASE_LABELS_ANNOTATIONS, } - MODEL_MESH_MINIO_CONFIG: dict[str, Any] = { + MODEL_MESH_MINIO_CONFIG: dict[str, Any] = { # noqa: RUF012 "image": "quay.io/trustyai_testing/modelmesh-minio-examples@sha256:d2ccbe92abf9aa5085b594b2cae6c65de2bf06306c30ff5207956eb949bb49da", # noqa: E501 **MINIO_BASE_CONFIG, } - QWEN_MINIO_CONFIG: dict[str, Any] = { + QWEN_MINIO_CONFIG: dict[str, Any] = { # noqa: RUF012 "image": "quay.io/trustyai_testing/hf-llm-minio@sha256:2404a37d578f2a9c7adb3971e26a7438fedbe7e2e59814f396bfa47cd5fe93bb", # noqa: E501 **MINIO_BASE_CONFIG, } - QWEN_HAP_BPIV2_MINIO_CONFIG: dict[str, Any] = { + QWEN_HAP_BPIV2_MINIO_CONFIG: dict[str, Any] = { # noqa: RUF012 "image": "quay.io/trustyai_testing/qwen2.5-0.5b-instruct-hap-bpiv2-minio@sha256:eac1ca56f62606e887c80b4a358b3061c8d67f0b071c367c0aa12163967d5b2b", # noqa: E501 **MINIO_BASE_CONFIG, } - KSERVE_MINIO_CONFIG: dict[str, Any] = { + KSERVE_MINIO_CONFIG: dict[str, Any] = { # noqa: RUF012 "image": KSERVE_MINIO_IMAGE, **MINIO_BASE_CONFIG, } - MODEL_REGISTRY_MINIO_CONFIG: dict[str, Any] = { + MODEL_REGISTRY_MINIO_CONFIG: dict[str, Any] = { # noqa: RUF012 "image": "quay.io/minio/minio@sha256:14cea493d9a34af32f524e538b8346cf79f3321eff8e708c1e2960462bd8936e", "args": ["server", "/data"], **MINIO_BASE_LABELS_ANNOTATIONS, @@ -433,7 +431,7 @@ class RunTimeConfig: MARIADB: str = "mariadb" MODEL_REGISTRY_CUSTOM_NAMESPACE: str = "model-registry-custom-ns" THANOS_QUERIER_ADDRESS = "https://thanos-querier.openshift-monitoring.svc:9092" -BUILTIN_DETECTOR_CONFIG: Dict[str, Any] = { +BUILTIN_DETECTOR_CONFIG: dict[str, Any] = { "regex": { "type": "text_contents", "service": { @@ -473,7 +471,7 @@ class OpenVINO: MODEL_SERVER: str = "quay.io/opendatahub/openvino_model_server@sha256:564664371d3a21b9e732a5c1b4b40bacad714a5144c0a9aaf675baec4a04b148" # noqa: E501 -CHAT_GENERATION_CONFIG: Dict[str, Any] = { +CHAT_GENERATION_CONFIG: dict[str, Any] = { "service": { "hostname": f"{QWEN_MODEL_NAME}-predictor", "port": 80, @@ -494,12 +492,10 @@ class LLMdInferenceSimConfig: isvc_name: str = f"{LLM_D_INFERENCE_SIM_NAME}-isvc" -LLM_D_CHAT_GENERATION_CONFIG: Dict[str, Any] = { +LLM_D_CHAT_GENERATION_CONFIG: dict[str, Any] = { "service": {"hostname": f"{LLMdInferenceSimConfig.isvc_name}-predictor", "port": 80} } class PodNotFound(Exception): """Pod not found""" - - pass diff --git a/utilities/data_science_cluster_utils.py b/utilities/data_science_cluster_utils.py index 38911eb40..3049e9149 100644 --- a/utilities/data_science_cluster_utils.py +++ b/utilities/data_science_cluster_utils.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator +from typing import Any from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.resource import ResourceEditor diff --git a/utilities/database.py b/utilities/database.py index 705984be4..aaaef83fc 100644 --- a/utilities/database.py +++ b/utilities/database.py @@ -2,8 +2,8 @@ import os from sqlalchemy import Integer, String, create_engine -from sqlalchemy.orm import Mapped, Session, mapped_column -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column + from utilities.must_gather_collector import get_base_dir LOGGER = logging.getLogger(__name__) diff --git a/utilities/exceptions.py b/utilities/exceptions.py index f9c2d6f99..f44be948b 100644 --- a/utilities/exceptions.py +++ b/utilities/exceptions.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - from ocp_resources.service import Service @@ -28,16 +26,16 @@ class InferenceResponseError(Exception): class InvalidStorageArgumentError(Exception): def __init__( self, - storage_uri: Optional[str], - storage_key: Optional[str], - storage_path: Optional[str], + storage_uri: str | None, + storage_key: str | None, + storage_path: str | None, ): self.storage_uri = storage_uri self.storage_key = storage_key self.storage_path = storage_path def __str__(self) -> str: - msg = f""" + return f""" You've passed the following parameters: "storage_uri": {self.storage_uri} "storage_key": {self.storage_key} @@ -45,7 +43,6 @@ def __str__(self) -> str: In order to create a valid ISVC you need to specify either a storage_uri value or both a storage key and a storage path. """ - return msg class MetricValidationError(Exception): @@ -99,8 +96,6 @@ def __str__(self) -> str: class InvalidArgumentsError(Exception): """Raised when mutually exclusive or invalid argument combinations are passed.""" - pass - class ResourceNotReadyError(Exception): pass @@ -125,14 +120,10 @@ class UnexpectedResourceCountError(Exception): class ResourceValueMismatch(Exception): """Resource value mismatch""" - pass - class MissingParameter(Exception): """Raised required argument is not passed.""" - pass - class ExceptionUserLogin(Exception): pass diff --git a/utilities/general.py b/utilities/general.py index 8b0b9bbf6..1bfcb40d0 100644 --- a/utilities/general.py +++ b/utilities/general.py @@ -1,23 +1,22 @@ import base64 +import os import re -from typing import List, Tuple, Any import uuid -import os +from typing import Any from kubernetes.dynamic import DynamicClient -from kubernetes.dynamic.exceptions import ResourceNotFoundError, NotFoundError +from kubernetes.dynamic.exceptions import NotFoundError, ResourceNotFoundError +from ocp_resources.deployment import Deployment from ocp_resources.inference_graph import InferenceGraph from ocp_resources.inference_service import InferenceService from ocp_resources.pod import Pod +from ocp_resources.resource import Resource from simple_logger.logger import get_logger +from timeout_sampler import TimeoutExpiredError, TimeoutSampler, retry import utilities.infra -from utilities.constants import Annotations, KServeDeploymentType, MODELMESH_SERVING -from utilities.exceptions import UnexpectedResourceCountError, ResourceValueMismatch -from ocp_resources.resource import Resource -from timeout_sampler import retry -from timeout_sampler import TimeoutExpiredError, TimeoutSampler -from ocp_resources.deployment import Deployment +from utilities.constants import MODELMESH_SERVING, Annotations, KServeDeploymentType +from utilities.exceptions import ResourceValueMismatch, UnexpectedResourceCountError # Constants for image validation SHA256_DIGEST_PATTERN = r"@sha256:[a-f0-9]{64}$" @@ -190,7 +189,7 @@ def create_isvc_label_selector_str(isvc: InferenceService, resource_type: str, r raise ValueError(f"Unknown deployment mode {deployment_mode}") -def get_pod_images(pod: Pod) -> List[str]: +def get_pod_images(pod: Pod) -> list[str]: """Get all container images from a pod. Args: @@ -205,7 +204,7 @@ def get_pod_images(pod: Pod) -> List[str]: return containers -def validate_image_format(image: str) -> Tuple[bool, str]: +def validate_image_format(image: str) -> tuple[bool, str]: """Validate image format according to requirements. Args: @@ -511,6 +510,6 @@ def collect_pod_information(pod: Pod) -> None: for container in containers: file_path = os.path.join(base_dir_name, f"{pod.name}_{container}.log") with open(file_path, "w") as fd: - fd.write(pod.log(**{"container": container})) - except Exception: + fd.write(pod.log(container=container)) + except Exception: # noqa: BLE001 LOGGER.warning(f"For pod: {pod.name} information gathering failed.") diff --git a/utilities/guardrails.py b/utilities/guardrails.py index 13204f34d..93c086a44 100644 --- a/utilities/guardrails.py +++ b/utilities/guardrails.py @@ -1,11 +1,10 @@ -import requests import http -from typing import Dict +import requests from timeout_sampler import retry -def get_auth_headers(token: str) -> Dict[str, str]: +def get_auth_headers(token: str) -> dict[str, str]: return {"Content-Type": "application/json", "Authorization": f"Bearer {token}"} diff --git a/utilities/inference_utils.py b/utilities/inference_utils.py index fa8f20471..61feddd8f 100644 --- a/utilities/inference_utils.py +++ b/utilities/inference_utils.py @@ -1,13 +1,15 @@ import json import re import shlex +from collections.abc import Generator from contextlib import contextmanager from http import HTTPStatus from json import JSONDecodeError from string import Template -from typing import Any, Optional, Generator +from typing import Any from urllib.parse import urlparse +import portforward from kubernetes.dynamic import DynamicClient from ocp_resources.inference_graph import InferenceGraph from ocp_resources.inference_service import InferenceService @@ -15,29 +17,28 @@ from ocp_resources.service import Service from pyhelper_utils.shell import run_command from simple_logger.logger import get_logger -from timeout_sampler import TimeoutWatch, retry, TimeoutSampler +from timeout_sampler import TimeoutSampler, TimeoutWatch, retry -from utilities.exceptions import InferenceResponseError, InvalidStorageArgumentError -from utilities.infra import ( - get_inference_serving_runtime, - get_model_route, - get_pods_by_isvc_label, - get_services_by_isvc_label, - wait_for_inference_deployment_replicas, - verify_no_failed_pods, - get_pods_by_ig_label, -) from utilities.certificates_utils import get_ca_bundle from utilities.constants import ( + Annotations, + HTTPRequest, KServeDeploymentType, Labels, ModelName, Protocols, - HTTPRequest, - Annotations, Timeout, ) -import portforward +from utilities.exceptions import InferenceResponseError, InvalidStorageArgumentError +from utilities.infra import ( + get_inference_serving_runtime, + get_model_route, + get_pods_by_ig_label, + get_pods_by_isvc_label, + get_services_by_isvc_label, + verify_no_failed_pods, + wait_for_inference_deployment_replicas, +) LOGGER = get_logger(name=__name__) @@ -80,7 +81,7 @@ def get_deployment_type(self) -> str: return KServeDeploymentType.SERVERLESS else: - raise ValueError(f"Unknown inference service type: {self.inference_service.name}") + raise TypeError(f"Unknown inference service type: {self.inference_service.name}") def get_inference_url(self) -> str: """ @@ -125,15 +126,11 @@ def is_service_exposed(self) -> bool: return labels and labels.get(Labels.Kserve.NETWORKING_KSERVE_IO) == Labels.Kserve.EXPOSED if self.deployment_mode == KServeDeploymentType.SERVERLESS: - if labels and labels.get(Labels.Kserve.NETWORKING_KNATIVE_IO) == "cluster-local": - return False - else: - return True + return not bool(labels and labels.get(Labels.Kserve.NETWORKING_KNATIVE_IO) == "cluster-local") - if self.deployment_mode == KServeDeploymentType.MODEL_MESH: - if self.runtime: - _annotations = self.runtime.instance.metadata.annotations - return _annotations and _annotations.get("enable-route") == "true" + if self.deployment_mode == KServeDeploymentType.MODEL_MESH and self.runtime: + _annotations = self.runtime.instance.metadata.annotations + return _annotations and _annotations.get("enable-route") == "true" return False @@ -187,7 +184,7 @@ def get_runtime_config(self) -> dict[str, Any]: ) @property - def inference_response_text_key_name(self) -> Optional[str]: + def inference_response_text_key_name(self) -> str | None: """ Get inference response text key name from runtime config @@ -211,7 +208,7 @@ def inference_response_key_name(self) -> str: def get_inference_body( self, model_name: str, - inference_input: Optional[Any] = None, + inference_input: Any | None = None, use_default_query: bool = False, ) -> str: """ @@ -278,10 +275,10 @@ def get_inference_endpoint_url(self) -> str: def generate_command( self, model_name: str, - inference_input: Optional[Any] = None, + inference_input: Any | None = None, use_default_query: bool = False, insecure: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> str: """ Generate command to run inference @@ -347,10 +344,10 @@ def generate_command( def run_inference_flow( self, model_name: str, - inference_input: Optional[str] = None, + inference_input: str | None = None, use_default_query: bool = False, insecure: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> dict[str, Any]: """ Run inference full flow - generate command and run it @@ -407,10 +404,10 @@ def run_inference_flow( def run_inference( self, model_name: str, - inference_input: Optional[str] = None, + inference_input: str | None = None, use_default_query: bool = False, insecure: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> str: """ Run inference command @@ -532,10 +529,7 @@ def get_target_port(self, svc: Service) -> int: self.deployment_mode == KServeDeploymentType.MODEL_MESH and port.protocol.lower() == svc_protocol.lower() and port.name == self.protocol - ): - return svc_port - - elif ( + ) or ( self.deployment_mode in ( KServeDeploymentType.RAW_DEPLOYMENT, @@ -678,12 +672,9 @@ def create_isvc( if deployment_mode: _annotations = {Annotations.KserveIo.DEPLOYMENT_MODE: deployment_mode} - if enable_auth: - # model mesh auth is set in ServingRuntime - if deployment_mode == KServeDeploymentType.SERVERLESS: - _annotations[Annotations.KserveAuth.SECURITY] = "true" - elif deployment_mode == KServeDeploymentType.RAW_DEPLOYMENT: - _annotations[Annotations.KserveAuth.SECURITY] = "true" + # model mesh auth is set in ServingRuntime + if enable_auth and deployment_mode in {KServeDeploymentType.SERVERLESS, KServeDeploymentType.RAW_DEPLOYMENT}: + _annotations[Annotations.KserveAuth.SECURITY] = "true" # default to True if deployment_mode is Serverless (default behavior of Serverless) if was not provided by the user # model mesh external route is set in ServingRuntime @@ -812,9 +803,9 @@ def _is_model_loaded() -> bool: def _check_storage_arguments( - storage_uri: Optional[str], - storage_key: Optional[str], - storage_path: Optional[str], + storage_uri: str | None, + storage_key: str | None, + storage_path: str | None, ) -> None: """ Check if storage_uri, storage_key and storage_path are valid. diff --git a/utilities/infra.py b/utilities/infra.py index 56612f231..3afedc144 100644 --- a/utilities/infra.py +++ b/utilities/infra.py @@ -1,16 +1,17 @@ import json import os +import platform import re import shlex import stat import tarfile import zipfile +from collections.abc import Callable, Generator from contextlib import contextmanager from functools import cache -from typing import Any, Generator, Optional, Set, Callable +from typing import Any import kubernetes -import platform import pytest import requests import urllib3 @@ -21,12 +22,11 @@ NotFoundError, ResourceNotFoundError, ) - from ocp_resources.authentication_config_openshift_io import Authentication from ocp_resources.catalog_source import CatalogSource from ocp_resources.cluster_service_version import ClusterServiceVersion -from ocp_resources.config_map import ConfigMap from ocp_resources.config_imageregistry_operator_openshift_io import Config +from ocp_resources.config_map import ConfigMap from ocp_resources.console_cli_download import ConsoleCLIDownload from ocp_resources.data_science_cluster import DataScienceCluster from ocp_resources.deployment import Deployment @@ -47,6 +47,8 @@ from ocp_resources.service import Service from ocp_resources.service_account import ServiceAccount from ocp_resources.serving_runtime import ServingRuntime +from ocp_resources.subscription import Subscription +from ocp_resources.utils.constants import DEFAULT_CLUSTER_RETRY_EXCEPTIONS from ocp_utilities.exceptions import NodeNotReadyError, NodeUnschedulableError from ocp_utilities.infra import ( assert_nodes_in_healthy_condition, @@ -56,15 +58,11 @@ from pytest_testconfig import config as py_config from semver import Version from simple_logger.logger import get_logger - -from ocp_resources.subscription import Subscription -from utilities.constants import ApiGroups, Labels, Timeout, RHOAI_OPERATOR_NAMESPACE -from utilities.constants import KServeDeploymentType -from utilities.constants import Annotations -from utilities.exceptions import ClusterLoginError, FailedPodsError, ResourceNotReadyError, UnexpectedResourceCountError from timeout_sampler import TimeoutExpiredError, TimeoutSampler, TimeoutWatch, retry + import utilities.general -from ocp_resources.utils.constants import DEFAULT_CLUSTER_RETRY_EXCEPTIONS +from utilities.constants import RHOAI_OPERATOR_NAMESPACE, Annotations, ApiGroups, KServeDeploymentType, Labels, Timeout +from utilities.exceptions import ClusterLoginError, FailedPodsError, ResourceNotReadyError, UnexpectedResourceCountError from utilities.general import generate_random_name LOGGER = get_logger(name=__name__) @@ -365,7 +363,7 @@ def create_isvc_view_role( client: DynamicClient, isvc: InferenceService, name: str, - resource_names: Optional[list[str]] = None, + resource_names: list[str] | None = None, teardown: bool = True, ) -> Generator[Role, Any, Any]: """ @@ -408,7 +406,7 @@ def create_inference_graph_view_role( client: DynamicClient, namespace: str, name: str, - resource_names: Optional[list[str]] = None, + resource_names: list[str] | None = None, teardown: bool = True, ) -> Generator[Role, Any, Any]: """ @@ -467,10 +465,7 @@ def login_with_user_password(api_address: str, user: str, password: str | None = if err and err.lower().startswith("error"): raise ClusterLoginError(user=user) - if re.search(r"Login successful|Logged into", out): - return True - - return False + return bool(re.search(r"Login successful|Logged into", out)) @cache @@ -481,14 +476,13 @@ def is_self_managed_operator(client: DynamicClient) -> bool: if py_config["distribution"] == "upstream": return True - if CatalogSource( - client=client, - name="addon-managed-odh-catalog", - namespace=py_config["applications_namespace"], - ).exists: - return False - - return True + return not bool( + CatalogSource( + client=client, + name="addon-managed-odh-catalog", + namespace=py_config["applications_namespace"], + ).exists + ) @cache @@ -505,10 +499,9 @@ def is_managed_cluster(client: DynamicClient) -> bool: platform_statuses = infra.instance.status.platformStatus for entry in platform_statuses.values(): - if isinstance(entry, kubernetes.dynamic.resource.ResourceField): - if tags := entry.resourceTags: - LOGGER.info(f"Infrastructure {infra.name} resource tags: {tags}") - return any([tag["value"] == "true" for tag in tags if tag["key"] == "red-hat-managed"]) + if isinstance(entry, kubernetes.dynamic.resource.ResourceField) and (tags := entry.resourceTags): + LOGGER.info(f"Infrastructure {infra.name} resource tags: {tags}") + return any(tag["value"] == "true" for tag in tags if tag["key"] == "red-hat-managed") return False @@ -823,7 +816,7 @@ def verify_no_failed_pods( raise FailedPodsError(pods=failed_pods) -def check_pod_status_in_time(pod: Pod, status: Set[str], duration: int = Timeout.TIMEOUT_2MIN, wait: int = 1) -> None: +def check_pod_status_in_time(pod: Pod, status: set[str], duration: int = Timeout.TIMEOUT_2MIN, wait: int = 1) -> None: """ Checks if a pod status is maintained for a given duration. If not, an AssertionError is raised. @@ -846,9 +839,8 @@ def check_pod_status_in_time(pod: Pod, status: Set[str], duration: int = Timeout try: for sample in sampler: - if sample: - if sample.status.phase not in status: - raise AssertionError(f"Pod status is not the expected: {pod.status}") + if sample and sample.status.phase not in status: + raise AssertionError(f"Pod status is not the expected: {pod.status}") except TimeoutExpiredError: LOGGER.info(f"Pod status is {pod.status} as expected") @@ -1024,7 +1016,7 @@ def get_rhods_subscription() -> Subscription | None: if subscriptions: for subscription in subscriptions: LOGGER.info(f"Checking subscription {subscription.name}") - if subscription.name.startswith(tuple(["rhods-operator", "rhoai-operator"])): + if subscription.name.startswith(("rhods-operator", "rhoai-operator")): return subscription LOGGER.warning("No RHOAI subscription found. Potentially ODH cluster") @@ -1108,7 +1100,7 @@ def verify_cluster_sanity( wait_for_dsc_status_ready(dsc_resource=dsc_resource) except (ResourceNotReadyError, NodeUnschedulableError, NodeNotReadyError) as ex: - error_msg = f"Cluster sanity check failed: {str(ex)}" + error_msg = f"Cluster sanity check failed: {ex!s}" # return_code set to 99 to not collide with https://docs.pytest.org/en/stable/reference/exit-codes.html return_code = 99 @@ -1172,9 +1164,9 @@ def download_oc_console_cli(admin_client: DynamicClient, tmpdir: LocalPath) -> s local_file_name = os.path.join(tmpdir, oc_console_cli_download_link.split("/")[-1]) with requests.get(oc_console_cli_download_link, verify=False, stream=True) as created_request: created_request.raise_for_status() + content_iterator = created_request.iter_content(chunk_size=8192) with open(local_file_name, "wb") as file_downloaded: - for chunk in created_request.iter_content(chunk_size=8192): - file_downloaded.write(chunk) + file_downloaded.writelines(content_iterator) LOGGER.info("Extract the downloaded archive.") extracted_filenames = [] if oc_console_cli_download_link.endswith(".zip"): @@ -1207,8 +1199,8 @@ def check_internal_image_registry_available(admin_client: DynamicClient) -> bool is_available = management_state == "managed" LOGGER.info(f"Image registry management state: {management_state}, available: {is_available}") - return is_available - except (ResourceNotFoundError, Exception) as e: + return is_available # noqa: TRY300 + except (ResourceNotFoundError, Exception) as e: # noqa: BLE001 LOGGER.warning(f"Failed to check image registry config: {e}") return False diff --git a/utilities/jira.py b/utilities/jira.py index ed7212c71..fe8aedab5 100644 --- a/utilities/jira.py +++ b/utilities/jira.py @@ -50,10 +50,11 @@ def is_jira_open(jira_id: str, admin_client: DynamicClient) -> bool: else: # Check if the operator version in ClusterServiceVersion is greater than the jira fix version - jira_fix_versions: list[Version] = [] - for fix_version in jira_fields.fixVersions: - if _fix_version := re.search(r"\d+\.\d+(?:\.\d+)?", fix_version.name): - jira_fix_versions.append(Version(_fix_version.group())) + jira_fix_versions: list[Version] = [ + Version(_fix_version.group()) + for fix_version in jira_fields.fixVersions + if (_fix_version := re.search(r"\d+\.\d+(?:\.\d+)?", fix_version.name)) + ] if not jira_fix_versions: raise ValueError(f"Jira {jira_id}: status is {jira_status} but does not have fix version(s)") @@ -68,7 +69,7 @@ def is_jira_open(jira_id: str, admin_client: DynamicClient) -> bool: raise MissingResourceError("Operator ClusterServiceVersion not found") csv_version = Version(version=operator_version) - if all([csv_version < fix_version for fix_version in jira_fix_versions]): + if all(csv_version < fix_version for fix_version in jira_fix_versions): LOGGER.info( f"Bug is open: Jira {jira_id}: status is {jira_status}, " f"fix versions {jira_fix_versions}, operator version is {operator_version}" diff --git a/utilities/kueue_utils.py b/utilities/kueue_utils.py index 40008c3a8..668b68e36 100644 --- a/utilities/kueue_utils.py +++ b/utilities/kueue_utils.py @@ -1,10 +1,13 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import Optional, Dict, Any, List, Generator +from typing import Any + from kubernetes.dynamic import DynamicClient -from ocp_resources.resource import NamespacedResource, Resource, MissingRequiredArgumentError from ocp_resources.pod import Pod +from ocp_resources.resource import MissingRequiredArgumentError, NamespacedResource, Resource from simple_logger.logger import get_logger from timeout_sampler import retry + from utilities.constants import Timeout LOGGER = get_logger(name=__name__) @@ -61,8 +64,8 @@ class ClusterQueue(Resource): def __init__( self, - namespace_selector: Dict[str, Any] | None = None, - resource_groups: List[Dict[str, Any]] | None = None, + namespace_selector: dict[str, Any] | None = None, + resource_groups: list[dict[str, Any]] | None = None, **kwargs: Any, ): """ @@ -134,8 +137,8 @@ def create_local_queue( def create_cluster_queue( client: DynamicClient, name: str, - resource_groups: List[Dict[str, Any]], - namespace_selector: Optional[Dict[str, Any]] = None, + resource_groups: list[dict[str, Any]], + namespace_selector: dict[str, Any] | None = None, teardown: bool = True, ) -> Generator[ClusterQueue, Any, Any]: """ @@ -166,14 +169,11 @@ def check_gated_pods_and_running_pods( for pod in pods: if pod.instance.status.phase == "Running": running_pods += 1 - elif pod.instance.status.phase == "Pending": - if all( - condition.type == "PodScheduled" - and condition.status == "False" - and condition.reason == "SchedulingGated" - for condition in pod.instance.status.conditions - ): - gated_pods += 1 + elif pod.instance.status.phase == "Pending" and all( + condition.type == "PodScheduled" and condition.status == "False" and condition.reason == "SchedulingGated" + for condition in pod.instance.status.conditions + ): + gated_pods += 1 return running_pods, gated_pods diff --git a/utilities/llmd_constants.py b/utilities/llmd_constants.py index 5d8f878d5..2707e7233 100644 --- a/utilities/llmd_constants.py +++ b/utilities/llmd_constants.py @@ -1,10 +1,14 @@ """LLMD-specific constants that extend the shared constants.""" from utilities.constants import ( - ModelName, ContainerImages as SharedContainerImages, - ModelStorage as SharedModelStorage, +) +from utilities.constants import ( Labels, + ModelName, +) +from utilities.constants import ( + ModelStorage as SharedModelStorage, ) diff --git a/utilities/llmd_utils.py b/utilities/llmd_utils.py index b7a7f71ae..fc73e75ea 100644 --- a/utilities/llmd_utils.py +++ b/utilities/llmd_utils.py @@ -3,16 +3,17 @@ import json import re import shlex +from collections.abc import Generator from contextlib import contextmanager from string import Template -from typing import Any, Dict, Generator, Optional +from typing import Any from kubernetes.dynamic import DynamicClient from ocp_resources.gateway import Gateway from ocp_resources.llm_inference_service import LLMInferenceService from pyhelper_utils.shell import run_command from simple_logger.logger import get_logger -from timeout_sampler import retry, TimeoutWatch +from timeout_sampler import TimeoutWatch, retry from utilities.certificates_utils import get_ca_bundle from utilities.constants import HTTPRequest, Timeout @@ -20,9 +21,9 @@ from utilities.infra import get_services_by_isvc_label from utilities.llmd_constants import ( ContainerImages, + KServeGateway, LLMDGateway, LLMEndpoint, - KServeGateway, ) LOGGER = get_logger(name=__name__) @@ -34,12 +35,12 @@ def create_llmd_gateway( name: str = LLMDGateway.DEFAULT_NAME, namespace: str = LLMDGateway.DEFAULT_NAMESPACE, gateway_class_name: str = LLMDGateway.DEFAULT_CLASS, - listeners: Optional[list[Dict[str, Any]]] = None, - infrastructure: Optional[Dict[str, Any]] = None, + listeners: list[dict[str, Any]] | None = None, + infrastructure: dict[str, Any] | None = None, wait_for_condition: bool = True, timeout: int = 300, teardown: bool = True, -) -> Generator[Gateway, None, None]: +) -> Generator[Gateway]: """ Context manager to create and manage a Gateway resource using ocp_resources. @@ -79,7 +80,7 @@ def create_llmd_gateway( if existing_gateway.exists: LOGGER.info(f"Cleaning up existing Gateway {name} in namespace {namespace}") existing_gateway.delete(wait=True, timeout=Timeout.TIMEOUT_2MIN) - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.debug(f"No existing Gateway to clean up: {e}") gateway_body = { "apiVersion": f"{KServeGateway.API_GROUP}/v1", @@ -113,7 +114,7 @@ def create_llmd_gateway( yield gateway -def _get_llm_config_references(enable_prefill_decode: bool = False, disable_scheduler: bool = False) -> Dict[str, str]: +def _get_llm_config_references(enable_prefill_decode: bool = False, disable_scheduler: bool = False) -> dict[str, str]: """ Get LLMInferenceServiceConfig references based on configuration type. @@ -151,31 +152,31 @@ def create_llmisvc( client: DynamicClient, name: str, namespace: str, - storage_uri: Optional[str] = None, - storage_key: Optional[str] = None, - storage_path: Optional[str] = None, + storage_uri: str | None = None, + storage_key: str | None = None, + storage_path: str | None = None, replicas: int = 1, wait: bool = True, enable_auth: bool = False, - router_config: Optional[Dict[str, Any]] = None, - container_image: Optional[str] = None, - container_resources: Optional[Dict[str, Any]] = None, - container_env: Optional[list[Dict[str, str]]] = None, - liveness_probe: Optional[Dict[str, Any]] = None, - readiness_probe: Optional[Dict[str, Any]] = None, - image_pull_secrets: Optional[list[str]] = None, - service_account: Optional[str] = None, - volumes: Optional[list[Dict[str, Any]]] = None, - volume_mounts: Optional[list[Dict[str, Any]]] = None, - annotations: Optional[Dict[str, str]] = None, - labels: Optional[Dict[str, str]] = None, + router_config: dict[str, Any] | None = None, + container_image: str | None = None, + container_resources: dict[str, Any] | None = None, + container_env: list[dict[str, str]] | None = None, + liveness_probe: dict[str, Any] | None = None, + readiness_probe: dict[str, Any] | None = None, + image_pull_secrets: list[str] | None = None, + service_account: str | None = None, + volumes: list[dict[str, Any]] | None = None, + volume_mounts: list[dict[str, Any]] | None = None, + annotations: dict[str, str] | None = None, + labels: dict[str, str] | None = None, timeout: int = Timeout.TIMEOUT_15MIN, teardown: bool = True, - model_name: Optional[str] = None, - prefill_config: Optional[Dict[str, Any]] = None, + model_name: str | None = None, + prefill_config: dict[str, Any] | None = None, disable_scheduler: bool = False, enable_prefill_decode: bool = False, -) -> Generator[LLMInferenceService, Any, None]: +) -> Generator[LLMInferenceService, Any]: """ Create LLMInferenceService object following the pattern of create_isvc. @@ -255,7 +256,7 @@ def create_llmisvc( {"name": "VLLM_ADDITIONAL_ARGS", "value": "--ssl-ciphers ECDHE+AESGCM:DHE+AESGCM"}, {"name": "VLLM_CPU_KVCACHE_SPACE", "value": "4"}, ]) - template_config: Dict[str, Any] = {"configRef": config_refs["template_ref"]} + template_config: dict[str, Any] = {"configRef": config_refs["template_ref"]} if any([ container_image, @@ -266,7 +267,7 @@ def create_llmisvc( volumes, service_account, ]): - main_container: Dict[str, Any] = {"name": "main"} + main_container: dict[str, Any] = {"name": "main"} if container_image: main_container["image"] = container_image @@ -390,7 +391,7 @@ def get_llm_inference_url(llm_service: LLMInferenceService) -> str: internal_url = f"http://{services[0].name}.{llm_service.namespace}.svc.cluster.local" LOGGER.debug(f"Using service discovery URL for {llm_service.name}: {internal_url}") return internal_url - except Exception as e: + except Exception as e: # noqa: BLE001 LOGGER.warning(f"Could not get service for LLMInferenceService {llm_service.name}: {e}") fallback_url = f"http://{llm_service.name}.{llm_service.namespace}.svc.cluster.local" LOGGER.debug(f"Using fallback URL for {llm_service.name}: {fallback_url}") @@ -399,16 +400,16 @@ def get_llm_inference_url(llm_service: LLMInferenceService) -> str: def verify_inference_response_llmd( llm_service: LLMInferenceService, - inference_config: Dict[str, Any], + inference_config: dict[str, Any], inference_type: str, protocol: str, - model_name: Optional[str] = None, - inference_input: Optional[Any] = None, + model_name: str | None = None, + inference_input: Any | None = None, use_default_query: bool = False, - expected_response_text: Optional[str] = None, + expected_response_text: str | None = None, insecure: bool = False, - token: Optional[str] = None, - authorized_user: Optional[bool] = None, + token: str | None = None, + authorized_user: bool | None = None, ) -> None: """ Verify the LLM inference response following the pattern of verify_inference_response. @@ -472,7 +473,7 @@ class LLMUserInference: def __init__( self, llm_service: LLMInferenceService, - inference_config: Dict[str, Any], + inference_config: dict[str, Any], inference_type: str, protocol: str, ) -> None: @@ -482,7 +483,7 @@ def __init__( self.protocol = protocol self.runtime_config = self.get_runtime_config() - def get_runtime_config(self) -> Dict[str, Any]: + def get_runtime_config(self) -> dict[str, Any]: """Get runtime config from inference config based on inference type and protocol.""" if inference_type_config := self.inference_config.get(self.inference_type): protocol = "http" if self.protocol.lower() in ["http", "https"] else self.protocol @@ -494,7 +495,7 @@ def get_runtime_config(self) -> Dict[str, Any]: raise ValueError(f"Inference type {self.inference_type} not supported in config") @property - def inference_response_text_key_name(self) -> Optional[str]: + def inference_response_text_key_name(self) -> str | None: """Get inference response text key name from runtime config.""" return self.runtime_config.get("response_fields_map", {}).get("response_output") @@ -506,7 +507,7 @@ def inference_response_key_name(self) -> str: def get_inference_body( self, model_name: str, - inference_input: Optional[Any] = None, + inference_input: Any | None = None, use_default_query: bool = False, ) -> str: """Get inference body for LLM request.""" @@ -553,10 +554,10 @@ def get_inference_body( def generate_command( self, model_name: str, - inference_input: Optional[str] = None, + inference_input: str | None = None, use_default_query: bool = False, insecure: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> str: """Generate curl command string for LLM inference.""" base_url = get_llm_inference_url(llm_service=self.llm_service) @@ -587,7 +588,7 @@ def generate_command( cmd += f" --cacert {ca_bundle}" else: cmd += " --insecure" - except Exception: + except Exception: # noqa: BLE001 cmd += " --insecure" cmd += f" --max-time {LLMEndpoint.DEFAULT_TIMEOUT} {endpoint_url}" @@ -597,10 +598,10 @@ def generate_command( def run_inference( self, model_name: str, - inference_input: Optional[str] = None, + inference_input: str | None = None, use_default_query: bool = False, insecure: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> str: """Run inference command and return raw output.""" cmd = self.generate_command( @@ -619,11 +620,11 @@ def run_inference( def run_inference_flow( self, model_name: str, - inference_input: Optional[str] = None, + inference_input: str | None = None, use_default_query: bool = False, insecure: bool = False, - token: Optional[str] = None, - ) -> Dict[str, Any]: + token: str | None = None, + ) -> dict[str, Any]: """Run LLM inference using the same high-level flow as inference_utils.""" out = self.run_inference( model_name=model_name, @@ -635,7 +636,7 @@ def run_inference_flow( return {"output": out} -def _validate_unauthorized_response(res: Dict[str, Any], token: Optional[str], inference: LLMUserInference) -> None: +def _validate_unauthorized_response(res: dict[str, Any], token: str | None, inference: LLMUserInference) -> None: """Validate response for unauthorized users.""" auth_header = "x-ext-auth-reason" @@ -657,11 +658,11 @@ def _validate_unauthorized_response(res: Dict[str, Any], token: Optional[str], i def _validate_authorized_response( - res: Dict[str, Any], + res: dict[str, Any], inference: LLMUserInference, - inference_config: Dict[str, Any], + inference_config: dict[str, Any], inference_type: str, - expected_response_text: Optional[str], + expected_response_text: str | None, use_default_query: bool, model_name: str, ) -> None: diff --git a/utilities/logger.py b/utilities/logger.py index 376be0d30..bbfbe98d4 100644 --- a/utilities/logger.py +++ b/utilities/logger.py @@ -1,8 +1,8 @@ import logging +import multiprocessing import shutil from logging.handlers import QueueHandler, QueueListener, RotatingFileHandler -import multiprocessing -from typing import Optional, Any +from typing import Any from simple_logger.logger import DuplicateFilter, WrapperLogFormatter @@ -14,7 +14,7 @@ class RedactedString(str): Used to redact the representation of a sensitive string. """ - def __new__(cls, *, value: object) -> "RedactedString": + def __new__(cls, *, value: object) -> "RedactedString": # noqa: PYI034 return super().__new__(cls, value) def __repr__(self) -> str: @@ -123,7 +123,7 @@ def setup_logging( return log_listener -def separator(symbol_: str, val: Optional[str] = None) -> str: +def separator(symbol_: str, val: str | None = None) -> str: terminal_width = shutil.get_terminal_size(fallback=(120, 40))[0] if not val: return f"{symbol_ * terminal_width}" diff --git a/utilities/mariadb_utils.py b/utilities/mariadb_utils.py index d2ac892c6..20b3e63b2 100644 --- a/utilities/mariadb_utils.py +++ b/utilities/mariadb_utils.py @@ -1,16 +1,16 @@ from kubernetes.dynamic import DynamicClient -from timeout_sampler import TimeoutSampler - from ocp_resources.deployment import Deployment from ocp_resources.maria_db import MariaDB from ocp_resources.mariadb_operator import MariadbOperator from ocp_resources.pod import Pod +from timeout_sampler import TimeoutSampler + from utilities.constants import Timeout def wait_for_mariadb_pods(client: DynamicClient, mariadb: MariaDB, timeout: int = Timeout.TIMEOUT_5MIN) -> None: def _get_mariadb_pods() -> list[Pod]: - _pods = [ + return [ _pod for _pod in Pod.get( client=client, @@ -18,7 +18,6 @@ def _get_mariadb_pods() -> list[Pod]: label_selector=f"app.kubernetes.io/instance={mariadb.name}", ) ] - return _pods sampler = TimeoutSampler(wait_timeout=timeout, sleep=1, func=lambda: bool(_get_mariadb_pods())) diff --git a/utilities/minio.py b/utilities/minio.py index e8ccb8269..42dee551c 100644 --- a/utilities/minio.py +++ b/utilities/minio.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator +from typing import Any from kubernetes.dynamic import DynamicClient from ocp_resources.secret import Secret @@ -32,7 +33,7 @@ def create_minio_data_connection_secret( aws_access_key=MinIo.Credentials.ACCESS_KEY_VALUE, aws_secret_access_key=MinIo.Credentials.SECRET_KEY_VALUE, # pragma: allowlist secret aws_s3_bucket=aws_s3_bucket, - aws_s3_endpoint=f"{Protocols.HTTP}://{minio_service.instance.spec.clusterIP}:{str(MinIo.Metadata.DEFAULT_PORT)}", # noqa: E501 + aws_s3_endpoint=f"{Protocols.HTTP}://{minio_service.instance.spec.clusterIP}:{MinIo.Metadata.DEFAULT_PORT!s}", aws_s3_region="us-south", ) with Secret( diff --git a/utilities/monitoring.py b/utilities/monitoring.py index f3513ecd7..ee9b11a5b 100644 --- a/utilities/monitoring.py +++ b/utilities/monitoring.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from ocp_resources.prometheus import Prometheus from simple_logger.logger import get_logger diff --git a/utilities/must_gather_collector.py b/utilities/must_gather_collector.py index 17c075841..2ff311280 100644 --- a/utilities/must_gather_collector.py +++ b/utilities/must_gather_collector.py @@ -2,10 +2,11 @@ import shlex import shutil -from pytest_testconfig import config as py_config -from pytest import Item from pyhelper_utils.shell import run_command +from pytest import Item +from pytest_testconfig import config as py_config from simple_logger.logger import get_logger + from utilities.exceptions import InvalidArgumentsError from utilities.infra import get_rhods_operator_installed_csv diff --git a/utilities/operator_utils.py b/utilities/operator_utils.py index 53ffd4879..c3de1b3e7 100644 --- a/utilities/operator_utils.py +++ b/utilities/operator_utils.py @@ -1,12 +1,10 @@ from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError, ResourceNotUniqueError -from simple_logger.logger import get_logger - from ocp_resources.cluster_service_version import ClusterServiceVersion -from utilities.infra import get_product_version from pytest_testconfig import config as py_config +from simple_logger.logger import get_logger -from typing import List, Dict +from utilities.infra import get_product_version LOGGER = get_logger(name=__name__) @@ -29,7 +27,7 @@ def get_cluster_service_version(client: DynamicClient, prefix: str, namespace: s return matching_csvs[0] -def get_csv_related_images(admin_client: DynamicClient, csv_name: str | None = None) -> List[Dict[str, str]]: +def get_csv_related_images(admin_client: DynamicClient, csv_name: str | None = None) -> list[dict[str, str]]: """Get relatedImages from the CSV. Args: diff --git a/utilities/plugins/constant.py b/utilities/plugins/constant.py index 91c046575..323231e79 100644 --- a/utilities/plugins/constant.py +++ b/utilities/plugins/constant.py @@ -8,4 +8,4 @@ class OpenAIEnpoints: class RestHeader: - HEADERS: dict[str, str] = {"Content-Type": "application/json"} + HEADERS: dict[str, str] = {"Content-Type": "application/json"} # noqa: RUF012 diff --git a/utilities/plugins/openai_plugin.py b/utilities/plugins/openai_plugin.py index 7109dca70..f46548dc2 100644 --- a/utilities/plugins/openai_plugin.py +++ b/utilities/plugins/openai_plugin.py @@ -1,14 +1,16 @@ import json +from typing import Any + import requests import urllib3 +from simple_logger.logger import get_logger from tenacity import retry, stop_after_attempt, wait_exponential -from typing import Any, Optional from urllib3.exceptions import InsecureRequestWarning + from utilities.plugins.constant import OpenAIEnpoints, RestHeader -from simple_logger.logger import get_logger urllib3.disable_warnings(category=InsecureRequestWarning) -requests.packages +requests.packages # noqa: B018 LOGGER = get_logger(name=__name__) MAX_RETRIES = 5 @@ -40,7 +42,7 @@ def __init__(self, host: Any, streaming: bool = False, model_name: Any = None) - self.request_func = self.streaming_request_http if streaming else self.request_http @retry(stop=stop_after_attempt(MAX_RETRIES), wait=wait_exponential(min=1, max=6)) - def request_http(self, endpoint: str, query: dict[str, str], extra_param: Optional[dict[str, Any]] = None) -> Any: + def request_http(self, endpoint: str, query: dict[str, str], extra_param: dict[str, Any] | None = None) -> Any: """ Sends a HTTP POST request to the specified endpoint and processes the response. @@ -71,7 +73,7 @@ def request_http(self, endpoint: str, query: dict[str, str], extra_param: Option @retry(stop=stop_after_attempt(MAX_RETRIES), wait=wait_exponential(min=1, max=6)) def streaming_request_http( - self, endpoint: str, query: dict[str, Any], extra_param: Optional[dict[str, Any]] = None + self, endpoint: str, query: dict[str, Any], extra_param: dict[str, Any] | None = None ) -> str: """ Sends a streaming HTTP POST request to the specified endpoint and processes the streamed response. @@ -134,7 +136,7 @@ def get_request_http(host: str, endpoint: str) -> Any: keys_to_remove = ["created", "id"] if data: data = OpenAIClient._remove_keys(data, keys_to_remove) - return data + return data # noqa: TRY300 except (requests.exceptions.RequestException, json.JSONDecodeError): LOGGER.exception("Request error") @@ -180,7 +182,7 @@ def _construct_request_data( self, endpoint: str, query: dict[str, Any], - extra_param: Optional[dict[str, Any]] = None, + extra_param: dict[str, Any] | None = None, streaming: bool = False, ) -> dict[str, Any]: """ diff --git a/utilities/plugins/tgis_grpc/generation_pb2.py b/utilities/plugins/tgis_grpc/generation_pb2.py index aa31b4a78..25e4f094d 100644 --- a/utilities/plugins/tgis_grpc/generation_pb2.py +++ b/utilities/plugins/tgis_grpc/generation_pb2.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: generation.proto diff --git a/utilities/plugins/tgis_grpc/generation_pb2_grpc.py b/utilities/plugins/tgis_grpc/generation_pb2_grpc.py index a4d9abce6..9f00bdb17 100644 --- a/utilities/plugins/tgis_grpc/generation_pb2_grpc.py +++ b/utilities/plugins/tgis_grpc/generation_pb2_grpc.py @@ -2,6 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc + from utilities.plugins.tgis_grpc import generation_pb2 as generation__pb2 GRPC_GENERATED_VERSION = "1.68.1" @@ -25,7 +26,7 @@ ) -class GenerationServiceStub(object): +class GenerationServiceStub: """Missing associated documentation comment in .proto file.""" def __init__(self, channel): # type: ignore @@ -60,7 +61,7 @@ def __init__(self, channel): # type: ignore ) -class GenerationServiceServicer(object): +class GenerationServiceServicer: """Missing associated documentation comment in .proto file.""" def Generate(self, request, context): # type: ignore @@ -117,7 +118,7 @@ def add_GenerationServiceServicer_to_server(servicer, server): # type: ignore # This class is part of an EXPERIMENTAL API. -class GenerationService(object): +class GenerationService: """Missing associated documentation comment in .proto file.""" @staticmethod diff --git a/utilities/plugins/tgis_grpc_plugin.py b/utilities/plugins/tgis_grpc_plugin.py index 18ec7ac97..e4f5e536f 100644 --- a/utilities/plugins/tgis_grpc_plugin.py +++ b/utilities/plugins/tgis_grpc_plugin.py @@ -1,11 +1,12 @@ -import grpc import socket import ssl import sys -from utilities.plugins.tgis_grpc import generation_pb2_grpc -from typing import Any, Optional +from typing import Any + +import grpc from simple_logger.logger import get_logger +from utilities.plugins.tgis_grpc import generation_pb2_grpc LOGGER = get_logger(name=__name__) @@ -28,9 +29,9 @@ def __init__(self, host: str, model_name: str, streaming: bool = False, use_tls: self.request_func = self.make_grpc_request_stream if streaming else self.make_grpc_request def _get_server_certificate(self, port: int) -> str: - if sys.version_info >= (3, 10): + if sys.version_info >= (3, 10): # noqa: UP036 return ssl.get_server_certificate((self.host, port)) - ssl.SSLContext + ssl.SSLContext # noqa: B018 context = ssl.SSLContext() with ( socket.create_connection((self.host, port)) as sock, @@ -39,7 +40,7 @@ def _get_server_certificate(self, port: int) -> str: cert_der = ssock.getpeercert(binary_form=True) return ssl.DER_cert_to_PEM_cert(cert_der) - def _channel_credentials(self) -> Optional[grpc.ChannelCredentials]: + def _channel_credentials(self) -> grpc.ChannelCredentials | None: if self.use_tls: cert = self._get_server_certificate(port=443).encode() return grpc.ssl_channel_credentials(root_certificates=cert) @@ -65,15 +66,15 @@ def make_grpc_request(self, query: dict[str, Any]) -> Any: try: response = stub.Generate(request=request) LOGGER.info(response) - response = response.responses[0] - return { - "input_tokens": response.input_token_count, - "stop_reason": response.stop_reason, - "output_text": response.text, - "output_tokens": response.generated_token_count, + res = response.responses[0] + return { # noqa: TRY300 + "input_tokens": res.input_token_count, + "stop_reason": res.stop_reason, + "output_text": res.text, + "output_tokens": res.generated_token_count, } except grpc.RpcError as err: - self._handle_grpc_error(err) + LOGGER.error("gRPC Error: %s", err.details()) def make_grpc_request_stream(self, query: dict[str, Any]) -> Any: channel = self._create_channel() @@ -103,7 +104,7 @@ def make_grpc_request_stream(self, query: dict[str, Any]) -> Any: "output_tokens": resp.generated_token_count, } except grpc.RpcError as err: - self._handle_grpc_error(err) + LOGGER.error("gRPC Error: %s", err.details()) def get_model_info(self) -> list[str]: # type: ignore channel = self._create_channel() @@ -112,11 +113,6 @@ def get_model_info(self) -> list[str]: # type: ignore request = generation_pb2_grpc.generation__pb2.ModelInfoRequest() # type: ignore LOGGER.info(request) try: - response = stub.ModelInfo(request=request) - return response + return stub.ModelInfo(request=request) except grpc.RpcError as err: - self._handle_grpc_error(err) - - def _handle_grpc_error(self, err: grpc.RpcError) -> None: - """Handle gRPC errors.""" - LOGGER.error("gRPC Error: %s", err.details()) + LOGGER.error("gRPC Error: %s", err.details()) diff --git a/utilities/resources/rate_limit_policy.py b/utilities/resources/rate_limit_policy.py index d1eb5b4e5..98bd8809a 100644 --- a/utilities/resources/rate_limit_policy.py +++ b/utilities/resources/rate_limit_policy.py @@ -2,8 +2,10 @@ from typing import Any -from ocp_resources.resource import NamespacedResource + from ocp_resources.exceptions import MissingRequiredArgumentError +from ocp_resources.resource import NamespacedResource + from utilities.constants import ApiGroups diff --git a/utilities/resources/token_rate_limit_policy.py b/utilities/resources/token_rate_limit_policy.py index 6a9365f5e..db55347f2 100644 --- a/utilities/resources/token_rate_limit_policy.py +++ b/utilities/resources/token_rate_limit_policy.py @@ -2,8 +2,10 @@ from typing import Any -from ocp_resources.resource import NamespacedResource + from ocp_resources.exceptions import MissingRequiredArgumentError +from ocp_resources.resource import NamespacedResource + from utilities.constants import ApiGroups diff --git a/utilities/serving_runtime.py b/utilities/serving_runtime.py index 77cf310c2..0c99873bd 100644 --- a/utilities/serving_runtime.py +++ b/utilities/serving_runtime.py @@ -1,12 +1,14 @@ import copy from typing import Any + from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError from ocp_resources.serving_runtime import ServingRuntime from ocp_resources.template import Template -from utilities.constants import ApiGroups, PortNames, Protocols, vLLM_CONFIG from pytest_testconfig import config as py_config +from utilities.constants import ApiGroups, PortNames, Protocols, vLLM_CONFIG + class ServingRuntimeFromTemplate(ServingRuntime): def __init__( @@ -181,7 +183,7 @@ def update_model_dict(self) -> dict[str, Any]: container["image"] = self.runtime_image # Support single entrypoint for TGIS and OpenAI - if self.support_tgis_open_ai_endpoints: + if self.support_tgis_open_ai_endpoints: # noqa: SIM102 if "vllm" in self.template_name and self.runtime_image is not None and self.deployment_type is not None: is_grpc = "grpc" in self.deployment_type.lower() is_raw = "raw" in self.deployment_type.lower() diff --git a/utilities/user_utils.py b/utilities/user_utils.py index 981f130c6..72df9ab01 100644 --- a/utilities/user_utils.py +++ b/utilities/user_utils.py @@ -1,7 +1,9 @@ +import base64 import logging import shlex import tempfile from dataclasses import dataclass +from pathlib import Path import requests from kubernetes.dynamic import DynamicClient @@ -10,9 +12,7 @@ from timeout_sampler import retry from utilities.exceptions import ExceptionUserLogin -from utilities.infra import login_with_user_password, get_cluster_authentication -import base64 -from pathlib import Path +from utilities.infra import get_cluster_authentication, login_with_user_password LOGGER = logging.getLogger(__name__) SLEEP_TIME = 5 @@ -61,12 +61,10 @@ def create_htpasswd_file(username: str, password: str) -> tuple[Path, str]: """ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: temp_path = Path(temp_file.name).resolve() # Get absolute path - run_command( - command=shlex.split(f"htpasswd -c -b {str(temp_path.absolute())} {username} {password}"), check=True - ) + run_command(command=shlex.split(f"htpasswd -c -b {temp_path.absolute()!s} {username} {password}"), check=True) # Read the htpasswd file content and encode it - temp_file.seek(0) # noqa: FCN001 - TextIOWrapper.seek() doesn't accept keyword arguments + temp_file.seek(0) htpasswd_content = temp_file.read() htpasswd_b64 = base64.b64encode(htpasswd_content.encode()).decode() @@ -111,26 +109,23 @@ def get_oidc_tokens(admin_client: DynamicClient, username: str, password: str) - "scope": "openid", } - try: - LOGGER.info(f"Requesting token for user {username} in byoidc environment") - response = requests.post( - url=url, - headers=headers, - data=data, - allow_redirects=True, - timeout=30, - verify=True, # Set to False if you need to skip SSL verification - ) - response.raise_for_status() - json_response = response.json() - - # Validate that we got an access token - if "id_token" not in json_response or "refresh_token" not in json_response: - LOGGER.error("Warning: No id_token or refresh_token in response") - raise AssertionError(f"No id_token or refresh_token in response: {json_response}") - return json_response["id_token"], json_response["refresh_token"] - except Exception as e: - raise e + LOGGER.info(f"Requesting token for user {username} in byoidc environment") + response = requests.post( + url=url, + headers=headers, + data=data, + allow_redirects=True, + timeout=30, + verify=True, # Set to False if you need to skip SSL verification + ) + response.raise_for_status() + json_response = response.json() + + # Validate that we got an access token + if "id_token" not in json_response or "refresh_token" not in json_response: + LOGGER.error("Warning: No id_token or refresh_token in response") + raise AssertionError(f"No id_token or refresh_token in response: {json_response}") + return json_response["id_token"], json_response["refresh_token"] def get_byoidc_issuer_url(admin_client: DynamicClient) -> str: