diff --git a/.github/actions/test-setup/action.yaml b/.github/actions/test-setup/action.yaml new file mode 100644 index 00000000..bc8b44ab --- /dev/null +++ b/.github/actions/test-setup/action.yaml @@ -0,0 +1,95 @@ +name: 'Test Setup' +description: 'Common setup for detector test workflows' + +inputs: + component_name: + description: 'Name of the component being tested (for caching)' + required: true + requirements_files: + description: 'Space-separated list of requirements files to install' + required: true + precommit_paths: + description: 'Space-separated list of paths to check with pre-commit' + required: false + default: '' + python_version: + description: 'Python version to use' + required: false + default: '3.11' + needs_system_deps: + description: 'Whether to install system dependencies (build-essential, wget)' + required: false + default: 'false' + +runs: + using: 'composite' + steps: + - name: Set up Python ${{ inputs.python_version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python_version }} + + - name: Generate cache key + id: cache-key + shell: bash + run: | + # Create a hash of all requirements files for cache key + CACHE_KEY="${{ runner.os }}-pip-${{ inputs.component_name }}-" + for file in ${{ inputs.requirements_files }}; do + if [ -f "$file" ]; then + CACHE_KEY="${CACHE_KEY}$(sha256sum $file | cut -d' ' -f1)-" + fi + done + echo "cache_key=${CACHE_KEY%%-}" >> $GITHUB_OUTPUT + echo "cache_restore_keys=${{ runner.os }}-pip-${{ inputs.component_name }}-" >> $GITHUB_OUTPUT + + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ steps.cache-key.outputs.cache_key }} + restore-keys: | + ${{ steps.cache-key.outputs.cache_restore_keys }} + ${{ runner.os }}-pip- + + - name: Install system dependencies + if: inputs.needs_system_deps == 'true' + shell: bash + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential \ + wget + + - name: Install dependencies + shell: bash + run: | + python -m pip install --upgrade pip + # Install test dependencies + pip install pytest-cov + # Install component-specific requirements + for file in ${{ inputs.requirements_files }}; do + if [ -f "$file" ]; then + echo "Installing requirements from $file" + pip install -r "$file" + else + echo "Warning: Requirements file $file not found" + fi + done + + - name: Set Python path + shell: bash + run: | + echo "PYTHONPATH=$GITHUB_WORKSPACE/detectors/huggingface:$GITHUB_WORKSPACE/detectors/llm_judge:$GITHUB_WORKSPACE/detectors:$GITHUB_WORKSPACE" >> $GITHUB_ENV + + - name: Lint with pre-commit (if available) + if: inputs.precommit_paths != '' + shell: bash + run: | + if [ -f .pre-commit-config.yaml ]; then + # Run pre-commit on specified paths + find ${{ inputs.precommit_paths }} -name '*.py' 2>/dev/null | xargs -r pre-commit run --files + else + echo "No pre-commit config found, skipping linting" + fi + continue-on-error: true \ No newline at end of file diff --git a/.github/workflows/build-and-push.yaml b/.github/workflows/build-and-push.yaml index bf026f8a..06b588ed 100644 --- a/.github/workflows/build-and-push.yaml +++ b/.github/workflows/build-and-push.yaml @@ -1,25 +1,46 @@ name: Build and Push - Detectors on: + # Trigger on successful test completion + workflow_run: + workflows: + - "Tier 1 - Built-in detectors unit tests" + - "Tier 1 - Hugging Face Runtime unit tests" + - "Tier 1 - LLM Judge unit tests" + types: + - completed + + # Direct triggers (tests will run in parallel) push: branches: - main + - incubation + - stable tags: - v* paths: - 'detectors/*' - '.github/workflows/*' - pull_request_target: + pull_request: paths: - 'detectors/*' types: [labeled, opened, synchronize, reopened] jobs: # Ensure that tests pass before publishing a new image. build-and-push-ci: + # Only run if: + # 1. Running in the trustyai-explainability/guardrails-detectors repository, AND + # 2. Tests completed successfully on target branches (from workflow_run trigger), OR + # 3. Direct push/PR trigger (tests will run in parallel) + if: | + github.repository == 'trustyai-explainability/guardrails-detectors' && + ((github.event_name == 'workflow_run' && + github.event.workflow_run.conclusion == 'success' && + contains(fromJSON('["main", "incubation", "stable"]'), github.event.workflow_run.head_branch)) || + (github.event_name != 'workflow_run')) runs-on: ubuntu-latest permissions: contents: read pull-requests: write - security-events: write env: PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }} GITHUB_REF_NAME: ${{ github.ref_name }} @@ -44,12 +65,15 @@ jobs: mode: minimum count: 1 labels: "ok-to-test, lgtm, approved" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 if: env.BUILD_CONTEXT == 'ci' with: ref: ${{ github.event.pull_request.head.sha }} - - uses: actions/checkout@v3 + persist-credentials: false + - uses: actions/checkout@v4 if: env.BUILD_CONTEXT == 'main' || env.BUILD_CONTEXT == 'tag' + with: + persist-credentials: false # # Print variables for debugging - name: Log reference variables @@ -129,51 +153,3 @@ jobs: 📦 [Huggingface PR image](https://quay.io/repository/trustyai/guardrails-detector-huggingface-runtime-ci?tab=tags): `quay.io/trustyai/guardrails-detector-huggingface-runtime-ci:$PR_HEAD_SHA` 📦 [Built-in PR image](https://quay.io/trustyai/guardrails-detector-built-in-ci?tab=tags): `quay.io/trustyai/guardrails-detector-built-in-ci:$PR_HEAD_SHA` 📦 [LLM Judge PR image](https://quay.io/trustyai/guardrails-detector-llm-judge-ci?tab=tags): `quay.io/trustyai/guardrails-detector-llm-judge-ci:$PR_HEAD_SHA` - - name: Trivy scan - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: 'image' - image-ref: "${{ env.IMAGE_NAME }}:${{ env.TAG }}" - format: 'sarif' - output: 'trivy-results.sarif' - severity: 'MEDIUM,HIGH,CRITICAL' - exit-code: '0' - ignore-unfixed: false - vuln-type: 'os,library' - - name: Trivy scan, built-in image - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: 'image' - image-ref: "${{ env.BUILTIN_IMAGE_NAME }}:${{ env.TAG }}" - format: 'sarif' - output: 'trivy-results-built-in.sarif' - severity: 'MEDIUM,HIGH,CRITICAL' - exit-code: '0' - ignore-unfixed: false - vuln-type: 'os,library' - - name: Trivy scan, LLM Judge image - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: 'image' - image-ref: "${{ env.LLM_JUDGE_IMAGE_NAME }}:${{ env.TAG }}" - format: 'sarif' - output: 'trivy-results-llm-judge.sarif' - severity: 'MEDIUM,HIGH,CRITICAL' - exit-code: '0' - ignore-unfixed: false - vuln-type: 'os,library' - - name: Update Security tab - Huggingface - uses: github/codeql-action/upload-sarif@v3 - with: - sarif_file: 'trivy-results.sarif' - category: huggingface - - name: Update Security tab - Built-in - uses: github/codeql-action/upload-sarif@v3 - with: - sarif_file: 'trivy-results-built-in.sarif' - category: built-in - - name: Update Security tab - LLM Judge - uses: github/codeql-action/upload-sarif@v3 - with: - sarif_file: 'trivy-results-llm-judge.sarif' - category: llm-judge \ No newline at end of file diff --git a/.github/workflows/security-scan.yaml b/.github/workflows/security-scan.yaml new file mode 100644 index 00000000..8f61991a --- /dev/null +++ b/.github/workflows/security-scan.yaml @@ -0,0 +1,159 @@ +name: Tier 1 - Security scan + +on: + push: + branches: [ main, incubation, stable ] + paths: + - 'detectors/**' + - 'requirements*.txt' + - '*.py' + - '.github/workflows/security-scan.yaml' + + pull_request: + branches: [ main, incubation, stable ] + paths: + - 'detectors/**' + - 'requirements*.txt' + - '*.py' + - '.github/workflows/security-scan.yaml' + + # Manual trigger for security scans + workflow_dispatch: + + # Scheduled security scans + schedule: + - cron: '0 2 * * 1' # Weekly on Mondays at 2 AM UTC + +jobs: + filesystem-security-scan: + runs-on: ubuntu-latest + + permissions: + contents: read + security-events: write + + strategy: + matrix: + component: + - name: "builtin-detectors" + path: "detectors/built_in" + - name: "huggingface-runtime" + path: "detectors/huggingface" + - name: "llm-judge" + path: "detectors/llm_judge" + - name: "common" + path: "detectors/common" + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Log scan parameters + run: | + echo "Scanning filesystem path: ${{ matrix.component.path }}" + echo "Component: ${{ matrix.component.name }}" + + - name: Run Trivy vulnerability scanner (filesystem) + uses: aquasecurity/trivy-action@0.28.0 + with: + scan-type: 'fs' + scan-ref: '${{ matrix.component.path }}' + format: 'sarif' + output: 'trivy-results-${{ matrix.component.name }}.sarif' + severity: 'MEDIUM,HIGH,CRITICAL' + exit-code: '0' + scanners: 'vuln,secret' + + - name: Run Trivy configuration scanner + uses: aquasecurity/trivy-action@0.28.0 + with: + scan-type: 'config' + scan-ref: '${{ matrix.component.path }}' + hide-progress: false + format: 'sarif' + output: 'trivy-config-${{ matrix.component.name }}.sarif' + exit-code: '0' + continue-on-error: true + + - name: Upload vulnerability scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: 'trivy-results-${{ matrix.component.name }}.sarif' + category: '${{ matrix.component.name }}-vulnerabilities' + + - name: Upload configuration scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v3 + if: hashFiles(format('trivy-config-{0}.sarif', matrix.component.name)) != '' + with: + sarif_file: 'trivy-config-${{ matrix.component.name }}.sarif' + category: '${{ matrix.component.name }}-config' + + - name: Generate human-readable vulnerability report + uses: aquasecurity/trivy-action@0.28.0 + with: + scan-type: 'fs' + scan-ref: '${{ matrix.component.path }}' + format: 'table' + output: 'trivy-report-${{ matrix.component.name }}.txt' + severity: 'HIGH,CRITICAL' + exit-code: '0' + scanners: 'vuln,secret' + + - name: Upload scan artifacts + uses: actions/upload-artifact@v4 + with: + name: security-scan-${{ matrix.component.name }} + path: | + trivy-results-${{ matrix.component.name }}.sarif + trivy-config-${{ matrix.component.name }}.sarif + trivy-report-${{ matrix.component.name }}.txt + retention-days: 30 + + # Scan the entire repository root for additional security issues + repository-security-scan: + runs-on: ubuntu-latest + + permissions: + contents: read + security-events: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Run Trivy repository scan + uses: aquasecurity/trivy-action@0.28.0 + with: + scan-type: 'fs' + scan-ref: '.' + format: 'sarif' + output: 'trivy-repository-results.sarif' + severity: 'HIGH,CRITICAL' + exit-code: '0' + scanners: 'vuln,secret' + + - name: Upload repository scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: 'trivy-repository-results.sarif' + category: 'repository-wide-security' + + - name: Generate repository security report + uses: aquasecurity/trivy-action@0.28.0 + with: + scan-type: 'fs' + scan-ref: '.' + format: 'table' + output: 'trivy-repository-report.txt' + severity: 'HIGH,CRITICAL' + exit-code: '0' + scanners: 'vuln,secret' + + - name: Upload repository scan artifacts + uses: actions/upload-artifact@v4 + with: + name: security-scan-repository + path: | + trivy-repository-results.sarif + trivy-repository-report.txt + retention-days: 30 \ No newline at end of file diff --git a/.github/workflows/test-builtin-detectors.yaml b/.github/workflows/test-builtin-detectors.yaml new file mode 100644 index 00000000..bd81190f --- /dev/null +++ b/.github/workflows/test-builtin-detectors.yaml @@ -0,0 +1,50 @@ +name: Tier 1 - Built-in detectors unit tests + +on: + push: + branches: [ main, incubation, stable ] + paths: + - 'detectors/built_in/**' + - 'detectors/common/**' + - 'tests/detectors/builtIn/**' + - 'tests/conftest.py' + - '.github/workflows/test-builtin-detectors.yaml' + pull_request: + branches: [ main, incubation, stable ] + paths: + - 'detectors/built_in/**' + - 'detectors/common/**' + - 'tests/detectors/builtIn/**' + - 'tests/conftest.py' + - '.github/workflows/test-builtin-detectors.yaml' + +jobs: + test-builtin-detectors: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + + permissions: + contents: read + checks: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Common test setup + uses: ./.github/actions/test-setup + with: + component_name: 'builtin' + requirements_files: 'detectors/common/requirements.txt detectors/common/requirements-dev.txt detectors/built_in/requirements.txt' + precommit_paths: 'detectors/built_in tests/detectors/builtIn detectors/common' + python_version: ${{ matrix.python-version }} + needs_system_deps: 'false' + + - name: Run Built-in Detector Tests + run: | + pytest tests/detectors/builtIn/ \ + --cov=detectors.built_in \ + --cov-report=term-missing \ + -v \ No newline at end of file diff --git a/.github/workflows/test-huggingface-runtime.yaml b/.github/workflows/test-huggingface-runtime.yaml new file mode 100644 index 00000000..d1d62dc2 --- /dev/null +++ b/.github/workflows/test-huggingface-runtime.yaml @@ -0,0 +1,91 @@ +name: Tier 1 - Hugging Face Runtime unit tests + +on: + push: + branches: [ main, incubation, stable ] + paths: + - 'detectors/huggingface/**' + - 'detectors/common/**' + - 'tests/detectors/huggingface/**' + - 'tests/dummy_models/**' + - 'tests/conftest.py' + - '.github/workflows/test-huggingface-runtime.yaml' + pull_request: + branches: [ main, incubation, stable ] + paths: + - 'detectors/huggingface/**' + - 'detectors/common/**' + - 'tests/detectors/huggingface/**' + - 'tests/dummy_models/**' + - 'tests/conftest.py' + - '.github/workflows/test-huggingface-runtime.yaml' + +jobs: + test-huggingface-runtime: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + + permissions: + contents: read + checks: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Common test setup + uses: ./.github/actions/test-setup + with: + component_name: 'huggingface' + requirements_files: 'detectors/common/requirements.txt detectors/common/requirements-dev.txt detectors/huggingface/requirements.txt' + precommit_paths: 'detectors/huggingface tests/detectors/huggingface detectors/common' + python_version: ${{ matrix.python-version }} + needs_system_deps: 'true' + + - name: Verify dummy models are available + run: | + ls -la tests/dummy_models/ + echo "Checking for required test models..." + if [ ! -d "tests/dummy_models/bert" ]; then + echo "Warning: BERT dummy models not found" + fi + if [ ! -d "tests/dummy_models/gpt2" ]; then + echo "Warning: GPT2 dummy models not found" + fi + + - name: Run Hugging Face Runtime Tests + timeout-minutes: 20 + env: + HF_HOME: /tmp/huggingface + TRANSFORMERS_CACHE: /tmp/transformers_cache + TOKENIZERS_PARALLELISM: false + run: | + pytest tests/detectors/huggingface/ \ + --cov=detectors.huggingface \ + --cov-report=term-missing \ + -v \ + --tb=short + + - name: Test model loading capabilities + timeout-minutes: 10 + env: + HF_HOME: /tmp/huggingface + TRANSFORMERS_CACHE: /tmp/transformers_cache + TOKENIZERS_PARALLELISM: false + MODEL_DIR: tests/dummy_models/bert/BertForSequenceClassification + run: | + python -c " + try: + from detectors.huggingface.detector import Detector + print('Detector import successful') + + # Test basic initialization + detector = Detector() + print('Detector initialization successful') + except Exception as e: + print(f'Error testing HF detector: {e}') + exit(1) + " + echo "Model loading verification complete" \ No newline at end of file diff --git a/.github/workflows/test-llm-judge.yaml b/.github/workflows/test-llm-judge.yaml new file mode 100644 index 00000000..9e8c5eea --- /dev/null +++ b/.github/workflows/test-llm-judge.yaml @@ -0,0 +1,84 @@ +name: Tier 1 - LLM Judge unit tests + +on: + push: + branches: [ main, incubation, stable ] + paths: + - 'detectors/llm_judge/**' + - 'detectors/common/**' + - 'tests/detectors/llm_judge/**' + - 'tests/conftest.py' + - '.github/workflows/test-llm-judge.yaml' + pull_request: + branches: [ main, incubation, stable ] + paths: + - 'detectors/llm_judge/**' + - 'detectors/common/**' + - 'tests/detectors/llm_judge/**' + - 'tests/conftest.py' + - '.github/workflows/test-llm-judge.yaml' + +jobs: + test-llm-judge: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + + permissions: + contents: read + checks: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Common test setup + uses: ./.github/actions/test-setup + with: + component_name: 'llm-judge' + requirements_files: 'detectors/common/requirements.txt detectors/common/requirements-dev.txt detectors/llm_judge/requirements.txt' + precommit_paths: 'detectors/llm_judge tests/detectors/llm_judge detectors/common' + python_version: ${{ matrix.python-version }} + needs_system_deps: 'true' + + - name: Verify vllm-judge installation + run: | + python -c " + try: + import vllm_judge + print('vllm-judge import successful') + print(f'vllm-judge version: {vllm_judge.__version__}') + except Exception as e: + print(f'Error importing vllm-judge: {e}') + print('This may be expected if running without GPU resources') + " + + - name: Run LLM Judge Tests + timeout-minutes: 15 + run: | + pytest tests/detectors/llm_judge/ \ + --cov=detectors.llm_judge \ + --cov-report=term-missing \ + -v \ + --tb=short + + - name: Test LLM Judge detector initialization + timeout-minutes: 5 + run: | + python -c " + try: + from detectors.llm_judge.detector import LLMJudgeDetector + print('LLMJudgeDetector import successful') + + # Test basic initialization (may fail without proper model access) + try: + detector = LLMJudgeDetector() + print('LLMJudgeDetector initialization successful') + except Exception as init_e: + print(f'Note: LLMJudgeDetector initialization failed (may require specific models): {init_e}') + except Exception as e: + print(f'Error testing LLM Judge detector: {e}') + exit(1) + " + echo "LLM Judge detector verification complete" \ No newline at end of file diff --git a/detectors/Dockerfile.hf b/detectors/Dockerfile.hf index 04424f30..c97bf568 100644 --- a/detectors/Dockerfile.hf +++ b/detectors/Dockerfile.hf @@ -100,6 +100,8 @@ COPY ./common /app/detectors/common COPY ./huggingface/detector.py /app/detectors/huggingface/ RUN mkdir /common; cp /app/detectors/common/log_conf.yaml /common/ COPY ./huggingface/app.py /app +ENV PROMETHEUS_MULTIPROC_DIR="/tmp/prometheus_multiproc_dir" +RUN mkdir -p $PROMETHEUS_MULTIPROC_DIR && chmod 777 $PROMETHEUS_MULTIPROC_DIR EXPOSE 8000 CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"] diff --git a/detectors/built_in/app.py b/detectors/built_in/app.py index 662d2cfc..a0efda90 100644 --- a/detectors/built_in/app.py +++ b/detectors/built_in/app.py @@ -1,3 +1,4 @@ +import json import logging from fastapi import HTTPException, Request @@ -72,5 +73,11 @@ def get_registry(): raise TypeError(f"Detector {detector_type} is not a valid BaseDetectorRegistry") result[detector_type] = {} for detector_name, detector_fn in detector_registry.get_registry().items(): - result[detector_type][detector_name] = detector_fn.__doc__ + docstring = detector_fn.__doc__ + try: + # Try to parse as JSON + parsed = json.loads(docstring) + result[detector_type][detector_name] = parsed + except Exception: + result[detector_type][detector_name] = docstring return result \ No newline at end of file diff --git a/detectors/built_in/base_detector_registry.py b/detectors/built_in/base_detector_registry.py index 170e0d74..60e4adb9 100644 --- a/detectors/built_in/base_detector_registry.py +++ b/detectors/built_in/base_detector_registry.py @@ -30,8 +30,13 @@ def throw_internal_detector_error(self, function_name: str, logger: logging.Logg def get_detection_functions_from_params(self, params: dict): """Parse the request parameters to extract and normalize detection functions as iterable list""" - if self.registry_name in params and isinstance(params[self.registry_name], (list, str)): + if self.registry_name in params and isinstance(params[self.registry_name], (list, str, dict)): funcs = params[self.registry_name] - return [funcs] if isinstance(funcs, str) else funcs + if isinstance(funcs, str): + return [funcs] + elif isinstance(funcs, dict): + return list(funcs.keys()) + else: + return funcs else: return [] \ No newline at end of file diff --git a/detectors/built_in/custom_detectors/custom_detectors.py b/detectors/built_in/custom_detectors/custom_detectors.py index f2e5dbd9..9763a159 100644 --- a/detectors/built_in/custom_detectors/custom_detectors.py +++ b/detectors/built_in/custom_detectors/custom_detectors.py @@ -5,26 +5,51 @@ See [docs/custom_detectors.md](../../docs/custom_detectors.md) for more details. """ +# example boolean-returning function +def over_100_characters(text: str) -> bool: + return len(text)>100 + +# example dict-returning function +def contains_word(text: str) -> dict: + detection = "apple" in text.lower() + if detection: + detection_position = text.lower().find("apple") + return { + "start":detection_position, # start position of detection in text + "end": detection_position+5, # end position of detection in text + "text": text, # "the flagged text, or some arbitrary message to return to the user" + "detection_type": "content_check", #detection_type -> use these fields to define your detector taxonomy as you see fit + "detection": "forbidden_word: apple", ##detection -> use these fields to define your detector taxonomy as you see fit + "score": 1.0 # the score/severity/probability of the detection + } + else: + return {} + +def _this_function_will_not_be_exposed(): + pass + +def function_that_needs_headers(text: str, headers: dict) -> bool: + return headers['magic-key'] != "123" + +def function_that_needs_kwargs(text: str, **kwargs: dict) -> bool: + return kwargs['magic-key'] != "123" + +# === CUSTOM METRICS ===== import time -def slow_func(text: str) -> bool: - time.sleep(.25) - return False - from prometheus_client import Counter - prompt_rejection_counter = Counter( - "trustyai_guardrails_system_prompt_rejections", + "system_prompt_rejections", "Number of rejections by the system prompt", ) - -@use_instruments(instruments=[prompt_rejection_counter]) +@use_instruments(instruments=[prompt_rejection_counter]) def has_metrics(text: str) -> bool: if "sorry" in text: prompt_rejection_counter.inc() return False - + + background_metric = Counter( - "trustyai_guardrails_background_metric", + "background_metric", "Runs some logic in the background without blocking the /detections call" ) @use_instruments(instruments=[background_metric]) @@ -32,5 +57,5 @@ def has_metrics(text: str) -> bool: def background_function(text: str) -> bool: time.sleep(.25) if "sorry" in text: - background_metric.inc() - return False + background_metric.inc() + return False \ No newline at end of file diff --git a/detectors/built_in/custom_detectors_wrapper.py b/detectors/built_in/custom_detectors_wrapper.py index 42a18548..a2451983 100644 --- a/detectors/built_in/custom_detectors_wrapper.py +++ b/detectors/built_in/custom_detectors_wrapper.py @@ -5,6 +5,7 @@ import functools import os import sys +import traceback from concurrent.futures import ThreadPoolExecutor from fastapi import HTTPException @@ -66,14 +67,20 @@ def get_underlying_function(func): return func -def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict) -> Optional[ContentAnalysisResponse]: +def custom_func_wrapper(func: Callable, func_name: str, s: str, headers: dict, func_kwargs: dict=None) -> Optional[ContentAnalysisResponse]: """Convert a some f(text)->bool into a Detector response""" sig = inspect.signature(func) try: if headers is not None: - result = func(s, headers) + if func_kwargs is None: + result = func(s, headers=headers) + else: + result = func(s, headers=headers, **func_kwargs) else: - result = func(s) + if func_kwargs is None: + result = func(s) + else: + result = func(s, **func_kwargs) except Exception as e: logging.error(f"Error when computing custom detector function {func_name}: {e}") @@ -171,7 +178,13 @@ def __init__(self): self.registry = {name: obj for name, obj in inspect.getmembers(custom_detectors, inspect.isfunction) if not name.startswith("_") and name not in forbidden_names} - self.function_needs_headers = {name: "headers" in inspect.signature(obj).parameters for name, obj in self.registry.items() } + + self.function_needs_headers = {} + self.function_needs_kwargs = {} + for name, obj in self.registry.items(): + self.function_needs_headers[name] = "headers" in inspect.signature(obj).parameters + self.function_needs_kwargs[name] = "kwargs" in inspect.signature(obj).parameters + # check if functions have requested user prometheus metrics for name, func in self.registry.items(): @@ -190,8 +203,14 @@ def handle_request(self, content: str, detector_params: dict, headers: dict, **k if self.registry.get(custom_function_name): try: func_headers = headers if self.function_needs_headers.get(custom_function_name) else None + + if self.function_needs_kwargs.get(custom_function_name)and isinstance(detector_params[self.registry_name][custom_function_name], dict): + func_kwargs = detector_params[self.registry_name][custom_function_name] + else: + func_kwargs = None + with self.instrument_runtime(custom_function_name): - result = custom_func_wrapper(self.registry[custom_function_name], custom_function_name, content, func_headers) + result = custom_func_wrapper(self.registry[custom_function_name], custom_function_name, content, func_headers, func_kwargs) is_detection = result is not None self.increment_detector_instruments(custom_function_name, is_detection) if is_detection: diff --git a/detectors/built_in/requirements.txt b/detectors/built_in/requirements.txt index 985781e0..483be2ca 100644 --- a/detectors/built_in/requirements.txt +++ b/detectors/built_in/requirements.txt @@ -1,3 +1,4 @@ +urllib3==2.6.2 markdown==3.8.2 jsonschema==4.24.0 xmlschema==4.1.0 diff --git a/detectors/common/requirements-dev.txt b/detectors/common/requirements-dev.txt index a50e66ce..6509a192 100644 --- a/detectors/common/requirements-dev.txt +++ b/detectors/common/requirements-dev.txt @@ -4,3 +4,4 @@ pre-commit==3.8.0 pytest==8.3.2 tls-test-tools protobuf==6.33.0 +torch==2.9.1 diff --git a/detectors/common/requirements.txt b/detectors/common/requirements.txt index 62b3e843..20fcede0 100644 --- a/detectors/common/requirements.txt +++ b/detectors/common/requirements.txt @@ -1,4 +1,5 @@ -fastapi==0.112.0 +urllib3==2.6.2 +fastapi==0.121.2 uvicorn==0.30.5 httpx==0.27.0 prometheus_client >= 0.18.0 diff --git a/detectors/huggingface/app.py b/detectors/huggingface/app.py index dd511a27..92d4cdc9 100644 --- a/detectors/huggingface/app.py +++ b/detectors/huggingface/app.py @@ -3,6 +3,9 @@ from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import run_in_threadpool +from prometheus_client import generate_latest, CONTENT_TYPE_LATEST, CollectorRegistry, multiprocess +from starlette.responses import Response + from detectors.common.app import DetectorBaseAPI as FastAPI from detectors.huggingface.detector import Detector from detectors.common.scheme import ( @@ -24,9 +27,14 @@ async def lifespan(app: FastAPI): detector.close() app.cleanup_detector() - app = FastAPI(lifespan=lifespan, dependencies=[]) -Instrumentator().instrument(app).expose(app) + +@app.get("/metrics") +def metrics(): + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + data = generate_latest(registry) + return Response(data, media_type=CONTENT_TYPE_LATEST) @app.post( diff --git a/detectors/huggingface/requirements.txt b/detectors/huggingface/requirements.txt index 9bb03df2..c9cc6d1b 100644 --- a/detectors/huggingface/requirements.txt +++ b/detectors/huggingface/requirements.txt @@ -1,3 +1,4 @@ +urllib3==2.6.2 transformers==4.57.1 sentencepiece==0.2.1 -tiktoken==0.12.0 \ No newline at end of file +tiktoken==0.12.0 diff --git a/tests/detectors/builtIn/test_custom.py b/tests/detectors/builtIn/test_custom.py index ab92e732..c8a1e7ab 100644 --- a/tests/detectors/builtIn/test_custom.py +++ b/tests/detectors/builtIn/test_custom.py @@ -35,6 +35,11 @@ class TestCustomDetectors: @pytest.fixture def client(self): from detectors.built_in.app import app + + # clear the metric registry at the start of each test, but AFTER the multiprocessing metrics is set up + import prometheus_client + prometheus_client.REGISTRY._names_to_collectors.clear() + from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry app.set_detector(CustomDetectorRegistry(), "custom") return TestClient(app) @@ -96,6 +101,30 @@ def test_custom_detectors_need_header(self, client): texts = [d["text"] for d in resp.json()[0]] assert msg in texts + def test_custom_detectors_need_kwargs(self, client): + msg = "What is an banana?" + payload1 = { + "contents": [msg], + "detector_params": {"custom": {"function_that_needs_kwargs": {"magic-key": "123"}}} + } + payload2 = { + "contents": [msg], + "detector_params": {"custom": {"function_that_needs_kwargs": {"magic-key": "345"}}} + } + + # shouldn't flag + resp = client.post("/api/v1/text/contents", json=payload1) + assert resp.status_code == 200 + texts = [d["text"] for d in resp.json()[0]] + assert msg not in texts + + # should flag + resp = client.post("/api/v1/text/contents", json=payload2) + assert resp.status_code == 200 + texts = [d["text"] for d in resp.json()[0]] + assert msg in texts + + def test_unsafe_code(self, client): write_code_to_custom_detectors(UNSAFE_CODE) from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry diff --git a/tests/detectors/builtIn/test_filetype.py b/tests/detectors/builtIn/test_filetype.py index f65785e1..73698a2d 100644 --- a/tests/detectors/builtIn/test_filetype.py +++ b/tests/detectors/builtIn/test_filetype.py @@ -289,18 +289,17 @@ def test_multiple_filetype_valid_and_invalid(self, client: TestClient): # === ERROR HANDLING & INVALID DETECTOR TYPES ================================================= def test_unregistered_detector_kind_ignored(self, client: TestClient): - """Test that requesting an unregistered detector kind is silently ignored""" + """Test that requesting an unregistered detector kind returns 400 error""" payload = { "contents": ['{"a": 1}'], "detector_params": {"nonexistent_detector": ["some_value"]} } resp = client.post("/api/v1/text/contents", json=payload) - assert resp.status_code == 200 - # Should return empty list since nonexistent_detector is not registered - assert resp.json()[0] == [] + # Should return 400 error since nonexistent_detector is not registered + assert resp.status_code == 400 def test_mixed_valid_invalid_detector_kinds(self, client: TestClient): - """Test mixing valid and invalid detector kinds""" + """Test mixing valid and invalid detector kinds returns 400 error""" payload = { "contents": ['{a: 1, b: 2}'], "detector_params": { @@ -309,10 +308,8 @@ def test_mixed_valid_invalid_detector_kinds(self, client: TestClient): } } resp = client.post("/api/v1/text/contents", json=payload) - assert resp.status_code == 200 - detections = resp.json()[0] - # Should only process the valid detector kind - assert detections[0]["detection"] == "invalid_json" + # Should return 400 error for the unregistered detector + assert resp.status_code == 400 def test_empty_detector_params(self, client: TestClient): """Test with empty detector_params""" diff --git a/tests/detectors/huggingface/test_client_integration.py b/tests/detectors/huggingface/test_client_integration.py new file mode 100644 index 00000000..2dedcff4 --- /dev/null +++ b/tests/detectors/huggingface/test_client_integration.py @@ -0,0 +1,70 @@ +""" +Integration tests for HF detector FastAPI lifespan context +""" + +import os +import sys +import pytest +from fastapi.testclient import TestClient + +# Set up paths and MODEL_DIR before app import +_current_dir = os.path.dirname(__file__) +_tests_dir = os.path.dirname(os.path.dirname(_current_dir)) +_project_root = os.path.dirname(_tests_dir) +_detectors_path = os.path.join(_project_root, "detectors") +_huggingface_path = os.path.join(_detectors_path, "huggingface") + +for path in [_huggingface_path, _detectors_path, _project_root]: + if path not in sys.path: + sys.path.insert(0, path) + +os.environ["MODEL_DIR"] = os.path.join( + _tests_dir, "dummy_models", "bert/BertForSequenceClassification" +) + +from app import app # noqa: E402 + + +class TestLifespanIntegration: + """Test FastAPI lifespan context for HF detector.""" + + @pytest.fixture + def client(self): + """Create test client with lifespan-initialized detector.""" + with TestClient(app) as test_client: + yield test_client + + def test_lifespan_loads_detector(self, client): + """Verify lifespan initializes detector on startup.""" + detectors = app.get_all_detectors() + assert len(detectors) > 0 + detector = list(detectors.values())[0] + assert detector.model is not None + assert detector.tokenizer is not None + + def test_lifespan_handles_requests(self, client): + """Verify requests work through lifespan-initialized app.""" + response = client.post( + "/api/v1/text/contents", + json={"contents": ["Test message"], "detector_params": {}}, + ) + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_multiple_requests(self, client): + """Verify detector handles multiple requests without state leakage.""" + for i in range(10): + response = client.post( + "/api/v1/text/contents", + json={"contents": [f"Request {i}"], "detector_params": {}}, + ) + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_lifespan_cleanup(self, client): + """Verify lifespan cleanup runs on shutdown.""" + assert len(app.get_all_detectors()) > 0 + + client.__exit__(None, None, None) + + assert len(app.get_all_detectors()) == 0 diff --git a/tests/detectors/huggingface/test_method_initialize_model.py b/tests/detectors/huggingface/test_method_initialize_model.py index 74cf3315..e50aa164 100644 --- a/tests/detectors/huggingface/test_method_initialize_model.py +++ b/tests/detectors/huggingface/test_method_initialize_model.py @@ -3,7 +3,7 @@ import pytest # local imports -from detectors.huggingface.scheme import ContentAnalysisResponse +from detectors.common.scheme import ContentAnalysisResponse from detectors.huggingface.detector import Detector diff --git a/tests/detectors/huggingface/test_method_process_causal_lm.py b/tests/detectors/huggingface/test_method_process_causal_lm.py index 20b24639..6d4a5fba 100644 --- a/tests/detectors/huggingface/test_method_process_causal_lm.py +++ b/tests/detectors/huggingface/test_method_process_causal_lm.py @@ -60,12 +60,9 @@ def validate_results(self, results, input_text, detector): "detection", "detection_type", "score", - "sequence_classification", - "sequence_probability", - "token_classifications", - "token_probabilities", "text", "evidences", + "metadata", ] for field in expected_fields: @@ -79,16 +76,12 @@ def validate_results(self, results, input_text, detector): assert isinstance(result.detection, str) assert isinstance(result.detection_type, str) assert isinstance(result.score, float) - assert isinstance(result.sequence_classification, str) - assert isinstance(result.sequence_probability, float) assert isinstance(result.text, str) assert isinstance(result.evidences, list) assert 0 <= result.start <= len(input_text) assert 0 <= result.end <= len(input_text) assert 0.0 <= result.score <= 1.0 - assert 0.0 <= result.sequence_probability <= 1.0 - assert result.sequence_classification in detector.risk_names def test_process_causal_lm_single_short_input(self, detector_instance): text = "This is a test." diff --git a/tests/detectors/huggingface/test_method_process_sequence_classification.py b/tests/detectors/huggingface/test_method_process_sequence_classification.py index bfb566a0..a323fa83 100644 --- a/tests/detectors/huggingface/test_method_process_sequence_classification.py +++ b/tests/detectors/huggingface/test_method_process_sequence_classification.py @@ -40,12 +40,9 @@ def validate_results(self, results, input_text): "detection", "detection_type", "score", - "sequence_classification", - "sequence_probability", - "token_classifications", - "token_probabilities", "text", "evidences", + "metadata", ] for field in expected_fields: @@ -59,12 +56,6 @@ def validate_results(self, results, input_text): assert isinstance(result.detection, str), "detection should be string" assert isinstance(result.detection_type, str), "detection_type should be string" assert isinstance(result.score, float), "score should be float" - assert isinstance( - result.sequence_classification, str - ), "sequence_classification should be string" - assert isinstance( - result.sequence_probability, float - ), "sequence_probability should be float" assert isinstance(result.text, str), "text should be string" assert isinstance(result.evidences, list), "evidences should be list" @@ -73,9 +64,6 @@ def validate_results(self, results, input_text): ), "start should be within text bounds" assert 0 <= result.end <= len(input_text), "end should be within text bounds" assert 0.0 <= result.score <= 1.0, "score should be between 0 and 1" - assert ( - 0.0 <= result.sequence_probability <= 1.0 - ), "sequence_probability should be between 0 and 1" return result diff --git a/tests/detectors/huggingface/test_method_run.py b/tests/detectors/huggingface/test_method_run.py index fbe7eb0f..9813a796 100644 --- a/tests/detectors/huggingface/test_method_run.py +++ b/tests/detectors/huggingface/test_method_run.py @@ -5,8 +5,8 @@ from unittest.mock import Mock, patch # relative imports -from detectors.huggingface.detector import Detector, ContentAnalysisResponse -from scheme import ContentAnalysisHttpRequest +from detectors.huggingface.detector import Detector +from detectors.common.scheme import ContentAnalysisResponse, ContentAnalysisHttpRequest @pytest.fixture @@ -60,58 +60,63 @@ def detector_causal_lm(self): detector.is_causal_lm = True detector.is_sequence_classifier = False detector.risk_names = ["harm", "bias"] + detector.function_name = "test_causal_lm" + detector.instruments = {} # Initialize empty instruments dict return detector def test_run_sequence_classifier_single_short_input(self, detector_sequence): - request = ContentAnalysisHttpRequest(contents=["Test content"]) + request = ContentAnalysisHttpRequest(contents=["Test content"], detector_params=None) results = detector_sequence.run(request) assert len(results) == 1 assert isinstance(results[0][0], ContentAnalysisResponse) - assert results[0][0].detection_type == "sequence_classification" + # detection_type is the label from the model (e.g., "LABEL_1", not "sequence_classification") + assert results[0][0].detection_type in detector_sequence.model.config.id2label.values() def test_run_sequence_classifier_single_long_input(self, detector_sequence): request = ContentAnalysisHttpRequest( contents=[ "This is a long content. " * 1_000, - ] + ], + detector_params=None ) results = detector_sequence.run(request) assert len(results) == 1 assert isinstance(results[0][0], ContentAnalysisResponse) - assert results[0][0].detection_type == "sequence_classification" + assert results[0][0].detection_type in detector_sequence.model.config.id2label.values() def test_run_sequence_classifier_empty_input(self, detector_sequence): - request = ContentAnalysisHttpRequest(contents=[""]) + request = ContentAnalysisHttpRequest(contents=[""], detector_params=None) results = detector_sequence.run(request) assert len(results) == 1 assert isinstance(results[0][0], ContentAnalysisResponse) - assert results[0][0].detection_type == "sequence_classification" + assert results[0][0].detection_type in detector_sequence.model.config.id2label.values() def test_run_sequence_classifier_multiple_contents(self, detector_sequence): - request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"]) + request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"], detector_params=None) results = detector_sequence.run(request) assert len(results) == 2 for content_analysis in results: assert len(content_analysis) == 1 assert isinstance(content_analysis[0], ContentAnalysisResponse) - assert content_analysis[0].detection_type == "sequence_classification" + assert content_analysis[0].detection_type in detector_sequence.model.config.id2label.values() def test_run_unsupported_model(self): detector = Detector.__new__(Detector) detector.is_causal_lm = False detector.is_sequence_classifier = False + detector.function_name = "test_detector" - request = ContentAnalysisHttpRequest(contents=["Test content"]) + request = ContentAnalysisHttpRequest(contents=["Test content"], detector_params=None) with pytest.raises(ValueError, match="Unsupported model type for analysis"): detector.run(request) def test_run_causal_lm_single_short_input(self, detector_causal_lm): - request = ContentAnalysisHttpRequest(contents=["Test content"]) + request = ContentAnalysisHttpRequest(contents=["Test content"], detector_params=None) results = detector_causal_lm.run(request) assert len(results) == 1 @@ -122,7 +127,8 @@ def test_run_causal_lm_single_long_input(self, detector_causal_lm): request = ContentAnalysisHttpRequest( contents=[ "This is a long content. " * 1_000, - ] + ], + detector_params=None ) results = detector_causal_lm.run(request) @@ -131,7 +137,7 @@ def test_run_causal_lm_single_long_input(self, detector_causal_lm): assert results[0][0].detection_type == "causal_lm" def test_run_causal_lm_empty_input(self, detector_causal_lm): - request = ContentAnalysisHttpRequest(contents=[""]) + request = ContentAnalysisHttpRequest(contents=[""], detector_params=None) results = detector_causal_lm.run(request) assert len(results) == 1 @@ -139,7 +145,7 @@ def test_run_causal_lm_empty_input(self, detector_causal_lm): assert results[0][0].detection_type == "causal_lm" def tes_run_causal_lm_multiple_contents(self, detector_causal_lm): - request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"]) + request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"], detector_params=None) results = detector_causal_lm.run(request) assert len(results) == 2 diff --git a/tests/detectors/huggingface/test_metrics.py b/tests/detectors/huggingface/test_metrics.py index 1bd40339..a37b374f 100644 --- a/tests/detectors/huggingface/test_metrics.py +++ b/tests/detectors/huggingface/test_metrics.py @@ -1,18 +1,14 @@ -import random import os import time -from collections import namedtuple -from unittest import mock -from unittest.mock import Mock, MagicMock +from unittest.mock import Mock import pytest import torch from starlette.testclient import TestClient +from prometheus_client import REGISTRY -from detectors.common.app import METRIC_PREFIX -from detectors.huggingface.detector import Detector -from detectors.huggingface.app import app - +# DO NOT IMPORT THIS VALUE, if we import common.app before the test fixtures we can break prometheus multiprocessing +METRIC_PREFIX = "trustyai_guardrails" def send_request(client: TestClient, detect: bool, slow: bool = False): payload = { @@ -30,8 +26,10 @@ def send_request(client: TestClient, detect: bool, slow: bool = False): def get_metric_dict(client: TestClient): - metrics = client.get("/metrics") - metrics = metrics.content.decode().split("\n") + # In test mode with TestClient, we're running in a single process, + # so multiprocess mode doesn't work. Use the default REGISTRY directly. + from prometheus_client import generate_latest, REGISTRY + metrics = generate_latest(REGISTRY).decode().split("\n") metric_dict = {} for m in metrics: @@ -41,39 +39,54 @@ def get_metric_dict(client: TestClient): return metric_dict +@pytest.fixture(scope="session") +def client(prometheus_multiproc_dir): + # Clear any existing metrics from the REGISTRY before importing the app + # This is needed because even in multiprocess mode, metrics are registered to REGISTRY + collectors_to_unregister = [ + c for c in list(REGISTRY._collector_to_names.keys()) + if hasattr(c, '_name') and 'trustyai_guardrails' in c._name + ] + for collector in collectors_to_unregister: + try: + REGISTRY.unregister(collector) + except Exception: + pass + + current_dir = os.path.dirname(__file__) + parent_dir = os.path.dirname(os.path.dirname(current_dir)) + os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models", "bert/BertForSequenceClassification") + + from detectors.huggingface.app import app + from detectors.huggingface.detector import Detector + detector = Detector() + + # patch the model to allow for control over detections - long messages will flag + def detection_fn(*args, **kwargs): + output = Mock() + if kwargs["input_ids"].shape[-1] > 10: + output.logits = torch.tensor([[0.0, 1.0]]) + else: + output.logits = torch.tensor([[1.0, 0.0]]) + + if kwargs["input_ids"].shape[-1] > 100: + time.sleep(.25) + return output + + class ModelMock: + def __init__(self): + self.config = Mock() + self.config.id2label = detector.model.config.id2label + self.config.problem_type = detector.model.config.problem_type + def __call__(self, *args, **kwargs): + return detection_fn(*args, **kwargs) + + detector.model = ModelMock() + app.set_detector(detector, detector.registry_name) + detector.set_instruments(app.state.instruments) + return TestClient(app) + class TestMetrics: - @pytest.fixture - def client(self): - current_dir = os.path.dirname(__file__) - parent_dir = os.path.dirname(os.path.dirname(current_dir)) - os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models", "bert/BertForSequenceClassification") - - detector = Detector() - - # patch the model to allow for control over detections - long messages will flag - def detection_fn(*args, **kwargs): - output = Mock() - if kwargs["input_ids"].shape[-1] > 10: - output.logits = torch.tensor([[0.0, 1.0]]) - else: - output.logits = torch.tensor([[1.0, 0.0]]) - - if kwargs["input_ids"].shape[-1] > 100: - time.sleep(.25) - return output - - class ModelMock: - def __init__(self): - self.config = Mock() - self.config.id2label = detector.model.config.id2label - self.config.problem_type = detector.model.config.problem_type - def __call__(self, *args, **kwargs): - return detection_fn(*args, **kwargs) - - detector.model = ModelMock() - app.set_detector(detector, detector.registry_name) - detector.set_instruments(app.state.instruments) - return TestClient(app)