diff --git a/.dockerignore b/.dockerignore index 5b34fe7..55a3fdb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,9 @@ node_modules .venv .ijwb -.idea \ No newline at end of file +.idea +temp +csaf_analysis +bazel-* +.git +container_data \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2454fd2..394e28c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -37,6 +37,8 @@ jobs: bazel test //apollo/tests:test_auth --test_output=all bazel test //apollo/tests:test_validation --test_output=all bazel test //apollo/tests:test_admin_routes_supported_products --test_output=all + bazel test //apollo/tests:test_api_osv --test_output=all + bazel test //apollo/tests:test_database_service --test_output=all - name: Integration Tests run: ./build/scripts/test.bash diff --git a/apollo/db/__init__.py b/apollo/db/__init__.py index 2b98c4c..1fb0f25 100644 --- a/apollo/db/__init__.py +++ b/apollo/db/__init__.py @@ -201,6 +201,7 @@ class SupportedProductsRhMirror(Model): match_major_version = fields.IntField() match_minor_version = fields.IntField(null=True) match_arch = fields.CharField(max_length=255) + active = fields.BooleanField(default=True) rpm_repomds: fields.ReverseRelation["SupportedProductsRpmRepomd"] rpm_rh_overrides: fields.ReverseRelation["SupportedProductsRpmRhOverride"] diff --git a/apollo/migrations/20251104111759_add_mirror_active_field.sql b/apollo/migrations/20251104111759_add_mirror_active_field.sql new file mode 100644 index 0000000..6c12b1b --- /dev/null +++ b/apollo/migrations/20251104111759_add_mirror_active_field.sql @@ -0,0 +1,11 @@ +-- migrate:up +alter table supported_products_rh_mirrors +add column active boolean not null default true; + +create index supported_products_rh_mirrors_active_idx +on supported_products_rh_mirrors(active); + + +-- migrate:down +drop index if exists supported_products_rh_mirrors_active_idx; +alter table supported_products_rh_mirrors drop column active; diff --git a/apollo/rhcsaf/__init__.py b/apollo/rhcsaf/__init__.py index 762cedf..95175c4 100644 --- a/apollo/rhcsaf/__init__.py +++ b/apollo/rhcsaf/__init__.py @@ -4,24 +4,65 @@ from common.logger import Logger from apollo.rpm_helpers import parse_nevra -# Initialize Info before Logger for this module - logger = Logger() +EUS_CPE_PRODUCTS = frozenset([ + "rhel_eus", # Extended Update Support + "rhel_e4s", # Update Services for SAP Solutions + "rhel_aus", # Advanced Update Support (IBM Power) + "rhel_tus", # Telecommunications Update Service +]) + +EUS_PRODUCT_NAME_KEYWORDS = frozenset([ + "e4s", + "eus", + "aus", + "tus", + "extended update support", + "update services for sap", + "advanced update support", + "telecommunications update service", +]) + +def _is_eus_product(product_name: str, cpe: str) -> bool: + """ + Detects if a product is EUS-related based on product name and CPE. + + Args: + product_name: Full product name (e.g., "Red Hat Enterprise Linux AppStream E4S (v.9.0)") + cpe: CPE string (e.g., "cpe:/a:redhat:rhel_e4s:9.0::appstream") + + Returns: + True if product is EUS/E4S/AUS/TUS, False otherwise + """ + if cpe: + parts = cpe.split(":") + if len(parts) > 3: + cpe_product = parts[3] + if cpe_product in EUS_CPE_PRODUCTS: + return True + + if product_name: + name_lower = product_name.lower() + for keyword in EUS_PRODUCT_NAME_KEYWORDS: + if keyword in name_lower: + return True + + return False + + def extract_rhel_affected_products_for_db(csaf: dict) -> set: """ Extracts all needed info for red_hat_advisory_affected_products table from CSAF product_tree. Expands 'noarch' to all main arches and maps names to user-friendly values. Returns a set of tuples: (variant, name, major_version, minor_version, arch) """ - # Maps architecture short names to user-friendly product names arch_name_map = { "aarch64": "Red Hat Enterprise Linux for ARM 64", "x86_64": "Red Hat Enterprise Linux for x86_64", "s390x": "Red Hat Enterprise Linux for IBM z Systems", "ppc64le": "Red Hat Enterprise Linux for Power, little endian", } - # List of main architectures to expand 'noarch' main_arches = list(arch_name_map.keys()) affected_products = set() product_tree = csaf.get("product_tree", {}) @@ -29,25 +70,20 @@ def extract_rhel_affected_products_for_db(csaf: dict) -> set: logger.warning("No product tree found in CSAF document") return affected_products - # Iterate over all vendor branches in the product tree for vendor_branch in product_tree.get("branches", []): - # Find the product_family branch for RHEL family_branch = None arches = set() for branch in vendor_branch.get("branches", []): if branch.get("category") == "product_family" and branch.get("name") == "Red Hat Enterprise Linux": family_branch = branch - # Collect all architecture branches at the same level as product_family elif branch.get("category") == "architecture": arch = branch.get("name") if arch: arches.add(arch) - # If 'noarch' is present, expand to all main architectures if "noarch" in arches: arches = set(main_arches) if not family_branch: continue - # Find the product_name branch for CPE/version info prod_name = None cpe = None for branch in family_branch.get("branches", []): @@ -59,24 +95,24 @@ def extract_rhel_affected_products_for_db(csaf: dict) -> set: if not prod_name or not cpe: continue - # Parses the CPE string to extract major and minor version numbers + if _is_eus_product(prod_name, cpe): + logger.debug(f"Skipping EUS product: {prod_name}") + continue + # Example CPE: "cpe:/a:redhat:enterprise_linux:9::appstream" - parts = cpe.split(":") # Split the CPE string by colon + parts = cpe.split(":") major = None minor = None if len(parts) > 4: - version = parts[4] # The version is typically the 5th field (index 4) + version = parts[4] if version: if "." in version: - # If the version contains a dot, split into major and minor major, minor = version.split(".", 1) major = int(major) minor = int(minor) else: - # If no dot, only major version is present major = int(version) - # For each architecture, add a tuple with product info to the set for arch in arches: name = arch_name_map.get(arch) if name is None: @@ -84,26 +120,142 @@ def extract_rhel_affected_products_for_db(csaf: dict) -> set: continue if major: affected_products.add(( - family_branch.get("name"), # variant (e.g., "Red Hat Enterprise Linux") - name, # user-friendly architecture name - major, # major version number - minor, # minor version number (may be None) - arch # architecture short name + family_branch.get("name"), + name, + major, + minor, + arch )) logger.debug(f"Number of affected products: {len(affected_products)}") return affected_products + +def _traverse_for_eus(branches, product_eus_map=None): + """ + Recursively traverse CSAF branches to build EUS product map. + + Args: + branches: List of CSAF branch dictionaries to traverse + product_eus_map: Optional dict to accumulate results + + Returns: + Dict mapping product_id to boolean indicating if product is EUS + """ + if product_eus_map is None: + product_eus_map = {} + + for branch in branches: + category = branch.get("category") + + if category == "product_name": + prod = branch.get("product", {}) + product_id = prod.get("product_id") + + if product_id: + product_name = prod.get("name", "") + cpe = prod.get("product_identification_helper", {}).get("cpe", "") + is_eus = _is_eus_product(product_name, cpe) + product_eus_map[product_id] = is_eus + + if "branches" in branch: + _traverse_for_eus(branch["branches"], product_eus_map) + + return product_eus_map + + +def _extract_packages_from_branches(branches, product_eus_map, packages=None): + """ + Recursively traverse CSAF branches to extract package NEVRAs. + + Args: + branches: List of CSAF branch dictionaries to traverse + product_eus_map: Dict mapping product_id to EUS status + packages: Optional set to accumulate results + + Returns: + Set of NEVRA strings + """ + if packages is None: + packages = set() + + for branch in branches: + category = branch.get("category") + + if category == "product_version": + prod = branch.get("product", {}) + product_id = prod.get("product_id") + purl = prod.get("product_identification_helper", {}).get("purl") + + if not product_id: + continue + + if purl and not purl.startswith("pkg:rpm/"): + continue + + # Product IDs for packages can have format: "AppStream-9.0.0.Z.E4S:package-nevra" + # or just "package-nevra" for packages in product_version entries + skip_eus = False + for eus_prod_id, is_eus in product_eus_map.items(): + if is_eus and (":" in product_id and product_id.startswith(eus_prod_id + ":")): + skip_eus = True + break + + if skip_eus: + continue + + # Format: "package-epoch:version-release.arch" or "package-epoch:version-release.arch::module:stream" + packages.add(product_id.split("::")[0]) + + if "branches" in branch: + _extract_packages_from_branches(branch["branches"], product_eus_map, packages) + + return packages + + +def _extract_packages_from_product_tree(csaf: dict) -> set: + """ + Extracts fixed packages from CSAF product_tree using product_id fields. + Handles both regular and modular packages by extracting NEVRAs directly from product_id. + Filters out EUS products. + + Args: + csaf: CSAF document dict + + Returns: + Set of NEVRA strings + """ + product_tree = csaf.get("product_tree", {}) + + if not product_tree: + return set() + + product_eus_map = {} + for vendor_branch in product_tree.get("branches", []): + product_eus_map = _traverse_for_eus(vendor_branch.get("branches", []), product_eus_map) + + packages = set() + for vendor_branch in product_tree.get("branches", []): + packages = _extract_packages_from_branches(vendor_branch.get("branches", []), product_eus_map, packages) + + return packages + + def red_hat_advisory_scraper(csaf: dict): # At the time of writing there are ~254 advisories that do not have any vulnerabilities. if not csaf.get("vulnerabilities"): logger.warning("No vulnerabilities found in CSAF document") return None - # red_hat_advisories table values - red_hat_issued_at = csaf["document"]["tracking"]["initial_release_date"] # "2025-02-24T03:42:46+00:00" - red_hat_updated_at = csaf["document"]["tracking"]["current_release_date"] # "2025-04-17T12:08:56+00:00" - name = csaf["document"]["tracking"]["id"] # "RHSA-2025:1234" - red_hat_synopsis = csaf["document"]["title"] # "Red Hat Bug Fix Advisory: Red Hat Quay v3.13.4 bug fix release" + name = csaf["document"]["tracking"]["id"] + + red_hat_affected_products = extract_rhel_affected_products_for_db(csaf) + if not red_hat_affected_products: + logger.info(f"Skipping advisory {name}: all products are EUS-only") + return None + + red_hat_issued_at = csaf["document"]["tracking"]["initial_release_date"] + red_hat_updated_at = csaf["document"]["tracking"]["current_release_date"] + red_hat_synopsis = csaf["document"]["title"] red_hat_description = None topic = None for item in csaf["document"]["notes"]: @@ -112,59 +264,31 @@ def red_hat_advisory_scraper(csaf: dict): elif item["category"] == "summary": topic = item["text"] kind_lookup = {"RHSA": "Security", "RHBA": "Bug Fix", "RHEA": "Enhancement"} - kind = kind_lookup[name.split("-")[0]] # "RHSA-2025:1234" --> "Security" - severity = csaf["document"]["aggregate_severity"]["text"] # "Important" + kind = kind_lookup[name.split("-")[0]] + severity = csaf["document"]["aggregate_severity"]["text"] - # To maintain consistency with the existing database, we need to replace the + # To maintain consistency with the existing database, replace # "Red Hat [KIND] Advisory:" prefixes with the severity level. red_hat_synopsis = red_hat_synopsis.replace("Red Hat Bug Fix Advisory: ", f"{severity}:") red_hat_synopsis = red_hat_synopsis.replace("Red Hat Security Advisory:", f"{severity}:") red_hat_synopsis = red_hat_synopsis.replace("Red Hat Enhancement Advisory: ", f"{severity}:") - # red_hat_advisory_packages table values - red_hat_fixed_packages = set() + red_hat_fixed_packages = _extract_packages_from_product_tree(csaf) + red_hat_cve_set = set() red_hat_bugzilla_set = set() - product_id_suffix_list = ( - ".aarch64", - ".i386", - ".i686", - ".noarch", - ".ppc", - ".ppc64", - ".ppc64le", - ".s390", - ".s390x", - ".src", - ".x86_64" - ) # TODO: find a better way to filter product IDs. This is a workaround for the fact that - # the product IDs in the CSAF documents also contain artifacts like container images - # and we only are interested in RPMs. + for vulnerability in csaf["vulnerabilities"]: - for product_id in vulnerability["product_status"]["fixed"]: - if product_id.endswith(product_id_suffix_list): - # These IDs are in the format product:package_nevra - # ie- AppStream-9.4.0.Z.EUS:rsync-0:3.2.3-19.el9_4.1.aarch64" - split_on_colon = product_id.split(":") - product = split_on_colon[0] - package_nevra = ":".join(split_on_colon[-2:]) - red_hat_fixed_packages.add(package_nevra) - - # red_hat_advisory_cves table values. Many older advisories do not have CVEs and so we need to handle that. cve_id = vulnerability.get("cve", None) cve_cvss3_scoring_vector = vulnerability.get("scores", [{}])[0].get("cvss_v3", {}).get("vectorString", None) cve_cvss3_base_score = vulnerability.get("scores", [{}])[0].get("cvss_v3", {}).get("baseScore", None) cve_cwe = vulnerability.get("cwe", {}).get("id", None) red_hat_cve_set.add((cve_id, cve_cvss3_scoring_vector, cve_cvss3_base_score, cve_cwe)) - # red_hat_advisory_bugzilla_bugs table values for bug_id in vulnerability.get("ids", []): if bug_id.get("system_name") == "Red Hat Bugzilla ID": red_hat_bugzilla_set.add(bug_id["text"]) - # red_hat_advisory_affected_products table values - red_hat_affected_products = extract_rhel_affected_products_for_db(csaf) - return { "red_hat_issued_at": str(red_hat_issued_at), "red_hat_updated_at": str(red_hat_updated_at), diff --git a/apollo/rhworker/poll_rh_activities.py b/apollo/rhworker/poll_rh_activities.py index e592136..85a4380 100644 --- a/apollo/rhworker/poll_rh_activities.py +++ b/apollo/rhworker/poll_rh_activities.py @@ -651,8 +651,11 @@ async def fetch_csv_with_dates(session, url): releases = await fetch_csv_with_dates(session, base_url + "releases.csv") deletions = await fetch_csv_with_dates(session, base_url + "deletions.csv") - # Merge changes and releases, keeping the most recent timestamp for each advisory - all_advisories = {**changes, **releases} + # Merge changes and releases, prioritizing changes.csv for updated timestamps + # changes.csv contains the most recent modification time for each advisory + # releases.csv contains original publication dates + # We want changes.csv to take precedence to catch updates to existing advisories + all_advisories = {**releases, **changes} # Remove deletions for advisory_id in deletions: all_advisories.pop(advisory_id, None) diff --git a/apollo/rpmworker/rh_matcher_activities.py b/apollo/rpmworker/rh_matcher_activities.py index eb5f95b..af31b37 100644 --- a/apollo/rpmworker/rh_matcher_activities.py +++ b/apollo/rpmworker/rh_matcher_activities.py @@ -277,7 +277,7 @@ async def get_supported_products_with_rh_mirrors(filter_major_versions: Optional Filtering now happens at the mirror level within match_rh_repos activity. """ logger = Logger() - rh_mirrors = await SupportedProductsRhMirror.all().prefetch_related( + rh_mirrors = await SupportedProductsRhMirror.filter(active=True).prefetch_related( "rpm_repomds", ) ret = [] @@ -790,6 +790,9 @@ async def match_rh_repos(params) -> None: all_advisories = {} for mirror in supported_product.rh_mirrors: + if not mirror.active: + logger.debug(f"Skipping inactive mirror {mirror.name}") + continue # Apply major version filtering if specified if filter_major_versions is not None and int(mirror.match_major_version) not in filter_major_versions: logger.debug(f"Skipping mirror {mirror.name} with major version {mirror.match_major_version} due to filtering") @@ -836,23 +839,20 @@ async def match_rh_repos(params) -> None: @activity.defn async def block_remaining_rh_advisories(supported_product_id: int) -> None: - supported_product = await SupportedProduct.filter( - id=supported_product_id - ).first().prefetch_related("rh_mirrors") - for mirror in supported_product.rh_mirrors: - mirrors = await SupportedProductsRhMirror.filter( - supported_product_id=supported_product_id + mirrors = await SupportedProductsRhMirror.filter( + supported_product_id=supported_product_id, + active=True + ) + for mirror in mirrors: + advisories = await get_matching_rh_advisories(mirror) + await SupportedProductsRhBlock.bulk_create( + [ + SupportedProductsRhBlock( + **{ + "supported_products_rh_mirror_id": mirror.id, + "red_hat_advsiory_id": advisory.id, + } + ) for advisory in advisories + ], + ignore_conflicts=True ) - for mirror in mirrors: - advisories = await get_matching_rh_advisories(mirror) - await SupportedProductsRhBlock.bulk_create( - [ - SupportedProductsRhBlock( - **{ - "supported_products_rh_mirror_id": mirror.id, - "red_hat_advsiory_id": advisory.id, - } - ) for advisory in advisories - ], - ignore_conflicts=True - ) diff --git a/apollo/schema.sql b/apollo/schema.sql index 686021e..46c385a 100644 --- a/apollo/schema.sql +++ b/apollo/schema.sql @@ -623,7 +623,8 @@ CREATE TABLE public.supported_products_rh_mirrors ( match_variant text NOT NULL, match_major_version numeric NOT NULL, match_minor_version numeric, - match_arch text NOT NULL + match_arch text NOT NULL, + active boolean DEFAULT true NOT NULL ); @@ -1507,6 +1508,13 @@ CREATE INDEX supported_products_rh_mirrors_match_variant_idx ON public.supported CREATE INDEX supported_products_rh_mirrors_supported_product_idx ON public.supported_products_rh_mirrors USING btree (supported_product_id); +-- +-- Name: supported_products_rh_mirrors_active_idx; Type: INDEX; Schema: public; Owner: - +-- + +CREATE INDEX supported_products_rh_mirrors_active_idx ON public.supported_products_rh_mirrors USING btree (active); + + -- -- Name: supported_products_rpm_repomds_arch_idx; Type: INDEX; Schema: public; Owner: - -- diff --git a/apollo/server/routes/admin_supported_products.py b/apollo/server/routes/admin_supported_products.py index 730ef5b..e59e269 100644 --- a/apollo/server/routes/admin_supported_products.py +++ b/apollo/server/routes/admin_supported_products.py @@ -42,7 +42,6 @@ async def get_entity_or_error_response( ) -> Union[Model, Response]: """Get an entity by ID or filters, or return error template response.""" - # Build the query if entity_id is not None: query = model_class.get_or_none(id=entity_id) elif filters: @@ -50,7 +49,6 @@ async def get_entity_or_error_response( else: raise ValueError("Either entity_id or filters must be provided") - # Add prefetch_related if specified if prefetch_related: query = query.prefetch_related(*prefetch_related) @@ -158,7 +156,6 @@ async def admin_supported_products( params=params, ) - # Get statistics for each product for product in products.items: mirrors_count = await SupportedProductsRhMirror.filter(supported_product=product).count() repomds_count = await SupportedProductsRpmRepomd.filter( @@ -195,7 +192,6 @@ async def export_all_configs( production_only: Optional[bool] = Query(None) ): """Export configurations for all supported products as JSON with optional filtering""" - # Build query with filters query = SupportedProductsRhMirror.all() if major_version is not None: query = query.filter(match_major_version=major_version) @@ -207,7 +203,6 @@ async def export_all_configs( "rpm_repomds" ).all() - # Filter repositories by production status if specified config_data = [] for mirror in mirrors: mirror_data = await _get_mirror_config_data(mirror) @@ -251,14 +246,11 @@ async def _import_configuration(import_data: List[Dict[str, Any]], replace_exist mirror_data = config["mirror"] repositories_data = config["repositories"] - # Find or create product product = await SupportedProduct.get_or_none(name=product_data["name"]) if not product: - # For import, we should require products to exist already skipped_count += 1 continue - # Check if mirror already exists existing_mirror = await SupportedProductsRhMirror.get_or_none( supported_product=product, name=mirror_data["name"], @@ -273,24 +265,24 @@ async def _import_configuration(import_data: List[Dict[str, Any]], replace_exist continue if existing_mirror and replace_existing: - # Delete existing repositories await SupportedProductsRpmRepomd.filter(supported_products_rh_mirror=existing_mirror).delete() mirror = existing_mirror + mirror.active = mirror_data.get("active", True) + await mirror.save() updated_count += 1 else: - # Create new mirror mirror = SupportedProductsRhMirror( supported_product=product, name=mirror_data["name"], match_variant=mirror_data["match_variant"], match_major_version=mirror_data["match_major_version"], match_minor_version=mirror_data.get("match_minor_version"), - match_arch=mirror_data["match_arch"] + match_arch=mirror_data["match_arch"], + active=mirror_data.get("active", True) ) await mirror.save() created_count += 1 - # Create repositories for repo_data in repositories_data: repo = SupportedProductsRpmRepomd( supported_products_rh_mirror=mirror, @@ -336,10 +328,8 @@ async def import_configurations( status_code=302 ) - # Validate import data validation_errors = await _validate_import_data(import_data) if validation_errors: - # Limit the number of errors shown to avoid overwhelming the user max_errors = 20 if len(validation_errors) > max_errors: shown_errors = validation_errors[:max_errors] @@ -352,7 +342,6 @@ async def import_configurations( status_code=302 ) - # Import the data try: results = await _import_configuration(import_data, replace_existing) success_message = f"Import completed: {results['created']} created, {results['updated']} updated, {results['skipped']} skipped" @@ -374,13 +363,16 @@ async def admin_supported_product(request: Request, product_id: int): SupportedProduct, f"Supported product with id {product_id}", entity_id=product_id, - prefetch_related=["rh_mirrors", "rh_mirrors__rpm_repomds", "code"] + prefetch_related=["code"] ) if isinstance(product, Response): return product - # Get detailed statistics for each mirror - for mirror in product.rh_mirrors: + mirrors = await SupportedProductsRhMirror.filter( + supported_product=product + ).order_by("-active", "-match_major_version", "name").prefetch_related("rpm_repomds").all() + + for mirror in mirrors: repomds_count = await SupportedProductsRpmRepomd.filter( supported_products_rh_mirror=mirror ).count() @@ -401,6 +393,7 @@ async def admin_supported_product(request: Request, product_id: int): "admin_supported_product.jinja", { "request": request, "product": product, + "mirrors": mirrors, } ) @@ -419,10 +412,7 @@ async def admin_supported_product_delete( } ) - # Check for existing mirrors (which would contain blocks, overrides, and repomds) mirrors_count = await SupportedProductsRhMirror.filter(supported_product=product).count() - - # Check for existing advisory packages and affected products packages_count = await AdvisoryPackage.filter(supported_product=product).count() affected_products_count = await AdvisoryAffectedProduct.filter(supported_product=product).count() @@ -485,13 +475,16 @@ async def admin_supported_product_mirror_new_post( if isinstance(product, Response): return product - # Validation using centralized validation utility + form_data_raw = await request.form() + active_value = "true" if "true" in form_data_raw.getlist("active") else "false" + form_data = { "name": name, "match_variant": match_variant, "match_major_version": match_major_version, "match_minor_version": match_minor_version, "match_arch": match_arch, + "active": active_value, } try: @@ -521,6 +514,7 @@ async def admin_supported_product_mirror_new_post( match_major_version=match_major_version, match_minor_version=match_minor_version, match_arch=validated_arch, + active=(active_value == "true"), ) await mirror.save() @@ -581,7 +575,9 @@ async def admin_supported_product_mirror_post( if isinstance(mirror, Response): return mirror - # Validation using centralized validation utility + form_data = await request.form() + active_value = "true" if "true" in form_data.getlist("active") else "false" + try: validated_name = FieldValidator.validate_name( name, @@ -606,9 +602,9 @@ async def admin_supported_product_mirror_post( mirror.match_major_version = match_major_version mirror.match_minor_version = match_minor_version mirror.match_arch = validated_arch + mirror.active = (active_value == "true") await mirror.save() - # Re-fetch the mirror with all required relations after saving mirror = await SupportedProductsRhMirror.get_or_none( id=mirror_id, supported_product_id=product_id @@ -646,7 +642,6 @@ async def admin_supported_product_mirror_delete( if isinstance(mirror, Response): return mirror - # Check for existing blocks and overrides using shared logic blocks_count, overrides_count = await check_mirror_dependencies(mirror) if blocks_count > 0 or overrides_count > 0: @@ -674,7 +669,6 @@ async def admin_supported_product_mirrors_bulk_delete( """Bulk delete multiple mirrors by calling individual delete logic for each mirror""" base_url = f"/admin/supported-products/{product_id}" - # Parse and validate mirror IDs if not mirror_ids or not mirror_ids.strip(): return create_error_redirect(base_url, "No mirror IDs provided") @@ -682,14 +676,12 @@ async def admin_supported_product_mirrors_bulk_delete( mirror_id_list = [int(id_str.strip()) for id_str in mirror_ids.split(",") if id_str.strip()] if not mirror_id_list: return create_error_redirect(base_url, "No valid mirror IDs provided") - # Validate all IDs are positive for id_val in mirror_id_list: if id_val <= 0: return create_error_redirect(base_url, "All mirror IDs must be positive numbers") except ValueError: return create_error_redirect(base_url, "Invalid mirror IDs: must be comma-separated numbers") - # Get all mirrors to delete mirrors = await SupportedProductsRhMirror.filter( id__in=mirror_id_list, supported_product_id=product_id ).prefetch_related("supported_product") @@ -697,28 +689,23 @@ async def admin_supported_product_mirrors_bulk_delete( if not mirrors: return create_error_redirect(base_url, "No mirrors found with provided IDs") - # Process each mirror individually using existing single delete logic successful_deletes = [] failed_deletes = [] for mirror in mirrors: - # Check for existing blocks and overrides using shared logic blocks_count, overrides_count = await check_mirror_dependencies(mirror) if blocks_count > 0 or overrides_count > 0: - # Mirror has dependencies, cannot delete error_parts = format_dependency_error_parts(blocks_count, overrides_count) error_reason = f"{' and '.join(error_parts)}" failed_deletes.append({"name": mirror.name, "reason": error_reason}) else: - # Mirror can be deleted try: await mirror.delete() successful_deletes.append(mirror.name) except Exception as e: failed_deletes.append({"name": mirror.name, "reason": f"deletion failed: {str(e)}"}) - # Build result message with clear formatting message_parts = [] if successful_deletes: @@ -732,15 +719,12 @@ async def admin_supported_product_mirrors_bulk_delete( message = ". ".join(message_parts) - # If we had any successful deletes, treat as success even if some failed if successful_deletes: return create_success_redirect(base_url, message) else: - # All deletes failed return create_error_redirect(base_url, message) -# Repository (repomd) management routes @router.get("/{product_id}/mirrors/{mirror_id}/repomds/new", response_class=HTMLResponse) async def admin_supported_product_mirror_repomd_new( request: Request, @@ -787,24 +771,16 @@ async def admin_supported_product_mirror_repomd_new_post( if isinstance(mirror, Response): return mirror - # Validation using centralized validation utility - form_data = { - "production": production, - "arch": arch, - "url": url, - "debug_url": debug_url, - "source_url": source_url, - "repo_name": repo_name, - } - - validated_data, validation_errors = FormValidator.validate_repomd_form(form_data) + validated_data, validation_errors, form_data = _validate_repomd_form( + production, arch, url, debug_url, source_url, repo_name + ) if validation_errors: return templates.TemplateResponse( "admin_supported_product_repomd_new.jinja", { "request": request, "mirror": mirror, - "error": validation_errors[0], # Show first error + "error": validation_errors[0], "form_data": form_data } ) @@ -879,24 +855,16 @@ async def admin_supported_product_mirror_repomd_post( if isinstance(repomd, Response): return repomd - # Validation using centralized validation utility - form_data = { - "production": production, - "arch": arch, - "url": url, - "debug_url": debug_url, - "source_url": source_url, - "repo_name": repo_name, - } - - validated_data, validation_errors = FormValidator.validate_repomd_form(form_data) + validated_data, validation_errors, form_data = _validate_repomd_form( + production, arch, url, debug_url, source_url, repo_name + ) if validation_errors: return templates.TemplateResponse( "admin_supported_product_repomd.jinja", { "request": request, "repomd": repomd, - "error": validation_errors[0], # Show first error + "error": validation_errors[0], } ) @@ -938,7 +906,6 @@ async def admin_supported_product_mirror_repomd_delete( if isinstance(repomd, Response): return repomd - # Check for existing advisory packages using this repository packages_count = await AdvisoryPackage.filter( supported_products_rh_mirror=repomd.supported_products_rh_mirror, repo_name=repomd.repo_name @@ -956,7 +923,6 @@ async def admin_supported_product_mirror_repomd_delete( return RedirectResponse(f"/admin/supported-products/{product_id}/mirrors/{mirror_id}", status_code=302) -# Blocks management routes @router.get("/{product_id}/mirrors/{mirror_id}/blocks", response_class=HTMLResponse) async def admin_supported_product_mirror_blocks( request: Request, @@ -978,17 +944,14 @@ async def admin_supported_product_mirror_blocks( if isinstance(mirror, Response): return mirror - # Build query for blocked advisories query = SupportedProductsRhBlock.filter( supported_products_rh_mirror=mirror ).prefetch_related("red_hat_advisory") - # Apply search if provided if search: query = query.filter(red_hat_advisory__name__icontains=search) - # Set page size and get paginated results - params.size = min(params.size or 50, 100) # Default 50, max 100 + params.size = min(params.size or 50, 100) blocks = await paginate( query.order_by("-red_hat_advisory__red_hat_issued_at"), params=params ) @@ -1025,7 +988,6 @@ async def admin_supported_product_mirror_block_new( if isinstance(mirror, Response): return mirror - # Get advisories that are not already blocked existing_blocks = await SupportedProductsRhBlock.filter( supported_products_rh_mirror=mirror ).values_list("red_hat_advisory_id", flat=True) @@ -1034,11 +996,9 @@ async def admin_supported_product_mirror_block_new( if search: query = query.filter(name__icontains=search) - # Set page size and get paginated results - params.size = min(params.size or 50, 100) # Default 50, max 100 + params.size = min(params.size or 50, 100) advisories = await paginate(query.order_by("-red_hat_issued_at"), params=params) - # Calculate total pages for pagination component advisories_pages = ( math.ceil(advisories.total / advisories.size) if advisories.size > 0 else 1 ) @@ -1080,7 +1040,6 @@ async def admin_supported_product_mirror_block_new_post( if isinstance(advisory, Response): return advisory - # Check if block already exists existing_block = await SupportedProductsRhBlock.get_or_none( supported_products_rh_mirror=mirror, red_hat_advisory=advisory ) @@ -1130,7 +1089,6 @@ async def admin_supported_product_mirror_block_delete( ) -# Overrides management routes (similar structure to blocks) @router.get("/{product_id}/mirrors/{mirror_id}/overrides/new", response_class=HTMLResponse) async def admin_supported_product_mirror_override_new( request: Request, @@ -1149,7 +1107,6 @@ async def admin_supported_product_mirror_override_new( if isinstance(mirror, Response): return mirror - # Get advisories that don't already have overrides existing_overrides = await SupportedProductsRpmRhOverride.filter( supported_products_rh_mirror=mirror ).values_list("red_hat_advisory_id", flat=True) @@ -1158,11 +1115,9 @@ async def admin_supported_product_mirror_override_new( if search: query = query.filter(name__icontains=search) - # Set page size and get paginated results - params.size = min(params.size or 50, 100) # Default 50, max 100 + params.size = min(params.size or 50, 100) advisories = await paginate(query.order_by("-red_hat_issued_at"), params=params) - # Calculate total pages for pagination component advisories_pages = ( math.ceil(advisories.total / advisories.size) if advisories.size > 0 else 1 ) @@ -1204,7 +1159,6 @@ async def admin_supported_product_mirror_override_new_post( if isinstance(advisory, Response): return advisory - # Check if override already exists existing_override = await SupportedProductsRpmRhOverride.get_or_none( supported_products_rh_mirror=mirror, red_hat_advisory=advisory @@ -1274,6 +1228,7 @@ async def _get_mirror_config_data(mirror: SupportedProductsRhMirror) -> Dict[str "match_major_version": mirror.match_major_version, "match_minor_version": mirror.match_minor_version, "match_arch": mirror.match_arch, + "active": mirror.active, "created_at": mirror.created_at.isoformat(), "updated_at": mirror.updated_at.isoformat() if mirror.updated_at else None, }, @@ -1293,9 +1248,31 @@ async def _get_mirror_config_data(mirror: SupportedProductsRhMirror) -> Dict[str ] } +def _validate_repomd_form( + production: bool, + arch: str, + url: str, + debug_url: str, + source_url: str, + repo_name: str +) -> Tuple[Dict[str, Any], List[str], Dict[str, Any]]: + """Validate repomd form data and return validated data, errors, and original form data.""" + form_data = { + "production": production, + "arch": arch, + "url": url, + "debug_url": debug_url, + "source_url": source_url, + "repo_name": repo_name, + } + validated_data, validation_errors = FormValidator.validate_repomd_form(form_data) + return validated_data, validation_errors, form_data + def _json_serializer(obj): """Custom JSON serializer for non-standard types""" if isinstance(obj, Decimal): + if obj % 1 == 0: + return int(obj) return float(obj) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") @@ -1348,7 +1325,6 @@ async def export_product_config( media_type="text/plain" ) - # Build query with filters query = SupportedProductsRhMirror.filter(supported_product_id=product_id) if major_version is not None: query = query.filter(match_major_version=major_version) @@ -1360,7 +1336,6 @@ async def export_product_config( "rpm_repomds" ).all() - # Filter repositories by production status if specified config_data = [] for mirror in mirrors: mirror_data = await _get_mirror_config_data(mirror) @@ -1411,7 +1386,6 @@ async def admin_supported_product_mirror_overrides( if isinstance(mirror, Response): return mirror - # Build query for overrides with search query = SupportedProductsRpmRhOverride.filter( supported_products_rh_mirror_id=mirror_id ).prefetch_related("red_hat_advisory") @@ -1419,11 +1393,9 @@ async def admin_supported_product_mirror_overrides( if search: query = query.filter(red_hat_advisory__name__icontains=search) - # Apply ordering and pagination query = query.order_by("-created_at") overrides = await paginate(query, params) - # Calculate total pages for pagination component overrides_pages = ( math.ceil(overrides.total / overrides.size) if overrides.size > 0 else 1 ) diff --git a/apollo/server/routes/admin_workflows.py b/apollo/server/routes/admin_workflows.py index ef319dc..cfb26ec 100644 --- a/apollo/server/routes/admin_workflows.py +++ b/apollo/server/routes/admin_workflows.py @@ -21,7 +21,8 @@ async def admin_workflows(request: Request, user: User = Depends(admin_user_sche """Render admin workflows page for manual workflow triggering""" db_service = DatabaseService() env_info = await db_service.get_environment_info() - + index_state = await db_service.get_last_indexed_at() + return templates.TemplateResponse( "admin_workflows.jinja", { "request": request, @@ -29,6 +30,8 @@ async def admin_workflows(request: Request, user: User = Depends(admin_user_sche "env_name": env_info["environment"], "is_production": env_info["is_production"], "reset_allowed": env_info["reset_allowed"], + "last_indexed_at": index_state.get("last_indexed_at_iso"), + "last_indexed_exists": index_state.get("exists", False), } ) @@ -92,6 +95,39 @@ async def trigger_poll_rhcsaf( return RedirectResponse(url="/admin/workflows", status_code=303) +@router.post("/workflows/update-index-timestamp") +async def update_index_timestamp( + request: Request, + new_timestamp: str = Form(...), + user: User = Depends(admin_user_scheme) +): + """Update the last_indexed_at timestamp in red_hat_index_state""" + try: + # Parse the timestamp + timestamp_dt = datetime.fromisoformat(new_timestamp.replace("Z", "+00:00")) + + db_service = DatabaseService() + result = await db_service.update_last_indexed_at(timestamp_dt, user.email) + + Logger().info(f"Admin user {user.email} updated last_indexed_at to {new_timestamp}") + + # Store success message in session + request.session["workflow_message"] = result["message"] + request.session["workflow_type"] = "success" + + except ValueError as e: + Logger().error(f"Invalid timestamp format: {str(e)}") + request.session["workflow_message"] = f"Invalid timestamp format: {str(e)}" + request.session["workflow_type"] = "error" + + except Exception as e: + Logger().error(f"Error updating last_indexed_at: {str(e)}") + request.session["workflow_message"] = f"Error updating timestamp: {str(e)}" + request.session["workflow_type"] = "error" + + return RedirectResponse(url="/admin/workflows", status_code=303) + + @router.get("/workflows/database/preview-reset") async def preview_database_reset( request: Request, diff --git a/apollo/server/routes/api_osv.py b/apollo/server/routes/api_osv.py index f0022ee..debf89a 100644 --- a/apollo/server/routes/api_osv.py +++ b/apollo/server/routes/api_osv.py @@ -143,7 +143,6 @@ def to_osv_advisory(ui_url: str, advisory: Advisory) -> OSVAdvisory: for pkg in affected_packages: x = pkg[0] nevra = pkg[1] - # Only process "src" packages if nevra.group(5) != "src": continue if x.nevra in processed_nvra: @@ -198,11 +197,9 @@ def to_osv_advisory(ui_url: str, advisory: Advisory) -> OSVAdvisory: if advisory.red_hat_advisory: osv_credits.append(OSVCredit(name="Red Hat")) - # Calculate severity by finding the highest CVSS score highest_cvss_base_score = 0.0 final_score_vector = None for x in advisory.cves: - # Convert cvss3_scoring_vector to a float base_score = x.cvss3_base_score if base_score and base_score != "UNKNOWN": base_score = float(base_score) @@ -255,15 +252,14 @@ async def get_advisories_osv( cve, synopsis, severity, - kind="Security", + kind=None, fetch_related=True, ) - count = fetch_adv[0] advisories = fetch_adv[1] ui_url = await get_setting(UI_URL) - osv_advisories = [to_osv_advisory(ui_url, x) for x in advisories] - page = create_page(osv_advisories, count, params) + osv_advisories = [to_osv_advisory(ui_url, adv) for adv in advisories if adv.cves] + page = create_page(osv_advisories, len(osv_advisories), params) state = await RedHatIndexState.first() page.last_updated_at = ( @@ -282,7 +278,7 @@ async def get_advisories_osv( ) async def get_advisory_osv(advisory_id: str): advisory = ( - await Advisory.filter(name=advisory_id, kind="Security") + await Advisory.filter(name=advisory_id) .prefetch_related( "packages", "cves", @@ -295,7 +291,7 @@ async def get_advisory_osv(advisory_id: str): .get_or_none() ) - if not advisory: + if not advisory or not advisory.cves: raise HTTPException(404) ui_url = await get_setting(UI_URL) diff --git a/apollo/server/routes/api_updateinfo.py b/apollo/server/routes/api_updateinfo.py index dcc9838..efb169a 100644 --- a/apollo/server/routes/api_updateinfo.py +++ b/apollo/server/routes/api_updateinfo.py @@ -1,11 +1,13 @@ import datetime -from typing import Optional +import logging +from typing import Optional, List from xml.etree import ElementTree as ET from fastapi import APIRouter, Response from slugify import slugify +from tortoise.exceptions import DoesNotExist -from apollo.db import AdvisoryAffectedProduct +from apollo.db import AdvisoryAffectedProduct, SupportedProduct from apollo.server.settings import COMPANY_NAME, MANAGING_EDITOR, UI_URL, get_setting from apollo.rpmworker.repomd import NEVRA_RE, NVRA_RE, EPOCH_RE @@ -13,6 +15,107 @@ from common.fastapi import RenderErrorTemplateException router = APIRouter(tags=["updateinfo"]) +logger = logging.getLogger(__name__) + +# Product slug to supported_product.name mapping +PRODUCT_SLUG_MAP = { + "rocky-linux": "Rocky Linux", + "rocky-linux-sig-cloud": "Rocky Linux SIG Cloud", +} + + +def build_source_rpm_mapping(packages): + """ + Build a mapping from package names to their source RPM filenames. + + This function handles both regular packages and module packages, where + module source packages have a "module." prefix that needs to be stripped + for matching with binary packages. + + Args: + packages: List of advisory package objects with package_name, module_name, + module_stream, and nevra attributes + + Returns: + dict: Mapping of package names (with module prefix if applicable) to + source RPM filenames. For example: + { + "xorg-x11-server": "xorg-x11-server-1.20.11-27.el8_10.src.rpm", + "go-toolset:delve:1.24": "delve-1.24.1-1.module+el8.10.0+1987+42f155bb.src.rpm" + } + """ + # First, create a map of package names to their package objects + # Strip "module." prefix from package_name to ensure binary and source packages + # are grouped together (binary: "delve", source: "module.delve" -> both map to "delve") + pkg_name_map = {} + for pkg in packages: + # Strip module. prefix for consistent grouping + base_pkg_name = pkg.package_name.removeprefix("module.") + + if pkg.module_name: + name = f"{pkg.module_name}:{base_pkg_name}:{pkg.module_stream}" + else: + name = base_pkg_name + + if name not in pkg_name_map: + pkg_name_map[name] = [] + pkg_name_map[name].append(pkg) + + # Build the source RPM mapping + pkg_src_rpm = {} + for top_pkg in packages: + # Use same key format as pkg_name_map + base_pkg_name = top_pkg.package_name.removeprefix("module.") + + if top_pkg.module_name: + name = f"{top_pkg.module_name}:{base_pkg_name}:{top_pkg.module_stream}" + else: + name = base_pkg_name + + if name not in pkg_src_rpm: + for pkg in pkg_name_map[name]: + nvra_no_epoch = EPOCH_RE.sub("", pkg.nevra) + nvra = NVRA_RE.search(nvra_no_epoch) + if nvra: + nvr_name = nvra.group(1) + nvr_arch = nvra.group(4) + + # FIX: Handle module packages where package_name has "module." prefix + # Binary packages: package_name = "delve" + # Source packages: package_name = "module.delve" + # We need to strip the prefix for comparison + pkg_name_to_match = pkg.package_name.removeprefix("module.") + + if pkg_name_to_match == nvr_name and nvr_arch == "src": + src_rpm = nvra_no_epoch + if not src_rpm.endswith(".rpm"): + src_rpm += ".rpm" + pkg_src_rpm[name] = src_rpm + logger.debug(f"Found source RPM for {name}: {src_rpm} (pkg.package_name={pkg.package_name}, nvr_name={nvr_name})") + break # Found the source RPM, no need to continue + + return pkg_src_rpm + + +def resolve_product_slug(slug: str) -> Optional[str]: + """ + Convert product slug to supported_product.name. + + Args: + slug: Product slug (e.g., 'rocky-linux', 'rocky-linux-sig-cloud') + + Returns: + Product name from supported_products table, or None if not found + + Examples: + >>> resolve_product_slug('rocky-linux') + 'Rocky Linux' + >>> resolve_product_slug('Rocky-Linux') # Case insensitive + 'Rocky Linux' + >>> resolve_product_slug('invalid') + None + """ + return PRODUCT_SLUG_MAP.get(slug.lower()) @router.get("/{product_name}/{repo}/updateinfo.xml") @@ -160,33 +263,8 @@ async def get_updateinfo( "-debugsource-common", ] - pkg_name_map = {} - for pkg in advisory.packages: - name = pkg.package_name - if pkg.module_name: - name = f"{pkg.module_name}:{pkg.package_name}:{pkg.module_stream}" - if name not in pkg_name_map: - pkg_name_map[name] = [] - - pkg_name_map[name].append(pkg) - - pkg_src_rpm = {} - for top_pkg in advisory.packages: - name = top_pkg.package_name - if top_pkg.module_name: - name = f"{top_pkg.module_name}:{top_pkg.package_name}:{top_pkg.module_stream}" - if name not in pkg_src_rpm: - for pkg in pkg_name_map[name]: - nvra_no_epoch = EPOCH_RE.sub("", pkg.nevra) - nvra = NVRA_RE.search(nvra_no_epoch) - if nvra: - nvr_name = nvra.group(1) - nvr_arch = nvra.group(4) - if pkg.package_name == nvr_name and nvr_arch == "src": - src_rpm = nvra_no_epoch - if not src_rpm.endswith(".rpm"): - src_rpm += ".rpm" - pkg_src_rpm[name] = src_rpm + # Build source RPM mapping using common function + pkg_src_rpm = build_source_rpm_mapping(advisory.packages) # Collection list, may be more than one if module RPMs are involved collections = {} @@ -321,3 +399,419 @@ async def get_updateinfo( ) return Response(content=xml_str, media_type="application/xml") + + +def generate_updateinfo_xml( + affected_products: List[AdvisoryAffectedProduct], + ui_url: str, + managing_editor: str, + company_name: str, + product_name_for_packages: Optional[str] = None, + repo: Optional[str] = None, + validate_product_consistency: bool = True, +) -> str: + """ + Generate updateinfo.xml from affected products. + + This function creates XML content compatible with DNF/YUM package managers. + It handles advisory deduplication, package filtering, module RPM handling, + and data integrity validation. + + Args: + affected_products: List of AdvisoryAffectedProduct records with prefetched + advisory, cves, fixes, packages, and supported_product + ui_url: Base URL for UI references + managing_editor: Editor email for XML header + company_name: Company name for copyright + product_name_for_packages: Product name to filter packages by. + If None, uses affected_product.name + repo: Repository name to filter packages by. Required for filtering. + validate_product_consistency: If True, validate that all packages + have matching supported_product_id to + prevent cross-product contamination + + Returns: + XML string in updateinfo.xml format + """ + # Deduplicate advisories by name + advisories = {} + for affected_product in affected_products: + advisory = affected_product.advisory + if advisory.name not in advisories: + advisories[advisory.name] = { + "advisory": advisory, + "arch": affected_product.arch, + "major_version": affected_product.major_version, + "minor_version": affected_product.minor_version, + "supported_product_name": affected_product.supported_product.name, + "supported_product_id": affected_product.supported_product_id, + } + + tree = ET.Element("updates") + + for _, adv in advisories.items(): + advisory = adv["advisory"] + product_arch = adv["arch"] + major_version = adv["major_version"] + minor_version = adv["minor_version"] + supported_product_name = adv["supported_product_name"] + supported_product_id = adv["supported_product_id"] + + update = ET.SubElement(tree, "update") + + # Set update attributes + update.set("from", managing_editor) + update.set("status", "final") + + if advisory.kind == "Security": + update.set("type", "security") + elif advisory.kind == "Bug Fix": + update.set("type", "bugfix") + elif advisory.kind == "Enhancement": + update.set("type", "enhancement") + + update.set("version", "2") + + # Add id + ET.SubElement(update, "id").text = advisory.name + + # Add title + ET.SubElement(update, "title").text = advisory.synopsis + + # Add time + time_format = "%Y-%m-%d %H:%M:%S" + issued = ET.SubElement(update, "issued") + issued.set("date", advisory.published_at.strftime(time_format)) + updated = ET.SubElement(update, "updated") + updated.set("date", advisory.updated_at.strftime(time_format)) + + # Add rights + now = datetime.datetime.utcnow() + ET.SubElement( + update, "rights" + ).text = f"Copyright {now.year} {company_name}" + + # Add release name + release_name = f"{supported_product_name} {major_version}" + if minor_version: + release_name += f".{minor_version}" + ET.SubElement(update, "release").text = release_name + + # Add pushcount + ET.SubElement(update, "pushcount").text = "1" + + # Add severity + ET.SubElement(update, "severity").text = advisory.severity + + # Add summary + ET.SubElement(update, "summary").text = advisory.topic + + # Add description + ET.SubElement(update, "description").text = advisory.description + + # Add solution + ET.SubElement(update, "solution").text = "" + + # Add references + references = ET.SubElement(update, "references") + for cve in advisory.cves: + reference = ET.SubElement(references, "reference") + reference.set( + "href", + f"https://cve.mitre.org/cgi-bin/cvename.cgi?name={cve.cve}", + ) + reference.set("id", cve.cve) + reference.set("type", "cve") + reference.set("title", cve.cve) + + for fix in advisory.fixes: + reference = ET.SubElement(references, "reference") + reference.set("href", fix.source) + reference.set("id", fix.ticket_id) + reference.set("type", "bugzilla") + reference.set("title", fix.description) + + # Add UI self reference + reference = ET.SubElement(references, "reference") + reference.set("href", f"{ui_url}/{advisory.name}") + reference.set("id", advisory.name) + reference.set("type", "self") + reference.set("title", advisory.name) + + # Add packages + packages = ET.SubElement(update, "pkglist") + + suffixes_to_skip = [ + "-debuginfo", + "-debugsource", + "-debuginfo-common", + "-debugsource-common", + ] + + # Build source RPM mapping using common function + pkg_src_rpm = build_source_rpm_mapping(advisory.packages) + + # Determine the product name to use for package filtering + filter_product_name = product_name_for_packages + if filter_product_name is None: + # Use the first affected_product's name as fallback + filter_product_name = affected_products[0].name if affected_products else None + + # Collection list, may be more than one if module RPMs are involved + collections = {} + no_default_collection = False + default_collection_short = slugify(f"{filter_product_name}-{repo}-rpms") if repo else slugify(f"{filter_product_name}-rpms") + + # Check if this is an actual module advisory, if so we need to split the + # collections, and module RPMs need to go into their own collection based on + # module name, while non-module RPMs go into the main collection (if any) + for pkg in advisory.packages: + # DATA INTEGRITY CHECK: Validate supported_product_id consistency + if validate_product_consistency: + if pkg.supported_product_id != supported_product_id: + logger.error( + f"Data integrity violation detected for advisory {advisory.name}: " + f"Package {pkg.nevra} (id={pkg.id}) has supported_product_id={pkg.supported_product_id} " + f"but affected_product has supported_product_id={supported_product_id}. " + f"Skipping this package to prevent cross-product contamination." + ) + continue # Skip this package - don't include it in updateinfo + + # Filter by product name + if filter_product_name and pkg.product_name != filter_product_name: + continue + + # Filter by repository + if repo and pkg.repo_name != repo: + continue + + if pkg.module_name: + collection_short = f"{default_collection_short}__{pkg.module_name}" + if collection_short not in collections: + collections[collection_short] = { + "packages": [], + "module_context": pkg.module_context, + "module_name": pkg.module_name, + "module_stream": pkg.module_stream, + "module_version": pkg.module_version, + } + no_default_collection = True + collections[collection_short]["packages"].append(pkg) + else: + if no_default_collection: + continue + if default_collection_short not in collections: + collections[default_collection_short] = { + "packages": [], + } + collections[default_collection_short]["packages"].append(pkg) + + if no_default_collection and default_collection_short in collections: + del collections[default_collection_short] + + collections_added = 0 + + for collection_short, info in collections.items(): + # Create collection + collection = ET.Element("collection") + collection.set("short", collection_short) + + # Set short to name as well + ET.SubElement(collection, "name").text = collection_short + + if "module_name" in info: + module_element = ET.SubElement(collection, "module") + module_element.set("name", info["module_name"]) + module_element.set("stream", info["module_stream"]) + module_element.set("version", info["module_version"]) + module_element.set("context", info["module_context"]) + module_element.set("arch", product_arch) + + added_pkg_count = 0 + for pkg in info["packages"]: + if pkg.nevra.endswith(".src.rpm"): + continue + + name = pkg.package_name + epoch = "0" + if NEVRA_RE.match(pkg.nevra): + nevra = NEVRA_RE.search(pkg.nevra) + name = nevra.group(1) + epoch = nevra.group(2) + version = nevra.group(3) + release = nevra.group(4) + arch = nevra.group(5) + elif NVRA_RE.match(pkg.nevra): + nvra = NVRA_RE.search(pkg.nevra) + name = nvra.group(1) + version = nvra.group(2) + release = nvra.group(3) + arch = nvra.group(4) + else: + continue + + p_name = pkg.package_name + if pkg.module_name: + p_name = f"{pkg.module_name}:{pkg.package_name}:{pkg.module_stream}" + + if p_name not in pkg_src_rpm: + continue + if arch != product_arch and arch != "noarch": + if product_arch != "x86_64": + continue + if product_arch == "x86_64" and arch != "i686": + continue + + skip = False + for suffix in suffixes_to_skip: + if name.endswith(suffix): + skip = True + break + if skip: + continue + + package = ET.SubElement(collection, "package") + package.set("name", name) + package.set("arch", arch) + package.set("epoch", epoch) + package.set("version", version) + package.set("release", release) + package.set("src", pkg_src_rpm[p_name]) + + # Add filename element + ET.SubElement(package, + "filename").text = EPOCH_RE.sub("", pkg.nevra) + + # Add checksum + ET.SubElement( + package, "sum", type=pkg.checksum_type + ).text = pkg.checksum + + added_pkg_count += 1 + + if added_pkg_count > 0: + packages.append(collection) + collections_added += 1 + + if collections_added == 0: + tree.remove(update) + + ET.indent(tree) + xml_str = ET.tostring( + tree, + encoding="unicode", + method="xml", + short_empty_elements=True, + ) + + return xml_str + + +@router.get("/{product}/{major_version}/{repo}/updateinfo.xml") +async def get_updateinfo_v2( + product: str, + major_version: int, + repo: str, + arch: str, + minor_version: Optional[int] = None, +): + """ + Get updateinfo.xml for a product major version and repository (v2 API). + + This endpoint aggregates all advisories for the specified major version, + including all minor versions, unless minor_version is specified. + + Architecture filtering is REQUIRED because: + - Each advisory contains packages for multiple architectures + - Repository structure is architecture-specific + - DNF/YUM expects arch-specific updateinfo.xml files + + Args: + product: Product slug (e.g., 'rocky-linux', 'rocky-linux-sig-cloud') + major_version: Major version number (e.g., 8, 9, 10) + repo: Repository name (e.g., 'BaseOS', 'AppStream') + arch: Architecture (REQUIRED: 'x86_64', 'aarch64', 'ppc64le', 's390x') + minor_version: Optional minor version filter (e.g., 6 for 8.6) + + Returns: + updateinfo.xml file + + Raises: + 400: Invalid architecture or missing required parameter + 404: No advisories found or invalid product + """ + # Resolve product slug to name + product_name = resolve_product_slug(product) + if not product_name: + raise RenderErrorTemplateException( + f"Unknown product: {product}. Valid products: {', '.join(PRODUCT_SLUG_MAP.keys())}", + 404 + ) + + # Get the supported_product record + try: + supported_product = await SupportedProduct.get(name=product_name) + except DoesNotExist: + raise RenderErrorTemplateException( + f"Product not found in database: {product_name}", + 404 + ) + + # Validate architecture + valid_arches = ["x86_64", "aarch64", "ppc64le", "s390x"] + if arch not in valid_arches: + raise RenderErrorTemplateException( + f"Invalid architecture: {arch}. Must be one of {', '.join(valid_arches)}", + 400 + ) + + # Build filters using explicit supported_product_id + # This prevents cross-contamination between products + filters = { + "supported_product_id": supported_product.id, # Explicit FK - prevents cross-product contamination + "major_version": major_version, + "arch": arch, # REQUIRED filter + "advisory__packages__repo_name": repo, + "advisory__packages__supported_product_id": supported_product.id, # Double-check packages match + } + + if minor_version is not None: + filters["minor_version"] = minor_version + + # Query with prefetch + affected_products = await AdvisoryAffectedProduct.filter( + **filters + ).prefetch_related( + "advisory", + "advisory__cves", + "advisory__fixes", + "advisory__packages", + "supported_product", + ).all() + + if not affected_products: + raise RenderErrorTemplateException( + f"No advisories found for {product_name} {major_version} {repo} {arch}", + 404 + ) + + ui_url = await get_setting(UI_URL) + managing_editor = await get_setting(MANAGING_EDITOR) + company_name_value = await get_setting(COMPANY_NAME) + + # Generate the XML using the shared function + # For v2 API, we use a generic product name format for package filtering + # since we're aggregating across minor versions + product_name_for_packages = f"{product_name} {major_version} {arch}" + + xml_str = generate_updateinfo_xml( + affected_products=affected_products, + ui_url=ui_url, + managing_editor=managing_editor, + company_name=company_name_value, + product_name_for_packages=product_name_for_packages, + repo=repo, + validate_product_consistency=True, + ) + + return Response(content=xml_str, media_type="application/xml") diff --git a/apollo/server/services/database_service.py b/apollo/server/services/database_service.py index 78d6fb0..0a66800 100644 --- a/apollo/server/services/database_service.py +++ b/apollo/server/services/database_service.py @@ -123,4 +123,67 @@ async def get_environment_info(self) -> Dict[str, str]: "environment": env_name, "is_production": self.is_production_environment(), "reset_allowed": not self.is_production_environment() - } \ No newline at end of file + } + + async def get_last_indexed_at(self) -> Dict[str, Any]: + """ + Get the current last_indexed_at timestamp from red_hat_index_state + + Returns: + Dictionary with timestamp information + """ + index_state = await RedHatIndexState.first() + + if not index_state or not index_state.last_indexed_at: + return { + "last_indexed_at": None, + "last_indexed_at_iso": None, + "exists": False + } + + return { + "last_indexed_at": index_state.last_indexed_at, + "last_indexed_at_iso": index_state.last_indexed_at.isoformat(), + "exists": True + } + + async def update_last_indexed_at(self, new_timestamp: datetime, user_email: str) -> Dict[str, Any]: + """ + Update the last_indexed_at timestamp in red_hat_index_state + + Args: + new_timestamp: New timestamp to set + user_email: Email of user making the change (for logging) + + Returns: + Dictionary with operation results + + Raises: + ValueError: If timestamp is invalid + """ + logger = Logger() + + try: + # Get or create index state + index_state = await RedHatIndexState.first() + + old_timestamp = None + if index_state: + old_timestamp = index_state.last_indexed_at + index_state.last_indexed_at = new_timestamp + await index_state.save() + logger.info(f"Updated last_indexed_at by {user_email}: {old_timestamp} -> {new_timestamp}") + else: + await RedHatIndexState.create(last_indexed_at=new_timestamp) + logger.info(f"Created last_indexed_at by {user_email}: {new_timestamp}") + + return { + "success": True, + "old_timestamp": old_timestamp.isoformat() if old_timestamp else None, + "new_timestamp": new_timestamp.isoformat(), + "message": f"Successfully updated last_indexed_at to {new_timestamp.isoformat()}" + } + + except Exception as e: + logger.error(f"Failed to update last_indexed_at: {e}") + raise RuntimeError(f"Failed to update timestamp: {e}") from e \ No newline at end of file diff --git a/apollo/server/services/workflow_service.py b/apollo/server/services/workflow_service.py index 18511de..56d5103 100644 --- a/apollo/server/services/workflow_service.py +++ b/apollo/server/services/workflow_service.py @@ -206,9 +206,9 @@ async def _validate_major_versions(self, major_versions: List[int]) -> None: # Import here to avoid circular imports from apollo.db import SupportedProductsRhMirror - + # Get available major versions from RH mirrors - rh_mirrors = await SupportedProductsRhMirror.all() + rh_mirrors = await SupportedProductsRhMirror.filter(active=True) available_versions = {int(mirror.match_major_version) for mirror in rh_mirrors} # Check if all requested major versions are available diff --git a/apollo/server/templates/admin_supported_product.jinja b/apollo/server/templates/admin_supported_product.jinja index a04e2c8..4c3092c 100644 --- a/apollo/server/templates/admin_supported_product.jinja +++ b/apollo/server/templates/admin_supported_product.jinja @@ -41,13 +41,13 @@
| Architecture | ++ Status + | Repositories | @@ -92,7 +95,7 @@ - {% for mirror in product.rh_mirrors %} + {% for mirror in mirrors %}|
|---|---|---|---|
| {{ mirror.match_arch }} | ++ {% if mirror.active %} + Active + {% else %} + Inactive + {% endif %} + |
{{ mirror.stats.repomds }} repos
diff --git a/apollo/server/templates/admin_supported_product_mirror.jinja b/apollo/server/templates/admin_supported_product_mirror.jinja
index 58298ed..a16e35c 100644
--- a/apollo/server/templates/admin_supported_product_mirror.jinja
+++ b/apollo/server/templates/admin_supported_product_mirror.jinja
@@ -119,6 +119,22 @@
+
+
+
+
diff --git a/apollo/server/templates/admin_supported_product_mirror_new.jinja b/apollo/server/templates/admin_supported_product_mirror_new.jinja
index 003f826..1b9e039 100644
--- a/apollo/server/templates/admin_supported_product_mirror_new.jinja
+++ b/apollo/server/templates/admin_supported_product_mirror_new.jinja
@@ -108,6 +108,22 @@
+
+
+
+
+
+
{% if reset_allowed %}
+
+
+
+ Update CSAF Index Timestamp+Set the last_indexed_at timestamp to control which advisories are processed by the Poll RHCSAF workflow. + + {% if last_indexed_exists %} ++ Current last_indexed_at: {{ last_indexed_at }} + + {% else %} ++ No timestamp set - workflow will process all advisories + + {% endif %} + + +
@@ -114,7 +152,7 @@
-
-{% if reset_allowed %}
-{% endif %}
{% endblock %}
\ No newline at end of file
diff --git a/apollo/server/validation.py b/apollo/server/validation.py
index ec078b2..224b843 100644
--- a/apollo/server/validation.py
+++ b/apollo/server/validation.py
@@ -57,16 +57,12 @@ def __init__(
class ValidationPatterns:
"""Regex patterns for common validations."""
- # URL validation - must start with http:// or https://
URL_PATTERN = re.compile(r"^https?://.+")
- # Name patterns - alphanumeric with common special characters and spaces
- NAME_PATTERN = re.compile(r"^[a-zA-Z0-9._\s-]+$")
+ NAME_PATTERN = re.compile(r"^[a-zA-Z0-9._\s()\-]+$")
- # Architecture validation
ARCH_PATTERN = re.compile(r"^(x86_64|aarch64|i386|i686|ppc64|ppc64le|s390x|riscv64|noarch)$")
- # Repository name - more permissive for repo naming conventions
REPO_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9._-]+$")
@@ -107,7 +103,7 @@ def validate_name(name: str, min_length: int = 3, field_name: str = "name") -> s
if not ValidationPatterns.NAME_PATTERN.match(trimmed_name):
raise ValidationError(
- f"{field_name.title()} can only contain letters, numbers, spaces, dots, hyphens, and underscores",
+ f"{field_name.title()} can only contain letters, numbers, spaces, dots, hyphens, underscores, and parentheses",
ValidationErrorType.INVALID_FORMAT,
field_name,
)
@@ -176,7 +172,6 @@ def validate_architecture(arch: str, field_name: str = "architecture") -> str:
trimmed_arch = arch.strip()
- # Check if it's a valid architecture enum value
try:
Architecture(trimmed_arch)
except ValueError:
@@ -277,27 +272,23 @@ def validate_config_structure(config: Any, config_index: int) -> List[str]:
errors.append(f"Config {config_index}: Must be a dictionary")
return errors
- # Validate required top-level keys
required_keys = ["product", "mirror", "repositories"]
for key in required_keys:
if key not in config:
errors.append(f"Config {config_index}: Missing required key '{key}'")
- # Validate product data structure
if "product" in config:
product_errors = ConfigValidator.validate_product_config(
config["product"], config_index
)
errors.extend(product_errors)
- # Validate mirror data structure
if "mirror" in config:
mirror_errors = ConfigValidator.validate_mirror_config(
config["mirror"], config_index
)
errors.extend(mirror_errors)
- # Validate repositories data structure
if "repositories" in config:
repo_errors = ConfigValidator.validate_repositories_config(
config["repositories"], config_index
@@ -324,7 +315,6 @@ def validate_product_config(product: Any, config_index: int) -> List[str]:
errors.append(f"Config {config_index}: Product must be a dictionary")
return errors
- # Validate required product fields
required_fields = ["name", "variant", "vendor"]
for field in required_fields:
if field not in product or not product[field]:
@@ -332,7 +322,6 @@ def validate_product_config(product: Any, config_index: int) -> List[str]:
f"Config {config_index}: Product missing required field '{field}'"
)
- # Validate product name format if present
if product.get("name"):
try:
FieldValidator.validate_name(
@@ -361,7 +350,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]:
errors.append(f"Config {config_index}: Mirror must be a dictionary")
return errors
- # Validate required mirror fields
required_fields = ["name", "match_variant", "match_major_version", "match_arch"]
for field in required_fields:
if field not in mirror or mirror[field] is None:
@@ -369,7 +357,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]:
f"Config {config_index}: Mirror missing required field '{field}'"
)
- # Validate mirror name format if present
if mirror.get("name"):
try:
FieldValidator.validate_name(
@@ -378,7 +365,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]:
except ValidationError as e:
errors.append(f"Config {config_index}: Mirror name '{mirror['name']}' - {e.message}")
- # Validate architecture if present
if mirror.get("match_arch"):
try:
FieldValidator.validate_architecture(
@@ -387,7 +373,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]:
except ValidationError as e:
errors.append(f"Config {config_index}: Mirror architecture '{mirror['match_arch']}' - {e.message}")
- # Validate major version is numeric if present
if mirror.get("match_major_version") is not None:
if (
not isinstance(mirror["match_major_version"], int)
@@ -397,7 +382,6 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]:
f"Config {config_index}: Mirror match_major_version must be a non-negative integer"
)
- # Validate minor version is numeric if present
if mirror.get("match_minor_version") is not None:
if (
not isinstance(mirror["match_minor_version"], int)
@@ -407,6 +391,13 @@ def validate_mirror_config(mirror: Any, config_index: int) -> List[str]:
f"Config {config_index}: Mirror match_minor_version must be a non-negative integer"
)
+ # Validate active field if present
+ if "active" in mirror and mirror["active"] is not None:
+ if not isinstance(mirror["active"], bool):
+ errors.append(
+ f"Config {config_index}: Mirror active must be a boolean value"
+ )
+
return errors
@staticmethod
@@ -458,7 +449,6 @@ def validate_repository_config(
)
return errors
- # Validate required repository fields
required_fields = ["repo_name", "arch", "production", "url"]
for field in required_fields:
if field not in repo or repo[field] is None:
@@ -466,7 +456,6 @@ def validate_repository_config(
f"Config {config_index}, Repo {repo_index}: Missing required field '{field}'"
)
- # Validate repository name format if present
if repo.get("repo_name"):
try:
FieldValidator.validate_repo_name(
@@ -475,7 +464,6 @@ def validate_repository_config(
except ValidationError as e:
errors.append(f"Config {config_index}, Repo {repo_index}: Repository name '{repo['repo_name']}' - {e.message}")
- # Validate architecture if present
if repo.get("arch"):
try:
FieldValidator.validate_architecture(
@@ -484,10 +472,9 @@ def validate_repository_config(
except ValidationError as e:
errors.append(f"Config {config_index}, Repo {repo_index}: Architecture '{repo['arch']}' - {e.message}")
- # Validate URLs if present
for url_field in ["url", "debug_url", "source_url"]:
url_value = repo.get(url_field)
- if url_value: # Only validate if not empty
+ if url_value:
try:
FieldValidator.validate_url(
url_value,
@@ -499,7 +486,6 @@ def validate_repository_config(
f"Config {config_index}, Repo {repo_index}: {url_field.replace('_', ' ').title()} '{url_value}' - {e.message}"
)
- # Validate production is boolean if present
if "production" in repo and repo["production"] is not None:
if not isinstance(repo["production"], bool):
errors.append(
@@ -528,7 +514,6 @@ def validate_mirror_form(
validated_data = {}
errors = []
- # Validate name
try:
validated_data["name"] = FieldValidator.validate_name(
form_data.get("name", ""), min_length=3, field_name="mirror name"
@@ -536,7 +521,6 @@ def validate_mirror_form(
except ValidationError as e:
errors.append(e.message)
- # Validate architecture
try:
validated_data["match_arch"] = FieldValidator.validate_architecture(
form_data.get("match_arch", ""), field_name="architecture"
@@ -544,11 +528,13 @@ def validate_mirror_form(
except ValidationError as e:
errors.append(e.message)
- # Copy other fields as-is for now (they have different validation requirements)
for field in ["match_variant", "match_major_version", "match_minor_version"]:
if field in form_data:
validated_data[field] = form_data[field]
+ # Handle active checkbox (form data comes as string "true"/"false" or may be missing)
+ validated_data["active"] = form_data.get("active", "true") == "true"
+
return validated_data, errors
@staticmethod
@@ -567,7 +553,6 @@ def validate_repomd_form(
validated_data = {}
errors = []
- # Validate repository name
try:
validated_data["repo_name"] = FieldValidator.validate_repo_name(
form_data.get("repo_name", ""),
@@ -577,7 +562,6 @@ def validate_repomd_form(
except ValidationError as e:
errors.append(e.message)
- # Validate main URL (required)
try:
validated_data["url"] = FieldValidator.validate_url(
form_data.get("url", ""), field_name="repository URL", required=True
@@ -585,7 +569,6 @@ def validate_repomd_form(
except ValidationError as e:
errors.append(e.message)
- # Validate debug URL (optional)
try:
debug_url = FieldValidator.validate_url(
form_data.get("debug_url", ""), field_name="debug URL", required=False
@@ -594,7 +577,6 @@ def validate_repomd_form(
except ValidationError as e:
errors.append(e.message)
- # Validate source URL (optional)
try:
source_url = FieldValidator.validate_url(
form_data.get("source_url", ""), field_name="source URL", required=False
@@ -603,7 +585,6 @@ def validate_repomd_form(
except ValidationError as e:
errors.append(e.message)
- # Validate architecture
try:
validated_data["arch"] = FieldValidator.validate_architecture(
form_data.get("arch", ""), field_name="architecture"
@@ -611,7 +592,6 @@ def validate_repomd_form(
except ValidationError as e:
errors.append(e.message)
- # Copy production flag
validated_data["production"] = form_data.get("production", False)
return validated_data, errors
diff --git a/apollo/tests/BUILD.bazel b/apollo/tests/BUILD.bazel
index b658f79..66398ad 100644
--- a/apollo/tests/BUILD.bazel
+++ b/apollo/tests/BUILD.bazel
@@ -61,3 +61,29 @@ py_test(
"//apollo/server:server_lib",
],
)
+
+py_test(
+ name = "test_api_osv",
+ srcs = ["test_api_osv.py"],
+ deps = [
+ "//apollo/server:server_lib",
+ ],
+)
+
+py_test(
+ name = "test_api_updateinfo",
+ srcs = ["test_api_updateinfo.py"],
+ deps = [
+ "//apollo/server:server_lib",
+ ],
+)
+
+py_test(
+ name = "test_database_service",
+ srcs = ["test_database_service.py"],
+ deps = [
+ "//apollo/server:server_lib",
+ "//apollo/db:db_lib",
+ "//common:common_lib",
+ ],
+)
diff --git a/apollo/tests/test_admin_routes_supported_products.py b/apollo/tests/test_admin_routes_supported_products.py
index 32ad971..006391b 100644
--- a/apollo/tests/test_admin_routes_supported_products.py
+++ b/apollo/tests/test_admin_routes_supported_products.py
@@ -168,8 +168,8 @@ def test_json_serializer_decimal_integer(self):
"""Test JSON serializer with integer Decimal."""
decimal_val = Decimal("42")
result = _json_serializer(decimal_val)
- self.assertEqual(result, 42.0)
- self.assertIsInstance(result, float)
+ self.assertEqual(result, 42)
+ self.assertIsInstance(result, int)
def test_json_serializer_unsupported_type(self):
"""Test JSON serializer with unsupported type."""
@@ -211,10 +211,9 @@ def test_format_export_data_with_decimal(self):
result = _format_export_data(data)
- # Should be valid JSON with Decimals converted to floats
parsed = json.loads(result)
self.assertEqual(parsed[0]["price"], 19.99)
- self.assertEqual(parsed[1]["price"], 99.0)
+ self.assertEqual(parsed[1]["price"], 99)
def test_format_export_data_empty(self):
"""Test formatting empty export data."""
@@ -442,6 +441,215 @@ def test_validate_import_data_not_list(self):
self.assertIn("must be a list", errors[0])
+class TestActiveFieldCheckboxParsing(unittest.TestCase):
+ """Test the checkbox parsing logic for the active field."""
+
+ def test_checkbox_checked_returns_true(self):
+ """Test that checked checkbox results in active_value='true'."""
+ from unittest.mock import Mock
+ form_data = Mock()
+ form_data.getlist.return_value = ["true"]
+
+ active_value = "true" if "true" in form_data.getlist("active") else "false"
+
+ self.assertEqual(active_value, "true")
+
+ def test_checkbox_unchecked_returns_false(self):
+ """Test that unchecked checkbox results in active_value='false'."""
+ from unittest.mock import Mock
+ form_data = Mock()
+ form_data.getlist.return_value = []
+
+ active_value = "true" if "true" in form_data.getlist("active") else "false"
+
+ self.assertEqual(active_value, "false")
+
+ def test_checkbox_missing_field_defaults_to_false(self):
+ """Test that missing active field defaults to false."""
+ from unittest.mock import Mock
+ form_data = Mock()
+ form_data.getlist.return_value = []
+
+ active_value = "true" if "true" in form_data.getlist("active") else "false"
+
+ self.assertEqual(active_value, "false")
+
+ def test_checkbox_boolean_conversion(self):
+ """Test that string active_value converts correctly to boolean."""
+ active_value_true = "true"
+ active_value_false = "false"
+
+ self.assertTrue(active_value_true == "true")
+ self.assertFalse(active_value_false == "true")
+
+ def test_checkbox_with_unexpected_value(self):
+ """Test that unexpected values default to false."""
+ from unittest.mock import Mock
+ form_data = Mock()
+ form_data.getlist.return_value = ["yes", "1", "on"]
+
+ active_value = "true" if "true" in form_data.getlist("active") else "false"
+
+ self.assertEqual(active_value, "false")
+
+
+class TestMirrorConfigDataWithActiveField(unittest.TestCase):
+ """Test mirror configuration export includes active field."""
+
+ def test_exported_config_includes_active_true(self):
+ """Test that exported mirror config includes active=true."""
+ mirror = Mock()
+ mirror.id = 1
+ mirror.name = "Active Mirror"
+ mirror.match_variant = "Red Hat Enterprise Linux"
+ mirror.match_major_version = 9
+ mirror.match_minor_version = None
+ mirror.match_arch = "x86_64"
+ mirror.active = True
+ mirror.created_at = datetime(2024, 1, 1, 10, 0, 0)
+ mirror.updated_at = None
+
+ # Mock supported product
+ mirror.supported_product = Mock()
+ mirror.supported_product.id = 1
+ mirror.supported_product.name = "Rocky Linux"
+ mirror.supported_product.variant = "Rocky Linux"
+ mirror.supported_product.vendor = "RESF"
+
+ # Mock repositories
+ mirror.rpm_repomds = []
+
+ result = asyncio.run(_get_mirror_config_data(mirror))
+
+ # Verify active field is included and true
+ self.assertIn("active", result["mirror"])
+ self.assertTrue(result["mirror"]["active"])
+
+ def test_exported_config_includes_active_false(self):
+ """Test that exported mirror config includes active=false."""
+ mirror = Mock()
+ mirror.id = 2
+ mirror.name = "Inactive Mirror"
+ mirror.match_variant = "Red Hat Enterprise Linux"
+ mirror.match_major_version = 8
+ mirror.match_minor_version = None
+ mirror.match_arch = "x86_64"
+ mirror.active = False
+ mirror.created_at = datetime(2024, 1, 1, 10, 0, 0)
+ mirror.updated_at = None
+
+ # Mock supported product
+ mirror.supported_product = Mock()
+ mirror.supported_product.id = 1
+ mirror.supported_product.name = "Rocky Linux"
+ mirror.supported_product.variant = "Rocky Linux"
+ mirror.supported_product.vendor = "RESF"
+
+ # Mock repositories
+ mirror.rpm_repomds = []
+
+ result = asyncio.run(_get_mirror_config_data(mirror))
+
+ # Verify active field is included and false
+ self.assertIn("active", result["mirror"])
+ self.assertFalse(result["mirror"]["active"])
+
+
+class TestImportWithActiveField(unittest.TestCase):
+ """Test import validation with active field."""
+
+ def test_import_data_with_active_true_is_valid(self):
+ """Test that import data with active=true is valid."""
+ valid_data = [
+ {
+ "product": {
+ "name": "Rocky Linux",
+ "variant": "Rocky Linux",
+ "vendor": "RESF",
+ },
+ "mirror": {
+ "name": "Rocky Linux 9 x86_64",
+ "match_variant": "Red Hat Enterprise Linux",
+ "match_major_version": 9,
+ "match_arch": "x86_64",
+ "active": True,
+ },
+ "repositories": [
+ {
+ "repo_name": "BaseOS",
+ "arch": "x86_64",
+ "production": True,
+ "url": "https://example.com/repo",
+ }
+ ],
+ }
+ ]
+
+ errors = asyncio.run(_validate_import_data(valid_data))
+ self.assertEqual(errors, [])
+
+ def test_import_data_with_active_false_is_valid(self):
+ """Test that import data with active=false is valid."""
+ valid_data = [
+ {
+ "product": {
+ "name": "Rocky Linux",
+ "variant": "Rocky Linux",
+ "vendor": "RESF",
+ },
+ "mirror": {
+ "name": "Rocky Linux 8 x86_64",
+ "match_variant": "Red Hat Enterprise Linux",
+ "match_major_version": 8,
+ "match_arch": "x86_64",
+ "active": False,
+ },
+ "repositories": [
+ {
+ "repo_name": "BaseOS",
+ "arch": "x86_64",
+ "production": True,
+ "url": "https://example.com/repo",
+ }
+ ],
+ }
+ ]
+
+ errors = asyncio.run(_validate_import_data(valid_data))
+ self.assertEqual(errors, [])
+
+ def test_import_data_without_active_field_is_valid(self):
+ """Test that import data without active field is valid (backwards compatibility)."""
+ valid_data = [
+ {
+ "product": {
+ "name": "Rocky Linux",
+ "variant": "Rocky Linux",
+ "vendor": "RESF",
+ },
+ "mirror": {
+ "name": "Rocky Linux 9 x86_64",
+ "match_variant": "Red Hat Enterprise Linux",
+ "match_major_version": 9,
+ "match_arch": "x86_64",
+ # No active field - should default to true
+ },
+ "repositories": [
+ {
+ "repo_name": "BaseOS",
+ "arch": "x86_64",
+ "production": True,
+ "url": "https://example.com/repo",
+ }
+ ],
+ }
+ ]
+
+ # Should still be valid - active field is optional for backwards compatibility
+ errors = asyncio.run(_validate_import_data(valid_data))
+ self.assertEqual(errors, [])
+
+
if __name__ == "__main__":
# Run with verbose output
unittest.main(verbosity=2)
diff --git a/apollo/tests/test_api_osv.py b/apollo/tests/test_api_osv.py
new file mode 100644
index 0000000..6422c3d
--- /dev/null
+++ b/apollo/tests/test_api_osv.py
@@ -0,0 +1,248 @@
+"""
+Tests for OSV API CVE filtering functionality
+"""
+
+import unittest
+import datetime
+from unittest.mock import Mock
+
+from apollo.server.routes.api_osv import to_osv_advisory
+
+
+class MockSupportedProduct:
+ """Mock SupportedProduct model"""
+
+ def __init__(self, variant="Rocky Linux", vendor="Rocky Enterprise Software Foundation"):
+ self.variant = variant
+ self.vendor = vendor
+
+
+class MockSupportedProductsRhMirror:
+ """Mock SupportedProductsRhMirror model"""
+
+ def __init__(self, match_major_version=9):
+ self.match_major_version = match_major_version
+
+
+class MockPackage:
+ """Mock Package model"""
+
+ def __init__(
+ self,
+ nevra,
+ product_name="Rocky Linux 9",
+ repo_name="BaseOS",
+ supported_product=None,
+ supported_products_rh_mirror=None,
+ ):
+ self.nevra = nevra
+ self.product_name = product_name
+ self.repo_name = repo_name
+ self.supported_product = supported_product or MockSupportedProduct()
+ self.supported_products_rh_mirror = supported_products_rh_mirror
+
+
+class MockCVE:
+ """Mock CVE model"""
+
+ def __init__(
+ self,
+ cve="CVE-2024-1234",
+ cvss3_base_score="7.5",
+ cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:N/A:N",
+ ):
+ self.cve = cve
+ self.cvss3_base_score = cvss3_base_score
+ self.cvss3_scoring_vector = cvss3_scoring_vector
+
+
+class MockFix:
+ """Mock Fix model"""
+
+ def __init__(self, source="https://bugzilla.redhat.com/show_bug.cgi?id=1234567"):
+ self.source = source
+
+
+class MockAdvisory:
+ """Mock Advisory model"""
+
+ def __init__(
+ self,
+ name="RLSA-2024:1234",
+ synopsis="Important: test security update",
+ description="A security update for test package",
+ published_at=None,
+ updated_at=None,
+ packages=None,
+ cves=None,
+ fixes=None,
+ red_hat_advisory=None,
+ ):
+ self.name = name
+ self.synopsis = synopsis
+ self.description = description
+ self.published_at = published_at or datetime.datetime.now(
+ datetime.timezone.utc
+ )
+ self.updated_at = updated_at or datetime.datetime.now(datetime.timezone.utc)
+ self.packages = packages or []
+ self.cves = cves or []
+ self.fixes = fixes or []
+ self.red_hat_advisory = red_hat_advisory
+
+
+class TestOSVCVEFiltering(unittest.TestCase):
+ """Test CVE filtering logic in OSV API"""
+
+ def test_advisory_with_cve_has_upstream_references(self):
+ """Test that advisories with CVEs have upstream references populated"""
+ packages = [
+ MockPackage(
+ nevra="pcs-0:0.11.8-2.el9_5.src",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+ cves = [MockCVE(cve="CVE-2024-1234")]
+
+ advisory = MockAdvisory(packages=packages, cves=cves)
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ self.assertIsNotNone(result.upstream)
+ self.assertEqual(len(result.upstream), 1)
+ self.assertIn("CVE-2024-1234", result.upstream)
+
+ def test_advisory_with_multiple_cves(self):
+ """Test that advisories with multiple CVEs include all in upstream"""
+ packages = [
+ MockPackage(
+ nevra="openssl-1:3.0.7-28.el9_5.src",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+ cves = [
+ MockCVE(cve="CVE-2024-1111"),
+ MockCVE(cve="CVE-2024-2222"),
+ MockCVE(cve="CVE-2024-3333"),
+ ]
+
+ advisory = MockAdvisory(packages=packages, cves=cves)
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ self.assertIsNotNone(result.upstream)
+ self.assertEqual(len(result.upstream), 3)
+ self.assertIn("CVE-2024-1111", result.upstream)
+ self.assertIn("CVE-2024-2222", result.upstream)
+ self.assertIn("CVE-2024-3333", result.upstream)
+
+ def test_advisory_without_cves_has_empty_upstream(self):
+ """Test that advisories without CVEs have empty upstream list"""
+ packages = [
+ MockPackage(
+ nevra="kernel-0:5.14.0-427.el9.src",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+
+ advisory = MockAdvisory(packages=packages, cves=[])
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ self.assertIsNotNone(result.upstream)
+ self.assertEqual(len(result.upstream), 0)
+
+ def test_source_packages_only(self):
+ """Test that only source packages are processed, not binary packages"""
+ packages = [
+ MockPackage(
+ nevra="httpd-0:2.4.57-8.el9.src",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ MockPackage(
+ nevra="httpd-0:2.4.57-8.el9.x86_64",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ MockPackage(
+ nevra="httpd-0:2.4.57-8.el9.aarch64",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+ cves = [MockCVE()]
+
+ advisory = MockAdvisory(packages=packages, cves=cves)
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ # Should only have 1 affected package (the source package)
+ self.assertEqual(len(result.affected), 1)
+ self.assertEqual(result.affected[0].package.name, "httpd")
+
+ def test_severity_from_highest_cvss(self):
+ """Test that severity uses the highest CVSS score from multiple CVEs"""
+ packages = [
+ MockPackage(
+ nevra="vim-2:9.0.1592-1.el9.src",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+ cves = [
+ MockCVE(
+ cve="CVE-2024-1111",
+ cvss3_base_score="5.5",
+ cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:N/A:N",
+ ),
+ MockCVE(
+ cve="CVE-2024-2222",
+ cvss3_base_score="9.8",
+ cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H",
+ ),
+ MockCVE(
+ cve="CVE-2024-3333",
+ cvss3_base_score="7.5",
+ cvss3_scoring_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:N/A:N",
+ ),
+ ]
+
+ advisory = MockAdvisory(packages=packages, cves=cves)
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ self.assertIsNotNone(result.severity)
+ self.assertEqual(len(result.severity), 1)
+ self.assertEqual(result.severity[0].type, "CVSS_V3")
+ self.assertEqual(
+ result.severity[0].score, "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H"
+ )
+
+ def test_ecosystem_format(self):
+ """Test that ecosystem field is formatted correctly"""
+ packages = [
+ MockPackage(
+ nevra="bash-0:5.1.8-9.el9.src",
+ product_name="Rocky Linux 9",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+ cves = [MockCVE()]
+
+ advisory = MockAdvisory(packages=packages, cves=cves)
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ self.assertEqual(len(result.affected), 1)
+ self.assertEqual(result.affected[0].package.ecosystem, "Rocky Linux:9")
+
+ def test_version_format_with_epoch(self):
+ """Test that fixed version includes epoch in epoch:version-release format"""
+ packages = [
+ MockPackage(
+ nevra="systemd-0:252-38.el9_5.src",
+ supported_products_rh_mirror=MockSupportedProductsRhMirror(9),
+ ),
+ ]
+ cves = [MockCVE()]
+
+ advisory = MockAdvisory(packages=packages, cves=cves)
+ result = to_osv_advisory("https://errata.rockylinux.org", advisory)
+
+ fixed_version = result.affected[0].ranges[0].events[1].fixed
+ self.assertEqual(fixed_version, "0:252-38.el9_5")
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/apollo/tests/test_api_updateinfo.py b/apollo/tests/test_api_updateinfo.py
new file mode 100644
index 0000000..258bb45
--- /dev/null
+++ b/apollo/tests/test_api_updateinfo.py
@@ -0,0 +1,107 @@
+"""
+Unit tests for updateinfo API v2
+Tests product slug resolution, database queries, XML generation,
+and data integrity validation
+"""
+
+import unittest
+import sys
+import os
+
+# Add the project root to the Python path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
+
+from apollo.server.routes.api_updateinfo import (
+ resolve_product_slug,
+ PRODUCT_SLUG_MAP,
+)
+
+
+class TestProductSlugResolution(unittest.TestCase):
+ """Test product slug to product name resolution."""
+
+ def test_resolve_valid_slugs(self):
+ """Test that valid slugs resolve correctly."""
+ test_cases = [
+ ("rocky-linux", "Rocky Linux"),
+ ("rocky-linux-sig-cloud", "Rocky Linux SIG Cloud"),
+ ]
+
+ for slug, expected_name in test_cases:
+ with self.subTest(slug=slug):
+ result = resolve_product_slug(slug)
+ self.assertEqual(result, expected_name)
+
+ def test_resolve_case_insensitive(self):
+ """Test that slug resolution is case insensitive."""
+ test_cases = [
+ ("Rocky-Linux", "Rocky Linux"),
+ ("ROCKY-LINUX", "Rocky Linux"),
+ ("rocky-LINUX", "Rocky Linux"),
+ ("Rocky-Linux-SIG-Cloud", "Rocky Linux SIG Cloud"),
+ ]
+
+ for slug, expected_name in test_cases:
+ with self.subTest(slug=slug):
+ result = resolve_product_slug(slug)
+ self.assertEqual(result, expected_name)
+
+ def test_resolve_invalid_slug(self):
+ """Test that invalid slugs return None."""
+ test_cases = [
+ "invalid-slug",
+ "rocky",
+ "linux",
+ "centos-linux",
+ "",
+ "rocky_linux", # underscore instead of hyphen
+ ]
+
+ for slug in test_cases:
+ with self.subTest(slug=slug):
+ result = resolve_product_slug(slug)
+ self.assertIsNone(result)
+
+ def test_all_mapped_slugs_unique(self):
+ """Test that all product slugs map to unique names."""
+ product_names = list(PRODUCT_SLUG_MAP.values())
+ self.assertEqual(len(product_names), len(set(product_names)),
+ "Product names should be unique")
+
+ def test_slug_map_not_empty(self):
+ """Test that the slug map is not empty."""
+ self.assertGreater(len(PRODUCT_SLUG_MAP), 0,
+ "PRODUCT_SLUG_MAP should not be empty")
+
+
+class TestProductSlugFormat(unittest.TestCase):
+ """Test product slug formatting requirements."""
+
+ def test_slugs_are_lowercase(self):
+ """Test that all defined slugs are lowercase."""
+ for slug in PRODUCT_SLUG_MAP.keys():
+ with self.subTest(slug=slug):
+ self.assertEqual(slug, slug.lower(),
+ f"Slug '{slug}' should be lowercase")
+
+ def test_slugs_use_hyphens(self):
+ """Test that slugs use hyphens not underscores."""
+ for slug in PRODUCT_SLUG_MAP.keys():
+ with self.subTest(slug=slug):
+ self.assertNotIn("_", slug,
+ f"Slug '{slug}' should not contain underscores")
+ if len(slug) > 5: # Only check multi-word slugs
+ self.assertIn("-", slug,
+ f"Multi-word slug '{slug}' should contain hyphens")
+
+ def test_product_names_are_capitalized(self):
+ """Test that product names are properly capitalized."""
+ for product_name in PRODUCT_SLUG_MAP.values():
+ with self.subTest(product_name=product_name):
+ # Should start with capital letter
+ self.assertTrue(product_name[0].isupper(),
+ f"Product name '{product_name}' should start with capital letter")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/apollo/tests/test_csaf_processing.py b/apollo/tests/test_csaf_processing.py
index dbd0f91..aa31c10 100644
--- a/apollo/tests/test_csaf_processing.py
+++ b/apollo/tests/test_csaf_processing.py
@@ -22,17 +22,10 @@
)
class TestCsafProcessing(unittest.IsolatedAsyncioTestCase):
- @classmethod
- async def asyncSetUp(cls):
- # Initialize test database for all tests in this class
+ async def asyncSetUp(self):
+ # Initialize test database before each test
await initialize_test_db()
-
- @classmethod
- async def asyncTearDown(cls):
- # Close database connections when tests are done
- await close_test_db()
- def setUp(self):
# Create sample CSAF data matching schema requirements
self.sample_csaf = {
"document": {
@@ -69,10 +62,35 @@ def setUp(self):
"name": "Red Hat Enterprise Linux 9",
"product": {
"name": "Red Hat Enterprise Linux 9",
+ "product_id": "AppStream-9.4.0.Z.MAIN",
"product_identification_helper": {
- "cpe": "cpe:/o:redhat:enterprise_linux:9.4"
+ "cpe": "cpe:/o:redhat:enterprise_linux:9::appstream"
+ }
+ },
+ "branches": [
+ {
+ "category": "product_version",
+ "name": "rsync-0:3.2.3-19.el9_4.1.x86_64",
+ "product": {
+ "name": "rsync-0:3.2.3-19.el9_4.1.x86_64",
+ "product_id": "rsync-0:3.2.3-19.el9_4.1.x86_64",
+ "product_identification_helper": {
+ "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=x86_64"
+ }
+ }
+ },
+ {
+ "category": "product_version",
+ "name": "rsync-0:3.2.3-19.el9_4.1.src",
+ "product": {
+ "name": "rsync-0:3.2.3-19.el9_4.1.src",
+ "product_id": "rsync-0:3.2.3-19.el9_4.1.src",
+ "product_identification_helper": {
+ "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=src"
+ }
+ }
}
- }
+ ]
}
]
},
@@ -95,8 +113,8 @@ def setUp(self):
],
"product_status": {
"fixed": [
- "AppStream-9.4.0.Z.EUS:rsync-0:3.2.3-19.el9_4.1.x86_64",
- "AppStream-9.4.0.Z.EUS:rsync-0:3.2.3-19.el9_4.1.src"
+ "AppStream-9.4.0.Z.MAIN:rsync-0:3.2.3-19.el9_4.1.x86_64",
+ "AppStream-9.4.0.Z.MAIN:rsync-0:3.2.3-19.el9_4.1.src"
]
},
"scores": [{
@@ -117,28 +135,31 @@ def setUp(self):
}
]
}
-
+
# Create a temporary file with the sample data
self.test_file = pathlib.Path("test_csaf.json")
with open(self.test_file, "w") as f:
json.dump(self.sample_csaf, f)
- async def tearDown(self):
- # Clean up database and temporary files after each test
+ async def asyncTearDown(self):
+ # Clean up database entries and temporary files after each test
await RedHatAdvisory.all().delete()
await RedHatAdvisoryPackage.all().delete()
await RedHatAdvisoryCVE.all().delete()
- await RedHatAdvisoryBugzillaBug.all().delete()
+ await RedHatAdvisoryBugzillaBug.all().delete()
await RedHatAdvisoryAffectedProduct.all().delete()
-
- # Clean up temporary file
+
+ # Close database connections
+ await close_test_db()
+
+ # Clean up temporary files
self.test_file.unlink(missing_ok=True)
pathlib.Path("invalid_csaf.json").unlink(missing_ok=True)
async def test_new_advisory_creation(self):
# Test creating a new advisory with a real test database
result = await process_csaf_file(self.sample_csaf, "test.json")
-
+
# Verify advisory was created correctly
advisory = await RedHatAdvisory.get_or_none(name="RHSA-2025:1234")
self.assertIsNotNone(advisory)
@@ -176,7 +197,8 @@ async def test_new_advisory_creation(self):
self.assertEqual(products[0].variant, "Red Hat Enterprise Linux")
self.assertEqual(products[0].arch, "x86_64")
self.assertEqual(products[0].major_version, 9)
- self.assertEqual(products[0].minor_version, 4)
+ # Minor version is None because CPE doesn't include minor version
+ self.assertIsNone(products[0].minor_version)
async def test_advisory_update(self):
# First create an advisory with different values
@@ -224,12 +246,13 @@ async def test_no_vulnerabilities(self):
self.assertEqual(count, 0)
async def test_no_fixed_packages(self):
- # Test CSAF with vulnerabilities but no fixed packages
+ # Test CSAF with vulnerabilities but no fixed packages in product_tree
csaf = self.sample_csaf.copy()
- csaf["vulnerabilities"][0]["product_status"]["fixed"] = []
+ # Remove product_version entries from product_tree to simulate no fixed packages
+ csaf["product_tree"]["branches"][0]["branches"][0]["branches"][0].pop("branches", None)
result = await process_csaf_file(csaf, "test.json")
self.assertIsNone(result)
-
+
# Verify nothing was created
count = await RedHatAdvisory.all().count()
self.assertEqual(count, 0)
@@ -239,4 +262,7 @@ async def test_db_exception(self, mock_get_or_none):
# Simulate a database error
mock_get_or_none.side_effect = Exception("DB error")
with self.assertRaises(Exception):
- await process_csaf_file(self.sample_csaf, "test.json")
\ No newline at end of file
+ await process_csaf_file(self.sample_csaf, "test.json")
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/apollo/tests/test_database_service.py b/apollo/tests/test_database_service.py
new file mode 100644
index 0000000..0a8c215
--- /dev/null
+++ b/apollo/tests/test_database_service.py
@@ -0,0 +1,226 @@
+"""
+Tests for DatabaseService functionality
+Tests utility functions for database operations including timestamp management
+"""
+
+import unittest
+import asyncio
+from datetime import datetime, timezone
+from unittest.mock import Mock, AsyncMock, patch
+import os
+
+# Add the project root to the Python path
+import sys
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
+
+from apollo.server.services.database_service import DatabaseService
+
+
+class TestEnvironmentDetection(unittest.TestCase):
+ """Test environment detection functionality."""
+
+ def test_is_production_when_env_is_production(self):
+ """Test production detection when ENV=production."""
+ with patch.dict(os.environ, {"ENV": "production"}):
+ service = DatabaseService()
+ self.assertTrue(service.is_production_environment())
+
+ def test_is_not_production_when_env_is_development(self):
+ """Test production detection when ENV=development."""
+ with patch.dict(os.environ, {"ENV": "development"}):
+ service = DatabaseService()
+ self.assertFalse(service.is_production_environment())
+
+ def test_is_not_production_when_env_not_set(self):
+ """Test production detection when ENV is not set."""
+ with patch.dict(os.environ, {}, clear=True):
+ service = DatabaseService()
+ self.assertFalse(service.is_production_environment())
+
+ def test_is_not_production_with_staging_env(self):
+ """Test production detection with staging environment."""
+ with patch.dict(os.environ, {"ENV": "staging"}):
+ service = DatabaseService()
+ self.assertFalse(service.is_production_environment())
+
+ def test_get_environment_info_production(self):
+ """Test getting environment info for production."""
+ with patch.dict(os.environ, {"ENV": "production"}):
+ service = DatabaseService()
+ result = asyncio.run(service.get_environment_info())
+
+ self.assertEqual(result["environment"], "production")
+ self.assertTrue(result["is_production"])
+ self.assertFalse(result["reset_allowed"])
+
+ def test_get_environment_info_development(self):
+ """Test getting environment info for development."""
+ with patch.dict(os.environ, {"ENV": "development"}):
+ service = DatabaseService()
+ result = asyncio.run(service.get_environment_info())
+
+ self.assertEqual(result["environment"], "development")
+ self.assertFalse(result["is_production"])
+ self.assertTrue(result["reset_allowed"])
+
+
+class TestLastIndexedAtOperations(unittest.TestCase):
+ """Test last_indexed_at timestamp operations."""
+
+ def test_get_last_indexed_at_when_exists(self):
+ """Test getting last_indexed_at when record exists."""
+ mock_index_state = Mock()
+ test_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc)
+ mock_index_state.last_indexed_at = test_time
+
+ with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state:
+ mock_state.first = AsyncMock(return_value=mock_index_state)
+
+ service = DatabaseService()
+ result = asyncio.run(service.get_last_indexed_at())
+
+ self.assertEqual(result["last_indexed_at"], test_time)
+ self.assertEqual(result["last_indexed_at_iso"], "2025-07-01T00:00:00+00:00")
+ self.assertTrue(result["exists"])
+
+ def test_get_last_indexed_at_when_not_exists(self):
+ """Test getting last_indexed_at when no record exists."""
+ with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state:
+ mock_state.first = AsyncMock(return_value=None)
+
+ service = DatabaseService()
+ result = asyncio.run(service.get_last_indexed_at())
+
+ self.assertIsNone(result["last_indexed_at"])
+ self.assertIsNone(result["last_indexed_at_iso"])
+ self.assertFalse(result["exists"])
+
+ def test_get_last_indexed_at_when_timestamp_is_none(self):
+ """Test getting last_indexed_at when timestamp field is None."""
+ mock_index_state = Mock()
+ mock_index_state.last_indexed_at = None
+
+ with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state:
+ mock_state.first = AsyncMock(return_value=mock_index_state)
+
+ service = DatabaseService()
+ result = asyncio.run(service.get_last_indexed_at())
+
+ self.assertIsNone(result["last_indexed_at"])
+ self.assertIsNone(result["last_indexed_at_iso"])
+ self.assertFalse(result["exists"])
+
+ def test_update_last_indexed_at_existing_record(self):
+ """Test updating last_indexed_at for existing record."""
+ old_time = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
+ new_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ mock_index_state = Mock()
+ mock_index_state.last_indexed_at = old_time
+ mock_index_state.save = AsyncMock()
+
+ with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state, \
+ patch("apollo.server.services.database_service.Logger"):
+ mock_state.first = AsyncMock(return_value=mock_index_state)
+
+ service = DatabaseService()
+ result = asyncio.run(service.update_last_indexed_at(new_time, "admin@example.com"))
+
+ self.assertTrue(result["success"])
+ self.assertEqual(result["old_timestamp"], "2025-06-01T00:00:00+00:00")
+ self.assertEqual(result["new_timestamp"], "2025-07-01T00:00:00+00:00")
+ self.assertIn("Successfully updated", result["message"])
+
+ # Verify save was called
+ mock_index_state.save.assert_called_once()
+ # Verify timestamp was updated
+ self.assertEqual(mock_index_state.last_indexed_at, new_time)
+
+ def test_update_last_indexed_at_create_new_record(self):
+ """Test updating last_indexed_at when no record exists (creates new)."""
+ new_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state, \
+ patch("apollo.server.services.database_service.Logger"):
+ mock_state.first = AsyncMock(return_value=None)
+ mock_state.create = AsyncMock()
+
+ service = DatabaseService()
+ result = asyncio.run(service.update_last_indexed_at(new_time, "admin@example.com"))
+
+ self.assertTrue(result["success"])
+ self.assertIsNone(result["old_timestamp"])
+ self.assertEqual(result["new_timestamp"], "2025-07-01T00:00:00+00:00")
+ self.assertIn("Successfully updated", result["message"])
+
+ # Verify create was called with correct timestamp
+ mock_state.create.assert_called_once_with(last_indexed_at=new_time)
+
+ def test_update_last_indexed_at_handles_exception(self):
+ """Test that update_last_indexed_at handles database exceptions."""
+ new_time = datetime(2025, 7, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ with patch("apollo.server.services.database_service.RedHatIndexState") as mock_state, \
+ patch("apollo.server.services.database_service.Logger"):
+ mock_state.first = AsyncMock(side_effect=Exception("Database error"))
+
+ service = DatabaseService()
+
+ with self.assertRaises(RuntimeError) as cm:
+ asyncio.run(service.update_last_indexed_at(new_time, "admin@example.com"))
+
+ self.assertIn("Failed to update timestamp", str(cm.exception))
+
+
+class TestPartialResetValidation(unittest.TestCase):
+ """Test partial reset validation logic."""
+
+ def test_preview_partial_reset_blocks_in_production(self):
+ """Test that preview_partial_reset raises error in production."""
+ with patch.dict(os.environ, {"ENV": "production"}):
+ service = DatabaseService()
+ cutoff_date = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ with self.assertRaises(ValueError) as cm:
+ asyncio.run(service.preview_partial_reset(cutoff_date))
+
+ self.assertIn("production environment", str(cm.exception))
+
+ def test_preview_partial_reset_rejects_future_date(self):
+ """Test that preview_partial_reset rejects future dates."""
+ with patch.dict(os.environ, {"ENV": "development"}):
+ service = DatabaseService()
+ future_date = datetime(2099, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ with self.assertRaises(ValueError) as cm:
+ asyncio.run(service.preview_partial_reset(future_date))
+
+ self.assertIn("must be in the past", str(cm.exception))
+
+ def test_perform_partial_reset_blocks_in_production(self):
+ """Test that perform_partial_reset raises error in production."""
+ with patch.dict(os.environ, {"ENV": "production"}):
+ service = DatabaseService()
+ cutoff_date = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ with self.assertRaises(ValueError) as cm:
+ asyncio.run(service.perform_partial_reset(cutoff_date, "admin@example.com"))
+
+ self.assertIn("production environment", str(cm.exception))
+
+ def test_perform_partial_reset_rejects_future_date(self):
+ """Test that perform_partial_reset rejects future dates."""
+ with patch.dict(os.environ, {"ENV": "development"}):
+ service = DatabaseService()
+ future_date = datetime(2099, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
+
+ with self.assertRaises(ValueError) as cm:
+ asyncio.run(service.perform_partial_reset(future_date, "admin@example.com"))
+
+ self.assertIn("must be in the past", str(cm.exception))
+
+
+if __name__ == "__main__":
+ # Run with verbose output
+ unittest.main(verbosity=2)
diff --git a/apollo/tests/test_rhcsaf.py b/apollo/tests/test_rhcsaf.py
index 1c62f0a..2b4fc62 100644
--- a/apollo/tests/test_rhcsaf.py
+++ b/apollo/tests/test_rhcsaf.py
@@ -52,7 +52,29 @@ def setUp(self):
"product_identification_helper": {
"cpe": "cpe:/o:redhat:enterprise_linux:9.4"
}
- }
+ },
+ "branches": [
+ {
+ "category": "product_version",
+ "name": "rsync-0:3.2.3-19.el9_4.1.x86_64",
+ "product": {
+ "product_id": "rsync-0:3.2.3-19.el9_4.1.x86_64",
+ "product_identification_helper": {
+ "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=x86_64"
+ }
+ }
+ },
+ {
+ "category": "product_version",
+ "name": "rsync-0:3.2.3-19.el9_4.1.src",
+ "product": {
+ "product_id": "rsync-0:3.2.3-19.el9_4.1.src",
+ "product_identification_helper": {
+ "purl": "pkg:rpm/redhat/rsync@3.2.3-19.el9_4.1?arch=src"
+ }
+ }
+ }
+ ]
}
]
},
@@ -252,4 +274,227 @@ def test_major_only_version(self):
self.assertIn(
("Red Hat Enterprise Linux", "Red Hat Enterprise Linux for x86_64", 9, None, "x86_64"),
result
- )
\ No newline at end of file
+ )
+
+
+class TestEUSDetection(unittest.TestCase):
+ """Test EUS product detection and filtering"""
+
+ def setUp(self):
+ with patch('common.logger.Logger') as mock_logger_class:
+ mock_logger_class.return_value = MagicMock()
+ from apollo.rhcsaf import _is_eus_product
+ self._is_eus_product = _is_eus_product
+
+ def test_detect_eus_via_cpe(self):
+ """Test EUS detection via CPE product field"""
+ # EUS CPE products
+ self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_eus:9.4::appstream"))
+ self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_e4s:9.0::appstream"))
+ self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_aus:8.2::appstream"))
+ self.assertTrue(self._is_eus_product("Some Product", "cpe:/a:redhat:rhel_tus:8.8::appstream"))
+
+ # Non-EUS CPE product
+ self.assertFalse(self._is_eus_product("Some Product", "cpe:/a:redhat:enterprise_linux:9::appstream"))
+
+ def test_detect_eus_via_name(self):
+ """Test EUS detection via product name keywords"""
+ self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream EUS (v.9.4)", ""))
+ self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream E4S (v.9.0)", ""))
+ self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream AUS (v.8.2)", ""))
+ self.assertTrue(self._is_eus_product("Red Hat Enterprise Linux AppStream TUS (v.8.8)", ""))
+
+ # Non-EUS product name
+ self.assertFalse(self._is_eus_product("Red Hat Enterprise Linux AppStream", ""))
+
+ def test_eus_filtering_in_affected_products(self):
+ """Test that EUS products are filtered from affected products"""
+ csaf = {
+ "product_tree": {
+ "branches": [
+ {
+ "branches": [
+ {
+ "category": "product_family",
+ "name": "Red Hat Enterprise Linux",
+ "branches": [
+ {
+ "category": "product_name",
+ "product": {
+ "name": "Red Hat Enterprise Linux AppStream EUS (v.9.4)",
+ "product_identification_helper": {
+ "cpe": "cpe:/a:redhat:rhel_eus:9.4::appstream"
+ }
+ }
+ }
+ ]
+ },
+ {
+ "category": "architecture",
+ "name": "x86_64"
+ }
+ ]
+ }
+ ]
+ }
+ }
+
+ result = extract_rhel_affected_products_for_db(csaf)
+ # Should be empty because the only product is EUS
+ self.assertEqual(len(result), 0)
+
+
+class TestModularPackages(unittest.TestCase):
+ """Test modular package extraction"""
+
+ def test_extract_modular_packages(self):
+ """Test extraction of modular packages with ::module:stream suffix"""
+ csaf = {
+ "document": {
+ "tracking": {
+ "initial_release_date": "2025-07-28T00:00:00+00:00",
+ "current_release_date": "2025-07-28T00:00:00+00:00",
+ "id": "RHSA-2025:12008"
+ },
+ "title": "Red Hat Security Advisory: Important: redis:7 security update",
+ "aggregate_severity": {"text": "Important"},
+ "notes": [
+ {"category": "general", "text": "Test description"},
+ {"category": "summary", "text": "Test topic"}
+ ]
+ },
+ "product_tree": {
+ "branches": [
+ {
+ "branches": [
+ {
+ "category": "product_family",
+ "name": "Red Hat Enterprise Linux",
+ "branches": [
+ {
+ "category": "product_name",
+ "name": "Red Hat Enterprise Linux 9",
+ "product": {
+ "name": "Red Hat Enterprise Linux 9",
+ "product_identification_helper": {
+ "cpe": "cpe:/o:redhat:enterprise_linux:9::appstream"
+ }
+ },
+ "branches": [
+ {
+ "category": "product_version",
+ "name": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.x86_64::redis:7",
+ "product": {
+ "product_id": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.x86_64::redis:7",
+ "product_identification_helper": {
+ "purl": "pkg:rpm/redhat/redis@7.2.10-1.module+el9.6.0+23332+115a3b01?arch=x86_64&rpmmod=redis:7:9060020250716081121:9"
+ }
+ }
+ },
+ {
+ "category": "product_version",
+ "name": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.src::redis:7",
+ "product": {
+ "product_id": "redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.src::redis:7",
+ "product_identification_helper": {
+ "purl": "pkg:rpm/redhat/redis@7.2.10-1.module+el9.6.0+23332+115a3b01?arch=src&rpmmod=redis:7:9060020250716081121:9"
+ }
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "category": "architecture",
+ "name": "x86_64"
+ }
+ ]
+ }
+ ]
+ },
+ "vulnerabilities": [
+ {
+ "cve": "CVE-2025-12345",
+ "ids": [{"system_name": "Red Hat Bugzilla ID", "text": "123456"}],
+ "product_status": {"fixed": []},
+ "scores": [{"cvss_v3": {"vectorString": "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", "baseScore": 9.8}}],
+ "cwe": {"id": "CWE-79"}
+ }
+ ]
+ }
+
+ result = red_hat_advisory_scraper(csaf)
+
+ # Check that modular packages were extracted with ::module:stream stripped
+ self.assertIn("redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.x86_64", result["red_hat_fixed_packages"])
+ self.assertIn("redis-0:7.2.10-1.module+el9.6.0+23332+115a3b01.src", result["red_hat_fixed_packages"])
+
+ # Verify epoch is preserved
+ for pkg in result["red_hat_fixed_packages"]:
+ if "redis" in pkg:
+ self.assertIn("0:", pkg, "Epoch should be preserved in NEVRA")
+
+
+class TestEUSAdvisoryFiltering(unittest.TestCase):
+ """Test that EUS-only advisories are filtered out"""
+
+ def test_eus_only_advisory_returns_none(self):
+ """Test that advisory with only EUS products returns None"""
+ csaf = {
+ "document": {
+ "tracking": {
+ "initial_release_date": "2025-01-01T00:00:00+00:00",
+ "current_release_date": "2025-01-01T00:00:00+00:00",
+ "id": "RHSA-2025:9756"
+ },
+ "title": "Red Hat Security Advisory: Important: package security update",
+ "aggregate_severity": {"text": "Important"},
+ "notes": [
+ {"category": "general", "text": "EUS advisory"},
+ {"category": "summary", "text": "EUS topic"}
+ ]
+ },
+ "product_tree": {
+ "branches": [
+ {
+ "branches": [
+ {
+ "category": "product_family",
+ "name": "Red Hat Enterprise Linux",
+ "branches": [
+ {
+ "category": "product_name",
+ "name": "Red Hat Enterprise Linux AppStream EUS (v.9.4)",
+ "product": {
+ "name": "Red Hat Enterprise Linux AppStream EUS (v.9.4)",
+ "product_identification_helper": {
+ "cpe": "cpe:/a:redhat:rhel_eus:9.4::appstream"
+ }
+ }
+ }
+ ]
+ },
+ {
+ "category": "architecture",
+ "name": "x86_64"
+ }
+ ]
+ }
+ ]
+ },
+ "vulnerabilities": [
+ {
+ "cve": "CVE-2025-99999",
+ "ids": [{"system_name": "Red Hat Bugzilla ID", "text": "999999"}],
+ "product_status": {"fixed": []},
+ "scores": [{"cvss_v3": {"vectorString": "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", "baseScore": 9.8}}],
+ "cwe": {"id": "CWE-79"}
+ }
+ ]
+ }
+
+ result = red_hat_advisory_scraper(csaf)
+
+ # Advisory should be filtered out (return None) because all products are EUS
+ self.assertIsNone(result)
\ No newline at end of file
diff --git a/scripts/generate_rocky_config.py b/scripts/generate_rocky_config.py
index 1ae0438..1aafea8 100644
--- a/scripts/generate_rocky_config.py
+++ b/scripts/generate_rocky_config.py
@@ -602,7 +602,7 @@ def parse_repomd_path(repomd_url: str, base_url: str) -> Dict[str, str]:
def build_mirror_config(
- version: str, arch: str, name_suffix: Optional[str] = None
+ version: str, arch: str, name_suffix: Optional[str] = None, mirror_name_base: Optional[str] = None
) -> Dict[str, Any]:
"""
Build a mirror configuration dictionary.
@@ -611,15 +611,19 @@ def build_mirror_config(
version: Rocky Linux version
arch: Architecture
name_suffix: Optional suffix for mirror name
+ mirror_name_base: Optional custom base for mirror name (e.g., "Rocky Linux 9")
Returns:
Mirror configuration dictionary
"""
- # Build mirror name with optional suffix
- if name_suffix is not None and name_suffix != "":
- mirror_name = f"Rocky Linux {version} {name_suffix} {arch}"
+ # Build mirror name with optional custom base or suffix
+ if not mirror_name_base:
+ mirror_name_base = f"Rocky Linux {version}"
+
+ if name_suffix:
+ mirror_name = f"{mirror_name_base} {name_suffix} {arch}"
else:
- mirror_name = f"Rocky Linux {version} {arch}"
+ mirror_name = f"{mirror_name_base} {arch}"
# Parse version to extract major and minor components
if version != UNKNOWN_VALUE and "." in version:
@@ -690,6 +694,7 @@ def generate_rocky_config(
include_source: bool = True,
architectures: List[str] = None,
name_suffix: Optional[str] = None,
+ mirror_name_base: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Generate Rocky Linux configuration by discovering repository structure.
@@ -702,6 +707,7 @@ def generate_rocky_config(
include_source: Whether to include source repository URLs (default: True)
architectures: List of architectures to include (default: auto-detect)
name_suffix: Optional suffix to add to mirror names (e.g., "test", "staging")
+ mirror_name_base: Optional custom base for mirror name (e.g., "Rocky Linux 9")
Returns:
List of configuration dictionaries ready for JSON export
@@ -730,10 +736,12 @@ def generate_rocky_config(
continue
# Skip if version filter specified and doesn't match
+ # Supports both exact version match (e.g., "9.5") and major version match (e.g., "9")
if (
version
and metadata["version"] != version
and metadata["version"] != UNKNOWN_VALUE
+ and metadata["version"].split(".")[0] != version
):
continue
@@ -773,7 +781,7 @@ def generate_rocky_config(
if not detected_version:
detected_version = UNKNOWN_VALUE
- mirror_config = build_mirror_config(detected_version, arch, name_suffix)
+ mirror_config = build_mirror_config(detected_version, arch, name_suffix, mirror_name_base)
# Group repos by name and type
repo_groups = {}
@@ -828,6 +836,8 @@ def main():
%(prog)s https://mirror.example.com/pub/rocky/ --output rocky_config.json
%(prog)s https://mirror.example.com/pub/rocky/ --name-suffix test --version 9.6
%(prog)s https://staging.example.com/pub/rocky/ --name-suffix staging --arch riscv64
+ %(prog)s https://mirror.example.com/pub/rocky/ --mirror-name-base "Rocky Linux 9" --version 9.6
+ %(prog)s https://mirror.example.com/pub/rocky/ --mirror-name-base "Rocky Linux 9 (Legacy)" --version 9
""",
)
@@ -880,6 +890,11 @@ def main():
help="Optional suffix to add to mirror names (e.g., 'test', 'staging')",
)
+ parser.add_argument(
+ "--mirror-name-base",
+ help="Optional custom base for mirror name (e.g., 'Rocky Linux 9' instead of 'Rocky Linux 9.6')",
+ )
+
parser.add_argument("--output", "-o", help="Output file path (default: stdout)")
parser.add_argument(
@@ -926,6 +941,7 @@ def main():
include_source=not args.no_source,
architectures=args.arch,
name_suffix=args.name_suffix,
+ mirror_name_base=args.mirror_name_base,
)
if not config:
|