diff --git a/.github/workflows/sdk-tests.yml b/.github/workflows/sdk-tests.yml index 0646e5f8b59..5009d379e14 100644 --- a/.github/workflows/sdk-tests.yml +++ b/.github/workflows/sdk-tests.yml @@ -516,6 +516,32 @@ jobs: flags: prowler-py${{ matrix.python-version }}-vercel files: ./vercel_coverage.xml + # External Provider (dynamic loading) + - name: Check if External Provider files changed + if: steps.check-changes.outputs.any_changed == 'true' + id: changed-external + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + with: + files: | + ./prowler/providers/common/** + ./prowler/config/** + ./prowler/lib/** + ./tests/providers/external/** + ./poetry.lock + + - name: Run External Provider tests + if: steps.changed-external.outputs.any_changed == 'true' + run: poetry run pytest -n auto --cov=./prowler/providers/common --cov=./prowler/config --cov=./prowler/lib --cov-report=xml:external_coverage.xml tests/providers/external + + - name: Upload External Provider coverage to Codecov + if: steps.changed-external.outputs.any_changed == 'true' + uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de # v5.5.2 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + flags: prowler-py${{ matrix.python-version }}-external + files: ./external_coverage.xml + # Lib - name: Check if Lib files changed if: steps.check-changes.outputs.any_changed == 'true' diff --git a/prowler/CHANGELOG.md b/prowler/CHANGELOG.md index 3945d1621ab..5c059287dd6 100644 --- a/prowler/CHANGELOG.md +++ b/prowler/CHANGELOG.md @@ -6,6 +6,7 @@ All notable changes to the **Prowler SDK** are documented in this file. ### 🚀 Added +- Support for external/custom providers, checks, and compliance frameworks without modifying core code [(#10700)](https://github.com/prowler-cloud/prowler/pull/10700) - `bedrock_guardrails_configured` check for AWS provider [(#10844)](https://github.com/prowler-cloud/prowler/pull/10844) - Universal compliance pipeline integrated into the CLI: `--list-compliance` and `--list-compliance-requirements` show universal frameworks, and CSV plus OCSF outputs are generated for any framework declaring a `TableConfig` [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301) - ASD Essential Eight Maturity Model compliance framework for AWS (Maturity Level One, Nov 2023) [(#10808)](https://github.com/prowler-cloud/prowler/pull/10808) @@ -70,6 +71,10 @@ All notable changes to the **Prowler SDK** are documented in this file. - Google Workspace check reports now store the actual domain or account resource subject instead of `provider.identity` [(#10901)](https://github.com/prowler-cloud/prowler/pull/10901) - `entra_users_mfa_capable` evaluating disabled guest accounts; CIS 5.2.3.4 only targets enabled member users [(#10785)](https://github.com/prowler-cloud/prowler/pull/10785) +### 🐞 Fixed + +- `load_and_validate_config_file` now unwraps namespaced config for every built-in and external provider, and no longer leaks the full file as the provider's config when the file is namespaced [(#10700)](https://github.com/prowler-cloud/prowler/pull/10700) + --- ## [5.24.3] (Prowler v5.24.3) diff --git a/prowler/__main__.py b/prowler/__main__.py index e10c9c745b2..8fc7b5f6d24 100644 --- a/prowler/__main__.py +++ b/prowler/__main__.py @@ -10,7 +10,6 @@ from colorama import init as colorama_init from prowler.config.config import ( - EXTERNAL_TOOL_PROVIDERS, cloud_api_base_url, csv_file_suffix, get_available_compliance_frameworks, @@ -207,9 +206,10 @@ def prowler(): # We treat the compliance framework as another output format if compliance_framework: args.output_formats.extend(compliance_framework) - # If no input compliance framework, set all, unless a specific service or check is input - # Skip for IAC and LLM providers that don't use compliance frameworks - elif default_execution and provider not in ["iac", "llm"]: + # If no input compliance framework, set all, unless a specific service or check is input. + # Skip for tool-wrapper providers (iac, llm, image, and any external plug-in + # declaring `is_external_tool_provider = True`) — they don't use compliance frameworks. + elif default_execution and not Provider.is_tool_wrapper_provider(provider): args.output_formats.extend(get_available_compliance_frameworks(provider)) # Set Logger configuration @@ -247,7 +247,7 @@ def prowler(): universal_frameworks = {} # Skip compliance frameworks for external-tool providers - if provider not in EXTERNAL_TOOL_PROVIDERS: + if not Provider.is_tool_wrapper_provider(provider): bulk_compliance_frameworks = Compliance.get_bulk(provider) # Complete checks metadata with the compliance framework specification bulk_checks_metadata = update_checks_metadata_with_compliance( @@ -315,7 +315,7 @@ def prowler(): sys.exit() # Skip service and check loading for external-tool providers - if provider not in EXTERNAL_TOOL_PROVIDERS: + if not Provider.is_tool_wrapper_provider(provider): # Import custom checks from folder if checks_folder: custom_checks = parse_checks_from_folder(global_provider, checks_folder) @@ -426,6 +426,9 @@ def prowler(): output_options = VercelOutputOptions( args, bulk_checks_metadata, global_provider.identity ) + else: + # Dynamic fallback: any external/custom provider + output_options = global_provider.get_output_options(args, bulk_checks_metadata) # Run the quick inventory for the provider if available if hasattr(args, "quick_inventory") and args.quick_inventory: @@ -435,7 +438,7 @@ def prowler(): # Execute checks findings = [] - if provider in EXTERNAL_TOOL_PROVIDERS: + if Provider.is_tool_wrapper_provider(provider): # For external-tool providers, run the scan directly if provider == "llm": @@ -445,12 +448,19 @@ def streaming_callback(findings_batch): findings = global_provider.run_scan(streaming_callback=streaming_callback) else: - # Original behavior for IAC and Image - try: + if provider == "image": + try: + findings = global_provider.run() + except ImageBaseException as error: + logger.critical(f"{error}") + sys.exit(1) + else: + # IAC and external tool-wrapper providers registered via entry + # points. Unexpected failures propagate to the outer except + # Exception backstop further down in this file — keeping the + # branch free of an Image-specific catch that would otherwise + # mislead plug-in authors reading this code. findings = global_provider.run() - except ImageBaseException as error: - logger.critical(f"{error}") - sys.exit(1) # Note: External tool providers don't support granular progress tracking since # they run external tools as a black box and return all findings at once. # Progress tracking would just be 0% → 100%. @@ -1343,6 +1353,30 @@ def streaming_callback(findings_batch): ) generated_outputs["compliance"].append(generic_compliance) generic_compliance.batch_write_data_to_file() + else: + # Dynamic fallback: any external/custom provider + try: + global_provider.generate_compliance_output( + finding_outputs, + bulk_compliance_frameworks, + input_compliance_frameworks, + output_options, + generated_outputs, + ) + except NotImplementedError: + # Last resort: generic compliance + for compliance_name in input_compliance_frameworks: + filename = ( + f"{output_options.output_directory}/compliance/" + f"{output_options.output_filename}_{compliance_name}.csv" + ) + generic_compliance = GenericCompliance( + findings=finding_outputs, + compliance=bulk_compliance_frameworks[compliance_name], + file_path=filename, + ) + generated_outputs["compliance"].append(generic_compliance) + generic_compliance.batch_write_data_to_file() # AWS Security Hub Integration if provider == "aws": diff --git a/prowler/config/config.py b/prowler/config/config.py index aa81e811366..d651211801c 100644 --- a/prowler/config/config.py +++ b/prowler/config/config.py @@ -1,3 +1,4 @@ +import importlib.metadata import os import pathlib from datetime import datetime, timezone @@ -82,13 +83,38 @@ class Provider(str, Enum): actual_directory = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) +def _get_ep_compliance_dirs() -> dict: + """Discover compliance directories from entry points. Returns {provider: path}.""" + dirs = {} + for ep in importlib.metadata.entry_points(group="prowler.compliance"): + try: + module = ep.load() + if hasattr(module, "__path__"): + dirs[ep.name] = module.__path__[0] + elif hasattr(module, "__file__"): + dirs[ep.name] = os.path.dirname(module.__file__) + except Exception as error: + logger.warning( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + return dirs + + def get_available_compliance_frameworks(provider=None): available_compliance_frameworks = [] - providers = [p.value for p in Provider] + # Built-in compliance + compliance_base = f"{actual_directory}/../compliance" if provider: providers = [provider] - for current_provider in providers: - compliance_dir = f"{actual_directory}/../compliance/{current_provider}" + else: + # Scan compliance directory for all provider subdirectories + providers = [] + if os.path.isdir(compliance_base): + for entry in os.scandir(compliance_base): + if entry.is_dir(): + providers.append(entry.name) + for prov in providers: + compliance_dir = f"{compliance_base}/{prov}" if not os.path.isdir(compliance_dir): continue with os.scandir(compliance_dir) as files: @@ -97,7 +123,8 @@ def get_available_compliance_frameworks(provider=None): available_compliance_frameworks.append( file.name.removesuffix(".json") ) - # Also scan top-level compliance/ for multi-provider (universal) JSONs. + # Built-in multi-provider frameworks at top-level compliance/ directory. + # Placed before external entry points so built-ins win on name collisions. # When a specific provider was requested, only include the framework if it # declares support for that provider; otherwise include all universal frameworks. compliance_root = f"{actual_directory}/../compliance" @@ -114,6 +141,18 @@ def get_available_compliance_frameworks(provider=None): continue if name not in available_compliance_frameworks: available_compliance_frameworks.append(name) + # External compliance via entry points. + # Multi-provider support for external plug-ins is tracked in PROWLER-1444. + ep_dirs = _get_ep_compliance_dirs() + for prov, path in ep_dirs.items(): + if provider and prov != provider: + continue + if os.path.isdir(path): + for file in os.scandir(path): + if file.is_file() and file.name.endswith(".json"): + name = file.name.removesuffix(".json") + if name not in available_compliance_frameworks: + available_compliance_frameworks.append(name) return available_compliance_frameworks @@ -225,18 +264,26 @@ def load_and_validate_config_file(provider: str, config_file_path: str) -> dict: with open(config_file_path, "r", encoding=encoding_format_utf_8) as f: config_file = yaml.safe_load(f) - # Not to introduce a breaking change, allow the old format config file without any provider keys - # and a new format with a key for each provider to include their configuration values within. - if any( - key in config_file - for key in ["aws", "gcp", "azure", "kubernetes", "m365"] + # Namespaced format: each provider has its own top-level key. + # Works for every built-in and every external plugin without a hardcoded list. + # Flat legacy format is AWS-only (historical, pre-multicloud). We identify it + # by the absence of nested-dict top-level values (namespaced files always + # have dict values; the legacy AWS format only has primitives/lists). + if ( + isinstance(config_file, dict) + and provider in config_file + and isinstance(config_file[provider], dict) + ): + config = config_file.get(provider, {}) or {} + elif ( + isinstance(config_file, dict) + and config_file + and provider == "aws" + and not any(isinstance(v, dict) for v in config_file.values()) ): - config = config_file.get(provider, {}) + config = config_file else: - config = config_file if config_file else {} - # Not to break Azure, K8s and GCP does not support or use the old config format - if provider in ["azure", "gcp", "kubernetes", "m365"]: - config = {} + config = {} return config diff --git a/prowler/lib/check/check.py b/prowler/lib/check/check.py index b15cf8bfbe1..90956c79efc 100644 --- a/prowler/lib/check/check.py +++ b/prowler/lib/check/check.py @@ -1,4 +1,6 @@ import importlib +import importlib.metadata +import importlib.util import json import os import re @@ -19,6 +21,7 @@ from prowler.lib.logger import logger from prowler.lib.outputs.outputs import report from prowler.lib.utils.utils import open_file, parse_json_file, print_boxes +from prowler.providers.common.builtin import is_builtin_provider from prowler.providers.common.models import Audit_Metadata @@ -385,6 +388,45 @@ def import_check(check_path: str) -> ModuleType: return lib +def _resolve_check_module( + provider_type: str, service: str, check_name: str +) -> ModuleType: + """Resolve and import a check module. + + Built-in wins on CheckID collision. Plug-ins are first-class extenders + (they can add new checks under new CheckIDs) but cannot override + existing built-ins — a security tool prefers fail-loud predictability + over silent overrides. CheckMetadata.get_bulk() applies the same + precedence on the metadata side (first-write-wins) and emits a warning + when a plug-in tries to override, so the user knows their plug-in + duplicate is being ignored and can rename it. + + Gates the built-in branch on `is_builtin_provider(provider_type)` — + calling `find_spec` on `prowler.providers.{provider_type}.services...` + directly would propagate `ModuleNotFoundError` for external providers + (their parent package `prowler.providers.{provider_type}` does not + exist) instead of returning None. The leaf helper encapsulates the + safe lookup, so external providers go straight to entry points. For + built-ins we still use `find_spec` to distinguish "check doesn't + exist" from "check exists but failed to import" (broken transitive + dep, etc.). + """ + # Built-in first — built-in wins on CheckID collision + if is_builtin_provider(provider_type): + builtin_path = f"prowler.providers.{provider_type}.services.{service}.{check_name}.{check_name}" + if importlib.util.find_spec(builtin_path) is not None: + return import_check(builtin_path) + + # Entry point lookup — only consulted when the built-in truly doesn't exist + for ep in importlib.metadata.entry_points(group=f"prowler.checks.{provider_type}"): + if ep.name == check_name: + return importlib.import_module(ep.value) + + raise ModuleNotFoundError( + f"Check '{check_name}' not found for provider '{provider_type}'" + ) + + def run_fixer(check_findings: list) -> int: """ Run the fixer for the check if it exists and there are any FAIL findings @@ -525,9 +567,10 @@ def execute_checks( service = check_name.split("_")[0] try: try: - # Import check module - check_module_path = f"prowler.providers.{global_provider.type}.services.{service}.{check_name}.{check_name}" - lib = import_check(check_module_path) + # Import check module (built-in or entry point) + lib = _resolve_check_module( + global_provider.type, service, check_name + ) # Recover functions from check check_to_execute = getattr(lib, check_name) check = check_to_execute() @@ -605,9 +648,10 @@ def execute_checks( ) try: try: - # Import check module - check_module_path = f"prowler.providers.{global_provider.type}.services.{service}.{check_name}.{check_name}" - lib = import_check(check_module_path) + # Import check module (built-in or entry point) + lib = _resolve_check_module( + global_provider.type, service, check_name + ) # Recover functions from check check_to_execute = getattr(lib, check_name) check = check_to_execute() @@ -745,6 +789,9 @@ def execute( is_finding_muted_args["tenancy_id"] = ( global_provider.identity.tenancy_id ) + else: + # External/custom provider — delegate identity args + is_finding_muted_args = global_provider.get_mutelist_finding_args() for finding in check_findings: if global_provider.type == "cloudflare": is_finding_muted_args["account_id"] = finding.account_id diff --git a/prowler/lib/check/checks_loader.py b/prowler/lib/check/checks_loader.py index 9ef672df6b3..77840084ffd 100644 --- a/prowler/lib/check/checks_loader.py +++ b/prowler/lib/check/checks_loader.py @@ -2,10 +2,10 @@ from colorama import Fore, Style -from prowler.config.config import EXTERNAL_TOOL_PROVIDERS from prowler.lib.check.check import parse_checks_from_file from prowler.lib.check.compliance_models import Compliance from prowler.lib.check.models import CheckMetadata, Severity +from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider from prowler.lib.logger import logger @@ -26,8 +26,13 @@ def load_checks_to_execute( ) -> set: """Generate the list of checks to execute based on the cloud provider and the input arguments given""" try: - # Bypass check loading for providers that use external tools directly - if provider in EXTERNAL_TOOL_PROVIDERS: + # Bypass check loading for tool-wrapper providers — they delegate + # scanning to an external tool and have no checks to recover. + # Single source of truth across __main__, the CheckMetadata validators, + # check discovery and this loader, covering both built-in tool wrappers + # (iac/llm/image) and external plug-ins that declare + # `is_external_tool_provider = True` via the contract. + if is_tool_wrapper_provider(provider): return set() # Local subsets diff --git a/prowler/lib/check/compliance_models.py b/prowler/lib/check/compliance_models.py index 226af6ec36e..ef11986662d 100644 --- a/prowler/lib/check/compliance_models.py +++ b/prowler/lib/check/compliance_models.py @@ -1,3 +1,4 @@ +import importlib.metadata import json import os import sys @@ -434,26 +435,55 @@ def get_bulk(provider: str) -> dict: """Bulk load all compliance frameworks specification into a dict""" try: bulk_compliance_frameworks = {} + # Built-in compliance from prowler/compliance/{provider}/ available_compliance_framework_modules = list_compliance_modules() for compliance_framework in available_compliance_framework_modules: if provider in compliance_framework.name: compliance_specification_dir_path = ( f"{compliance_framework.module_finder.path}/{provider}" ) - # for compliance_framework in available_compliance_framework_modules: for filename in os.listdir(compliance_specification_dir_path): file_path = os.path.join( compliance_specification_dir_path, filename ) - # Check if it is a file and ti size is greater than 0 if os.path.isfile(file_path) and os.stat(file_path).st_size > 0: - # Open Compliance file in JSON - # cis_v1.4_aws.json --> cis_v1.4_aws compliance_framework_name = filename.split(".json")[0] - # Store the compliance info bulk_compliance_frameworks[compliance_framework_name] = ( load_compliance_framework(file_path) ) + + # External compliance via entry points + for ep in importlib.metadata.entry_points(group="prowler.compliance"): + if ep.name == provider: + try: + module = ep.load() + compliance_dir = ( + module.__path__[0] + if hasattr(module, "__path__") + else os.path.dirname(module.__file__) + ) + for filename in os.listdir(compliance_dir): + if filename.endswith(".json"): + file_path = os.path.join(compliance_dir, filename) + if ( + os.path.isfile(file_path) + and os.stat(file_path).st_size > 0 + ): + compliance_framework_name = filename.split(".json")[ + 0 + ] + if ( + compliance_framework_name + not in bulk_compliance_frameworks + ): + bulk_compliance_frameworks[ + compliance_framework_name + ] = load_compliance_framework(file_path) + except Exception as error: + logger.warning( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + except Exception as e: logger.error(f"{e.__class__.__name__}[{e.__traceback__.tb_lineno}] -- {e}") diff --git a/prowler/lib/check/models.py b/prowler/lib/check/models.py index f5f0ff87e81..5eb8ae50da1 100644 --- a/prowler/lib/check/models.py +++ b/prowler/lib/check/models.py @@ -11,10 +11,10 @@ from pydantic.v1 import BaseModel, Field, ValidationError, validator from pydantic.v1.error_wrappers import ErrorWrapper -from prowler.config.config import EXTERNAL_TOOL_PROVIDERS, Provider from prowler.lib.check.compliance_models import Compliance from prowler.lib.check.utils import recover_checks_from_provider from prowler.lib.logger import logger +from prowler.providers.common.provider import Provider as ProviderABC # Valid ResourceGroup values as defined in the RFC VALID_RESOURCE_GROUPS = frozenset( @@ -259,7 +259,7 @@ def valid_category(cls, value, values): # noqa: F841 ) if ( value_lower not in VALID_CATEGORIES - and values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS + and not ProviderABC.is_tool_wrapper_provider(values.get("Provider")) ): raise ValueError( f"Invalid category: '{value_lower}'. Must be one of: {', '.join(sorted(VALID_CATEGORIES))}." @@ -288,7 +288,9 @@ def validate_service_name(cls, service_name, values): # noqa: F841 raise ValueError("ServiceName must be a non-empty string") check_id = values.get("CheckID") - if check_id and values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if check_id and not ProviderABC.is_tool_wrapper_provider( + values.get("Provider") + ): service_from_check_id = check_id.split("_")[0] if service_name != service_from_check_id: raise ValueError( @@ -304,7 +306,9 @@ def valid_check_id(cls, check_id, values): # noqa: F841 if not check_id: raise ValueError("CheckID must be a non-empty string") - if check_id and values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if check_id and not ProviderABC.is_tool_wrapper_provider( + values.get("Provider") + ): if "-" in check_id: raise ValueError( f"CheckID {check_id} contains a hyphen, which is not allowed" @@ -313,8 +317,9 @@ def valid_check_id(cls, check_id, values): # noqa: F841 return check_id @validator("CheckTitle", pre=True, always=True) + @classmethod def validate_check_title(cls, check_title, values): # noqa: F841 - if values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if not ProviderABC.is_tool_wrapper_provider(values.get("Provider")): if len(check_title) > 150: raise ValueError( f"CheckTitle must not exceed 150 characters, got {len(check_title)} characters" @@ -326,14 +331,18 @@ def validate_check_title(cls, check_title, values): # noqa: F841 return check_title @validator("RelatedUrl", pre=True, always=True) + @classmethod def validate_related_url(cls, related_url, values): # noqa: F841 - if related_url and values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if related_url and not ProviderABC.is_tool_wrapper_provider( + values.get("Provider") + ): raise ValueError("RelatedUrl must be empty. This field is deprecated.") return related_url @validator("Remediation") + @classmethod def validate_recommendation_url(cls, remediation, values): # noqa: F841 - if values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if not ProviderABC.is_tool_wrapper_provider(values.get("Provider")): url = remediation.Recommendation.Url if url and not url.startswith("https://hub.prowler.com/"): raise ValueError( @@ -346,7 +355,7 @@ def validate_check_type(cls, check_type, values): # noqa: F841 provider = values.get("Provider", "").lower() # Non-AWS providers must have an empty CheckType list - if provider != "aws" and provider not in EXTERNAL_TOOL_PROVIDERS: + if provider != "aws" and not ProviderABC.is_tool_wrapper_provider(provider): if check_type: raise ValueError( f"CheckType must be empty for non-AWS providers. Got {check_type} for provider '{provider}'." @@ -371,8 +380,9 @@ def validate_check_type(cls, check_type, values): # noqa: F841 return check_type @validator("Description", pre=True, always=True) + @classmethod def validate_description(cls, description, values): # noqa: F841 - if values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if not ProviderABC.is_tool_wrapper_provider(values.get("Provider")): if len(description) > 400: raise ValueError( f"Description must not exceed 400 characters, got {len(description)} characters" @@ -380,8 +390,9 @@ def validate_description(cls, description, values): # noqa: F841 return description @validator("Risk", pre=True, always=True) + @classmethod def validate_risk(cls, risk, values): # noqa: F841 - if values.get("Provider") not in EXTERNAL_TOOL_PROVIDERS: + if not ProviderABC.is_tool_wrapper_provider(values.get("Provider")): if len(risk) > 400: raise ValueError( f"Risk must not exceed 400 characters, got {len(risk)} characters" @@ -433,6 +444,20 @@ def get_bulk(provider: str) -> dict[str, "CheckMetadata"]: metadata_file = f"{check_path}/{check_name}.metadata.json" # Load metadata check_metadata = load_check_metadata(metadata_file) + # Built-in wins on CheckID collision. Plug-in entry points are + # appended after built-ins by `recover_checks_from_provider`, so + # a duplicate CheckID here means an entry-point check is trying + # to override a built-in. Ignore the override (the built-in + # metadata stays) and surface it via a warning — matching the + # precedence enforced by `_resolve_check_module`. + if check_metadata.CheckID in bulk_check_metadata: + logger.warning( + f"Plug-in check metadata '{check_metadata.CheckID}' " + f"(loaded from '{metadata_file}') is being IGNORED — " + f"a built-in with the same CheckID exists. To use your " + f"plug-in, register it under a different CheckID." + ) + continue bulk_check_metadata[check_metadata.CheckID] = check_metadata return bulk_check_metadata @@ -470,7 +495,7 @@ def list( # If the bulk checks metadata is not provided, get it if not bulk_checks_metadata: bulk_checks_metadata = {} - available_providers = [p.value for p in Provider] + available_providers = ProviderABC.get_available_providers() for provider_name in available_providers: bulk_checks_metadata.update(CheckMetadata.get_bulk(provider_name)) if provider: @@ -495,7 +520,7 @@ def list( # Loaded here, as it is not always needed if not bulk_compliance_frameworks: bulk_compliance_frameworks = {} - available_providers = [p.value for p in Provider] + available_providers = ProviderABC.get_available_providers() for provider in available_providers: bulk_compliance_frameworks = Compliance.get_bulk(provider=provider) checks_from_compliance_framework = ( diff --git a/prowler/lib/check/tool_wrapper.py b/prowler/lib/check/tool_wrapper.py new file mode 100644 index 00000000000..f00afd35d60 --- /dev/null +++ b/prowler/lib/check/tool_wrapper.py @@ -0,0 +1,57 @@ +"""Standalone helper for tool-wrapper provider detection. + +A provider is a "tool wrapper" if it delegates scanning to an external tool +(Trivy, promptfoo, etc.) instead of running checks/services through the +standard Prowler engine. This module is the single source of truth for that +classification across the codebase. + +Kept as a leaf module with no Prowler imports beyond the leaf +`external_tool_providers` so it can be referenced from `prowler.lib.check.*` +and `prowler.providers.common.provider` without forming an import cycle. +""" + +import importlib.metadata + +from prowler.lib.check.external_tool_providers import EXTERNAL_TOOL_PROVIDERS + +# Module-level cache for entry-point classes consulted by this helper. +# Independent of `Provider._ep_providers` to keep this module leaf — the cost +# of a duplicate cache entry is negligible (one class object per external +# provider, loaded lazily on first lookup). +_ep_class_cache: dict = {} + + +def _load_ep_class(provider: str): + """Return the entry-point provider class for `provider`, or None. + + Caches the result in `_ep_class_cache`. Errors during entry-point loading + are swallowed (returning None) so a broken plug-in never crashes the + is-tool-wrapper check; it just falls through to "not a tool wrapper". + """ + if provider in _ep_class_cache: + return _ep_class_cache[provider] + for ep in importlib.metadata.entry_points(group="prowler.providers"): + if ep.name == provider: + try: + cls = ep.load() + except Exception: + cls = None + _ep_class_cache[provider] = cls + return cls + _ep_class_cache[provider] = None + return None + + +def is_tool_wrapper_provider(provider: str) -> bool: + """Return True if the provider delegates scanning to an external tool. + + Combines the built-in `EXTERNAL_TOOL_PROVIDERS` frozenset (fast path for + iac/llm/image) with the `is_external_tool_provider` class attribute of + external plug-ins registered via entry points. This is the single source + of truth consulted by `__main__`, the `CheckMetadata` validators, the + check-loading utilities, and the checks loader. + """ + if provider in EXTERNAL_TOOL_PROVIDERS: + return True + cls = _load_ep_class(provider) + return bool(cls and getattr(cls, "is_external_tool_provider", False)) diff --git a/prowler/lib/check/utils.py b/prowler/lib/check/utils.py index 0e4807078f6..9c9a9c05238 100644 --- a/prowler/lib/check/utils.py +++ b/prowler/lib/check/utils.py @@ -1,9 +1,43 @@ import importlib +import importlib.metadata +import importlib.util +import os import sys from pkgutil import walk_packages -from prowler.lib.check.external_tool_providers import EXTERNAL_TOOL_PROVIDERS +from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider from prowler.lib.logger import logger +from prowler.providers.common.builtin import is_builtin_provider + + +def _recover_ep_checks(provider: str, service: str = None) -> list[tuple]: + """Discover external checks registered via entry points for a provider. + + External plugins follow the same layout as built-ins: + `{plugin_root}.services.{service}.{check}.{check}` + + When `service` is provided, only entry points whose dotted path contains + `.services.{service}.` are included — mirroring how built-in discovery + filters by the `prowler.providers.{provider}.services.{service}` package. + + Uses find_spec to locate the check module without importing it, + avoiding service client initialization at discovery time. + """ + checks = [] + for ep in importlib.metadata.entry_points(group=f"prowler.checks.{provider}"): + try: + if service and f".services.{service}." not in ep.value: + continue + + spec = importlib.util.find_spec(ep.value) + if spec and spec.origin: + check_path = os.path.dirname(spec.origin) + checks.append((ep.name, check_path)) + except Exception as error: + logger.warning( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + return checks def recover_checks_from_provider( @@ -15,29 +49,55 @@ def recover_checks_from_provider( Returns a list of tuples with the following format (check_name, check_path) """ try: - # Bypass check loading for providers that use external tools directly - if provider in EXTERNAL_TOOL_PROVIDERS: + # Bypass check loading for tool-wrapper providers — they delegate + # scanning to an external tool and have no checks to recover. + # Single source of truth: combines the EXTERNAL_TOOL_PROVIDERS + # frozenset (built-ins) with the per-provider `is_external_tool_provider` + # class attribute (so external plug-ins opt in via the contract). + if is_tool_wrapper_provider(provider): return [] checks = [] - modules = list_modules(provider, service) - for module_name in modules: - # Format: "prowler.providers.{provider}.services.{service}.{check_name}.{check_name}" - check_module_name = module_name.name - # We need to exclude common shared libraries in services - if ( - check_module_name.count(".") == 6 - and ".lib." not in check_module_name - and (not check_module_name.endswith("_fixer") or include_fixers) - ): - check_path = module_name.module_finder.path - # Check name is the last part of the check_module_name - check_name = check_module_name.split(".")[-1] - check_info = (check_name, check_path) - checks.append(check_info) - except ModuleNotFoundError: - logger.critical(f"Service {service} was not found for the {provider} provider.") - sys.exit(1) + # Built-in checks from prowler.providers.{provider}.services. Gate + # the built-in branch on `is_builtin_provider(provider)` — calling + # `find_spec` directly on `prowler.providers.{provider}.services` + # would propagate `ModuleNotFoundError` when the parent package + # `prowler.providers.{provider}` does not exist (i.e. the provider + # is external), instead of returning None. The leaf helper + # encapsulates the safe lookup, so we only run the built-in + # discovery when the provider actually ships with the SDK; for + # external providers we go straight to entry points. + if is_builtin_provider(provider): + modules = list_modules(provider, service) + for module_name in modules: + # Format: "prowler.providers.{provider}.services.{service}.{check_name}.{check_name}" + check_module_name = module_name.name + # We need to exclude common shared libraries in services + if ( + check_module_name.count(".") == 6 + and ".lib." not in check_module_name + and (not check_module_name.endswith("_fixer") or include_fixers) + ): + check_path = module_name.module_finder.path + check_name = check_module_name.split(".")[-1] + check_info = (check_name, check_path) + checks.append(check_info) + + # External checks registered via entry points — always consulted, with + # optional service filter. Previously gated by `if not service:`, which + # prevented external providers from being usable with --service. + checks.extend(_recover_ep_checks(provider, service)) + + # A service was requested but nothing matched in either built-ins or + # entry points — surface this as a clear error instead of silently + # returning an empty list. + if service and not checks: + logger.critical( + f"Service '{service}' was not found for the '{provider}' provider " + f"(neither as a built-in nor via external entry points)." + ) + sys.exit(1) + except Exception as e: logger.critical(f"{e.__class__.__name__}[{e.__traceback__.tb_lineno}]: {e}") sys.exit(1) @@ -64,8 +124,9 @@ def recover_checks_from_service(service_list: list, provider: str) -> set: Returns a set of checks from the given services """ try: - # Bypass check loading for providers that use external tools directly - if provider in EXTERNAL_TOOL_PROVIDERS: + # Bypass check loading for tool-wrapper providers — symmetric with + # `recover_checks_from_provider` above, using the same source of truth. + if is_tool_wrapper_provider(provider): return set() checks = set() diff --git a/prowler/lib/cli/parser.py b/prowler/lib/cli/parser.py index 24aa02083db..fc75e9618d1 100644 --- a/prowler/lib/cli/parser.py +++ b/prowler/lib/cli/parser.py @@ -20,19 +20,58 @@ validate_provider_arguments, validate_sarif_usage, ) +from prowler.providers.common.provider import Provider class ProwlerArgumentParser: # Set the default parser def __init__(self): + # Discover any providers not in the hardcoded list below + # TODO - First step to support current providers and the new external provider implementation + known_providers = { + "aws", + "azure", + "gcp", + "kubernetes", + "m365", + "github", + "googleworkspace", + "cloudflare", + "oraclecloud", + "openstack", + "alibabacloud", + "iac", + "llm", + "image", + "nhn", + "mongodbatlas", + "vercel", + } + all_providers = set(Provider.get_available_providers()) + new_providers = sorted(all_providers - known_providers) + + # Build extra strings for dynamically discovered providers + extra_providers_csv = "" + extra_providers_text = "" + if new_providers: + providers_help = Provider.get_providers_help_text() + extra_providers_csv = "," + ",".join(new_providers) + extra_lines = [] + for name in new_providers: + help_text = providers_help.get(name, "") + if help_text: + extra_lines.append(f" {name:<20}{help_text}") + if extra_lines: + extra_providers_text = "\n" + "\n".join(extra_lines) + # CLI Arguments self.parser = argparse.ArgumentParser( prog="prowler", formatter_class=RawTextHelpFormatter, - usage="prowler [-h] [--version] {aws,azure,gcp,kubernetes,m365,github,googleworkspace,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack,vercel,dashboard,iac,image,llm} ...", - epilog=""" + usage=f"prowler [-h] [--version] {{aws,azure,gcp,kubernetes,m365,github,googleworkspace,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack,vercel,dashboard,iac,image,llm{extra_providers_csv}}} ...", + epilog=f""" Available Cloud Providers: - {aws,azure,gcp,kubernetes,m365,github,googleworkspace,iac,llm,image,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack,vercel} + {{aws,azure,gcp,kubernetes,m365,github,googleworkspace,nhn,mongodbatlas,oraclecloud,alibabacloud,cloudflare,openstack,vercel,dashboard,iac,image,llm{extra_providers_csv}}} aws AWS Provider azure Azure Provider gcp GCP Provider @@ -49,13 +88,13 @@ def __init__(self): image Container Image Provider nhn NHN Provider (Unofficial) mongodbatlas MongoDB Atlas Provider - vercel Vercel Provider + vercel Vercel Provider{extra_providers_text} Available components: dashboard Local dashboard To see the different available options on a specific component, run: - prowler {provider|dashboard} -h|--help + prowler {{provider|dashboard}} -h|--help Detailed documentation at https://docs.prowler.com """, @@ -114,8 +153,10 @@ def parse(self, args=None) -> argparse.Namespace: and (sys.argv[1] not in ("-v", "--version")) ): # Since the provider is always the second argument, we are checking if - # a flag, starting by "-", is supplied - if "-" in sys.argv[1]: + # a flag is supplied. Use startswith("-") instead of "in" to avoid + # matching external provider names that contain hyphens + # (e.g. "local-acme-snowflake"). + if sys.argv[1].startswith("-"): sys.argv = self.__set_default_provider__(sys.argv) # Provider aliases mapping diff --git a/prowler/lib/outputs/compliance/compliance.py b/prowler/lib/outputs/compliance/compliance.py index 7fe23305508..b3885b57c97 100644 --- a/prowler/lib/outputs/compliance/compliance.py +++ b/prowler/lib/outputs/compliance/compliance.py @@ -243,14 +243,32 @@ def display_compliance_table( compliance_overview, ) else: - get_generic_compliance_table( - findings, - bulk_checks_metadata, - compliance_framework, - output_filename, - output_directory, - compliance_overview, - ) + # Try provider-specific table first, fall back to generic + from prowler.providers.common.provider import Provider + + provider = Provider.get_global_provider() + handled = False + if provider is not None: + try: + handled = provider.display_compliance_table( + findings, + bulk_checks_metadata, + compliance_framework, + output_filename, + output_directory, + compliance_overview, + ) + except NotImplementedError: + handled = False + if not handled: + get_generic_compliance_table( + findings, + bulk_checks_metadata, + compliance_framework, + output_filename, + output_directory, + compliance_overview, + ) except Exception as error: logger.critical( f"{error.__class__.__name__}:{error.__traceback__.tb_lineno} -- {error}" diff --git a/prowler/lib/outputs/finding.py b/prowler/lib/outputs/finding.py index 95a634c85ff..e119d5ce153 100644 --- a/prowler/lib/outputs/finding.py +++ b/prowler/lib/outputs/finding.py @@ -474,6 +474,11 @@ def generate_output( check_output, "fixed_version", "" ) + else: + # Dynamic fallback: any external/custom provider + provider_data = provider.get_finding_output_data(check_output) + output_data.update(provider_data) + # check_output Unique ID # TODO: move this to a function # TODO: in Azure, GCP and K8s there are findings without resource_name diff --git a/prowler/lib/outputs/html/html.py b/prowler/lib/outputs/html/html.py index d547ea3b950..1cd097815bd 100644 --- a/prowler/lib/outputs/html/html.py +++ b/prowler/lib/outputs/html/html.py @@ -1417,11 +1417,13 @@ def get_assessment_summary(provider: Provider) -> str: # Azure_provider --> azure # Kubernetes_provider --> kubernetes - # Dynamically get the Provider quick inventory handler - provider_html_assessment_summary_function = ( - f"get_{provider.type}_assessment_summary" - ) - return getattr(HTML, provider_html_assessment_summary_function)(provider) + # Try static method first, fall back to provider method + method_name = f"get_{provider.type}_assessment_summary" + if hasattr(HTML, method_name): + return getattr(HTML, method_name)(provider) + else: + # Dynamic fallback: any external/custom provider + return provider.get_html_assessment_summary() except Exception as error: logger.error( f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}" diff --git a/prowler/lib/outputs/outputs.py b/prowler/lib/outputs/outputs.py index 1e2f6f2058d..7bb9a396da8 100644 --- a/prowler/lib/outputs/outputs.py +++ b/prowler/lib/outputs/outputs.py @@ -7,39 +7,46 @@ from prowler.lib.outputs.finding import Finding -def stdout_report(finding, color, verbose, status, fix): +def stdout_report(finding, color, verbose, status, fix, provider=None): if finding.check_metadata.Provider == "aws": details = finding.region - if finding.check_metadata.Provider == "azure": + elif finding.check_metadata.Provider == "azure": details = finding.location - if finding.check_metadata.Provider == "gcp": + elif finding.check_metadata.Provider == "gcp": details = finding.location.lower() - if finding.check_metadata.Provider == "kubernetes": + elif finding.check_metadata.Provider == "kubernetes": details = finding.namespace.lower() - if finding.check_metadata.Provider == "github": + elif finding.check_metadata.Provider == "github": details = finding.owner - if finding.check_metadata.Provider == "m365": + elif finding.check_metadata.Provider == "m365": details = finding.location - if finding.check_metadata.Provider == "mongodbatlas": + elif finding.check_metadata.Provider == "mongodbatlas": details = finding.location - if finding.check_metadata.Provider == "nhn": + elif finding.check_metadata.Provider == "nhn": details = finding.location - if finding.check_metadata.Provider == "llm": + elif finding.check_metadata.Provider == "llm": details = finding.check_metadata.CheckID - if finding.check_metadata.Provider == "iac": + elif finding.check_metadata.Provider == "iac": details = finding.check_metadata.CheckID - if finding.check_metadata.Provider == "oraclecloud": + elif finding.check_metadata.Provider == "oraclecloud": details = finding.region - if finding.check_metadata.Provider == "alibabacloud": + elif finding.check_metadata.Provider == "alibabacloud": details = finding.region - if finding.check_metadata.Provider == "openstack": + elif finding.check_metadata.Provider == "openstack": details = finding.region - if finding.check_metadata.Provider == "cloudflare": + elif finding.check_metadata.Provider == "cloudflare": details = finding.zone_name - if finding.check_metadata.Provider == "googleworkspace": + elif finding.check_metadata.Provider == "googleworkspace": details = finding.location - if finding.check_metadata.Provider == "vercel": + elif finding.check_metadata.Provider == "vercel": details = finding.region + else: + # Dynamic fallback: any external/custom provider + if provider is None: + from prowler.providers.common.provider import Provider + + provider = Provider.get_global_provider() + details = provider.get_stdout_detail(finding) if (verbose or fix) and (not status or finding.status in status): if finding.muted: @@ -59,12 +66,15 @@ def report(check_findings, provider, output_options): if hasattr(output_options, "verbose"): verbose = output_options.verbose if check_findings: - # TO-DO Generic Function if provider.type == "aws": check_findings.sort(key=lambda x: x.region) - - if provider.type == "azure": + elif provider.type == "azure": check_findings.sort(key=lambda x: x.subscription) + else: + # Dynamic fallback: any external/custom provider + sort_key = provider.get_finding_sort_key() + if sort_key and isinstance(sort_key, str): + check_findings.sort(key=lambda x: getattr(x, sort_key, "")) for finding in check_findings: # Print findings by stdout @@ -75,12 +85,16 @@ def report(check_findings, provider, output_options): if hasattr(output_options, "fixer"): fixer = output_options.fixer color = set_report_color(finding.status, finding.muted) + # Pass the local `provider` through so the dynamic else inside + # `stdout_report` does not have to consult the global singleton + # — defeating the whole purpose of the new parameter. stdout_report( finding, color, verbose, status, fixer, + provider=provider, ) else: # No service resources in the whole account diff --git a/prowler/lib/outputs/summary_table.py b/prowler/lib/outputs/summary_table.py index 1f30063c8f1..06d520ab327 100644 --- a/prowler/lib/outputs/summary_table.py +++ b/prowler/lib/outputs/summary_table.py @@ -108,6 +108,9 @@ def display_summary_table( ) else: audited_entities = provider.identity.username or "Personal Account" + else: + # Dynamic fallback: any external/custom provider + entity_type, audited_entities = provider.get_summary_entity() # Check if there are findings and that they are not all MANUAL if findings and not all(finding.status == "MANUAL" for finding in findings): diff --git a/prowler/lib/scan/scan.py b/prowler/lib/scan/scan.py index 2ce7263e2be..4bef660d33c 100644 --- a/prowler/lib/scan/scan.py +++ b/prowler/lib/scan/scan.py @@ -4,8 +4,8 @@ from typing import Generator from prowler.lib.check.check import ( + _resolve_check_module, execute, - import_check, list_services, update_audit_metadata, ) @@ -426,9 +426,14 @@ def scan( # Recover service from check name service = get_service_name_from_check_name(check_name) try: - # Import check module - check_module_path = f"prowler.providers.{self._provider.type}.services.{service}.{check_name}.{check_name}" - lib = import_check(check_module_path) + # Import check module (built-in or entry point) — + # delegates to `_resolve_check_module` so external + # providers registered via entry points are resolved + # correctly (their checks do not live under + # `prowler.providers.{type}.services...`). + lib = _resolve_check_module( + self._provider.type, service, check_name + ) # Recover functions from check check_to_execute = getattr(lib, check_name) check = check_to_execute() diff --git a/prowler/providers/common/arguments.py b/prowler/providers/common/arguments.py index 9745bab2cac..9e23274a903 100644 --- a/prowler/providers/common/arguments.py +++ b/prowler/providers/common/arguments.py @@ -16,18 +16,41 @@ def init_providers_parser(self): # We need to call the arguments parser for each provider providers = Provider.get_available_providers() for provider in providers: - try: - getattr( - import_module( - f"{providers_path}.{provider}.{provider_arguments_lib_path}" - ), - init_provider_arguments_function, - )(self) - except Exception as error: - logger.critical( - f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" - ) - sys.exit(1) + # Discriminate built-in vs external upfront via find_spec, so an + # ImportError from a transitive dependency missing inside a built-in + # arguments module surfaces clearly instead of being silently + # re-routed to the entry-point path (which only has external providers). + if Provider.is_builtin(provider): + try: + getattr( + import_module( + f"{providers_path}.{provider}.{provider_arguments_lib_path}" + ), + init_provider_arguments_function, + )(self) + except ImportError as e: + logger.critical( + f"Failed to load arguments for built-in provider '{provider}'. " + f"Missing dependency: {e}. " + f"Ensure all required dependencies are installed." + ) + logger.debug("Full traceback:", exc_info=True) + sys.exit(1) + except Exception as error: + logger.critical( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + sys.exit(1) + else: + # External provider — init_parser classmethod via entry point + cls = Provider._load_ep_provider(provider) + if cls and hasattr(cls, "init_parser"): + try: + cls.init_parser(self) + except Exception as error: + logger.warning( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) def validate_provider_arguments(arguments: Namespace) -> tuple[bool, str]: diff --git a/prowler/providers/common/builtin.py b/prowler/providers/common/builtin.py new file mode 100644 index 00000000000..d60b5483d94 --- /dev/null +++ b/prowler/providers/common/builtin.py @@ -0,0 +1,29 @@ +"""Leaf helper for built-in provider detection. + +Lives in its own module — with no imports back into `prowler.lib.check` — so +that callers in `prowler.lib.check.*` can ask "is this provider built-in?" +without creating an import cycle through `prowler.providers.common.provider` +(which transitively imports `prowler.config.config` and from there +`prowler.lib.check.compliance_models` / `prowler.lib.check.external_tool_providers`). + +Same rationale as `prowler.lib.check.tool_wrapper`: extracting the predicate +to a leaf module is the canonical way to break the cycle in this codebase. +""" + +import importlib.util + + +def is_builtin_provider(provider: str) -> bool: + """Return True if the provider's own package ships with the SDK. + + Wraps `importlib.util.find_spec` in `try/except (ImportError, ValueError)` + because `find_spec` propagates `ModuleNotFoundError` when a parent package + in the dotted path does not exist (instead of returning `None`). The + try/except is what makes the call safe for external providers, whose + package does not live under `prowler.providers.{provider}`. + """ + try: + spec = importlib.util.find_spec(f"prowler.providers.{provider}") + return spec is not None + except (ImportError, ValueError): + return False diff --git a/prowler/providers/common/provider.py b/prowler/providers/common/provider.py index b0486d94103..cea9f091121 100644 --- a/prowler/providers/common/provider.py +++ b/prowler/providers/common/provider.py @@ -1,4 +1,6 @@ import importlib +import importlib.metadata +import importlib.util import os import pkgutil import sys @@ -136,6 +138,108 @@ def get_checks_to_execute_by_audit_resources(self) -> set: """ return set() + # --- Dynamic provider contract methods (not @abstractmethod for incremental migration) --- + + _cli_help_text: str = "" + + @classmethod + def from_cli_args(cls, arguments: Namespace, fixer_config: dict) -> "Provider": + """Instantiate the provider from CLI arguments and return the instance. + + The caller wires the returned instance into the global provider slot + via Provider.set_global_provider(). Implementations that already call + set_global_provider(self) from __init__ are also supported — the call + site tolerates a None return in that case. + """ + raise NotImplementedError(f"{cls.__name__} has not implemented from_cli_args()") + + def get_output_options(self, arguments, _bulk_checks_metadata): + """Create the provider-specific OutputOptions.""" + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented get_output_options()" + ) + + def get_stdout_detail(self, _finding) -> str: + """Return the detail string for stdout reporting (region, location, etc.).""" + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented get_stdout_detail()" + ) + + def get_finding_sort_key(self) -> Optional[str]: + """Return the attribute name to sort findings by, or None for no sorting.""" + return None + + def get_summary_entity(self) -> tuple: + """Return (entity_type, audited_entities) for the summary table.""" + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented get_summary_entity()" + ) + + def get_finding_output_data(self, _check_output) -> dict: + """Return provider-specific fields for Finding.generate_output().""" + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented get_finding_output_data()" + ) + + def get_html_assessment_summary(self) -> str: + """Return the HTML assessment summary card for this provider.""" + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented get_html_assessment_summary()" + ) + + def generate_compliance_output( + self, + _findings, + _bulk_compliance_frameworks, + _input_compliance_frameworks, + _output_options, + _generated_outputs, + ) -> None: + """Generate compliance CSV output for this provider's frameworks.""" + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented generate_compliance_output()" + ) + + def get_mutelist_finding_args(self) -> dict: + """Return extra kwargs for mutelist.is_finding_muted() besides 'finding'. + + External providers must return a dict with the identity key their + Mutelist subclass expects, e.g. ``{"account_id": self.identity.account_id}``. + The ``finding`` kwarg is added automatically by the caller. + """ + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented get_mutelist_finding_args()" + ) + + def display_compliance_table( + self, + _findings: list, + _bulk_checks_metadata: dict, + _compliance_framework: str, + _output_filename: str, + _output_directory: str, + _compliance_overview: bool, + ) -> bool: + """Render a custom compliance table in the terminal. + + External providers can override this to display a detailed + compliance table (e.g., per-section breakdown). Return True + if the table was rendered, False to fall back to the generic table. + """ + raise NotImplementedError( + f"{self.__class__.__name__} has not implemented display_compliance_table()" + ) + + # Class-level flag: True for providers that delegate scanning to an external + # tool (e.g. Trivy, promptfoo) and bypass standard check/service loading and + # metadata validation. Subclasses override as `is_external_tool_provider = True`. + # Kept as a class attribute (not a property) so it can be read from the class + # without instantiation — the metadata validators in lib.check.models need to + # decide whether to relax validation before any provider instance exists. + is_external_tool_provider: bool = False + + # --- End dynamic provider contract methods --- + @staticmethod def get_excluded_regions_from_env() -> set: """Parse the PROWLER_AWS_DISALLOWED_REGIONS environment variable. @@ -159,20 +263,70 @@ def set_global_provider(global_provider: "Provider") -> None: @staticmethod def init_global_provider(arguments: Namespace) -> None: try: - provider_class_path = ( - f"{providers_path}.{arguments.provider}.{arguments.provider}_provider" - ) - provider_class_name = f"{arguments.provider.capitalize()}Provider" - provider_class = getattr( - import_module(provider_class_path), provider_class_name + # Discriminate built-in vs external upfront via find_spec, so an + # ImportError from a transitive dependency missing inside a + # built-in's own import chain surfaces clearly instead of being + # silently re-routed to the entry-point path. + provider_class = None + if Provider.is_builtin(arguments.provider): + # Built-in wins on provider-name collision. Plug-ins are + # first-class extenders (they can register new provider + # names) but cannot override existing built-ins — a security + # tool prefers fail-loud predictability over silent + # overrides. Surface the override so the user knows their + # plug-in is being ignored and can rename it. + if Provider._load_ep_provider(arguments.provider) is not None: + logger.warning( + f"Plug-in provider '{arguments.provider}' registered " + f"via entry points is being IGNORED — a built-in with " + f"the same name exists. To use your plug-in, register " + f"it under a different name." + ) + provider_class_path = f"{providers_path}.{arguments.provider}.{arguments.provider}_provider" + provider_class_name = f"{arguments.provider.capitalize()}Provider" + try: + provider_class = getattr( + import_module(provider_class_path), provider_class_name + ) + except ImportError as e: + logger.critical( + f"Failed to load built-in provider '{arguments.provider}'. " + f"Missing dependency: {e}. " + f"Ensure all required dependencies are installed." + ) + logger.debug("Full traceback:", exc_info=True) + sys.exit(1) + except AttributeError: + # Module exists but doesn't define the expected class — + # treat as external and try entry points. + provider_class = Provider._load_ep_provider(arguments.provider) + else: + provider_class = Provider._load_ep_provider(arguments.provider) + + if provider_class is None: + raise ImportError( + f"Provider '{arguments.provider}' not found as built-in or entry point" + ) + + # Kept for downstream forks that may extend the dispatch below + # with their own custom built-in branches and reference this name. + # The upstream chain dispatches by `arguments.provider` directly. + provider_class_name = ( + f"{arguments.provider.capitalize()}Provider" # noqa: F841 ) fixer_config = load_and_validate_config_file( arguments.provider, arguments.fixer_config ) + # Dispatch by exact provider name (equality, not substring) so + # external plug-ins whose names contain a built-in substring + # (e.g. `awsx`, `azure_gov`, `iac_v2`) cannot be silently routed + # to the wrong built-in branch. Anything that doesn't match a + # built-in falls through to the dynamic else and uses the + # contract's `from_cli_args`. if not isinstance(Provider._global, provider_class): - if "aws" in provider_class_name.lower(): + if arguments.provider == "aws": excluded_regions = ( set(arguments.excluded_region) if getattr(arguments, "excluded_region", None) @@ -196,7 +350,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "azure" in provider_class_name.lower(): + elif arguments.provider == "azure": provider_class( az_cli_auth=arguments.az_cli_auth, sp_env_auth=arguments.sp_env_auth, @@ -209,7 +363,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "gcp" in provider_class_name.lower(): + elif arguments.provider == "gcp": provider_class( retries_max_attempts=arguments.gcp_retries_max_attempts, organization_id=arguments.organization_id, @@ -223,7 +377,7 @@ def init_global_provider(arguments: Namespace) -> None: fixer_config=fixer_config, skip_api_check=arguments.skip_api_check, ) - elif "kubernetes" in provider_class_name.lower(): + elif arguments.provider == "kubernetes": provider_class( kubeconfig_file=arguments.kubeconfig_file, context=arguments.context, @@ -233,7 +387,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "m365" in provider_class_name.lower(): + elif arguments.provider == "m365": provider_class( region=arguments.region, config_path=arguments.config_file, @@ -247,7 +401,7 @@ def init_global_provider(arguments: Namespace) -> None: init_modules=arguments.init_modules, fixer_config=fixer_config, ) - elif "nhn" in provider_class_name.lower(): + elif arguments.provider == "nhn": provider_class( username=arguments.nhn_username, password=arguments.nhn_password, @@ -256,7 +410,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "github" in provider_class_name.lower(): + elif arguments.provider == "github": orgs = [] repos = [] @@ -288,13 +442,13 @@ def init_global_provider(arguments: Namespace) -> None: exclude_workflows=getattr(arguments, "exclude_workflows", []), fixer_config=fixer_config, ) - elif "googleworkspace" in provider_class_name.lower(): + elif arguments.provider == "googleworkspace": provider_class( config_path=arguments.config_file, mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "cloudflare" in provider_class_name.lower(): + elif arguments.provider == "cloudflare": provider_class( filter_zones=arguments.region, filter_accounts=arguments.account_id, @@ -302,7 +456,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "iac" in provider_class_name.lower(): + elif arguments.provider == "iac": provider_class( scan_path=arguments.scan_path, scan_repository_url=arguments.scan_repository_url, @@ -315,13 +469,13 @@ def init_global_provider(arguments: Namespace) -> None: oauth_app_token=arguments.oauth_app_token, provider_uid=arguments.provider_uid, ) - elif "llm" in provider_class_name.lower(): + elif arguments.provider == "llm": provider_class( max_concurrency=arguments.max_concurrency, config_path=arguments.config_file, fixer_config=fixer_config, ) - elif "image" in provider_class_name.lower(): + elif arguments.provider == "image": provider_class( images=arguments.images, image_list_file=arguments.image_list_file, @@ -339,7 +493,7 @@ def init_global_provider(arguments: Namespace) -> None: registry_insecure=arguments.registry_insecure, registry_list_images=arguments.registry_list_images, ) - elif "mongodbatlas" in provider_class_name.lower(): + elif arguments.provider == "mongodbatlas": provider_class( atlas_public_key=arguments.atlas_public_key, atlas_private_key=arguments.atlas_private_key, @@ -348,7 +502,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "oraclecloud" in provider_class_name.lower(): + elif arguments.provider == "oraclecloud": provider_class( oci_config_file=arguments.oci_config_file, profile=arguments.profile, @@ -359,7 +513,7 @@ def init_global_provider(arguments: Namespace) -> None: fixer_config=fixer_config, use_instance_principal=arguments.use_instance_principal, ) - elif "openstack" in provider_class_name.lower(): + elif arguments.provider == "openstack": provider_class( clouds_yaml_file=getattr(arguments, "clouds_yaml_file", None), clouds_yaml_content=getattr( @@ -384,7 +538,7 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "alibabacloud" in provider_class_name.lower(): + elif arguments.provider == "alibabacloud": provider_class( role_arn=arguments.role_arn, role_session_name=arguments.role_session_name, @@ -396,13 +550,25 @@ def init_global_provider(arguments: Namespace) -> None: mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) - elif "vercel" in provider_class_name.lower(): + elif arguments.provider == "vercel": provider_class( projects=getattr(arguments, "project", None), config_path=arguments.config_file, mutelist_path=arguments.mutelist_file, fixer_config=fixer_config, ) + else: + # Dynamic fallback: any external/custom provider. + # Honor the from_cli_args type hint (-> Provider): if the + # implementation returns an instance, wire it as the global + # provider here. Implementations that call + # set_global_provider(self) from __init__ return None and + # remain supported (the condition below is a no-op for them). + provider_instance = provider_class.from_cli_args( + arguments, fixer_config + ) + if provider_instance is not None: + Provider.set_global_provider(provider_instance) except TypeError as error: logger.critical( @@ -415,17 +581,102 @@ def init_global_provider(arguments: Namespace) -> None: ) sys.exit(1) + # Cache for entry-point provider classes {name: class} + _ep_providers: dict = {} + @staticmethod def get_available_providers() -> list[str]: """get_available_providers returns a list of the available providers""" - providers = [] - # Dynamically import the package based on its string path + providers = set() + # Built-in providers from local package prowler_providers = importlib.import_module(providers_path) - # Iterate over all modules found in the prowler_providers package for _, provider, ispkg in pkgutil.iter_modules(prowler_providers.__path__): if provider != "common" and ispkg: - providers.append(provider) - return providers + providers.add(provider) + # External providers registered via entry points + for ep in importlib.metadata.entry_points(group="prowler.providers"): + providers.add(ep.name) + return sorted(providers) + + @staticmethod + def is_tool_wrapper_provider(provider: str) -> bool: + """Return True if the provider delegates scanning to an external tool. + + Delegates to `prowler.lib.check.tool_wrapper.is_tool_wrapper_provider`, + the leaf module that holds the actual logic. Kept on `Provider` as a + convenience entry point for callers that already import `Provider`. + """ + from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider as _impl + + return _impl(provider) + + @staticmethod + def is_builtin(provider: str) -> bool: + """Return True if the provider's own package is importable as a built-in. + + Delegates to `prowler.providers.common.builtin.is_builtin_provider`, + the leaf module that holds the actual check. Kept on `Provider` as a + convenience entry point for callers that already import `Provider`. + Call sites in `prowler.lib.check.*` should import from the leaf + directly to avoid the import cycle through this module. + """ + from prowler.providers.common.builtin import is_builtin_provider as _impl + + return _impl(provider) + + @staticmethod + def _load_ep_provider(name: str): + """Load an external provider class from entry points, with cache. + + Caches both hits and misses so repeated lookups for unknown names do + not re-iterate entry_points(). Symmetric with + tool_wrapper._ep_class_cache. + """ + if name in Provider._ep_providers: + return Provider._ep_providers[name] + for ep in importlib.metadata.entry_points(group="prowler.providers"): + if ep.name == name: + try: + cls = ep.load() + Provider._ep_providers[name] = cls + return cls + except Exception as error: + logger.warning( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + Provider._ep_providers[name] = None + return None + + @staticmethod + def get_providers_help_text() -> dict: + """Returns a dict of {provider_name: cli_help_text} for all available providers.""" + help_text = {} + for name in Provider.get_available_providers(): + try: + # Try built-in first + module_path = f"{providers_path}.{name}.{name}_provider" + module = import_module(module_path) + cls = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, Provider) + and attr is not Provider + ): + cls = attr + break + help_text[name] = getattr(cls, "_cli_help_text", "") if cls else "" + except ImportError: + # External provider — load via entry point + cls = Provider._load_ep_provider(name) + help_text[name] = getattr(cls, "_cli_help_text", "") if cls else "" + except Exception as error: + logger.warning( + f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}" + ) + help_text[name] = "" + return help_text @staticmethod def update_provider_config(audit_config: dict, variable: str, value: str): diff --git a/tests/config/config_test.py b/tests/config/config_test.py index f7349d20045..0fe00d792f0 100644 --- a/tests/config/config_test.py +++ b/tests/config/config_test.py @@ -17,7 +17,7 @@ MOCK_PROWLER_MASTER_VERSION = "3.4.0" -def mock_prowler_get_latest_release(_, **kwargs): +def mock_prowler_get_latest_release(_, **_kwargs): """Mock requests.get() to get the Prowler latest release""" response = Response() response._content = b'[{"name":"3.3.0"}]' @@ -463,6 +463,32 @@ def test_get_available_compliance_frameworks_does_not_mutate_provider_param(self all_frameworks = get_available_compliance_frameworks() assert "csa_ccm_4.0" in all_frameworks + @mock.patch("prowler.config.config._get_ep_compliance_dirs") + def test_get_available_compliance_frameworks_dedupes_ep_collisions_with_builtins( + self, mock_dirs + ): + """Entry-point compliance frameworks that collide with a built-in + name must appear only once in the available frameworks list. + Built-in wins silently — same policy as the universal frameworks + loop and as Compliance.get_bulk.""" + import json + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + # cis_2.0_aws ships as a built-in under prowler/compliance/aws/ + json_path = os.path.join(tmpdir, "cis_2.0_aws.json") + with open(json_path, "w") as f: + json.dump({"Framework": "CIS", "Provider": "aws"}, f) + + mock_dirs.return_value = {"aws": tmpdir} + + frameworks = get_available_compliance_frameworks("aws") + + assert frameworks.count("cis_2.0_aws") == 1, ( + f"Expected cis_2.0_aws to appear exactly once, got " + f"{frameworks.count('cis_2.0_aws')} occurrences in: {frameworks}" + ) + def test_load_and_validate_config_file_aws(self): path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) config_test_file = f"{path}/fixtures/config.yaml" @@ -500,6 +526,32 @@ def test_load_and_validate_config_file_old_format(self): assert load_and_validate_config_file("azure", config_test_file) == {} assert load_and_validate_config_file("kubernetes", config_test_file) == {} + def test_load_and_validate_config_file_namespaced_non_listed_provider(self): + path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) + config_test_file = f"{path}/fixtures/config_namespaced_external.yaml" + # github is a built-in not in the legacy hardcoded list; namespaced format must unwrap it. + assert load_and_validate_config_file("github", config_test_file) == { + "token": "abc", + "org": "prowler-cloud", + } + + def test_load_and_validate_config_file_namespaced_external_provider(self): + path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) + config_test_file = f"{path}/fixtures/config_namespaced_external.yaml" + # External plug-in provider: namespaced format must unwrap its block. + assert load_and_validate_config_file("custom_plugin", config_test_file) == { + "setting": "value", + "nested": {"key": 42}, + } + + def test_load_and_validate_config_file_namespaced_missing_provider(self): + path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) + config_test_file = f"{path}/fixtures/config_namespaced_external.yaml" + # Provider with no section in a namespaced file must return empty config, + # not the full file (prevents cross-provider config leakage). + assert load_and_validate_config_file("aws", config_test_file) == {} + assert load_and_validate_config_file("gcp", config_test_file) == {} + def test_load_and_validate_config_file_invalid_config_file_path(self, caplog): provider = "aws" config_file_path = "invalid/path/to/fixer_config.yaml" diff --git a/tests/config/fixtures/config_namespaced_external.yaml b/tests/config/fixtures/config_namespaced_external.yaml new file mode 100644 index 00000000000..ec9f75c698a --- /dev/null +++ b/tests/config/fixtures/config_namespaced_external.yaml @@ -0,0 +1,8 @@ +# Namespaced config covering a non-listed built-in (github) and an external plugin. +github: + token: abc + org: prowler-cloud +custom_plugin: + setting: value + nested: + key: 42 diff --git a/tests/lib/check/models_test.py b/tests/lib/check/models_test.py index 71fd1f718b5..f17e359441c 100644 --- a/tests/lib/check/models_test.py +++ b/tests/lib/check/models_test.py @@ -95,6 +95,38 @@ def test_get_bulk(self, mock_recover_checks, mock_load_metadata): "/path/to/accessanalyzer_enabled/accessanalyzer_enabled.metadata.json" ) + @mock.patch("prowler.lib.check.models.logger") + @mock.patch("prowler.lib.check.models.load_check_metadata") + @mock.patch("prowler.lib.check.models.recover_checks_from_provider") + def test_get_bulk_builtin_wins_on_check_id_collision( + self, mock_recover_checks, mock_load_metadata, mock_logger + ): + """Regression guard: when an entry-point plug-in re-registers a + built-in CheckID, the BUILT-IN metadata wins (first-write-wins) and + the plug-in is IGNORED. The override is surfaced via a warning so + the user knows their plug-in duplicate is being skipped and can + rename it. Matches the precedence in `_resolve_check_module`. See + PR #10700 review (HugoPBrito).""" + # Built-in first, plug-in last (matches recover_checks_from_provider order) + mock_recover_checks.return_value = [ + ("accessanalyzer_enabled", "/builtin/accessanalyzer_enabled"), + ("accessanalyzer_enabled", "/plugin/accessanalyzer_enabled"), + ] + + builtin_metadata = mock.MagicMock(CheckID="accessanalyzer_enabled") + plugin_metadata = mock.MagicMock(CheckID="accessanalyzer_enabled") + mock_load_metadata.side_effect = [builtin_metadata, plugin_metadata] + + result = CheckMetadata.get_bulk(provider="aws") + + # Built-in wins (first-write-wins on CheckID), plug-in is ignored + assert result["accessanalyzer_enabled"] is builtin_metadata + # Override is surfaced via warning naming the plug-in metadata file + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args.args[0] + assert "accessanalyzer_enabled" in warning_msg + assert "/plugin/accessanalyzer_enabled" in warning_msg + @mock.patch("prowler.lib.check.models.load_check_metadata") @mock.patch("prowler.lib.check.models.recover_checks_from_provider") def test_list(self, mock_recover_checks, mock_load_metadata): diff --git a/tests/lib/check/tool_wrapper_test.py b/tests/lib/check/tool_wrapper_test.py new file mode 100644 index 00000000000..ed3b897d8db --- /dev/null +++ b/tests/lib/check/tool_wrapper_test.py @@ -0,0 +1,110 @@ +"""Unit tests for prowler.lib.check.tool_wrapper. + +Covers the leaf helper directly (Provider.is_tool_wrapper_provider delegates +to it). Tests the frozenset fast path, the entry-point fallback for external +plug-ins, the broken-plug-in path, the no-match path, and the module-level +cache. +""" + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def _clear_ep_class_cache(): + """Reset the leaf module's cache between tests so they stay independent.""" + from prowler.lib.check import tool_wrapper + + tool_wrapper._ep_class_cache.clear() + yield + tool_wrapper._ep_class_cache.clear() + + +def _make_entry_point(name, cls): + """Create a mock entry point whose `load()` returns `cls`.""" + ep = MagicMock() + ep.name = name + ep.load.return_value = cls + return ep + + +class TestIsToolWrapperProvider: + """is_tool_wrapper_provider: frozenset + entry-point fallback.""" + + @pytest.mark.parametrize("name", ["iac", "llm", "image"]) + def test_returns_true_for_builtin_tool_wrappers(self, name): + from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider + + assert is_tool_wrapper_provider(name) is True + + @pytest.mark.parametrize("name", ["aws", "azure", "gcp", "github", "kubernetes"]) + def test_returns_false_for_regular_builtins(self, name): + from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider + + assert is_tool_wrapper_provider(name) is False + + @patch("prowler.lib.check.tool_wrapper.importlib.metadata.entry_points") + def test_returns_true_for_external_plugin_with_flag(self, mock_eps): + from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider + + cls = MagicMock(is_external_tool_provider=True) + mock_eps.return_value = [_make_entry_point("custom_wrapper", cls)] + + assert is_tool_wrapper_provider("custom_wrapper") is True + + @patch("prowler.lib.check.tool_wrapper.importlib.metadata.entry_points") + def test_returns_false_for_external_plugin_without_flag(self, mock_eps): + from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider + + cls = MagicMock(is_external_tool_provider=False) + mock_eps.return_value = [_make_entry_point("vanilla_external", cls)] + + assert is_tool_wrapper_provider("vanilla_external") is False + + @patch("prowler.lib.check.tool_wrapper.importlib.metadata.entry_points") + def test_returns_false_for_unknown_provider(self, mock_eps): + from prowler.lib.check.tool_wrapper import is_tool_wrapper_provider + + mock_eps.return_value = [] + + assert is_tool_wrapper_provider("does-not-exist") is False + + +class TestLoadEpClass: + """_load_ep_class: cache, broken plug-ins, no-match.""" + + @patch("prowler.lib.check.tool_wrapper.importlib.metadata.entry_points") + def test_caches_result_across_calls(self, mock_eps): + from prowler.lib.check.tool_wrapper import _load_ep_class + + cls = MagicMock(is_external_tool_provider=True) + mock_eps.return_value = [_make_entry_point("cached_one", cls)] + + first = _load_ep_class("cached_one") + second = _load_ep_class("cached_one") + + assert first is cls + assert second is cls + # entry_points consulted only on the first call + assert mock_eps.call_count == 1 + + @patch("prowler.lib.check.tool_wrapper.importlib.metadata.entry_points") + def test_returns_none_for_broken_plugin(self, mock_eps): + from prowler.lib.check.tool_wrapper import _load_ep_class + + broken_ep = MagicMock() + broken_ep.name = "broken" + broken_ep.load.side_effect = ImportError("plug-in is broken") + mock_eps.return_value = [broken_ep] + + assert _load_ep_class("broken") is None + + @patch("prowler.lib.check.tool_wrapper.importlib.metadata.entry_points") + def test_returns_none_when_no_entry_point_matches(self, mock_eps): + from prowler.lib.check.tool_wrapper import _load_ep_class + + cls = MagicMock() + mock_eps.return_value = [_make_entry_point("other_provider", cls)] + + assert _load_ep_class("missing_provider") is None diff --git a/tests/lib/scan/scan_test.py b/tests/lib/scan/scan_test.py index b0fe82b7b20..8037668b105 100644 --- a/tests/lib/scan/scan_test.py +++ b/tests/lib/scan/scan_test.py @@ -51,7 +51,7 @@ def mock_provider(): def mock_execute(): with mock.patch("prowler.lib.scan.scan.execute", autospec=True) as mock_exec: findings = [finding] - mock_exec.side_effect = lambda *args, **kwargs: findings + mock_exec.side_effect = lambda *_args, **_kwargs: findings yield mock_exec @@ -264,10 +264,10 @@ def test_init_with_no_checks( @patch("prowler.lib.scan.scan.update_checks_metadata_with_compliance") @patch("prowler.lib.scan.scan.Compliance.get_bulk") @patch("prowler.lib.scan.scan.CheckMetadata.get_bulk") - @patch("prowler.lib.scan.scan.import_check") + @patch("prowler.lib.scan.scan._resolve_check_module") def test_scan( self, - mock_import_check, + mock_resolve_check_module, mock_get_bulk, mock_compliance_get_bulk, mock_update_checks_metadata, @@ -285,7 +285,7 @@ def test_scan( mock_check_instance.CheckTitle = "Check if IAM Access Analyzer is enabled" mock_check_instance.Categories = [] - mock_import_check.return_value = MagicMock( + mock_resolve_check_module.return_value = MagicMock( accessanalyzer_enabled=mock_check_class ) diff --git a/tests/providers/external/__init__.py b/tests/providers/external/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/providers/external/test_dynamic_provider_loading.py b/tests/providers/external/test_dynamic_provider_loading.py new file mode 100644 index 00000000000..9f45648c94a --- /dev/null +++ b/tests/providers/external/test_dynamic_provider_loading.py @@ -0,0 +1,1909 @@ +""" +Tests for dynamic provider loading via entry points. + +Covers: provider discovery, check discovery, check execution, +CLI argument registration, compliance frameworks, parser integration, +and all dispatch fallbacks for external providers. +""" + +from argparse import Namespace +from unittest.mock import MagicMock, patch + +import pytest + +from prowler.providers.common.provider import Provider + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_entry_point(name, value, group): + """Create a mock entry point.""" + ep = MagicMock() + ep.name = name + ep.value = value + ep.group = group + return ep + + +class FakeExternalProvider(Provider): + """Minimal Provider subclass for testing the dynamic contract.""" + + _type = "fakeexternal" + _cli_help_text = "Fake External Provider" + + def __init__(self): + Provider.set_global_provider(self) + + @property + def type(self): + return self._type + + @property + def identity(self): + return MagicMock(host_id="fake-host-1") + + @property + def session(self): + return MagicMock() + + @property + def audit_config(self): + return {} + + def setup_session(self): + return MagicMock() + + def print_credentials(self): + pass + + @classmethod + def from_cli_args(cls, _arguments, _fixer_config): + cls() + + def get_output_options(self, _arguments, _bulk_checks_metadata): + return MagicMock(output_directory="/tmp", output_filename="fake") + + def get_stdout_detail(self, finding): + return "fake-detail" + + def get_finding_sort_key(self): + return "region" + + def get_summary_entity(self): + return ("Fake Host", "fake-host-1") + + def get_finding_output_data(self, check_output): + return { + "auth_method": "fake", + "account_uid": "fake-account", + "account_name": "fake", + "resource_name": "fake-resource", + "resource_uid": "fake-uid", + "region": "local", + } + + def get_mutelist_finding_args(self): + return {"host_id": self.identity.host_id} + + def display_compliance_table( + self, + findings, + _bulk_checks_metadata, + _compliance_framework, + _output_filename, + output_directory, # referenced via name elsewhere in tests + _compliance_overview, + ): + return True + + def get_html_assessment_summary(self): + return "