diff --git a/.dockerignore b/.dockerignore index 5b34fe7..55a3fdb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,9 @@ node_modules .venv .ijwb -.idea \ No newline at end of file +.idea +temp +csaf_analysis +bazel-* +.git +container_data \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2454fd2..b754c1e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -35,8 +35,12 @@ jobs: bazel test //apollo/tests:test_csaf_processing --test_output=all bazel test //apollo/tests:test_api_keys --test_output=all bazel test //apollo/tests:test_auth --test_output=all + bazel test //apollo/tests:test_api_updateinfo --test_output=all bazel test //apollo/tests:test_validation --test_output=all bazel test //apollo/tests:test_admin_routes_supported_products --test_output=all + bazel test //apollo/tests:test_api_osv --test_output=all + bazel test //apollo/tests:test_database_service --test_output=all + bazel test //apollo/tests:test_rh_matcher_activities --test_output=all - name: Integration Tests run: ./build/scripts/test.bash diff --git a/apollo/db/__init__.py b/apollo/db/__init__.py index 2b98c4c..6f98127 100644 --- a/apollo/db/__init__.py +++ b/apollo/db/__init__.py @@ -201,6 +201,7 @@ class SupportedProductsRhMirror(Model): match_major_version = fields.IntField() match_minor_version = fields.IntField(null=True) match_arch = fields.CharField(max_length=255) + active = fields.BooleanField(default=True) rpm_repomds: fields.ReverseRelation["SupportedProductsRpmRepomd"] rpm_rh_overrides: fields.ReverseRelation["SupportedProductsRpmRhOverride"] @@ -303,7 +304,7 @@ class AdvisoryPackage(Model): module_stream = fields.TextField(null=True) module_version = fields.TextField(null=True) repo_name = fields.TextField() - package_name = fields.TextField() + _package_name = fields.TextField(source_field='package_name') product_name = fields.TextField() supported_products_rh_mirror = fields.ForeignKeyField( "models.SupportedProductsRhMirror", @@ -318,6 +319,30 @@ class Meta: table = "advisory_packages" unique_together = ("advisory_id", "nevra") + def __init__(self, **kwargs): + if 'package_name' in kwargs: + kwargs['_package_name'] = self._clean_package_name( + kwargs.pop('package_name') + ) + super().__init__(**kwargs) + + @property + def package_name(self): + return self._clean_package_name(self._package_name) + + @package_name.setter + def package_name(self, value): + self._package_name = self._clean_package_name(value) + + def _clean_package_name(self, value): + if isinstance(value, str) and value.startswith('module.'): + return value.replace('module.', '') + return value + + async def save(self, *args, **kwargs): + self._package_name = self._clean_package_name(self._package_name) + await super().save(*args, **kwargs) + class AdvisoryCVE(Model): id = fields.BigIntField(pk=True) diff --git a/apollo/migrations/20251104111759_add_mirror_active_field.sql b/apollo/migrations/20251104111759_add_mirror_active_field.sql new file mode 100644 index 0000000..6c12b1b --- /dev/null +++ b/apollo/migrations/20251104111759_add_mirror_active_field.sql @@ -0,0 +1,11 @@ +-- migrate:up +alter table supported_products_rh_mirrors +add column active boolean not null default true; + +create index supported_products_rh_mirrors_active_idx +on supported_products_rh_mirrors(active); + + +-- migrate:down +drop index if exists supported_products_rh_mirrors_active_idx; +alter table supported_products_rh_mirrors drop column active; diff --git a/apollo/publishing_tools/apollo_tree.py b/apollo/publishing_tools/apollo_tree.py index dc4c084..ae8410e 100644 --- a/apollo/publishing_tools/apollo_tree.py +++ b/apollo/publishing_tools/apollo_tree.py @@ -6,6 +6,7 @@ import logging import hashlib import gzip +import re from dataclasses import dataclass import time from urllib.parse import quote @@ -13,6 +14,8 @@ import aiohttp +from apollo.server.routes.api_updateinfo import PRODUCT_SLUG_MAP + logging.basicConfig(level=logging.INFO) logger = logging.getLogger("apollo_tree") @@ -21,6 +24,31 @@ "rpm": "http://linux.duke.edu/metadata/rpm" } +PRODUCT_NAME_TO_SLUG = {v: k for k, v in PRODUCT_SLUG_MAP.items()} +API_BASE_URL = "https://apollo.build.resf.org/api/v3/updateinfo" + +def get_product_slug(product_name: str) -> str: + """ + Convert product name to API slug for v2 endpoint. + Strips version numbers and architecture placeholders before lookup. + + Examples: + "Rocky Linux 9 $arch" -> "Rocky Linux" + "Rocky Linux 10.5" -> "Rocky Linux" + "Rocky Linux SIG Cloud" -> "Rocky Linux SIG Cloud" + """ + clean_name = product_name.replace("$arch", "").strip() + + clean_name = re.sub(r'\s+\d+(\.\d+)?$', '', clean_name).strip() + + slug = PRODUCT_NAME_TO_SLUG.get(clean_name) + if not slug: + raise ValueError( + f"Unknown product: {clean_name}. " + f"Valid products: {', '.join(PRODUCT_NAME_TO_SLUG.keys())}" + ) + return slug + @dataclass class Repository: @@ -142,16 +170,37 @@ async def fetch_updateinfo_from_apollo( repo: dict, product_name: str, api_base: str = None, + major_version: int = None, + minor_version: int = None, ) -> str: - pname_arch = product_name.replace("$arch", repo["arch"]) + """ + Fetch updateinfo.xml from Apollo API. + + Args: + repo: Repository dict with 'name' and 'arch' keys + product_name: Product name + api_base: Optional API base URL override + major_version: Required for api_version=2 + minor_version: Optional for api_version=2 + """ if not api_base: - api_base = "https://apollo.build.resf.org/api/v3/updateinfo" - api_url = f"{api_base}/{quote(pname_arch)}/{quote(repo['name'])}/updateinfo.xml" - api_url += f"?req_arch={repo['arch']}" + api_base = API_BASE_URL + + if major_version: + product_slug = get_product_slug(product_name) + api_url = f"{api_base}/{product_slug}/{major_version}/{quote(repo['name'])}/updateinfo.xml" + api_params = {'arch': repo['arch']} + if minor_version is not None: + api_params['minor_version'] = minor_version + logger.info("Using v2 endpoint: %s with params %s", api_url, api_params) + else: + pname_arch = product_name.replace("$arch", repo["arch"]) + api_url = f"{api_base}/{quote(pname_arch)}/{quote(repo['name'])}/updateinfo.xml" + api_params = {'req_arch': repo['arch']} + logger.info("Using legacy endpoint: %s with params %s", api_url, api_params) - logger.info("Fetching updateinfo from %s", api_url) async with aiohttp.ClientSession() as session: - async with session.get(api_url) as resp: + async with session.get(api_url, params=api_params) as resp: if resp.status != 200 and resp.status != 404: logger.warning( "Failed to fetch updateinfo from %s, skipping", api_url @@ -303,6 +352,9 @@ async def run_apollo_tree( ignore: list[str], ignore_arch: list[str], product_name: str, + major_version: int = None, + minor_version: int = None, + api_base: str = None, ): if manual: raise Exception("Manual mode not implemented yet") @@ -320,6 +372,9 @@ async def run_apollo_tree( updateinfo = await fetch_updateinfo_from_apollo( repo, product_name, + api_base=api_base, + major_version=major_version, + minor_version=minor_version, ) if not updateinfo: logger.warning("No updateinfo found for %s", repo["name"]) @@ -394,13 +449,30 @@ async def run_apollo_tree( "-n", "--product-name", required=True, - help="Product name", + help="Product name (e.g., 'Rocky Linux', 'Rocky Linux 8 $arch')", + ) + parser.add_argument( + "--major-version", + type=int, + help="Major version (required for --api-version 2)", + ) + parser.add_argument( + "--minor-version", + type=int, + help="Minor version filter (optional, only with --api-version 2)", + ) + parser.add_argument( + "--api-base", + help="API base URL (default: https://apollo.build.resf.org/api/v3/updateinfo)", ) p_args = parser.parse_args() if p_args.auto_scan and p_args.manual: parser.error("Cannot use --auto-scan and --manual together") + if p_args.minor_version and not p_args.major_version: + parser.error("--minor-version can only be used with --major-version") + if p_args.manual and not p_args.repos: parser.error("Must specify repos to publish in manual mode") @@ -416,5 +488,8 @@ async def run_apollo_tree( [y for x in p_args.ignore for y in x], [y for x in p_args.ignore_arch for y in x], p_args.product_name, + p_args.major_version, + p_args.minor_version, + p_args.api_base, ) ) diff --git a/apollo/rhcsaf/__init__.py b/apollo/rhcsaf/__init__.py index 762cedf..95175c4 100644 --- a/apollo/rhcsaf/__init__.py +++ b/apollo/rhcsaf/__init__.py @@ -4,24 +4,65 @@ from common.logger import Logger from apollo.rpm_helpers import parse_nevra -# Initialize Info before Logger for this module - logger = Logger() +EUS_CPE_PRODUCTS = frozenset([ + "rhel_eus", # Extended Update Support + "rhel_e4s", # Update Services for SAP Solutions + "rhel_aus", # Advanced Update Support (IBM Power) + "rhel_tus", # Telecommunications Update Service +]) + +EUS_PRODUCT_NAME_KEYWORDS = frozenset([ + "e4s", + "eus", + "aus", + "tus", + "extended update support", + "update services for sap", + "advanced update support", + "telecommunications update service", +]) + +def _is_eus_product(product_name: str, cpe: str) -> bool: + """ + Detects if a product is EUS-related based on product name and CPE. + + Args: + product_name: Full product name (e.g., "Red Hat Enterprise Linux AppStream E4S (v.9.0)") + cpe: CPE string (e.g., "cpe:/a:redhat:rhel_e4s:9.0::appstream") + + Returns: + True if product is EUS/E4S/AUS/TUS, False otherwise + """ + if cpe: + parts = cpe.split(":") + if len(parts) > 3: + cpe_product = parts[3] + if cpe_product in EUS_CPE_PRODUCTS: + return True + + if product_name: + name_lower = product_name.lower() + for keyword in EUS_PRODUCT_NAME_KEYWORDS: + if keyword in name_lower: + return True + + return False + + def extract_rhel_affected_products_for_db(csaf: dict) -> set: """ Extracts all needed info for red_hat_advisory_affected_products table from CSAF product_tree. Expands 'noarch' to all main arches and maps names to user-friendly values. Returns a set of tuples: (variant, name, major_version, minor_version, arch) """ - # Maps architecture short names to user-friendly product names arch_name_map = { "aarch64": "Red Hat Enterprise Linux for ARM 64", "x86_64": "Red Hat Enterprise Linux for x86_64", "s390x": "Red Hat Enterprise Linux for IBM z Systems", "ppc64le": "Red Hat Enterprise Linux for Power, little endian", } - # List of main architectures to expand 'noarch' main_arches = list(arch_name_map.keys()) affected_products = set() product_tree = csaf.get("product_tree", {}) @@ -29,25 +70,20 @@ def extract_rhel_affected_products_for_db(csaf: dict) -> set: logger.warning("No product tree found in CSAF document") return affected_products - # Iterate over all vendor branches in the product tree for vendor_branch in product_tree.get("branches", []): - # Find the product_family branch for RHEL family_branch = None arches = set() for branch in vendor_branch.get("branches", []): if branch.get("category") == "product_family" and branch.get("name") == "Red Hat Enterprise Linux": family_branch = branch - # Collect all architecture branches at the same level as product_family elif branch.get("category") == "architecture": arch = branch.get("name") if arch: arches.add(arch) - # If 'noarch' is present, expand to all main architectures if "noarch" in arches: arches = set(main_arches) if not family_branch: continue - # Find the product_name branch for CPE/version info prod_name = None cpe = None for branch in family_branch.get("branches", []): @@ -59,24 +95,24 @@ def extract_rhel_affected_products_for_db(csaf: dict) -> set: if not prod_name or not cpe: continue - # Parses the CPE string to extract major and minor version numbers + if _is_eus_product(prod_name, cpe): + logger.debug(f"Skipping EUS product: {prod_name}") + continue + # Example CPE: "cpe:/a:redhat:enterprise_linux:9::appstream" - parts = cpe.split(":") # Split the CPE string by colon + parts = cpe.split(":") major = None minor = None if len(parts) > 4: - version = parts[4] # The version is typically the 5th field (index 4) + version = parts[4] if version: if "." in version: - # If the version contains a dot, split into major and minor major, minor = version.split(".", 1) major = int(major) minor = int(minor) else: - # If no dot, only major version is present major = int(version) - # For each architecture, add a tuple with product info to the set for arch in arches: name = arch_name_map.get(arch) if name is None: @@ -84,26 +120,142 @@ def extract_rhel_affected_products_for_db(csaf: dict) -> set: continue if major: affected_products.add(( - family_branch.get("name"), # variant (e.g., "Red Hat Enterprise Linux") - name, # user-friendly architecture name - major, # major version number - minor, # minor version number (may be None) - arch # architecture short name + family_branch.get("name"), + name, + major, + minor, + arch )) logger.debug(f"Number of affected products: {len(affected_products)}") return affected_products + +def _traverse_for_eus(branches, product_eus_map=None): + """ + Recursively traverse CSAF branches to build EUS product map. + + Args: + branches: List of CSAF branch dictionaries to traverse + product_eus_map: Optional dict to accumulate results + + Returns: + Dict mapping product_id to boolean indicating if product is EUS + """ + if product_eus_map is None: + product_eus_map = {} + + for branch in branches: + category = branch.get("category") + + if category == "product_name": + prod = branch.get("product", {}) + product_id = prod.get("product_id") + + if product_id: + product_name = prod.get("name", "") + cpe = prod.get("product_identification_helper", {}).get("cpe", "") + is_eus = _is_eus_product(product_name, cpe) + product_eus_map[product_id] = is_eus + + if "branches" in branch: + _traverse_for_eus(branch["branches"], product_eus_map) + + return product_eus_map + + +def _extract_packages_from_branches(branches, product_eus_map, packages=None): + """ + Recursively traverse CSAF branches to extract package NEVRAs. + + Args: + branches: List of CSAF branch dictionaries to traverse + product_eus_map: Dict mapping product_id to EUS status + packages: Optional set to accumulate results + + Returns: + Set of NEVRA strings + """ + if packages is None: + packages = set() + + for branch in branches: + category = branch.get("category") + + if category == "product_version": + prod = branch.get("product", {}) + product_id = prod.get("product_id") + purl = prod.get("product_identification_helper", {}).get("purl") + + if not product_id: + continue + + if purl and not purl.startswith("pkg:rpm/"): + continue + + # Product IDs for packages can have format: "AppStream-9.0.0.Z.E4S:package-nevra" + # or just "package-nevra" for packages in product_version entries + skip_eus = False + for eus_prod_id, is_eus in product_eus_map.items(): + if is_eus and (":" in product_id and product_id.startswith(eus_prod_id + ":")): + skip_eus = True + break + + if skip_eus: + continue + + # Format: "package-epoch:version-release.arch" or "package-epoch:version-release.arch::module:stream" + packages.add(product_id.split("::")[0]) + + if "branches" in branch: + _extract_packages_from_branches(branch["branches"], product_eus_map, packages) + + return packages + + +def _extract_packages_from_product_tree(csaf: dict) -> set: + """ + Extracts fixed packages from CSAF product_tree using product_id fields. + Handles both regular and modular packages by extracting NEVRAs directly from product_id. + Filters out EUS products. + + Args: + csaf: CSAF document dict + + Returns: + Set of NEVRA strings + """ + product_tree = csaf.get("product_tree", {}) + + if not product_tree: + return set() + + product_eus_map = {} + for vendor_branch in product_tree.get("branches", []): + product_eus_map = _traverse_for_eus(vendor_branch.get("branches", []), product_eus_map) + + packages = set() + for vendor_branch in product_tree.get("branches", []): + packages = _extract_packages_from_branches(vendor_branch.get("branches", []), product_eus_map, packages) + + return packages + + def red_hat_advisory_scraper(csaf: dict): # At the time of writing there are ~254 advisories that do not have any vulnerabilities. if not csaf.get("vulnerabilities"): logger.warning("No vulnerabilities found in CSAF document") return None - # red_hat_advisories table values - red_hat_issued_at = csaf["document"]["tracking"]["initial_release_date"] # "2025-02-24T03:42:46+00:00" - red_hat_updated_at = csaf["document"]["tracking"]["current_release_date"] # "2025-04-17T12:08:56+00:00" - name = csaf["document"]["tracking"]["id"] # "RHSA-2025:1234" - red_hat_synopsis = csaf["document"]["title"] # "Red Hat Bug Fix Advisory: Red Hat Quay v3.13.4 bug fix release" + name = csaf["document"]["tracking"]["id"] + + red_hat_affected_products = extract_rhel_affected_products_for_db(csaf) + if not red_hat_affected_products: + logger.info(f"Skipping advisory {name}: all products are EUS-only") + return None + + red_hat_issued_at = csaf["document"]["tracking"]["initial_release_date"] + red_hat_updated_at = csaf["document"]["tracking"]["current_release_date"] + red_hat_synopsis = csaf["document"]["title"] red_hat_description = None topic = None for item in csaf["document"]["notes"]: @@ -112,59 +264,31 @@ def red_hat_advisory_scraper(csaf: dict): elif item["category"] == "summary": topic = item["text"] kind_lookup = {"RHSA": "Security", "RHBA": "Bug Fix", "RHEA": "Enhancement"} - kind = kind_lookup[name.split("-")[0]] # "RHSA-2025:1234" --> "Security" - severity = csaf["document"]["aggregate_severity"]["text"] # "Important" + kind = kind_lookup[name.split("-")[0]] + severity = csaf["document"]["aggregate_severity"]["text"] - # To maintain consistency with the existing database, we need to replace the + # To maintain consistency with the existing database, replace # "Red Hat [KIND] Advisory:" prefixes with the severity level. red_hat_synopsis = red_hat_synopsis.replace("Red Hat Bug Fix Advisory: ", f"{severity}:") red_hat_synopsis = red_hat_synopsis.replace("Red Hat Security Advisory:", f"{severity}:") red_hat_synopsis = red_hat_synopsis.replace("Red Hat Enhancement Advisory: ", f"{severity}:") - # red_hat_advisory_packages table values - red_hat_fixed_packages = set() + red_hat_fixed_packages = _extract_packages_from_product_tree(csaf) + red_hat_cve_set = set() red_hat_bugzilla_set = set() - product_id_suffix_list = ( - ".aarch64", - ".i386", - ".i686", - ".noarch", - ".ppc", - ".ppc64", - ".ppc64le", - ".s390", - ".s390x", - ".src", - ".x86_64" - ) # TODO: find a better way to filter product IDs. This is a workaround for the fact that - # the product IDs in the CSAF documents also contain artifacts like container images - # and we only are interested in RPMs. + for vulnerability in csaf["vulnerabilities"]: - for product_id in vulnerability["product_status"]["fixed"]: - if product_id.endswith(product_id_suffix_list): - # These IDs are in the format product:package_nevra - # ie- AppStream-9.4.0.Z.EUS:rsync-0:3.2.3-19.el9_4.1.aarch64" - split_on_colon = product_id.split(":") - product = split_on_colon[0] - package_nevra = ":".join(split_on_colon[-2:]) - red_hat_fixed_packages.add(package_nevra) - - # red_hat_advisory_cves table values. Many older advisories do not have CVEs and so we need to handle that. cve_id = vulnerability.get("cve", None) cve_cvss3_scoring_vector = vulnerability.get("scores", [{}])[0].get("cvss_v3", {}).get("vectorString", None) cve_cvss3_base_score = vulnerability.get("scores", [{}])[0].get("cvss_v3", {}).get("baseScore", None) cve_cwe = vulnerability.get("cwe", {}).get("id", None) red_hat_cve_set.add((cve_id, cve_cvss3_scoring_vector, cve_cvss3_base_score, cve_cwe)) - # red_hat_advisory_bugzilla_bugs table values for bug_id in vulnerability.get("ids", []): if bug_id.get("system_name") == "Red Hat Bugzilla ID": red_hat_bugzilla_set.add(bug_id["text"]) - # red_hat_advisory_affected_products table values - red_hat_affected_products = extract_rhel_affected_products_for_db(csaf) - return { "red_hat_issued_at": str(red_hat_issued_at), "red_hat_updated_at": str(red_hat_updated_at), diff --git a/apollo/rhworker/poll_rh_activities.py b/apollo/rhworker/poll_rh_activities.py index e592136..85a4380 100644 --- a/apollo/rhworker/poll_rh_activities.py +++ b/apollo/rhworker/poll_rh_activities.py @@ -651,8 +651,11 @@ async def fetch_csv_with_dates(session, url): releases = await fetch_csv_with_dates(session, base_url + "releases.csv") deletions = await fetch_csv_with_dates(session, base_url + "deletions.csv") - # Merge changes and releases, keeping the most recent timestamp for each advisory - all_advisories = {**changes, **releases} + # Merge changes and releases, prioritizing changes.csv for updated timestamps + # changes.csv contains the most recent modification time for each advisory + # releases.csv contains original publication dates + # We want changes.csv to take precedence to catch updates to existing advisories + all_advisories = {**releases, **changes} # Remove deletions for advisory_id in deletions: all_advisories.pop(advisory_id, None) diff --git a/apollo/rpmworker/rh_matcher_activities.py b/apollo/rpmworker/rh_matcher_activities.py index eb5f95b..190526f 100644 --- a/apollo/rpmworker/rh_matcher_activities.py +++ b/apollo/rpmworker/rh_matcher_activities.py @@ -277,7 +277,7 @@ async def get_supported_products_with_rh_mirrors(filter_major_versions: Optional Filtering now happens at the mirror level within match_rh_repos activity. """ logger = Logger() - rh_mirrors = await SupportedProductsRhMirror.all().prefetch_related( + rh_mirrors = await SupportedProductsRhMirror.filter(active=True).prefetch_related( "rpm_repomds", ) ret = [] @@ -495,14 +495,24 @@ async def clone_advisory( "{http://linux.duke.edu/metadata/common}format" ).find("{http://linux.duke.edu/metadata/rpm}sourcerpm") - # This means we're checking a source RPM + package_name = None if advisory_nvra.endswith(".src.rpm" ) or advisory_nvra.endswith(".src"): source_nvra = repomd.NVRA_RE.search(advisory_nvra) - package_name = source_nvra.group(1) - else: + if source_nvra: + package_name = source_nvra.group(1) + elif source_rpm is not None and source_rpm.text: source_nvra = repomd.NVRA_RE.search(source_rpm.text) - package_name = source_nvra.group(1) + if source_nvra: + package_name = source_nvra.group(1) + + if not package_name: + logger.warning( + "Could not extract package_name for %s in advisory %s, skipping package", + nevra, + advisory.name, + ) + continue checksum_tree = pkg.find( "{http://linux.duke.edu/metadata/common}checksum" @@ -790,6 +800,9 @@ async def match_rh_repos(params) -> None: all_advisories = {} for mirror in supported_product.rh_mirrors: + if not mirror.active: + logger.debug(f"Skipping inactive mirror {mirror.name}") + continue # Apply major version filtering if specified if filter_major_versions is not None and int(mirror.match_major_version) not in filter_major_versions: logger.debug(f"Skipping mirror {mirror.name} with major version {mirror.match_major_version} due to filtering") @@ -836,23 +849,20 @@ async def match_rh_repos(params) -> None: @activity.defn async def block_remaining_rh_advisories(supported_product_id: int) -> None: - supported_product = await SupportedProduct.filter( - id=supported_product_id - ).first().prefetch_related("rh_mirrors") - for mirror in supported_product.rh_mirrors: - mirrors = await SupportedProductsRhMirror.filter( - supported_product_id=supported_product_id + mirrors = await SupportedProductsRhMirror.filter( + supported_product_id=supported_product_id, + active=True + ) + for mirror in mirrors: + advisories = await get_matching_rh_advisories(mirror) + await SupportedProductsRhBlock.bulk_create( + [ + SupportedProductsRhBlock( + **{ + "supported_products_rh_mirror_id": mirror.id, + "red_hat_advsiory_id": advisory.id, + } + ) for advisory in advisories + ], + ignore_conflicts=True ) - for mirror in mirrors: - advisories = await get_matching_rh_advisories(mirror) - await SupportedProductsRhBlock.bulk_create( - [ - SupportedProductsRhBlock( - **{ - "supported_products_rh_mirror_id": mirror.id, - "red_hat_advsiory_id": advisory.id, - } - ) for advisory in advisories - ], - ignore_conflicts=True - ) diff --git a/apollo/schema.sql b/apollo/schema.sql index 686021e..46c385a 100644 --- a/apollo/schema.sql +++ b/apollo/schema.sql @@ -623,7 +623,8 @@ CREATE TABLE public.supported_products_rh_mirrors ( match_variant text NOT NULL, match_major_version numeric NOT NULL, match_minor_version numeric, - match_arch text NOT NULL + match_arch text NOT NULL, + active boolean DEFAULT true NOT NULL ); @@ -1507,6 +1508,13 @@ CREATE INDEX supported_products_rh_mirrors_match_variant_idx ON public.supported CREATE INDEX supported_products_rh_mirrors_supported_product_idx ON public.supported_products_rh_mirrors USING btree (supported_product_id); +-- +-- Name: supported_products_rh_mirrors_active_idx; Type: INDEX; Schema: public; Owner: - +-- + +CREATE INDEX supported_products_rh_mirrors_active_idx ON public.supported_products_rh_mirrors USING btree (active); + + -- -- Name: supported_products_rpm_repomds_arch_idx; Type: INDEX; Schema: public; Owner: - -- diff --git a/apollo/server/routes/admin_supported_products.py b/apollo/server/routes/admin_supported_products.py index 730ef5b..e59e269 100644 --- a/apollo/server/routes/admin_supported_products.py +++ b/apollo/server/routes/admin_supported_products.py @@ -42,7 +42,6 @@ async def get_entity_or_error_response( ) -> Union[Model, Response]: """Get an entity by ID or filters, or return error template response.""" - # Build the query if entity_id is not None: query = model_class.get_or_none(id=entity_id) elif filters: @@ -50,7 +49,6 @@ async def get_entity_or_error_response( else: raise ValueError("Either entity_id or filters must be provided") - # Add prefetch_related if specified if prefetch_related: query = query.prefetch_related(*prefetch_related) @@ -158,7 +156,6 @@ async def admin_supported_products( params=params, ) - # Get statistics for each product for product in products.items: mirrors_count = await SupportedProductsRhMirror.filter(supported_product=product).count() repomds_count = await SupportedProductsRpmRepomd.filter( @@ -195,7 +192,6 @@ async def export_all_configs( production_only: Optional[bool] = Query(None) ): """Export configurations for all supported products as JSON with optional filtering""" - # Build query with filters query = SupportedProductsRhMirror.all() if major_version is not None: query = query.filter(match_major_version=major_version) @@ -207,7 +203,6 @@ async def export_all_configs( "rpm_repomds" ).all() - # Filter repositories by production status if specified config_data = [] for mirror in mirrors: mirror_data = await _get_mirror_config_data(mirror) @@ -251,14 +246,11 @@ async def _import_configuration(import_data: List[Dict[str, Any]], replace_exist mirror_data = config["mirror"] repositories_data = config["repositories"] - # Find or create product product = await SupportedProduct.get_or_none(name=product_data["name"]) if not product: - # For import, we should require products to exist already skipped_count += 1 continue - # Check if mirror already exists existing_mirror = await SupportedProductsRhMirror.get_or_none( supported_product=product, name=mirror_data["name"], @@ -273,24 +265,24 @@ async def _import_configuration(import_data: List[Dict[str, Any]], replace_exist continue if existing_mirror and replace_existing: - # Delete existing repositories await SupportedProductsRpmRepomd.filter(supported_products_rh_mirror=existing_mirror).delete() mirror = existing_mirror + mirror.active = mirror_data.get("active", True) + await mirror.save() updated_count += 1 else: - # Create new mirror mirror = SupportedProductsRhMirror( supported_product=product, name=mirror_data["name"], match_variant=mirror_data["match_variant"], match_major_version=mirror_data["match_major_version"], match_minor_version=mirror_data.get("match_minor_version"), - match_arch=mirror_data["match_arch"] + match_arch=mirror_data["match_arch"], + active=mirror_data.get("active", True) ) await mirror.save() created_count += 1 - # Create repositories for repo_data in repositories_data: repo = SupportedProductsRpmRepomd( supported_products_rh_mirror=mirror, @@ -336,10 +328,8 @@ async def import_configurations( status_code=302 ) - # Validate import data validation_errors = await _validate_import_data(import_data) if validation_errors: - # Limit the number of errors shown to avoid overwhelming the user max_errors = 20 if len(validation_errors) > max_errors: shown_errors = validation_errors[:max_errors] @@ -352,7 +342,6 @@ async def import_configurations( status_code=302 ) - # Import the data try: results = await _import_configuration(import_data, replace_existing) success_message = f"Import completed: {results['created']} created, {results['updated']} updated, {results['skipped']} skipped" @@ -374,13 +363,16 @@ async def admin_supported_product(request: Request, product_id: int): SupportedProduct, f"Supported product with id {product_id}", entity_id=product_id, - prefetch_related=["rh_mirrors", "rh_mirrors__rpm_repomds", "code"] + prefetch_related=["code"] ) if isinstance(product, Response): return product - # Get detailed statistics for each mirror - for mirror in product.rh_mirrors: + mirrors = await SupportedProductsRhMirror.filter( + supported_product=product + ).order_by("-active", "-match_major_version", "name").prefetch_related("rpm_repomds").all() + + for mirror in mirrors: repomds_count = await SupportedProductsRpmRepomd.filter( supported_products_rh_mirror=mirror ).count() @@ -401,6 +393,7 @@ async def admin_supported_product(request: Request, product_id: int): "admin_supported_product.jinja", { "request": request, "product": product, + "mirrors": mirrors, } ) @@ -419,10 +412,7 @@ async def admin_supported_product_delete( } ) - # Check for existing mirrors (which would contain blocks, overrides, and repomds) mirrors_count = await SupportedProductsRhMirror.filter(supported_product=product).count() - - # Check for existing advisory packages and affected products packages_count = await AdvisoryPackage.filter(supported_product=product).count() affected_products_count = await AdvisoryAffectedProduct.filter(supported_product=product).count() @@ -485,13 +475,16 @@ async def admin_supported_product_mirror_new_post( if isinstance(product, Response): return product - # Validation using centralized validation utility + form_data_raw = await request.form() + active_value = "true" if "true" in form_data_raw.getlist("active") else "false" + form_data = { "name": name, "match_variant": match_variant, "match_major_version": match_major_version, "match_minor_version": match_minor_version, "match_arch": match_arch, + "active": active_value, } try: @@ -521,6 +514,7 @@ async def admin_supported_product_mirror_new_post( match_major_version=match_major_version, match_minor_version=match_minor_version, match_arch=validated_arch, + active=(active_value == "true"), ) await mirror.save() @@ -581,7 +575,9 @@ async def admin_supported_product_mirror_post( if isinstance(mirror, Response): return mirror - # Validation using centralized validation utility + form_data = await request.form() + active_value = "true" if "true" in form_data.getlist("active") else "false" + try: validated_name = FieldValidator.validate_name( name, @@ -606,9 +602,9 @@ async def admin_supported_product_mirror_post( mirror.match_major_version = match_major_version mirror.match_minor_version = match_minor_version mirror.match_arch = validated_arch + mirror.active = (active_value == "true") await mirror.save() - # Re-fetch the mirror with all required relations after saving mirror = await SupportedProductsRhMirror.get_or_none( id=mirror_id, supported_product_id=product_id @@ -646,7 +642,6 @@ async def admin_supported_product_mirror_delete( if isinstance(mirror, Response): return mirror - # Check for existing blocks and overrides using shared logic blocks_count, overrides_count = await check_mirror_dependencies(mirror) if blocks_count > 0 or overrides_count > 0: @@ -674,7 +669,6 @@ async def admin_supported_product_mirrors_bulk_delete( """Bulk delete multiple mirrors by calling individual delete logic for each mirror""" base_url = f"/admin/supported-products/{product_id}" - # Parse and validate mirror IDs if not mirror_ids or not mirror_ids.strip(): return create_error_redirect(base_url, "No mirror IDs provided") @@ -682,14 +676,12 @@ async def admin_supported_product_mirrors_bulk_delete( mirror_id_list = [int(id_str.strip()) for id_str in mirror_ids.split(",") if id_str.strip()] if not mirror_id_list: return create_error_redirect(base_url, "No valid mirror IDs provided") - # Validate all IDs are positive for id_val in mirror_id_list: if id_val <= 0: return create_error_redirect(base_url, "All mirror IDs must be positive numbers") except ValueError: return create_error_redirect(base_url, "Invalid mirror IDs: must be comma-separated numbers") - # Get all mirrors to delete mirrors = await SupportedProductsRhMirror.filter( id__in=mirror_id_list, supported_product_id=product_id ).prefetch_related("supported_product") @@ -697,28 +689,23 @@ async def admin_supported_product_mirrors_bulk_delete( if not mirrors: return create_error_redirect(base_url, "No mirrors found with provided IDs") - # Process each mirror individually using existing single delete logic successful_deletes = [] failed_deletes = [] for mirror in mirrors: - # Check for existing blocks and overrides using shared logic blocks_count, overrides_count = await check_mirror_dependencies(mirror) if blocks_count > 0 or overrides_count > 0: - # Mirror has dependencies, cannot delete error_parts = format_dependency_error_parts(blocks_count, overrides_count) error_reason = f"{' and '.join(error_parts)}" failed_deletes.append({"name": mirror.name, "reason": error_reason}) else: - # Mirror can be deleted try: await mirror.delete() successful_deletes.append(mirror.name) except Exception as e: failed_deletes.append({"name": mirror.name, "reason": f"deletion failed: {str(e)}"}) - # Build result message with clear formatting message_parts = [] if successful_deletes: @@ -732,15 +719,12 @@ async def admin_supported_product_mirrors_bulk_delete( message = ". ".join(message_parts) - # If we had any successful deletes, treat as success even if some failed if successful_deletes: return create_success_redirect(base_url, message) else: - # All deletes failed return create_error_redirect(base_url, message) -# Repository (repomd) management routes @router.get("/{product_id}/mirrors/{mirror_id}/repomds/new", response_class=HTMLResponse) async def admin_supported_product_mirror_repomd_new( request: Request, @@ -787,24 +771,16 @@ async def admin_supported_product_mirror_repomd_new_post( if isinstance(mirror, Response): return mirror - # Validation using centralized validation utility - form_data = { - "production": production, - "arch": arch, - "url": url, - "debug_url": debug_url, - "source_url": source_url, - "repo_name": repo_name, - } - - validated_data, validation_errors = FormValidator.validate_repomd_form(form_data) + validated_data, validation_errors, form_data = _validate_repomd_form( + production, arch, url, debug_url, source_url, repo_name + ) if validation_errors: return templates.TemplateResponse( "admin_supported_product_repomd_new.jinja", { "request": request, "mirror": mirror, - "error": validation_errors[0], # Show first error + "error": validation_errors[0], "form_data": form_data } ) @@ -879,24 +855,16 @@ async def admin_supported_product_mirror_repomd_post( if isinstance(repomd, Response): return repomd - # Validation using centralized validation utility - form_data = { - "production": production, - "arch": arch, - "url": url, - "debug_url": debug_url, - "source_url": source_url, - "repo_name": repo_name, - } - - validated_data, validation_errors = FormValidator.validate_repomd_form(form_data) + validated_data, validation_errors, form_data = _validate_repomd_form( + production, arch, url, debug_url, source_url, repo_name + ) if validation_errors: return templates.TemplateResponse( "admin_supported_product_repomd.jinja", { "request": request, "repomd": repomd, - "error": validation_errors[0], # Show first error + "error": validation_errors[0], } ) @@ -938,7 +906,6 @@ async def admin_supported_product_mirror_repomd_delete( if isinstance(repomd, Response): return repomd - # Check for existing advisory packages using this repository packages_count = await AdvisoryPackage.filter( supported_products_rh_mirror=repomd.supported_products_rh_mirror, repo_name=repomd.repo_name @@ -956,7 +923,6 @@ async def admin_supported_product_mirror_repomd_delete( return RedirectResponse(f"/admin/supported-products/{product_id}/mirrors/{mirror_id}", status_code=302) -# Blocks management routes @router.get("/{product_id}/mirrors/{mirror_id}/blocks", response_class=HTMLResponse) async def admin_supported_product_mirror_blocks( request: Request, @@ -978,17 +944,14 @@ async def admin_supported_product_mirror_blocks( if isinstance(mirror, Response): return mirror - # Build query for blocked advisories query = SupportedProductsRhBlock.filter( supported_products_rh_mirror=mirror ).prefetch_related("red_hat_advisory") - # Apply search if provided if search: query = query.filter(red_hat_advisory__name__icontains=search) - # Set page size and get paginated results - params.size = min(params.size or 50, 100) # Default 50, max 100 + params.size = min(params.size or 50, 100) blocks = await paginate( query.order_by("-red_hat_advisory__red_hat_issued_at"), params=params ) @@ -1025,7 +988,6 @@ async def admin_supported_product_mirror_block_new( if isinstance(mirror, Response): return mirror - # Get advisories that are not already blocked existing_blocks = await SupportedProductsRhBlock.filter( supported_products_rh_mirror=mirror ).values_list("red_hat_advisory_id", flat=True) @@ -1034,11 +996,9 @@ async def admin_supported_product_mirror_block_new( if search: query = query.filter(name__icontains=search) - # Set page size and get paginated results - params.size = min(params.size or 50, 100) # Default 50, max 100 + params.size = min(params.size or 50, 100) advisories = await paginate(query.order_by("-red_hat_issued_at"), params=params) - # Calculate total pages for pagination component advisories_pages = ( math.ceil(advisories.total / advisories.size) if advisories.size > 0 else 1 ) @@ -1080,7 +1040,6 @@ async def admin_supported_product_mirror_block_new_post( if isinstance(advisory, Response): return advisory - # Check if block already exists existing_block = await SupportedProductsRhBlock.get_or_none( supported_products_rh_mirror=mirror, red_hat_advisory=advisory ) @@ -1130,7 +1089,6 @@ async def admin_supported_product_mirror_block_delete( ) -# Overrides management routes (similar structure to blocks) @router.get("/{product_id}/mirrors/{mirror_id}/overrides/new", response_class=HTMLResponse) async def admin_supported_product_mirror_override_new( request: Request, @@ -1149,7 +1107,6 @@ async def admin_supported_product_mirror_override_new( if isinstance(mirror, Response): return mirror - # Get advisories that don't already have overrides existing_overrides = await SupportedProductsRpmRhOverride.filter( supported_products_rh_mirror=mirror ).values_list("red_hat_advisory_id", flat=True) @@ -1158,11 +1115,9 @@ async def admin_supported_product_mirror_override_new( if search: query = query.filter(name__icontains=search) - # Set page size and get paginated results - params.size = min(params.size or 50, 100) # Default 50, max 100 + params.size = min(params.size or 50, 100) advisories = await paginate(query.order_by("-red_hat_issued_at"), params=params) - # Calculate total pages for pagination component advisories_pages = ( math.ceil(advisories.total / advisories.size) if advisories.size > 0 else 1 ) @@ -1204,7 +1159,6 @@ async def admin_supported_product_mirror_override_new_post( if isinstance(advisory, Response): return advisory - # Check if override already exists existing_override = await SupportedProductsRpmRhOverride.get_or_none( supported_products_rh_mirror=mirror, red_hat_advisory=advisory @@ -1274,6 +1228,7 @@ async def _get_mirror_config_data(mirror: SupportedProductsRhMirror) -> Dict[str "match_major_version": mirror.match_major_version, "match_minor_version": mirror.match_minor_version, "match_arch": mirror.match_arch, + "active": mirror.active, "created_at": mirror.created_at.isoformat(), "updated_at": mirror.updated_at.isoformat() if mirror.updated_at else None, }, @@ -1293,9 +1248,31 @@ async def _get_mirror_config_data(mirror: SupportedProductsRhMirror) -> Dict[str ] } +def _validate_repomd_form( + production: bool, + arch: str, + url: str, + debug_url: str, + source_url: str, + repo_name: str +) -> Tuple[Dict[str, Any], List[str], Dict[str, Any]]: + """Validate repomd form data and return validated data, errors, and original form data.""" + form_data = { + "production": production, + "arch": arch, + "url": url, + "debug_url": debug_url, + "source_url": source_url, + "repo_name": repo_name, + } + validated_data, validation_errors = FormValidator.validate_repomd_form(form_data) + return validated_data, validation_errors, form_data + def _json_serializer(obj): """Custom JSON serializer for non-standard types""" if isinstance(obj, Decimal): + if obj % 1 == 0: + return int(obj) return float(obj) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") @@ -1348,7 +1325,6 @@ async def export_product_config( media_type="text/plain" ) - # Build query with filters query = SupportedProductsRhMirror.filter(supported_product_id=product_id) if major_version is not None: query = query.filter(match_major_version=major_version) @@ -1360,7 +1336,6 @@ async def export_product_config( "rpm_repomds" ).all() - # Filter repositories by production status if specified config_data = [] for mirror in mirrors: mirror_data = await _get_mirror_config_data(mirror) @@ -1411,7 +1386,6 @@ async def admin_supported_product_mirror_overrides( if isinstance(mirror, Response): return mirror - # Build query for overrides with search query = SupportedProductsRpmRhOverride.filter( supported_products_rh_mirror_id=mirror_id ).prefetch_related("red_hat_advisory") @@ -1419,11 +1393,9 @@ async def admin_supported_product_mirror_overrides( if search: query = query.filter(red_hat_advisory__name__icontains=search) - # Apply ordering and pagination query = query.order_by("-created_at") overrides = await paginate(query, params) - # Calculate total pages for pagination component overrides_pages = ( math.ceil(overrides.total / overrides.size) if overrides.size > 0 else 1 ) diff --git a/apollo/server/routes/admin_workflows.py b/apollo/server/routes/admin_workflows.py index ef319dc..cfb26ec 100644 --- a/apollo/server/routes/admin_workflows.py +++ b/apollo/server/routes/admin_workflows.py @@ -21,7 +21,8 @@ async def admin_workflows(request: Request, user: User = Depends(admin_user_sche """Render admin workflows page for manual workflow triggering""" db_service = DatabaseService() env_info = await db_service.get_environment_info() - + index_state = await db_service.get_last_indexed_at() + return templates.TemplateResponse( "admin_workflows.jinja", { "request": request, @@ -29,6 +30,8 @@ async def admin_workflows(request: Request, user: User = Depends(admin_user_sche "env_name": env_info["environment"], "is_production": env_info["is_production"], "reset_allowed": env_info["reset_allowed"], + "last_indexed_at": index_state.get("last_indexed_at_iso"), + "last_indexed_exists": index_state.get("exists", False), } ) @@ -92,6 +95,39 @@ async def trigger_poll_rhcsaf( return RedirectResponse(url="/admin/workflows", status_code=303) +@router.post("/workflows/update-index-timestamp") +async def update_index_timestamp( + request: Request, + new_timestamp: str = Form(...), + user: User = Depends(admin_user_scheme) +): + """Update the last_indexed_at timestamp in red_hat_index_state""" + try: + # Parse the timestamp + timestamp_dt = datetime.fromisoformat(new_timestamp.replace("Z", "+00:00")) + + db_service = DatabaseService() + result = await db_service.update_last_indexed_at(timestamp_dt, user.email) + + Logger().info(f"Admin user {user.email} updated last_indexed_at to {new_timestamp}") + + # Store success message in session + request.session["workflow_message"] = result["message"] + request.session["workflow_type"] = "success" + + except ValueError as e: + Logger().error(f"Invalid timestamp format: {str(e)}") + request.session["workflow_message"] = f"Invalid timestamp format: {str(e)}" + request.session["workflow_type"] = "error" + + except Exception as e: + Logger().error(f"Error updating last_indexed_at: {str(e)}") + request.session["workflow_message"] = f"Error updating timestamp: {str(e)}" + request.session["workflow_type"] = "error" + + return RedirectResponse(url="/admin/workflows", status_code=303) + + @router.get("/workflows/database/preview-reset") async def preview_database_reset( request: Request, diff --git a/apollo/server/routes/api_osv.py b/apollo/server/routes/api_osv.py index f0022ee..debf89a 100644 --- a/apollo/server/routes/api_osv.py +++ b/apollo/server/routes/api_osv.py @@ -143,7 +143,6 @@ def to_osv_advisory(ui_url: str, advisory: Advisory) -> OSVAdvisory: for pkg in affected_packages: x = pkg[0] nevra = pkg[1] - # Only process "src" packages if nevra.group(5) != "src": continue if x.nevra in processed_nvra: @@ -198,11 +197,9 @@ def to_osv_advisory(ui_url: str, advisory: Advisory) -> OSVAdvisory: if advisory.red_hat_advisory: osv_credits.append(OSVCredit(name="Red Hat")) - # Calculate severity by finding the highest CVSS score highest_cvss_base_score = 0.0 final_score_vector = None for x in advisory.cves: - # Convert cvss3_scoring_vector to a float base_score = x.cvss3_base_score if base_score and base_score != "UNKNOWN": base_score = float(base_score) @@ -255,15 +252,14 @@ async def get_advisories_osv( cve, synopsis, severity, - kind="Security", + kind=None, fetch_related=True, ) - count = fetch_adv[0] advisories = fetch_adv[1] ui_url = await get_setting(UI_URL) - osv_advisories = [to_osv_advisory(ui_url, x) for x in advisories] - page = create_page(osv_advisories, count, params) + osv_advisories = [to_osv_advisory(ui_url, adv) for adv in advisories if adv.cves] + page = create_page(osv_advisories, len(osv_advisories), params) state = await RedHatIndexState.first() page.last_updated_at = ( @@ -282,7 +278,7 @@ async def get_advisories_osv( ) async def get_advisory_osv(advisory_id: str): advisory = ( - await Advisory.filter(name=advisory_id, kind="Security") + await Advisory.filter(name=advisory_id) .prefetch_related( "packages", "cves", @@ -295,7 +291,7 @@ async def get_advisory_osv(advisory_id: str): .get_or_none() ) - if not advisory: + if not advisory or not advisory.cves: raise HTTPException(404) ui_url = await get_setting(UI_URL) diff --git a/apollo/server/routes/api_updateinfo.py b/apollo/server/routes/api_updateinfo.py index dcc9838..1ef34a2 100644 --- a/apollo/server/routes/api_updateinfo.py +++ b/apollo/server/routes/api_updateinfo.py @@ -5,8 +5,10 @@ from fastapi import APIRouter, Response from slugify import slugify -from apollo.db import AdvisoryAffectedProduct +from apollo.db import AdvisoryAffectedProduct, SupportedProduct +from tortoise.exceptions import DoesNotExist from apollo.server.settings import COMPANY_NAME, MANAGING_EDITOR, UI_URL, get_setting +from apollo.server.validation import Architecture from apollo.rpmworker.repomd import NEVRA_RE, NVRA_RE, EPOCH_RE @@ -15,63 +17,119 @@ router = APIRouter(tags=["updateinfo"]) -@router.get("/{product_name}/{repo}/updateinfo.xml") -async def get_updateinfo( - product_name: str, - repo: str, - req_arch: Optional[str] = None, -): - filters = { - "name": product_name, - "advisory__packages__repo_name": repo, - } - if req_arch: - filters["arch"] = req_arch - - affected_products = await AdvisoryAffectedProduct.filter( - **filters - ).prefetch_related( - "advisory", - "advisory__cves", - "advisory__fixes", - "advisory__packages", - "supported_product", - ).all() - if not affected_products: - raise RenderErrorTemplateException("No advisories found", 404) - - ui_url = await get_setting(UI_URL) - managing_editor = await get_setting(MANAGING_EDITOR) - company_name = await get_setting(COMPANY_NAME) - +PRODUCT_SLUG_MAP = { + "rocky-linux": "Rocky Linux", + "rocky-linux-sig-cloud": "Rocky Linux SIG Cloud", +} + + +def resolve_product_slug(slug: str) -> Optional[str]: + """Convert product slug to supported_product.name""" + return PRODUCT_SLUG_MAP.get(slug.lower()) + + +def get_source_package_name(pkg) -> str: + """ + Extract source package name from package for grouping with source RPM. + + Returns a consistent key for grouping binary packages with their source RPM. + For module packages, includes module context for proper identification. + """ + if pkg.module_name: + return f"{pkg.module_name}:{pkg.package_name}:{pkg.module_stream}" + return pkg.package_name + + +def build_source_rpm_mapping(packages: list) -> dict: + """ + Build mapping from source package name to source RPM filename. + + Groups packages by source package name, then finds the source RPM + (arch=="src") within each group. + + Returns: + dict: Mapping of source_package_name -> source_rpm_filename + """ + pkg_name_map = {} + for pkg in packages: + name = get_source_package_name(pkg) + if name not in pkg_name_map: + pkg_name_map[name] = [] + pkg_name_map[name].append(pkg) + + pkg_src_rpm = {} + for name, pkgs in pkg_name_map.items(): + if name in pkg_src_rpm: + continue + + for pkg in pkgs: + nvra_no_epoch = EPOCH_RE.sub("", pkg.nevra) + nvra = NVRA_RE.search(nvra_no_epoch) + if nvra: + nvr_name = nvra.group(1) + nvr_arch = nvra.group(4) + + if pkg.package_name == nvr_name and nvr_arch == "src": + src_rpm = nvra_no_epoch + if not src_rpm.endswith(".rpm"): + src_rpm += ".rpm" + pkg_src_rpm[name] = src_rpm + break + + return pkg_src_rpm + + +def generate_updateinfo_xml( + affected_products: list, + repo_name: str, + product_arch: str, + ui_url: str, + managing_editor: str, + company_name: str, + supported_product_id: int = None, + product_name_for_packages: str = None, +) -> str: + """ + Generate updateinfo.xml from affected products. + + Args: + affected_products: List of AdvisoryAffectedProduct records with prefetched + advisory, cves, fixes, packages, supported_product + repo_name: Repository name for package filtering + product_arch: Architecture for package filtering + ui_url: Base URL for UI references + managing_editor: Editor email for XML header + company_name: Company name for copyright + supported_product_id: Optional supported_product_id for FK-based filtering (v2) + product_name_for_packages: Optional product_name for legacy filtering (v1) + + Returns: + XML string in updateinfo.xml format + + Note: Either supported_product_id (v2) or product_name_for_packages (v1) must be provided. + """ advisories = {} for affected_product in affected_products: advisory = affected_product.advisory if advisory.name not in advisories: advisories[advisory.name] = { - "advisory": - advisory, - "arch": - affected_product.arch, - "major_version": - affected_product.major_version, - "minor_version": - affected_product.minor_version, - "supported_product_name": - affected_product.supported_product.name, + "advisory": advisory, + "arch": affected_product.arch, + "major_version": affected_product.major_version, + "minor_version": affected_product.minor_version, + "supported_product_name": affected_product.supported_product.name, } tree = ET.Element("updates") for _, adv in advisories.items(): advisory = adv["advisory"] - product_arch = adv["arch"] + adv_arch = adv["arch"] major_version = adv["major_version"] minor_version = adv["minor_version"] supported_product_name = adv["supported_product_name"] update = ET.SubElement(tree, "update") - # Set update attributes update.set("from", managing_editor) update.set("status", "final") @@ -84,47 +142,31 @@ async def get_updateinfo( update.set("version", "2") - # Add id ET.SubElement(update, "id").text = advisory.name - - # Add title ET.SubElement(update, "title").text = advisory.synopsis - # Add time time_format = "%Y-%m-%d %H:%M:%S" issued = ET.SubElement(update, "issued") issued.set("date", advisory.published_at.strftime(time_format)) updated = ET.SubElement(update, "updated") updated.set("date", advisory.updated_at.strftime(time_format)) - # Add rights now = datetime.datetime.utcnow() ET.SubElement( update, "rights" ).text = f"Copyright {now.year} {company_name}" - # Add release name release_name = f"{supported_product_name} {major_version}" if minor_version: release_name += f".{minor_version}" ET.SubElement(update, "release").text = release_name - # Add pushcount ET.SubElement(update, "pushcount").text = "1" - - # Add severity ET.SubElement(update, "severity").text = advisory.severity - - # Add summary ET.SubElement(update, "summary").text = advisory.topic - - # Add description ET.SubElement(update, "description").text = advisory.description - - # Add solution ET.SubElement(update, "solution").text = "" - # Add references references = ET.SubElement(update, "references") for cve in advisory.cves: reference = ET.SubElement(references, "reference") @@ -143,15 +185,13 @@ async def get_updateinfo( reference.set("type", "bugzilla") reference.set("title", fix.description) - # Add UI self reference reference = ET.SubElement(references, "reference") reference.set("href", f"{ui_url}/{advisory.name}") reference.set("id", advisory.name) reference.set("type", "self") reference.set("title", advisory.name) - # Add packages - packages = ET.SubElement(update, "pkglist") + packages_element = ET.SubElement(update, "pkglist") suffixes_to_skip = [ "-debuginfo", @@ -160,47 +200,26 @@ async def get_updateinfo( "-debugsource-common", ] - pkg_name_map = {} - for pkg in advisory.packages: - name = pkg.package_name - if pkg.module_name: - name = f"{pkg.module_name}:{pkg.package_name}:{pkg.module_stream}" - if name not in pkg_name_map: - pkg_name_map[name] = [] - - pkg_name_map[name].append(pkg) - - pkg_src_rpm = {} - for top_pkg in advisory.packages: - name = top_pkg.package_name - if top_pkg.module_name: - name = f"{top_pkg.module_name}:{top_pkg.package_name}:{top_pkg.module_stream}" - if name not in pkg_src_rpm: - for pkg in pkg_name_map[name]: - nvra_no_epoch = EPOCH_RE.sub("", pkg.nevra) - nvra = NVRA_RE.search(nvra_no_epoch) - if nvra: - nvr_name = nvra.group(1) - nvr_arch = nvra.group(4) - if pkg.package_name == nvr_name and nvr_arch == "src": - src_rpm = nvra_no_epoch - if not src_rpm.endswith(".rpm"): - src_rpm += ".rpm" - pkg_src_rpm[name] = src_rpm - - # Collection list, may be more than one if module RPMs are involved + if supported_product_id is not None: + # v2: Filter by FK (normalized relational data) + filtered_packages = [ + pkg for pkg in advisory.packages + if pkg.supported_product_id == supported_product_id and pkg.repo_name == repo_name + ] + else: + # v1: Filter by product_name (legacy denormalized field) + filtered_packages = [ + pkg for pkg in advisory.packages + if pkg.product_name == product_name_for_packages and pkg.repo_name == repo_name + ] + + pkg_src_rpm = build_source_rpm_mapping(filtered_packages) + collections = {} no_default_collection = False - default_collection_short = slugify(f"{product_name}-{repo}-rpms") - - # Check if this is an actual module advisory, if so we need to split the - # collections, and module RPMs need to go into their own collection based on - # module name, while non-module RPMs go into the main collection (if any) - for pkg in advisory.packages: - if pkg.product_name != product_name: - continue - if pkg.repo_name != repo: - continue + default_collection_short = slugify(f"{product_name_for_packages}-{repo_name}-rpms") + + for pkg in filtered_packages: if pkg.module_name: collection_short = f"{default_collection_short}__{pkg.module_name}" if collection_short not in collections: @@ -228,11 +247,9 @@ async def get_updateinfo( collections_added = 0 for collection_short, info in collections.items(): - # Create collection collection = ET.Element("collection") collection.set("short", collection_short) - # Set short to name as well ET.SubElement(collection, "name").text = collection_short if "module_name" in info: @@ -241,7 +258,7 @@ async def get_updateinfo( module_element.set("stream", info["module_stream"]) module_element.set("version", info["module_version"]) module_element.set("context", info["module_context"]) - module_element.set("arch", product_arch) + module_element.set("arch", adv_arch) added_pkg_count = 0 for pkg in info["packages"]: @@ -266,9 +283,7 @@ async def get_updateinfo( else: continue - p_name = pkg.package_name - if pkg.module_name: - p_name = f"{pkg.module_name}:{pkg.package_name}:{pkg.module_stream}" + p_name = get_source_package_name(pkg) if p_name not in pkg_src_rpm: continue @@ -294,11 +309,9 @@ async def get_updateinfo( package.set("release", release) package.set("src", pkg_src_rpm[p_name]) - # Add filename element ET.SubElement(package, "filename").text = EPOCH_RE.sub("", pkg.nevra) - # Add checksum ET.SubElement( package, "sum", type=pkg.checksum_type ).text = pkg.checksum @@ -306,7 +319,7 @@ async def get_updateinfo( added_pkg_count += 1 if added_pkg_count > 0: - packages.append(collection) + packages_element.append(collection) collections_added += 1 if collections_added == 0: @@ -320,4 +333,147 @@ async def get_updateinfo( short_empty_elements=True, ) + return xml_str + + +@router.get("/{product_name}/{repo}/updateinfo.xml") +async def get_updateinfo( + product_name: str, + repo: str, + req_arch: Optional[str] = None, +): + filters = { + "name": product_name, + "advisory__packages__repo_name": repo, + } + if req_arch: + filters["arch"] = req_arch + + affected_products = await AdvisoryAffectedProduct.filter( + **filters + ).prefetch_related( + "advisory", + "advisory__cves", + "advisory__fixes", + "advisory__packages", + "supported_product", + ).all() + if not affected_products: + raise RenderErrorTemplateException("No advisories found", 404) + + ui_url = await get_setting(UI_URL) + managing_editor = await get_setting(MANAGING_EDITOR) + company_name = await get_setting(COMPANY_NAME) + + product_arch = affected_products[0].arch + + xml_str = generate_updateinfo_xml( + affected_products=affected_products, + repo_name=repo, + product_arch=product_arch, + ui_url=ui_url, + managing_editor=managing_editor, + company_name=company_name, + product_name_for_packages=product_name, + ) + + return Response(content=xml_str, media_type="application/xml") + + +@router.get("/{product}/{major_version}/{repo}/updateinfo.xml") +async def get_updateinfo_v2( + product: str, + major_version: int, + repo: str, + arch: str, + minor_version: Optional[int] = None, +): + """ + Get updateinfo.xml for a product major version and repository. + + This endpoint aggregates all advisories for the specified major version, + including all minor versions, unless minor_version is specified. + + Architecture filtering is REQUIRED because: + - Each advisory contains packages for multiple architectures + - Repository structure is architecture-specific + - DNF/YUM expects arch-specific updateinfo.xml files + + Args: + product: Product slug (e.g., 'rocky-linux', 'rocky-linux-sig-cloud') + major_version: Major version number (e.g., 8, 9, 10) + repo: Repository name (e.g., 'BaseOS', 'AppStream') + arch: Architecture (REQUIRED: 'x86_64', 'aarch64', 'ppc64le', 's390x') + minor_version: Optional minor version filter (e.g., 6 for 8.6) + + Returns: + updateinfo.xml file + + Raises: + 400: Invalid architecture or missing required parameter + 404: No advisories found or invalid product + """ + product_name = resolve_product_slug(product) + if not product_name: + raise RenderErrorTemplateException( + f"Unknown product: {product}. Valid: {', '.join(PRODUCT_SLUG_MAP.keys())}", + 404 + ) + + try: + supported_product = await SupportedProduct.get(name=product_name) + except DoesNotExist: + raise RenderErrorTemplateException(f"Product not found: {product_name}", 404) + + # Validate architecture using centralized validation + try: + Architecture(arch) + except ValueError: + valid_arches = [a.value for a in Architecture] + raise RenderErrorTemplateException( + f"Invalid architecture: {arch}. Must be one of {', '.join(valid_arches)}", + 400 + ) + + filters = { + "supported_product_id": supported_product.id, + "major_version": major_version, + "arch": arch, + "advisory__packages__repo_name": repo, + "advisory__packages__supported_product_id": supported_product.id, + } + + if minor_version is not None: + filters["minor_version"] = minor_version + + affected_products = await AdvisoryAffectedProduct.filter( + **filters + ).prefetch_related( + "advisory", + "advisory__cves", + "advisory__fixes", + "advisory__packages", + "supported_product", + ).all() + + if not affected_products: + raise RenderErrorTemplateException( + f"No advisories found for {product_name} {major_version} {repo} {arch}", + 404 + ) + + ui_url = await get_setting(UI_URL) + managing_editor = await get_setting(MANAGING_EDITOR) + company_name = await get_setting(COMPANY_NAME) + + xml_str = generate_updateinfo_xml( + affected_products=affected_products, + repo_name=repo, + product_arch=arch, + ui_url=ui_url, + managing_editor=managing_editor, + company_name=company_name, + supported_product_id=supported_product.id, + ) + return Response(content=xml_str, media_type="application/xml") diff --git a/apollo/server/services/database_service.py b/apollo/server/services/database_service.py index 78d6fb0..0a66800 100644 --- a/apollo/server/services/database_service.py +++ b/apollo/server/services/database_service.py @@ -123,4 +123,67 @@ async def get_environment_info(self) -> Dict[str, str]: "environment": env_name, "is_production": self.is_production_environment(), "reset_allowed": not self.is_production_environment() - } \ No newline at end of file + } + + async def get_last_indexed_at(self) -> Dict[str, Any]: + """ + Get the current last_indexed_at timestamp from red_hat_index_state + + Returns: + Dictionary with timestamp information + """ + index_state = await RedHatIndexState.first() + + if not index_state or not index_state.last_indexed_at: + return { + "last_indexed_at": None, + "last_indexed_at_iso": None, + "exists": False + } + + return { + "last_indexed_at": index_state.last_indexed_at, + "last_indexed_at_iso": index_state.last_indexed_at.isoformat(), + "exists": True + } + + async def update_last_indexed_at(self, new_timestamp: datetime, user_email: str) -> Dict[str, Any]: + """ + Update the last_indexed_at timestamp in red_hat_index_state + + Args: + new_timestamp: New timestamp to set + user_email: Email of user making the change (for logging) + + Returns: + Dictionary with operation results + + Raises: + ValueError: If timestamp is invalid + """ + logger = Logger() + + try: + # Get or create index state + index_state = await RedHatIndexState.first() + + old_timestamp = None + if index_state: + old_timestamp = index_state.last_indexed_at + index_state.last_indexed_at = new_timestamp + await index_state.save() + logger.info(f"Updated last_indexed_at by {user_email}: {old_timestamp} -> {new_timestamp}") + else: + await RedHatIndexState.create(last_indexed_at=new_timestamp) + logger.info(f"Created last_indexed_at by {user_email}: {new_timestamp}") + + return { + "success": True, + "old_timestamp": old_timestamp.isoformat() if old_timestamp else None, + "new_timestamp": new_timestamp.isoformat(), + "message": f"Successfully updated last_indexed_at to {new_timestamp.isoformat()}" + } + + except Exception as e: + logger.error(f"Failed to update last_indexed_at: {e}") + raise RuntimeError(f"Failed to update timestamp: {e}") from e \ No newline at end of file diff --git a/apollo/server/services/workflow_service.py b/apollo/server/services/workflow_service.py index 18511de..56d5103 100644 --- a/apollo/server/services/workflow_service.py +++ b/apollo/server/services/workflow_service.py @@ -206,9 +206,9 @@ async def _validate_major_versions(self, major_versions: List[int]) -> None: # Import here to avoid circular imports from apollo.db import SupportedProductsRhMirror - + # Get available major versions from RH mirrors - rh_mirrors = await SupportedProductsRhMirror.all() + rh_mirrors = await SupportedProductsRhMirror.filter(active=True) available_versions = {int(mirror.match_major_version) for mirror in rh_mirrors} # Check if all requested major versions are available diff --git a/apollo/server/templates/admin_supported_product.jinja b/apollo/server/templates/admin_supported_product.jinja index a04e2c8..4c3092c 100644 --- a/apollo/server/templates/admin_supported_product.jinja +++ b/apollo/server/templates/admin_supported_product.jinja @@ -41,13 +41,13 @@
-

Red Hat Mirrors ({{ product.rh_mirrors|length }})

+

Red Hat Mirrors ({{ mirrors|length }})

Add New Mirror
- {% if product.rh_mirrors %} + {% if mirrors %}
@@ -60,7 +60,7 @@
{% endif %} - {% if product.rh_mirrors %} + {% if mirrors %}
@@ -80,6 +80,9 @@ + @@ -92,7 +95,7 @@ - {% for mirror in product.rh_mirrors %} + {% for mirror in mirrors %} +
Architecture + Status + Repositories
{{ mirror.match_arch }} + {% if mirror.active %} + Active + {% else %} + Inactive + {% endif %} + {{ mirror.stats.repomds }} repos diff --git a/apollo/server/templates/admin_supported_product_mirror.jinja b/apollo/server/templates/admin_supported_product_mirror.jinja index 58298ed..a16e35c 100644 --- a/apollo/server/templates/admin_supported_product_mirror.jinja +++ b/apollo/server/templates/admin_supported_product_mirror.jinja @@ -119,6 +119,22 @@ +
+
+ Status +
+ + +
+
+ Only active mirrors are used for advisory processing. Deactivate to exclude from workflows without deleting. +
+
+
+ diff --git a/apollo/server/templates/admin_supported_product_mirror_new.jinja b/apollo/server/templates/admin_supported_product_mirror_new.jinja index 003f826..1b9e039 100644 --- a/apollo/server/templates/admin_supported_product_mirror_new.jinja +++ b/apollo/server/templates/admin_supported_product_mirror_new.jinja @@ -108,6 +108,22 @@ +
+
+ Status +
+ + +
+
+ Only active mirrors are used for advisory processing. Deactivate to exclude from workflows without deleting. +
+
+
+
@@ -80,6 +80,44 @@
+ +
+
+
+

Update CSAF Index Timestamp

+

Set the last_indexed_at timestamp to control which advisories are processed by the Poll RHCSAF workflow.

+ + {% if last_indexed_exists %} +

+ Current last_indexed_at: {{ last_indexed_at }} +

+ {% else %} +

+ No timestamp set - workflow will process all advisories +

+ {% endif %} + + +
+ + + +
+ The workflow will process advisories with timestamps after this date.
+ Time will be set to 00:00:00 UTC. +
+
+ + + +
+
+
+ {% if reset_allowed %}
@@ -114,7 +152,7 @@
- -{% if reset_allowed %} -{% endif %} {% endblock %} \ No newline at end of file diff --git a/apollo/server/validation.py b/apollo/server/validation.py index ec078b2..224b843 100644 --- a/apollo/server/validation.py +++ b/apollo/server/validation.py @@ -57,16 +57,12 @@ def __init__( class ValidationPatterns: """Regex patterns for common validations.""" - # URL validation - must start with http:// or https:// URL_PATTERN = re.compile(r"^https?://.+") - # Name patterns - alphanumeric with common special characters and spaces - NAME_PATTERN = re.compile(r"^[a-zA-Z0-9._\s-]+$") + NAME_PATTERN = re.compile(r"^[a-zA-Z0-9._\s()\-]+$") - # Architecture validation ARCH_PATTERN = re.compile(r"^(x86_64|aarch64|i386|i686|ppc64|ppc64le|s390x|riscv64|noarch)$") - # Repository name - more permissive for repo naming conventions REPO_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9._-]+$") @@ -107,7 +103,7 @@ def validate_name(name: str, min_length: int = 3, field_name: str = "name") -> s if not ValidationPatterns.NAME_PATTERN.match(trimmed_name): raise ValidationError( - f"{field_name.title()} can only contain letters, numbers, spaces, dots, hyphens, and underscores", + f"{field_name.title()} can only contain letters, numbers, spaces, dots, hyphens, underscores, and parentheses", ValidationErrorType.INVALID_FORMAT, field_name, ) @@ -176,7 +172,6 @@ def validate_architecture(arch: str, field_name: str = "architecture") -> str: trimmed_arch = arch.strip() - # Check if it's a valid architecture enum value try: Architecture(trimmed_arch) except ValueError: @@ -277,27 +272,23 @@ def validate_config_structure(config: Any, config_index: int) -> List[str]: errors.append(f"Config {config_index}: Must be a dictionary") return errors - # Validate required top-level keys required_keys = ["product", "mirror", "repositories"] for key in required_keys: if key not in config: errors.append(f"Config {config_index}: Missing required key '{key}'") - # Validate product data structure if "product" in config: product_errors = ConfigValidator.validate_product_config( config["product"], config_index ) errors.extend(product_errors) - # Validate mirror data structure if "mirror" in config: mirror_errors = ConfigValidator.validate_mirror_config( config["mirror"], config_index ) errors.extend(mirror_errors) - # Validate repositories data structure if "repositories" in config: repo_errors = ConfigValidator.validate_repositories_config( config["repositories"], config_index @@ -324,7 +315,6 @@ def validate_product_config(product: Any, config_index: int) -> List[str]: errors.append(f"Config {config_index}: Product must be a dictionary") return errors - # Validate required product fields required_fields = ["name", "variant", "vendor"] for field in required_fields: if field not in product or not product[field]: @@ -332,7 +322,6 @@ def validate_product_config(product: Any, config_index: int) -> List[str]: f"Config {config_index}: Product missing required field '{field}'" ) - # Validate product name format if present if product.get("name"): try: FieldValidator.validate_name( @@ -361,7 +350,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]: errors.append(f"Config {config_index}: Mirror must be a dictionary") return errors - # Validate required mirror fields required_fields = ["name", "match_variant", "match_major_version", "match_arch"] for field in required_fields: if field not in mirror or mirror[field] is None: @@ -369,7 +357,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]: f"Config {config_index}: Mirror missing required field '{field}'" ) - # Validate mirror name format if present if mirror.get("name"): try: FieldValidator.validate_name( @@ -378,7 +365,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]: except ValidationError as e: errors.append(f"Config {config_index}: Mirror name '{mirror['name']}' - {e.message}") - # Validate architecture if present if mirror.get("match_arch"): try: FieldValidator.validate_architecture( @@ -387,7 +373,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]: except ValidationError as e: errors.append(f"Config {config_index}: Mirror architecture '{mirror['match_arch']}' - {e.message}") - # Validate major version is numeric if present if mirror.get("match_major_version") is not None: if ( not isinstance(mirror["match_major_version"], int) @@ -397,7 +382,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]: f"Config {config_index}: Mirror match_major_version must be a non-negative integer" ) - # Validate minor version is numeric if present if mirror.get("match_minor_version") is not None: if ( not isinstance(mirror["match_minor_version"], int) @@ -407,6 +391,13 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]: f"Config {config_index}: Mirror match_minor_version must be a non-negative integer" ) + # Validate active field if present + if "active" in mirror and mirror["active"] is not None: + if not isinstance(mirror["active"], bool): + errors.append( + f"Config {config_index}: Mirror active must be a boolean value" + ) + return errors @staticmethod @@ -458,7 +449,6 @@ def validate_repository_config( ) return errors - # Validate required repository fields required_fields = ["repo_name", "arch", "production", "url"] for field in required_fields: if field not in repo or repo[field] is None: @@ -466,7 +456,6 @@ def validate_repository_config( f"Config {config_index}, Repo {repo_index}: Missing required field '{field}'" ) - # Validate repository name format if present if repo.get("repo_name"): try: FieldValidator.validate_repo_name( @@ -475,7 +464,6 @@ def validate_repository_config( except ValidationError as e: errors.append(f"Config {config_index}, Repo {repo_index}: Repository name '{repo['repo_name']}' - {e.message}") - # Validate architecture if present if repo.get("arch"): try: FieldValidator.validate_architecture( @@ -484,10 +472,9 @@ def validate_repository_config( except ValidationError as e: errors.append(f"Config {config_index}, Repo {repo_index}: Architecture '{repo['arch']}' - {e.message}") - # Validate URLs if present for url_field in ["url", "debug_url", "source_url"]: url_value = repo.get(url_field) - if url_value: # Only validate if not empty + if url_value: try: FieldValidator.validate_url( url_value, @@ -499,7 +486,6 @@ def validate_repository_config( f"Config {config_index}, Repo {repo_index}: {url_field.replace('_', ' ').title()} '{url_value}' - {e.message}" ) - # Validate production is boolean if present if "production" in repo and repo["production"] is not None: if not isinstance(repo["production"], bool): errors.append( @@ -528,7 +514,6 @@ def validate_mirror_form( validated_data = {} errors = [] - # Validate name try: validated_data["name"] = FieldValidator.validate_name( form_data.get("name", ""), min_length=3, field_name="mirror name" @@ -536,7 +521,6 @@ def validate_mirror_form( except ValidationError as e: errors.append(e.message) - # Validate architecture try: validated_data["match_arch"] = FieldValidator.validate_architecture( form_data.get("match_arch", ""), field_name="architecture" @@ -544,11 +528,13 @@ def validate_mirror_form( except ValidationError as e: errors.append(e.message) - # Copy other fields as-is for now (they have different validation requirements) for field in ["match_variant", "match_major_version", "match_minor_version"]: if field in form_data: validated_data[field] = form_data[field] + # Handle active checkbox (form data comes as string "true"/"false" or may be missing) + validated_data["active"] = form_data.get("active", "true") == "true" + return validated_data, errors @staticmethod @@ -567,7 +553,6 @@ def validate_repomd_form( validated_data = {} errors = [] - # Validate repository name try: validated_data["repo_name"] = FieldValidator.validate_repo_name( form_data.get("repo_name", ""), @@ -577,7 +562,6 @@ def validate_repomd_form( except ValidationError as e: errors.append(e.message) - # Validate main URL (required) try: validated_data["url"] = FieldValidator.validate_url( form_data.get("url", ""), field_name="repository URL", required=True @@ -585,7 +569,6 @@ def validate_repomd_form( except ValidationError as e: errors.append(e.message) - # Validate debug URL (optional) try: debug_url = FieldValidator.validate_url( form_data.get("debug_url", ""), field_name="debug URL", required=False @@ -594,7 +577,6 @@ def validate_repomd_form( except ValidationError as e: errors.append(e.message) - # Validate source URL (optional) try: source_url = FieldValidator.validate_url( form_data.get("source_url", ""), field_name="source URL", required=False @@ -603,7 +585,6 @@ def validate_repomd_form( except ValidationError as e: errors.append(e.message) - # Validate architecture try: validated_data["arch"] = FieldValidator.validate_architecture( form_data.get("arch", ""), field_name="architecture" @@ -611,7 +592,6 @@ def validate_repomd_form( except ValidationError as e: errors.append(e.message) - # Copy production flag validated_data["production"] = form_data.get("production", False) return validated_data, errors diff --git a/apollo/tests/BUILD.bazel b/apollo/tests/BUILD.bazel index b658f79..806782f 100644 --- a/apollo/tests/BUILD.bazel +++ b/apollo/tests/BUILD.bazel @@ -45,6 +45,15 @@ py_test( ], ) +py_test( + name = "test_api_updateinfo", + srcs = ["test_api_updateinfo.py"], + deps = [ + "//apollo/server:server_lib", + "//apollo/db:db_lib", + ], +) + py_test( name = "test_validation", @@ -61,3 +70,30 @@ py_test( "//apollo/server:server_lib", ], ) + +py_test( + name = "test_api_osv", + srcs = ["test_api_osv.py"], + deps = [ + "//apollo/server:server_lib", + ], +) + +py_test( + name = "test_database_service", + srcs = ["test_database_service.py"], + deps = [ + "//apollo/server:server_lib", + "//apollo/db:db_lib", + "//common:common_lib", + ], +) + +py_test( + name = "test_rh_matcher_activities", + srcs = ["test_rh_matcher_activities.py"], + deps = [ + "//apollo/rpmworker:rpmworker_lib", + "//apollo/rpm_helpers:rpm_helpers_lib", + ], +) diff --git a/apollo/tests/test_admin_routes_supported_products.py b/apollo/tests/test_admin_routes_supported_products.py index 32ad971..006391b 100644 --- a/apollo/tests/test_admin_routes_supported_products.py +++ b/apollo/tests/test_admin_routes_supported_products.py @@ -168,8 +168,8 @@ def test_json_serializer_decimal_integer(self): """Test JSON serializer with integer Decimal.""" decimal_val = Decimal("42") result = _json_serializer(decimal_val) - self.assertEqual(result, 42.0) - self.assertIsInstance(result, float) + self.assertEqual(result, 42) + self.assertIsInstance(result, int) def test_json_serializer_unsupported_type(self): """Test JSON serializer with unsupported type.""" @@ -211,10 +211,9 @@ def test_format_export_data_with_decimal(self): result = _format_export_data(data) - # Should be valid JSON with Decimals converted to floats parsed = json.loads(result) self.assertEqual(parsed[0]["price"], 19.99) - self.assertEqual(parsed[1]["price"], 99.0) + self.assertEqual(parsed[1]["price"], 99) def test_format_export_data_empty(self): """Test formatting empty export data.""" @@ -442,6 +441,215 @@ def test_validate_import_data_not_list(self): self.assertIn("must be a list", errors[0]) +class TestActiveFieldCheckboxParsing(unittest.TestCase): + """Test the checkbox parsing logic for the active field.""" + + def test_checkbox_checked_returns_true(self): + """Test that checked checkbox results in active_value='true'.""" + from unittest.mock import Mock + form_data = Mock() + form_data.getlist.return_value = ["true"] + + active_value = "true" if "true" in form_data.getlist("active") else "false" + + self.assertEqual(active_value, "true") + + def test_checkbox_unchecked_returns_false(self): + """Test that unchecked checkbox results in active_value='false'.""" + from unittest.mock import Mock + form_data = Mock() + form_data.getlist.return_value = [] + + active_value = "true" if "true" in form_data.getlist("active") else "false" + + self.assertEqual(active_value, "false") + + def test_checkbox_missing_field_defaults_to_false(self): + """Test that missing active field defaults to false.""" + from unittest.mock import Mock + form_data = Mock() + form_data.getlist.return_value = [] + + active_value = "true" if "true" in form_data.getlist("active") else "false" + + self.assertEqual(active_value, "false") + + def test_checkbox_boolean_conversion(self): + """Test that string active_value converts correctly to boolean.""" + active_value_true = "true" + active_value_false = "false" + + self.assertTrue(active_value_true == "true") + self.assertFalse(active_value_false == "true") + + def test_checkbox_with_unexpected_value(self): + """Test that unexpected values default to false.""" + from unittest.mock import Mock + form_data = Mock() + form_data.getlist.return_value = ["yes", "1", "on"] + + active_value = "true" if "true" in form_data.getlist("active") else "false" + + self.assertEqual(active_value, "false") + + +class TestMirrorConfigDataWithActiveField(unittest.TestCase): + """Test mirror configuration export includes active field.""" + + def test_exported_config_includes_active_true(self): + """Test that exported mirror config includes active=true.""" + mirror = Mock() + mirror.id = 1 + mirror.name = "Active Mirror" + mirror.match_variant = "Red Hat Enterprise Linux" + mirror.match_major_version = 9 + mirror.match_minor_version = None + mirror.match_arch = "x86_64" + mirror.active = True + mirror.created_at = datetime(2024, 1, 1, 10, 0, 0) + mirror.updated_at = None + + # Mock supported product + mirror.supported_product = Mock() + mirror.supported_product.id = 1 + mirror.supported_product.name = "Rocky Linux" + mirror.supported_product.variant = "Rocky Linux" + mirror.supported_product.vendor = "RESF" + + # Mock repositories + mirror.rpm_repomds = [] + + result = asyncio.run(_get_mirror_config_data(mirror)) + + # Verify active field is included and true + self.assertIn("active", result["mirror"]) + self.assertTrue(result["mirror"]["active"]) + + def test_exported_config_includes_active_false(self): + """Test that exported mirror config includes active=false.""" + mirror = Mock() + mirror.id = 2 + mirror.name = "Inactive Mirror" + mirror.match_variant = "Red Hat Enterprise Linux" + mirror.match_major_version = 8 + mirror.match_minor_version = None + mirror.match_arch = "x86_64" + mirror.active = False + mirror.created_at = datetime(2024, 1, 1, 10, 0, 0) + mirror.updated_at = None + + # Mock supported product + mirror.supported_product = Mock() + mirror.supported_product.id = 1 + mirror.supported_product.name = "Rocky Linux" + mirror.supported_product.variant = "Rocky Linux" + mirror.supported_product.vendor = "RESF" + + # Mock repositories + mirror.rpm_repomds = [] + + result = asyncio.run(_get_mirror_config_data(mirror)) + + # Verify active field is included and false + self.assertIn("active", result["mirror"]) + self.assertFalse(result["mirror"]["active"]) + + +class TestImportWithActiveField(unittest.TestCase): + """Test import validation with active field.""" + + def test_import_data_with_active_true_is_valid(self): + """Test that import data with active=true is valid.""" + valid_data = [ + { + "product": { + "name": "Rocky Linux", + "variant": "Rocky Linux", + "vendor": "RESF", + }, + "mirror": { + "name": "Rocky Linux 9 x86_64", + "match_variant": "Red Hat Enterprise Linux", + "match_major_version": 9, + "match_arch": "x86_64", + "active": True, + }, + "repositories": [ + { + "repo_name": "BaseOS", + "arch": "x86_64", + "production": True, + "url": "https://example.com/repo", + } + ], + } + ] + + errors = asyncio.run(_validate_import_data(valid_data)) + self.assertEqual(errors, []) + + def test_import_data_with_active_false_is_valid(self): + """Test that import data with active=false is valid.""" + valid_data = [ + { + "product": { + "name": "Rocky Linux", + "variant": "Rocky Linux", + "vendor": "RESF", + }, + "mirror": { + "name": "Rocky Linux 8 x86_64", + "match_variant": "Red Hat Enterprise Linux", + "match_major_version": 8, + "match_arch": "x86_64", + "active": False, + }, + "repositories": [ + { + "repo_name": "BaseOS", + "arch": "x86_64", + "production": True, + "url": "https://example.com/repo", + } + ], + } + ] + + errors = asyncio.run(_validate_import_data(valid_data)) + self.assertEqual(errors, []) + + def test_import_data_without_active_field_is_valid(self): + """Test that import data without active field is valid (backwards compatibility).""" + valid_data = [ + { + "product": { + "name": "Rocky Linux", + "variant": "Rocky Linux", + "vendor": "RESF", + }, + "mirror": { + "name": "Rocky Linux 9 x86_64", + "match_variant": "Red Hat Enterprise Linux", + "match_major_version": 9, + "match_arch": "x86_64", + # No active field - should default to true + }, + "repositories": [ + { + "repo_name": "BaseOS", + "arch": "x86_64", + "production": True, + "url": "https://example.com/repo", + } + ], + } + ] + + # Should still be valid - active field is optional for backwards compatibility + errors = asyncio.run(_validate_import_data(valid_data)) + self.assertEqual(errors, []) + + if __name__ == "__main__": # Run with verbose output unittest.main(verbosity=2) diff --git a/apollo/tests/test_api_osv.py b/apollo/tests/test_api_osv.py new file mode 100644 index 0000000..6422c3d --- /dev/null +++ b/apollo/tests/test_api_osv.py @@ -0,0 +1,248 @@ +""" +Tests for OSV API CVE filtering functionality +""" + +import unittest +import datetime +from unittest.mock import Mock + +from apollo.server.routes.api_osv import to_osv_advisory + + +class MockSupportedProduct: + """Mock SupportedProduct model""" + + def __init__(self, variant="Rocky Linux", vendor="Rocky Enterprise Software Foundation"): + self.variant = variant + self.vendor = vendor + + +class MockSupportedProductsRhMirror: + """Mock SupportedProductsRhMirror model""" + + def __init__(self, match_major_version=9): + self.match_major_version = match_major_version + + +class MockPackage: + """Mock Package model""" + + def __init__( + self, + nevra, + product_name="Rocky Linux 9", + repo_name="BaseOS", + supported_product=None, + supported_products_rh_mirror=None, + ): + self.nevra = nevra + self.product_name = product_name + self.repo_name = repo_name + self.supported_product = supported_product or MockSupportedProduct() + self.supported_products_rh_mirror = supported_products_rh_mirror + + +class MockCVE: + """Mock CVE model""" + + def __init__( + self, + cve="CVE-2024-1234", + cvss3_base_score="7.5", + cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:N/A:N", + ): + self.cve = cve + self.cvss3_base_score = cvss3_base_score + self.cvss3_scoring_vector = cvss3_scoring_vector + + +class MockFix: + """Mock Fix model""" + + def __init__(self, source="https://bugzilla.redhat.com/show_bug.cgi?id=1234567"): + self.source = source + + +class MockAdvisory: + """Mock Advisory model""" + + def __init__( + self, + name="RLSA-2024:1234", + synopsis="Important: test security update", + description="A security update for test package", + published_at=None, + updated_at=None, + packages=None, + cves=None, + fixes=None, + red_hat_advisory=None, + ): + self.name = name + self.synopsis = synopsis + self.description = description + self.published_at = published_at or datetime.datetime.now( + datetime.timezone.utc + ) + self.updated_at = updated_at or datetime.datetime.now(datetime.timezone.utc) + self.packages = packages or [] + self.cves = cves or [] + self.fixes = fixes or [] + self.red_hat_advisory = red_hat_advisory + + +class TestOSVCVEFiltering(unittest.TestCase): + """Test CVE filtering logic in OSV API""" + + def test_advisory_with_cve_has_upstream_references(self): + """Test that advisories with CVEs have upstream references populated""" + packages = [ + MockPackage( + nevra="pcs-0:0.11.8-2.el9_5.src", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + cves = [MockCVE(cve="CVE-2024-1234")] + + advisory = MockAdvisory(packages=packages, cves=cves) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + self.assertIsNotNone(result.upstream) + self.assertEqual(len(result.upstream), 1) + self.assertIn("CVE-2024-1234", result.upstream) + + def test_advisory_with_multiple_cves(self): + """Test that advisories with multiple CVEs include all in upstream""" + packages = [ + MockPackage( + nevra="openssl-1:3.0.7-28.el9_5.src", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + cves = [ + MockCVE(cve="CVE-2024-1111"), + MockCVE(cve="CVE-2024-2222"), + MockCVE(cve="CVE-2024-3333"), + ] + + advisory = MockAdvisory(packages=packages, cves=cves) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + self.assertIsNotNone(result.upstream) + self.assertEqual(len(result.upstream), 3) + self.assertIn("CVE-2024-1111", result.upstream) + self.assertIn("CVE-2024-2222", result.upstream) + self.assertIn("CVE-2024-3333", result.upstream) + + def test_advisory_without_cves_has_empty_upstream(self): + """Test that advisories without CVEs have empty upstream list""" + packages = [ + MockPackage( + nevra="kernel-0:5.14.0-427.el9.src", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + + advisory = MockAdvisory(packages=packages, cves=[]) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + self.assertIsNotNone(result.upstream) + self.assertEqual(len(result.upstream), 0) + + def test_source_packages_only(self): + """Test that only source packages are processed, not binary packages""" + packages = [ + MockPackage( + nevra="httpd-0:2.4.57-8.el9.src", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + MockPackage( + nevra="httpd-0:2.4.57-8.el9.x86_64", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + MockPackage( + nevra="httpd-0:2.4.57-8.el9.aarch64", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + cves = [MockCVE()] + + advisory = MockAdvisory(packages=packages, cves=cves) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + # Should only have 1 affected package (the source package) + self.assertEqual(len(result.affected), 1) + self.assertEqual(result.affected[0].package.name, "httpd") + + def test_severity_from_highest_cvss(self): + """Test that severity uses the highest CVSS score from multiple CVEs""" + packages = [ + MockPackage( + nevra="vim-2:9.0.1592-1.el9.src", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + cves = [ + MockCVE( + cve="CVE-2024-1111", + cvss3_base_score="5.5", + cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:N/A:N", + ), + MockCVE( + cve="CVE-2024-2222", + cvss3_base_score="9.8", + cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", + ), + MockCVE( + cve="CVE-2024-3333", + cvss3_base_score="7.5", + cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:N/A:N", + ), + ] + + advisory = MockAdvisory(packages=packages, cves=cves) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + self.assertIsNotNone(result.severity) + self.assertEqual(len(result.severity), 1) + self.assertEqual(result.severity[0].type, "CVSS_V3") + self.assertEqual( + result.severity[0].score, "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H" + ) + + def test_ecosystem_format(self): + """Test that ecosystem field is formatted correctly""" + packages = [ + MockPackage( + nevra="bash-0:5.1.8-9.el9.src", + product_name="Rocky Linux 9", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + cves = [MockCVE()] + + advisory = MockAdvisory(packages=packages, cves=cves) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + self.assertEqual(len(result.affected), 1) + self.assertEqual(result.affected[0].package.ecosystem, "Rocky Linux:9") + + def test_version_format_with_epoch(self): + """Test that fixed version includes epoch in epoch:version-release format""" + packages = [ + MockPackage( + nevra="systemd-0:252-38.el9_5.src", + supported_products_rh_mirror=MockSupportedProductsRhMirror(9), + ), + ] + cves = [MockCVE()] + + advisory = MockAdvisory(packages=packages, cves=cves) + result = to_osv_advisory("https://errata.rockylinux.org", advisory) + + fixed_version = result.affected[0].ranges[0].events[1].fixed + self.assertEqual(fixed_version, "0:252-38.el9_5") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/apollo/tests/test_api_updateinfo.py b/apollo/tests/test_api_updateinfo.py new file mode 100644 index 0000000..9b9b536 --- /dev/null +++ b/apollo/tests/test_api_updateinfo.py @@ -0,0 +1,181 @@ +""" +Tests for updateinfo API endpoints and helper functions +""" + +import unittest +import sys +import os +from unittest.mock import Mock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from apollo.server.routes.api_updateinfo import ( + resolve_product_slug, + get_source_package_name, + build_source_rpm_mapping, + PRODUCT_SLUG_MAP, +) + + +class TestProductSlugResolution(unittest.TestCase): + """Test product slug resolution""" + + def test_valid_slug_rocky_linux(self): + """Test resolving rocky-linux slug""" + result = resolve_product_slug("rocky-linux") + self.assertEqual(result, "Rocky Linux") + + def test_valid_slug_case_insensitive(self): + """Test slug resolution is case-insensitive""" + result = resolve_product_slug("ROCKY-LINUX") + self.assertEqual(result, "Rocky Linux") + + def test_invalid_slug(self): + """Test invalid slug returns None""" + result = resolve_product_slug("invalid-product") + self.assertIsNone(result) + + def test_sig_cloud_slug(self): + """Test resolving sig-cloud slug""" + result = resolve_product_slug("rocky-linux-sig-cloud") + self.assertEqual(result, "Rocky Linux SIG Cloud") + + +class TestGetSourcePackageName(unittest.TestCase): + """Test get_source_package_name helper""" + + def test_regular_package(self): + """Test regular package without module""" + pkg = Mock() + pkg.package_name = "kernel" + pkg.module_name = None + + result = get_source_package_name(pkg) + self.assertEqual(result, "kernel") + + def test_module_package_without_prefix(self): + """Test module package without module. prefix""" + pkg = Mock() + pkg.package_name = "python-markupsafe" + pkg.module_name = "python27" + pkg.module_stream = "el8" + + result = get_source_package_name(pkg) + self.assertEqual(result, "python27:python-markupsafe:el8") + + def test_module_package_cleaned_by_orm(self): + """Test module package (ORM strips module. prefix automatically)""" + pkg = Mock() + pkg.package_name = "python-markupsafe" # ORM already cleaned + pkg.module_name = "python27" + pkg.module_stream = "el8" + + result = get_source_package_name(pkg) + self.assertEqual(result, "python27:python-markupsafe:el8") + + def test_regular_package_cleaned_by_orm(self): + """Test regular package (ORM strips module. prefix automatically)""" + pkg = Mock() + pkg.package_name = "delve" # ORM already cleaned + pkg.module_name = None + + result = get_source_package_name(pkg) + self.assertEqual(result, "delve") + + +class TestBuildSourceRpmMapping(unittest.TestCase): + """Test build_source_rpm_mapping helper""" + + def test_simple_source_rpm_mapping(self): + """Test mapping for simple package""" + src_pkg = Mock() + src_pkg.package_name = "kernel" + src_pkg.module_name = None + src_pkg.nevra = "kernel-4.18.0-425.el8.src.rpm" + + bin_pkg = Mock() + bin_pkg.package_name = "kernel" + bin_pkg.module_name = None + bin_pkg.nevra = "kernel-4.18.0-425.el8.x86_64.rpm" + + packages = [src_pkg, bin_pkg] + + result = build_source_rpm_mapping(packages) + + self.assertEqual(result, {"kernel": "kernel-4.18.0-425.el8.src.rpm"}) + + def test_multi_binary_source_rpm_mapping(self): + """Test mapping when multiple binaries share one source""" + src_pkg = Mock() + src_pkg.package_name = "python-markupsafe" + src_pkg.module_name = None + src_pkg.nevra = "python-markupsafe-0.23-19.el8.src.rpm" + + bin_pkg1 = Mock() + bin_pkg1.package_name = "python-markupsafe" + bin_pkg1.module_name = None + bin_pkg1.nevra = "python2-markupsafe-0.23-19.el8.x86_64.rpm" + + bin_pkg2 = Mock() + bin_pkg2.package_name = "python-markupsafe" + bin_pkg2.module_name = None + bin_pkg2.nevra = "python3-markupsafe-0.23-19.el8.x86_64.rpm" + + packages = [src_pkg, bin_pkg1, bin_pkg2] + + result = build_source_rpm_mapping(packages) + + self.assertEqual(result, {"python-markupsafe": "python-markupsafe-0.23-19.el8.src.rpm"}) + + def test_module_package_cleaned_by_orm(self): + """Test module package (ORM strips module. prefix automatically)""" + src_pkg = Mock() + src_pkg.package_name = "python-markupsafe" # ORM already cleaned + src_pkg.module_name = "python27" + src_pkg.module_stream = "el8" + src_pkg.nevra = "python-markupsafe-0.23-19.module+el8.5.0+706+735ec4b3.src.rpm" + + bin_pkg = Mock() + bin_pkg.package_name = "python-markupsafe" # ORM already cleaned + bin_pkg.module_name = "python27" + bin_pkg.module_stream = "el8" + bin_pkg.nevra = "python2-markupsafe-0.23-19.module+el8.5.0+706+735ec4b3.x86_64.rpm" + + packages = [src_pkg, bin_pkg] + + result = build_source_rpm_mapping(packages) + + expected_key = "python27:python-markupsafe:el8" + self.assertIn(expected_key, result) + self.assertEqual(result[expected_key], "python-markupsafe-0.23-19.module+el8.5.0+706+735ec4b3.src.rpm") + + def test_no_source_rpm(self): + """Test when no source RPM is present""" + bin_pkg = Mock() + bin_pkg.package_name = "kernel" + bin_pkg.module_name = None + bin_pkg.nevra = "kernel-4.18.0-425.el8.x86_64.rpm" + + packages = [bin_pkg] + + result = build_source_rpm_mapping(packages) + + self.assertEqual(result, {}) + + +class TestProductSlugMapping(unittest.TestCase): + """Test product slug mapping configuration""" + + def test_slug_map_contains_rocky_linux(self): + """Test slug map contains rocky-linux""" + self.assertIn("rocky-linux", PRODUCT_SLUG_MAP) + self.assertEqual(PRODUCT_SLUG_MAP["rocky-linux"], "Rocky Linux") + + def test_slug_map_contains_sig_cloud(self): + """Test slug map contains sig-cloud""" + self.assertIn("rocky-linux-sig-cloud", PRODUCT_SLUG_MAP) + self.assertEqual(PRODUCT_SLUG_MAP["rocky-linux-sig-cloud"], "Rocky Linux SIG Cloud") + + +if __name__ == "__main__": + unittest.main() diff --git a/apollo/tests/test_csaf_processing.py b/apollo/tests/test_csaf_processing.py index dbd0f91..aa31c10 100644 --- a/apollo/tests/test_csaf_processing.py +++ b/apollo/tests/test_csaf_processing.py @@ -22,17 +22,10 @@ ) class TestCsafProcessing(unittest.IsolatedAsyncioTestCase): - @classmethod - async def asyncSetUp(cls): - # Initialize test database for all tests in this class + async def asyncSetUp(self): + # Initialize test database before each test await initialize_test_db() - - @classmethod - async def asyncTearDown(cls): - # Close database connections when tests are done - await close_test_db() - def setUp(self): # Create sample CSAF data matching schema requirements self.sample_csaf = { "document": { @@ -69,10 +62,35 @@ def setUp(self): "name": "Red Hat Enterprise Linux 9", "product": { "name": "Red Hat Enterprise Linux 9", + "product_id": "AppStream-9.4.0.Z.MAIN", "product_identification_helper": { - "cpe": "cpe:/o:redhat:enterprise_linux:9.4" + "cpe": "cpe:/o:redhat:enterprise_linux:9::appstream" + } + }, + "branches": [ + { + "category": "product_version", + "name": "rsync-0:3.2.3-19.el9_4.1.x86_64", + "product": { + "name": "rsync-0:3.2.3-19.el9_4.1.x86_64", + "product_id": "rsync-0:3.2.3-19.el9_4.1.x86_64", + "product_identification_helper": { + "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=x86_64" + } + } + }, + { + "category": "product_version", + "name": "rsync-0:3.2.3-19.el9_4.1.src", + "product": { + "name": "rsync-0:3.2.3-19.el9_4.1.src", + "product_id": "rsync-0:3.2.3-19.el9_4.1.src", + "product_identification_helper": { + "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=src" + } + } } - } + ] } ] }, @@ -95,8 +113,8 @@ def setUp(self): ], "product_status": { "fixed": [ - "AppStream-9.4.0.Z.EUS:rsync-0:3.2.3-19.el9_4.1.x86_64", - "AppStream-9.4.0.Z.EUS:rsync-0:3.2.3-19.el9_4.1.src" + "AppStream-9.4.0.Z.MAIN:rsync-0:3.2.3-19.el9_4.1.x86_64", + "AppStream-9.4.0.Z.MAIN:rsync-0:3.2.3-19.el9_4.1.src" ] }, "scores": [{ @@ -117,28 +135,31 @@ def setUp(self): } ] } - + # Create a temporary file with the sample data self.test_file = pathlib.Path("test_csaf.json") with open(self.test_file, "w") as f: json.dump(self.sample_csaf, f) - async def tearDown(self): - # Clean up database and temporary files after each test + async def asyncTearDown(self): + # Clean up database entries and temporary files after each test await RedHatAdvisory.all().delete() await RedHatAdvisoryPackage.all().delete() await RedHatAdvisoryCVE.all().delete() - await RedHatAdvisoryBugzillaBug.all().delete() + await RedHatAdvisoryBugzillaBug.all().delete() await RedHatAdvisoryAffectedProduct.all().delete() - - # Clean up temporary file + + # Close database connections + await close_test_db() + + # Clean up temporary files self.test_file.unlink(missing_ok=True) pathlib.Path("invalid_csaf.json").unlink(missing_ok=True) async def test_new_advisory_creation(self): # Test creating a new advisory with a real test database result = await process_csaf_file(self.sample_csaf, "test.json") - + # Verify advisory was created correctly advisory = await RedHatAdvisory.get_or_none(name="RHSA-2025:1234") self.assertIsNotNone(advisory) @@ -176,7 +197,8 @@ async def test_new_advisory_creation(self): self.assertEqual(products[0].variant, "Red Hat Enterprise Linux") self.assertEqual(products[0].arch, "x86_64") self.assertEqual(products[0].major_version, 9) - self.assertEqual(products[0].minor_version, 4) + # Minor version is None because CPE doesn't include minor version + self.assertIsNone(products[0].minor_version) async def test_advisory_update(self): # First create an advisory with different values @@ -224,12 +246,13 @@ async def test_no_vulnerabilities(self): self.assertEqual(count, 0) async def test_no_fixed_packages(self): - # Test CSAF with vulnerabilities but no fixed packages + # Test CSAF with vulnerabilities but no fixed packages in product_tree csaf = self.sample_csaf.copy() - csaf["vulnerabilities"][0]["product_status"]["fixed"] = [] + # Remove product_version entries from product_tree to simulate no fixed packages + csaf["product_tree"]["branches"][0]["branches"][0]["branches"][0].pop("branches", None) result = await process_csaf_file(csaf, "test.json") self.assertIsNone(result) - + # Verify nothing was created count = await RedHatAdvisory.all().count() self.assertEqual(count, 0) @@ -239,4 +262,7 @@ async def test_db_exception(self, mock_get_or_none): # Simulate a database error mock_get_or_none.side_effect = Exception("DB error") with self.assertRaises(Exception): - await process_csaf_file(self.sample_csaf, "test.json") \ No newline at end of file + await process_csaf_file(self.sample_csaf, "test.json") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/apollo/tests/test_database_service.py b/apollo/tests/test_database_service.py new file mode 100644 index 0000000..0a8c215 --- /dev/null +++ b/apollo/tests/test_database_service.py @@ -0,0 +1,226 @@ +""" +Tests for DatabaseService functionality +Tests utility functions for database operations including timestamp management +""" + +import unittest +import asyncio +from datetime import datetime, timezone +from unittest.mock import Mock, AsyncMock, patch +import os + +# Add the project root to the Python path +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from apollo.server.services.database_service import DatabaseService + + +class TestEnvironmentDetection(unittest.TestCase): + """Test environment detection functionality.""" + + def test_is_production_when_env_is_production(self): + """Test production detection when ENV=production.""" + with patch.dict(os.environ, {"ENV": "production"}): + service = DatabaseService() + self.assertTrue(service.is_production_environment()) + + def test_is_not_production_when_env_is_development(self): + """Test production detection when ENV=development.""" + with patch.dict(os.environ, {"ENV": "development"}): + service = DatabaseService() + self.assertFalse(service.is_production_environment()) + + def test_is_not_production_when_env_not_set(self): + """Test production detection when ENV is not set.""" + with patch.dict(os.environ, {}, clear=True): + service = DatabaseService() + self.assertFalse(service.is_production_environment()) + + def test_is_not_production_with_staging_env(self): + """Test production detection with staging environment.""" + with patch.dict(os.environ, {"ENV": "staging"}): + service = DatabaseService() + self.assertFalse(service.is_production_environment()) + + def test_get_environment_info_production(self): + """Test getting environment info for production.""" + with patch.dict(os.environ, {"ENV": "production"}): + service = DatabaseService() + result = asyncio.run(service.get_environment_info()) + + self.assertEqual(result["environment"], "production") + self.assertTrue(result["is_production"]) + self.assertFalse(result["reset_allowed"]) + + def test_get_environment_info_development(self): + """Test getting environment info for development.""" + with patch.dict(os.environ, {"ENV": "development"}): + service = DatabaseService() + result = asyncio.run(service.get_environment_info()) + + self.assertEqual(result["environment"], "development") + self.assertFalse(result["is_production"]) + self.assertTrue(result["reset_allowed"]) + + +class TestLastIndexedAtOperations(unittest.TestCase): + """Test last_indexed_at timestamp operations.""" + + def test_get_last_indexed_at_when_exists(self): + """Test getting last_indexed_at when record exists.""" + mock_index_state = Mock() + test_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc) + mock_index_state.last_indexed_at = test_time + + with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state: + mock_state.first = AsyncMock(return_value=mock_index_state) + + service = DatabaseService() + result = asyncio.run(service.get_last_indexed_at()) + + self.assertEqual(result["last_indexed_at"], test_time) + self.assertEqual(result["last_indexed_at_iso"], "2025-07-01T00:00:00+00:00") + self.assertTrue(result["exists"]) + + def test_get_last_indexed_at_when_not_exists(self): + """Test getting last_indexed_at when no record exists.""" + with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state: + mock_state.first = AsyncMock(return_value=None) + + service = DatabaseService() + result = asyncio.run(service.get_last_indexed_at()) + + self.assertIsNone(result["last_indexed_at"]) + self.assertIsNone(result["last_indexed_at_iso"]) + self.assertFalse(result["exists"]) + + def test_get_last_indexed_at_when_timestamp_is_none(self): + """Test getting last_indexed_at when timestamp field is None.""" + mock_index_state = Mock() + mock_index_state.last_indexed_at = None + + with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state: + mock_state.first = AsyncMock(return_value=mock_index_state) + + service = DatabaseService() + result = asyncio.run(service.get_last_indexed_at()) + + self.assertIsNone(result["last_indexed_at"]) + self.assertIsNone(result["last_indexed_at_iso"]) + self.assertFalse(result["exists"]) + + def test_update_last_indexed_at_existing_record(self): + """Test updating last_indexed_at for existing record.""" + old_time = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc) + new_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc) + + mock_index_state = Mock() + mock_index_state.last_indexed_at = old_time + mock_index_state.save = AsyncMock() + + with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state, \ + patch("apollo.server.services.database_service.Logger"): + mock_state.first = AsyncMock(return_value=mock_index_state) + + service = DatabaseService() + result = asyncio.run(service.update_last_indexed_at(new_time, "admin@example.com")) + + self.assertTrue(result["success"]) + self.assertEqual(result["old_timestamp"], "2025-06-01T00:00:00+00:00") + self.assertEqual(result["new_timestamp"], "2025-07-01T00:00:00+00:00") + self.assertIn("Successfully updated", result["message"]) + + # Verify save was called + mock_index_state.save.assert_called_once() + # Verify timestamp was updated + self.assertEqual(mock_index_state.last_indexed_at, new_time) + + def test_update_last_indexed_at_create_new_record(self): + """Test updating last_indexed_at when no record exists (creates new).""" + new_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc) + + with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state, \ + patch("apollo.server.services.database_service.Logger"): + mock_state.first = AsyncMock(return_value=None) + mock_state.create = AsyncMock() + + service = DatabaseService() + result = asyncio.run(service.update_last_indexed_at(new_time, "admin@example.com")) + + self.assertTrue(result["success"]) + self.assertIsNone(result["old_timestamp"]) + self.assertEqual(result["new_timestamp"], "2025-07-01T00:00:00+00:00") + self.assertIn("Successfully updated", result["message"]) + + # Verify create was called with correct timestamp + mock_state.create.assert_called_once_with(last_indexed_at=new_time) + + def test_update_last_indexed_at_handles_exception(self): + """Test that update_last_indexed_at handles database exceptions.""" + new_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc) + + with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state, \ + patch("apollo.server.services.database_service.Logger"): + mock_state.first = AsyncMock(side_effect=Exception("Database error")) + + service = DatabaseService() + + with self.assertRaises(RuntimeError) as cm: + asyncio.run(service.update_last_indexed_at(new_time, "admin@example.com")) + + self.assertIn("Failed to update timestamp", str(cm.exception)) + + +class TestPartialResetValidation(unittest.TestCase): + """Test partial reset validation logic.""" + + def test_preview_partial_reset_blocks_in_production(self): + """Test that preview_partial_reset raises error in production.""" + with patch.dict(os.environ, {"ENV": "production"}): + service = DatabaseService() + cutoff_date = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc) + + with self.assertRaises(ValueError) as cm: + asyncio.run(service.preview_partial_reset(cutoff_date)) + + self.assertIn("production environment", str(cm.exception)) + + def test_preview_partial_reset_rejects_future_date(self): + """Test that preview_partial_reset rejects future dates.""" + with patch.dict(os.environ, {"ENV": "development"}): + service = DatabaseService() + future_date = datetime(2099, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + + with self.assertRaises(ValueError) as cm: + asyncio.run(service.preview_partial_reset(future_date)) + + self.assertIn("must be in the past", str(cm.exception)) + + def test_perform_partial_reset_blocks_in_production(self): + """Test that perform_partial_reset raises error in production.""" + with patch.dict(os.environ, {"ENV": "production"}): + service = DatabaseService() + cutoff_date = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc) + + with self.assertRaises(ValueError) as cm: + asyncio.run(service.perform_partial_reset(cutoff_date, "admin@example.com")) + + self.assertIn("production environment", str(cm.exception)) + + def test_perform_partial_reset_rejects_future_date(self): + """Test that perform_partial_reset rejects future dates.""" + with patch.dict(os.environ, {"ENV": "development"}): + service = DatabaseService() + future_date = datetime(2099, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + + with self.assertRaises(ValueError) as cm: + asyncio.run(service.perform_partial_reset(future_date, "admin@example.com")) + + self.assertIn("must be in the past", str(cm.exception)) + + +if __name__ == "__main__": + # Run with verbose output + unittest.main(verbosity=2) diff --git a/apollo/tests/test_rh_matcher_activities.py b/apollo/tests/test_rh_matcher_activities.py new file mode 100644 index 0000000..c41adef --- /dev/null +++ b/apollo/tests/test_rh_matcher_activities.py @@ -0,0 +1,131 @@ +""" +Tests for RH matcher activities package_name extraction logic +""" + +import unittest +import sys +import os +from unittest.mock import Mock, MagicMock, patch +from xml.etree import ElementTree as ET + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from apollo.rpmworker import repomd + + +class TestPackageNameExtraction(unittest.TestCase): + """Test package_name extraction from source RPMs""" + + def setUp(self): + """Set up test fixtures""" + self.test_advisory_nvra = "libarchive-3.3.3-5.el8.src.rpm" + self.test_binary_nvra = "libarchive-0:3.3.3-5.el8.x86_64.rpm" + self.test_debuginfo_nvra = "libarchive-debuginfo-0:3.3.3-5.el8.aarch64.rpm" + + def test_nvra_regex_matches_source_rpm(self): + """Test NVRA_RE regex matches source RPM correctly""" + match = repomd.NVRA_RE.search(self.test_advisory_nvra) + self.assertIsNotNone(match) + self.assertEqual(match.group(1), "libarchive") + + def test_nvra_regex_matches_binary_rpm(self): + """Test NVRA_RE regex matches binary RPM name""" + source_rpm_text = "libarchive-3.3.3-5.el8.src.rpm" + match = repomd.NVRA_RE.search(source_rpm_text) + self.assertIsNotNone(match) + self.assertEqual(match.group(1), "libarchive") + + def test_nvra_regex_handles_module_packages(self): + """Test NVRA_RE regex extracts package name from module packages""" + module_source_rpm = "postgresql-12.5-1.module+el8.3.0+6656+95b1e5d5.src.rpm" + match = repomd.NVRA_RE.search(module_source_rpm) + self.assertIsNotNone(match) + self.assertEqual(match.group(1), "postgresql") + + def test_nvra_regex_no_match_returns_none(self): + """Test NVRA_RE regex returns None for invalid format""" + invalid_nvra = "not-a-valid-package-name" + match = repomd.NVRA_RE.search(invalid_nvra) + self.assertIsNone(match) + + def test_source_rpm_element_handling(self): + """Test handling of missing source_rpm XML element""" + xml_with_sourcerpm = """ + + + libarchive-3.3.3-5.el8.src.rpm + + + """ + xml_without_sourcerpm = """ + + + + + """ + + root_with = ET.fromstring(xml_with_sourcerpm) + source_rpm_with = root_with.find("format").find("{http://linux.duke.edu/metadata/rpm}sourcerpm") + self.assertIsNotNone(source_rpm_with) + + root_without = ET.fromstring(xml_without_sourcerpm) + source_rpm_without = root_without.find("format").find("{http://linux.duke.edu/metadata/rpm}sourcerpm") + self.assertIsNone(source_rpm_without) + + def test_package_name_extraction_workflow(self): + """Test complete workflow of package_name extraction with various scenarios""" + test_cases = [ + { + "name": "Valid source RPM", + "advisory_nvra": "libarchive-3.3.3-5.el8.src.rpm", + "is_source": True, + "source_rpm_text": None, + "expected": "libarchive" + }, + { + "name": "Valid binary RPM with source", + "advisory_nvra": "libarchive-0:3.3.3-5.el8.x86_64", + "is_source": False, + "source_rpm_text": "libarchive-3.3.3-5.el8.src.rpm", + "expected": "libarchive" + }, + { + "name": "Binary RPM with missing source", + "advisory_nvra": "libarchive-debuginfo-0:3.3.3-5.el8.aarch64", + "is_source": False, + "source_rpm_text": None, + "expected": None + }, + { + "name": "Invalid source RPM format", + "advisory_nvra": "invalid-format", + "is_source": True, + "source_rpm_text": None, + "expected": None + }, + ] + + for test_case in test_cases: + with self.subTest(test_case=test_case["name"]): + advisory_nvra = test_case["advisory_nvra"] + is_source = test_case["is_source"] + source_rpm_text = test_case["source_rpm_text"] + expected = test_case["expected"] + + package_name = None + + if advisory_nvra.endswith(".src.rpm") or advisory_nvra.endswith(".src"): + source_nvra = repomd.NVRA_RE.search(advisory_nvra) + if source_nvra: + package_name = source_nvra.group(1) + elif source_rpm_text: + source_nvra = repomd.NVRA_RE.search(source_rpm_text) + if source_nvra: + package_name = source_nvra.group(1) + + self.assertEqual(package_name, expected, + f"Failed for {test_case['name']}: expected {expected}, got {package_name}") + + +if __name__ == "__main__": + unittest.main() diff --git a/apollo/tests/test_rhcsaf.py b/apollo/tests/test_rhcsaf.py index 1c62f0a..2b4fc62 100644 --- a/apollo/tests/test_rhcsaf.py +++ b/apollo/tests/test_rhcsaf.py @@ -52,7 +52,29 @@ def setUp(self): "product_identification_helper": { "cpe": "cpe:/o:redhat:enterprise_linux:9.4" } - } + }, + "branches": [ + { + "category": "product_version", + "name": "rsync-0:3.2.3-19.el9_4.1.x86_64", + "product": { + "product_id": "rsync-0:3.2.3-19.el9_4.1.x86_64", + "product_identification_helper": { + "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=x86_64" + } + } + }, + { + "category": "product_version", + "name": "rsync-0:3.2.3-19.el9_4.1.src", + "product": { + "product_id": "rsync-0:3.2.3-19.el9_4.1.src", + "product_identification_helper": { + "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=src" + } + } + } + ] } ] }, @@ -252,4 +274,227 @@ def test_major_only_version(self): self.assertIn( ("Red Hat Enterprise Linux", "Red Hat Enterprise Linux for x86_64", 9, None, "x86_64"), result - ) \ No newline at end of file + ) + + +class TestEUSDetection(unittest.TestCase): + """Test EUS product detection and filtering""" + + def setUp(self): + with patch('common.logger.Logger') as mock_logger_class: + mock_logger_class.return_value = MagicMock() + from apollo.rhcsaf import _is_eus_product + self._is_eus_product = _is_eus_product + + def test_detect_eus_via_cpe(self): + """Test EUS detection via CPE product field""" + # EUS CPE products + self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_eus:9.4::appstream")) + self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_e4s:9.0::appstream")) + self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_aus:8.2::appstream")) + self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_tus:8.8::appstream")) + + # Non-EUS CPE product + self.assertFalse(self._is_eus_product("Some Product", "cpe:/a:redhat:enterprise_linux:9::appstream")) + + def test_detect_eus_via_name(self): + """Test EUS detection via product name keywords""" + self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream EUS (v.9.4)", "")) + self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream E4S (v.9.0)", "")) + self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream AUS (v.8.2)", "")) + self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream TUS (v.8.8)", "")) + + # Non-EUS product name + self.assertFalse(self._is_eus_product("Red Hat Enterprise Linux AppStream", "")) + + def test_eus_filtering_in_affected_products(self): + """Test that EUS products are filtered from affected products""" + csaf = { + "product_tree": { + "branches": [ + { + "branches": [ + { + "category": "product_family", + "name": "Red Hat Enterprise Linux", + "branches": [ + { + "category": "product_name", + "product": { + "name": "Red Hat Enterprise Linux AppStream EUS (v.9.4)", + "product_identification_helper": { + "cpe": "cpe:/a:redhat:rhel_eus:9.4::appstream" + } + } + } + ] + }, + { + "category": "architecture", + "name": "x86_64" + } + ] + } + ] + } + } + + result = extract_rhel_affected_products_for_db(csaf) + # Should be empty because the only product is EUS + self.assertEqual(len(result), 0) + + +class TestModularPackages(unittest.TestCase): + """Test modular package extraction""" + + def test_extract_modular_packages(self): + """Test extraction of modular packages with ::module:stream suffix""" + csaf = { + "document": { + "tracking": { + "initial_release_date": "2025-07-28T00:00:00+00:00", + "current_release_date": "2025-07-28T00:00:00+00:00", + "id": "RHSA-2025:12008" + }, + "title": "Red Hat Security Advisory: Important: redis:7 security update", + "aggregate_severity": {"text": "Important"}, + "notes": [ + {"category": "general", "text": "Test description"}, + {"category": "summary", "text": "Test topic"} + ] + }, + "product_tree": { + "branches": [ + { + "branches": [ + { + "category": "product_family", + "name": "Red Hat Enterprise Linux", + "branches": [ + { + "category": "product_name", + "name": "Red Hat Enterprise Linux 9", + "product": { + "name": "Red Hat Enterprise Linux 9", + "product_identification_helper": { + "cpe": "cpe:/o:redhat:enterprise_linux:9::appstream" + } + }, + "branches": [ + { + "category": "product_version", + "name": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.x86_64::redis:7", + "product": { + "product_id": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.x86_64::redis:7", + "product_identification_helper": { + "purl": "pkg:rpm/redhat/redis@7.2.10-1.module+el9.6.0+23332+115a3b01?arch=x86_64&rpmmod=redis:7:9060020250716081121:9" + } + } + }, + { + "category": "product_version", + "name": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.src::redis:7", + "product": { + "product_id": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.src::redis:7", + "product_identification_helper": { + "purl": "pkg:rpm/redhat/redis@7.2.10-1.module+el9.6.0+23332+115a3b01?arch=src&rpmmod=redis:7:9060020250716081121:9" + } + } + } + ] + } + ] + }, + { + "category": "architecture", + "name": "x86_64" + } + ] + } + ] + }, + "vulnerabilities": [ + { + "cve": "CVE-2025-12345", + "ids": [{"system_name": "Red Hat Bugzilla ID", "text": "123456"}], + "product_status": {"fixed": []}, + "scores": [{"cvss_v3": {"vectorString": "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", "baseScore": 9.8}}], + "cwe": {"id": "CWE-79"} + } + ] + } + + result = red_hat_advisory_scraper(csaf) + + # Check that modular packages were extracted with ::module:stream stripped + self.assertIn("redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.x86_64", result["red_hat_fixed_packages"]) + self.assertIn("redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.src", result["red_hat_fixed_packages"]) + + # Verify epoch is preserved + for pkg in result["red_hat_fixed_packages"]: + if "redis" in pkg: + self.assertIn("0:", pkg, "Epoch should be preserved in NEVRA") + + +class TestEUSAdvisoryFiltering(unittest.TestCase): + """Test that EUS-only advisories are filtered out""" + + def test_eus_only_advisory_returns_none(self): + """Test that advisory with only EUS products returns None""" + csaf = { + "document": { + "tracking": { + "initial_release_date": "2025-01-01T00:00:00+00:00", + "current_release_date": "2025-01-01T00:00:00+00:00", + "id": "RHSA-2025:9756" + }, + "title": "Red Hat Security Advisory: Important: package security update", + "aggregate_severity": {"text": "Important"}, + "notes": [ + {"category": "general", "text": "EUS advisory"}, + {"category": "summary", "text": "EUS topic"} + ] + }, + "product_tree": { + "branches": [ + { + "branches": [ + { + "category": "product_family", + "name": "Red Hat Enterprise Linux", + "branches": [ + { + "category": "product_name", + "name": "Red Hat Enterprise Linux AppStream EUS (v.9.4)", + "product": { + "name": "Red Hat Enterprise Linux AppStream EUS (v.9.4)", + "product_identification_helper": { + "cpe": "cpe:/a:redhat:rhel_eus:9.4::appstream" + } + } + } + ] + }, + { + "category": "architecture", + "name": "x86_64" + } + ] + } + ] + }, + "vulnerabilities": [ + { + "cve": "CVE-2025-99999", + "ids": [{"system_name": "Red Hat Bugzilla ID", "text": "999999"}], + "product_status": {"fixed": []}, + "scores": [{"cvss_v3": {"vectorString": "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", "baseScore": 9.8}}], + "cwe": {"id": "CWE-79"} + } + ] + } + + result = red_hat_advisory_scraper(csaf) + + # Advisory should be filtered out (return None) because all products are EUS + self.assertIsNone(result) \ No newline at end of file diff --git a/scripts/generate_rocky_config.py b/scripts/generate_rocky_config.py index 1ae0438..1aafea8 100644 --- a/scripts/generate_rocky_config.py +++ b/scripts/generate_rocky_config.py @@ -602,7 +602,7 @@ def parse_repomd_path(repomd_url: str, base_url: str) -> Dict[str, str]: def build_mirror_config( - version: str, arch: str, name_suffix: Optional[str] = None + version: str, arch: str, name_suffix: Optional[str] = None, mirror_name_base: Optional[str] = None ) -> Dict[str, Any]: """ Build a mirror configuration dictionary. @@ -611,15 +611,19 @@ def build_mirror_config( version: Rocky Linux version arch: Architecture name_suffix: Optional suffix for mirror name + mirror_name_base: Optional custom base for mirror name (e.g., "Rocky Linux 9") Returns: Mirror configuration dictionary """ - # Build mirror name with optional suffix - if name_suffix is not None and name_suffix != "": - mirror_name = f"Rocky Linux {version} {name_suffix} {arch}" + # Build mirror name with optional custom base or suffix + if not mirror_name_base: + mirror_name_base = f"Rocky Linux {version}" + + if name_suffix: + mirror_name = f"{mirror_name_base} {name_suffix} {arch}" else: - mirror_name = f"Rocky Linux {version} {arch}" + mirror_name = f"{mirror_name_base} {arch}" # Parse version to extract major and minor components if version != UNKNOWN_VALUE and "." in version: @@ -690,6 +694,7 @@ def generate_rocky_config( include_source: bool = True, architectures: List[str] = None, name_suffix: Optional[str] = None, + mirror_name_base: Optional[str] = None, ) -> List[Dict[str, Any]]: """ Generate Rocky Linux configuration by discovering repository structure. @@ -702,6 +707,7 @@ def generate_rocky_config( include_source: Whether to include source repository URLs (default: True) architectures: List of architectures to include (default: auto-detect) name_suffix: Optional suffix to add to mirror names (e.g., "test", "staging") + mirror_name_base: Optional custom base for mirror name (e.g., "Rocky Linux 9") Returns: List of configuration dictionaries ready for JSON export @@ -730,10 +736,12 @@ def generate_rocky_config( continue # Skip if version filter specified and doesn't match + # Supports both exact version match (e.g., "9.5") and major version match (e.g., "9") if ( version and metadata["version"] != version and metadata["version"] != UNKNOWN_VALUE + and metadata["version"].split(".")[0] != version ): continue @@ -773,7 +781,7 @@ def generate_rocky_config( if not detected_version: detected_version = UNKNOWN_VALUE - mirror_config = build_mirror_config(detected_version, arch, name_suffix) + mirror_config = build_mirror_config(detected_version, arch, name_suffix, mirror_name_base) # Group repos by name and type repo_groups = {} @@ -828,6 +836,8 @@ def main(): %(prog)s https://mirror.example.com/pub/rocky/ --output rocky_config.json %(prog)s https://mirror.example.com/pub/rocky/ --name-suffix test --version 9.6 %(prog)s https://staging.example.com/pub/rocky/ --name-suffix staging --arch riscv64 + %(prog)s https://mirror.example.com/pub/rocky/ --mirror-name-base "Rocky Linux 9" --version 9.6 + %(prog)s https://mirror.example.com/pub/rocky/ --mirror-name-base "Rocky Linux 9 (Legacy)" --version 9 """, ) @@ -880,6 +890,11 @@ def main(): help="Optional suffix to add to mirror names (e.g., 'test', 'staging')", ) + parser.add_argument( + "--mirror-name-base", + help="Optional custom base for mirror name (e.g., 'Rocky Linux 9' instead of 'Rocky Linux 9.6')", + ) + parser.add_argument("--output", "-o", help="Output file path (default: stdout)") parser.add_argument( @@ -926,6 +941,7 @@ def main(): include_source=not args.no_source, architectures=args.arch, name_suffix=args.name_suffix, + mirror_name_base=args.mirror_name_base, ) if not config: