diff --git a/.markdownlint-cli2.yaml b/.markdownlint-cli2.yaml new file mode 100644 index 000000000..daa55d30f --- /dev/null +++ b/.markdownlint-cli2.yaml @@ -0,0 +1,21 @@ +config: + # Disable line-length rule — markdown prose doesn't benefit from hard wrapping + line-length: false + # Allow duplicate headings — API docs reuse "Request Example", "Response Example", etc. + no-duplicate-heading: false + # Allow emphasis as headings — docs use **bold** as sub-section labels + no-emphasis-as-heading: false + # Allow ordered list prefixes to continue across interruptions (1. 2. 3. not 1. 1. 1.) + ol-prefix: false + # Allow blank lines inside blockquotes — used for readability in docs + no-blanks-blockquote: false + # Allow fenced code blocks without language — not all snippets need syntax highlighting + fenced-code-language: false + # Allow files to start without an h1 — some files start with h2 or frontmatter + first-line-heading: false + # Allow heading level increments to skip (e.g. h2 -> h4) — docs structure varies + heading-increment: false + # Allow multiple top-level headings + single-title: false + # Allow indented code blocks alongside fenced — legacy docs use both styles + code-block-style: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f24718f79..4f781cc0c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,10 +31,18 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: + - id: check-added-large-files + args: ['--maxkb=500'] + - id: check-ast - id: check-json exclude: ^test/cdk/stacks/__baselines__/ + - id: check-toml - id: check-yaml - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable + - id: check-symlinks + - id: debug-statements - id: mixed-line-ending args: ['--fix=lf'] exclude: ^test/cdk/stacks/__baselines__/ @@ -46,14 +54,29 @@ repos: exclude: ^test/cdk/stacks/__baselines__/ - id: trailing-whitespace exclude: ^test/cdk/stacks/__baselines__/ - -- repo: https://github.com/codespell-project/codespell - rev: v2.4.2 + - id: name-tests-test + args: ['--pytest-test-first'] + files: ^test/.*\.py$ + exclude: (^test/cdk/|^test/integration/|__init__\.py$|README|conftest\.py$|integration_definitions\.py$|integration_test_utils\.py$|config_loader\.py$|list-integ-models\.py$|integration-setup-test\.py$) + - id: no-commit-to-branch + args: ['--branch', 'main', '--branch', 'mainline'] + +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 hooks: - - id: codespell - entry: codespell - args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*/public/*,*poetry.lock*,*coverage*,*models/*,*htmlcov*,*TIKTOKEN_CACHE/*,*test/cdk/stacks/__baselines__/*,*.jsonl', "-L=xdescribe,assertIn,afterAll"] - pass_filenames: false + - id: python-check-mock-methods + - id: python-no-eval + - id: python-no-log-warn + - id: python-use-type-annotations + +# NOTE: shellcheck-py fails to install in environments with custom CA certs +# (SSL: CERTIFICATE_VERIFY_FAILED). Install shellcheck locally and uncomment, +# or use `apt install shellcheck` / `brew install shellcheck` first. +# - repo: https://github.com/shellcheck-py/shellcheck-py +# rev: v0.11.0.1 +# hooks: +# - id: shellcheck +# args: ['--severity=warning'] - repo: https://github.com/pycqa/isort rev: 8.0.1 @@ -68,13 +91,37 @@ repos: - id: black exclude: ^test/cdk/stacks/__baselines__/ +# NOTE: docformatter is incompatible with black — it removes the blank line +# before module docstrings, which violates PEP 8 E302 and black's formatting. +# Uncomment if the upstream issue is resolved: +# https://github.com/PyCQA/docformatter/issues +# - repo: https://github.com/PyCQA/docformatter +# rev: v1.7.8 +# hooks: +# - id: docformatter +# args: +# - --in-place +# - --black +# - --wrap-summaries=120 +# - --wrap-descriptions=120 +# - --style=google +# exclude: ^test/cdk/stacks/__baselines__/ + +- repo: https://github.com/codespell-project/codespell + rev: v2.4.2 + hooks: + - id: codespell + entry: codespell + args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*/public/*,*poetry.lock*,*coverage*,*models/*,*htmlcov*,*TIKTOKEN_CACHE/*,*test/cdk/stacks/__baselines__/*,*.jsonl', "-L=xdescribe,assertIn,afterAll"] + pass_filenames: false + - repo: https://github.com/astral-sh/ruff-pre-commit rev: 'v0.15.11' hooks: - id: ruff-check args: - --exit-non-zero-on-fix - - --per-file-ignores=test/**/*.py:E402,test/**/*.py:PLC0415 + - --per-file-ignores=test/**/*.py:E402,test/**/*.py:PLC0415,test/**/*.py:D - --fix exclude: (\.ipynb$|^test/cdk/stacks/__baselines__/) @@ -120,6 +167,20 @@ repos: - --no-warn-ignored - --fix +- repo: https://github.com/DavidAnson/markdownlint-cli2 + rev: v0.17.2 + hooks: + - id: markdownlint-cli2 + args: ['--fix'] + exclude: (^node_modules/|^test/cdk/stacks/__baselines__/|CHANGELOG\.md$) + +- repo: https://github.com/oxipng/oxipng + rev: v10.1.0 + hooks: + - id: oxipng + args: ['-o', '4', '--strip', 'safe'] + exclude: ^(node_modules/|lib/core/layers/) + # - repo: https://github.com/Lucas-C/pre-commit-hooks-safety # rev: v1.3.2 # hooks: diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index ec98f2b76..fd356798d 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -2,4 +2,4 @@ This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact -opensource-codeofconduct@amazon.com with any additional questions or comments. + with any additional questions or comments. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d0649a77a..a25f517a2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -46,7 +46,7 @@ Looking at the existing issues is a great way to find something to contribute on This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact -opensource-codeofconduct@amazon.com with any additional questions or comments. + with any additional questions or comments. ## Security issue notifications diff --git a/SECURITY.md b/SECURITY.md index e5acff157..cda4f3b6e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -22,6 +22,7 @@ Instead, please report security issues by: ### 📝 What to Include Please include the following information: + - Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) - Full paths of source file(s) related to the manifestation of the issue - The location of the affected source code (tag/branch/commit or direct URL) @@ -33,21 +34,25 @@ Please include the following information: ## 🛡️ Security Measures in Place ### **Static Analysis** + - **CodeQL**: Automated security scanning on all pull requests - **Dependency Scanning**: Regular vulnerability detection - **License Compliance**: Automated license validation ### **Dependency Management** + - **Dependabot**: Automated security updates - **Pin Dependencies**: Critical dependencies pinned by hash - **Vulnerability Monitoring**: Continuous monitoring of known CVEs ### **CI/CD Security** + - **Least Privilege**: GitHub Actions use minimal required permissions - **Supply Chain Protection**: All third-party actions pinned by commit hash - **Secure Workflows**: No dangerous workflow patterns ### **Infrastructure Security** + - **Container Security**: Base images pinned to specific digests - **AWS IAM**: Least privilege access controls - **Encryption**: TLS 1.2+ for all communications diff --git a/bin/lisa.ts b/bin/lisa.ts old mode 100644 new mode 100755 diff --git a/cypress/README.md b/cypress/README.md index c53b6d0e7..f62f83b90 100644 --- a/cypress/README.md +++ b/cypress/README.md @@ -3,22 +3,27 @@ We maintain two suites of tests for our application: ### **Smoke Tests** + - *Isolation:* All network calls (including authentication) are fully **mocked out**. - *Purpose:* Quickly verify that core UI components and routes render without hitting any backend. - *Use Case:* Fast, lightweight sanity checks on every code change. ### **End‑to‑End (E2E) Tests** + - *Integration:* Execute complete user flows against a **live API** (or your local dev stack). - *Coverage:* Real authentication, data fetches, and error‑handling paths. - *Goal:* Ensure the entire system (frontend ↔ backend) works seamlessly together. --- + ## Test Setup In `cypress.e2e.config.ts` or `cypress.smoke.config.ts` the following environment variables need to be configured: + - `baseUrl` - set to either `http://localhost:3000/` or the URL of your dev stack (e.g. `https://.execute-api.us-east-1.amazonaws.com/Prod/`). -#### Example setup for localhost: +#### Example setup for localhost + ``` e2e: { ... @@ -27,7 +32,9 @@ e2e: { ``` # Running the tests + If you are running the e2e tests, you will need to add the test account password to your env prior to executing the tests: + ``` export TEST_ACCOUNT_PASSWORD= @@ -35,6 +42,7 @@ npm run cypress:e2e:run ``` You should get output like: + ``` npm run cypress:e2e:run @@ -106,6 +114,7 @@ DevTools listening on ws://127.0.0.1:51352/devtools/browser/2f804c68-414e-4004-9 ``` ## Run tests interactively + ``` npm run cypress:e2e:open ``` @@ -113,6 +122,7 @@ npm run cypress:e2e:open # Linting To ensure that code is meeting the enforced code standards you can run the following command within the `cypress` directory: + ``` npm run lint:fix ``` diff --git a/lambda/api_tokens/domain_objects.py b/lambda/api_tokens/domain_objects.py index 3234ec5dd..e161373d8 100644 --- a/lambda/api_tokens/domain_objects.py +++ b/lambda/api_tokens/domain_objects.py @@ -19,12 +19,12 @@ def default_expiration() -> int: - """Calculate default token expiration (90 days from now)""" + """Calculate default token expiration (90 days from now).""" return now_seconds() + int(timedelta(days=90).total_seconds()) class CreateTokenAdminRequest(BaseModel): - """Admin request to create token for a user or system""" + """Admin request to create token for a user or system.""" tokenExpiration: int = Field( default_factory=default_expiration, description="Unix timestamp when token expires. Defaults to 90 days" @@ -43,7 +43,7 @@ def validate_expiration(cls, v: int) -> int: class CreateTokenUserRequest(BaseModel): - """User request to create their own token""" + """User request to create their own token.""" name: str = Field(description="Human-readable name for the token") tokenExpiration: int = Field( @@ -71,7 +71,7 @@ class CreateTokenResponse(BaseModel): class TokenInfo(BaseModel): - """Token information (without the actual token value)""" + """Token information (without the actual token value).""" tokenUUID: str tokenExpiration: int diff --git a/lambda/api_tokens/exception.py b/lambda/api_tokens/exception.py index 36775b997..dccd09317 100644 --- a/lambda/api_tokens/exception.py +++ b/lambda/api_tokens/exception.py @@ -26,12 +26,12 @@ class TokenNotFoundError(Exception): class UnauthorizedError(Exception): - """Raised when user is not authorized to perform an action""" + """Raised when user is not authorized to perform an action.""" pass class ForbiddenError(Exception): - """Raised when user lacks required permissions""" + """Raised when user lacks required permissions.""" pass diff --git a/lambda/api_tokens/handler.py b/lambda/api_tokens/handler.py index 66768b39f..ccead6efb 100644 --- a/lambda/api_tokens/handler.py +++ b/lambda/api_tokens/handler.py @@ -35,13 +35,13 @@ class CreateTokenAdminHandler: - """Admin creates token for any user or system""" + """Admin creates token for any user or system.""" def __init__(self, token_table: Any) -> None: self.token_table = token_table def _get_user_token(self, username: str) -> dict | None: - """Query for existing token by username using GSI""" + """Query for existing token by username using GSI.""" response = self.token_table.query( IndexName="username-index", KeyConditionExpression=Key("username").eq(username), Limit=1 ) @@ -97,13 +97,13 @@ def __call__( class CreateTokenUserHandler: - """User creates their own token""" + """User creates their own token.""" def __init__(self, token_table: Any) -> None: self.token_table = token_table def _get_user_token(self, username: str) -> dict | None: - """Query for existing token by username using GSI""" + """Query for existing token by username using GSI.""" response = self.token_table.query( IndexName="username-index", KeyConditionExpression=Key("username").eq(username), Limit=1 ) @@ -156,7 +156,7 @@ def __call__( class ListTokensHandler: - """List tokens - admins see all, users see only their own""" + """List tokens - admins see all, users see only their own.""" def __init__(self, token_table: Any) -> None: self.token_table = token_table @@ -203,7 +203,7 @@ def __call__(self, username: str, is_admin: bool) -> ListTokensResponse: class GetTokenHandler: - """Get specific token details""" + """Get specific token details.""" def __init__(self, token_table: Any) -> None: self.token_table = token_table @@ -263,7 +263,7 @@ def __call__(self, token_uuid: str, username: str, is_admin: bool) -> TokenInfo: class DeleteTokenHandler: - """Delete token - handles both modern and legacy tokens""" + """Delete token - handles both modern and legacy tokens.""" def __init__(self, token_table: Any) -> None: self.token_table = token_table diff --git a/lambda/chat_assistant_stacks/lambda_functions.py b/lambda/chat_assistant_stacks/lambda_functions.py index 26d3964ee..ad018de33 100644 --- a/lambda/chat_assistant_stacks/lambda_functions.py +++ b/lambda/chat_assistant_stacks/lambda_functions.py @@ -81,7 +81,10 @@ def list_stacks(event: dict, context: dict) -> dict: @api_wrapper @admin_only def get_stack(event: dict, context: dict) -> dict: - """Get a single Chat Assistant Stack by stackId. Admin only.""" + """Get a single Chat Assistant Stack by stackId. + + Admin only. + """ stack_id = _get_stack_id(event) try: response = table.get_item(Key={"stackId": stack_id}) @@ -97,7 +100,10 @@ def get_stack(event: dict, context: dict) -> dict: @api_wrapper @admin_only def create(event: dict, context: dict) -> dict: - """Create a new Chat Assistant Stack. Admin only.""" + """Create a new Chat Assistant Stack. + + Admin only. + """ body = json.loads(event["body"], parse_float=Decimal) model = ChatAssistantStackModel(**body) item = model.model_dump(exclude_none=True) @@ -112,7 +118,10 @@ def create(event: dict, context: dict) -> dict: @api_wrapper @admin_only def update(event: dict, context: dict) -> dict: - """Update an existing Chat Assistant Stack. Admin only.""" + """Update an existing Chat Assistant Stack. + + Admin only. + """ stack_id = _get_stack_id(event) body = json.loads(event["body"], parse_float=Decimal) body["stackId"] = stack_id @@ -134,7 +143,10 @@ def update(event: dict, context: dict) -> dict: @api_wrapper @admin_only def delete(event: dict, context: dict) -> dict: - """Delete a Chat Assistant Stack. Admin only.""" + """Delete a Chat Assistant Stack. + + Admin only. + """ stack_id = _get_stack_id(event) try: response = table.delete_item(Key={"stackId": stack_id}, ReturnValues="ALL_OLD") @@ -149,7 +161,10 @@ def delete(event: dict, context: dict) -> dict: @api_wrapper @admin_only def update_status(event: dict, context: dict) -> dict: - """Update isActive (activate/deactivate) for a stack. Admin only.""" + """Update isActive (activate/deactivate) for a stack. + + Admin only. + """ stack_id = _get_stack_id(event) body = json.loads(event.get("body") or "{}", parse_float=Decimal) is_active = body.get("isActive") diff --git a/lambda/management_key.py b/lambda/management_key.py index 55ba3abc8..29ecdf4c9 100644 --- a/lambda/management_key.py +++ b/lambda/management_key.py @@ -34,8 +34,7 @@ def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - AWS Secrets Manager rotation handler for management key. + """AWS Secrets Manager rotation handler for management key. This function implements the standard AWS Secrets Manager rotation workflow: 1. createSecret: Generate new secret version @@ -74,9 +73,7 @@ def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: def create_secret(secret_arn: str, token: str) -> None: - """ - Create a new secret version with a new randomly generated password. - """ + """Create a new secret version with a new randomly generated password.""" logger.info(f"Creating new secret version for {secret_arn}") try: @@ -106,18 +103,18 @@ def create_secret(secret_arn: str, token: str) -> None: def set_secret(secret_arn: str, token: str) -> None: - """ - Set the secret in the service that the secret belongs to. - For management keys, this step is typically a no-op since the secret - is used by the application directly from Secrets Manager. + """Set the secret in the service that the secret belongs to. + + For management keys, this step is typically a no-op since the secret is used by the application directly from + Secrets Manager. """ logger.info(f"Setting secret for {secret_arn} - No action needed for management key") # No action needed for management keys as they are retrieved directly from Secrets Manager def test_secret(secret_arn: str, token: str) -> None: - """ - Test the new secret to ensure it's valid and can be used. + """Test the new secret to ensure it's valid and can be used. + For management keys, we verify the secret can be retrieved and has the expected format. """ logger.info(f"Testing secret for {secret_arn}") @@ -147,9 +144,7 @@ def test_secret(secret_arn: str, token: str) -> None: def finish_secret(secret_arn: str, token: str) -> None: - """ - Finish the rotation by marking the new secret as current. - """ + """Finish the rotation by marking the new secret as current.""" logger.info(f"Finishing secret rotation for {secret_arn}") try: @@ -181,9 +176,7 @@ def finish_secret(secret_arn: str, token: str) -> None: def publish_rotation_event(secret_arn: str, new_version: str, old_version: str | None) -> None: - """ - Publish a management key rotation event to EventBridge. - """ + """Publish a management key rotation event to EventBridge.""" event_bus_name = os.environ.get("EVENT_BUS_NAME") if not event_bus_name: logger.warning("EVENT_BUS_NAME environment variable not set, skipping event publication") diff --git a/lambda/mcp_server/lambda_functions.py b/lambda/mcp_server/lambda_functions.py index 4f680650a..2af843c86 100644 --- a/lambda/mcp_server/lambda_functions.py +++ b/lambda/mcp_server/lambda_functions.py @@ -620,9 +620,7 @@ def update_hosted_mcp_server(event: dict, context: dict) -> Any: @api_wrapper def list_bedrock_agents(event: dict, context: dict) -> dict[str, Any]: - """ - List admin-approved Bedrock agents visible to this user, merged with live AWS discovery. - """ + """List admin-approved Bedrock agents visible to this user, merged with live AWS discovery.""" _user_id, is_admin, groups = get_user_context(event) logger.info("Listing approved Bedrock agents for catalog") @@ -717,9 +715,7 @@ def delete_bedrock_agent_approval(event: dict, context: dict) -> dict[str, str]: @api_wrapper def invoke_bedrock_agent(event: dict, context: dict) -> dict[str, Any]: - """ - Invoke a Bedrock Agent via bedrock-agent-runtime and return aggregated text output. - """ + """Invoke a Bedrock Agent via bedrock-agent-runtime and return aggregated text output.""" user_id, is_admin_user, groups = get_user_context(event) body = json.loads(event.get("body") or "{}") request = InvokeBedrockAgentRequest(**body) diff --git a/lambda/mcp_server/models.py b/lambda/mcp_server/models.py index 5a0d7ce36..520ba2b52 100644 --- a/lambda/mcp_server/models.py +++ b/lambda/mcp_server/models.py @@ -42,8 +42,8 @@ class McpServerStatus(StrEnum): class McpServerModel(BaseModel): - """ - A Pydantic model representing a template for prompts. + """A Pydantic model representing a template for prompts. + Contains metadata and functionality to create new revisions. """ @@ -127,8 +127,8 @@ class AutoScalingConfigUpdate(BaseModel): class HostedMcpServerModel(BaseModel): - """ - A Pydantic model representing a hosted MCP server configuration. + """A Pydantic model representing a hosted MCP server configuration. + This model is used for creating MCP servers that are deployed on ECS Fargate. """ diff --git a/lambda/mcp_server/state_machine/update_mcp_server.py b/lambda/mcp_server/state_machine/update_mcp_server.py index 11a60461f..1760b27f9 100644 --- a/lambda/mcp_server/state_machine/update_mcp_server.py +++ b/lambda/mcp_server/state_machine/update_mcp_server.py @@ -158,8 +158,7 @@ def _get_metadata_update_handlers(server_config: dict[str, Any], server_id: str) def _process_metadata_updates( server_config: dict[str, Any], update_payload: dict[str, Any], server_id: str ) -> tuple[bool, dict[str, Any]]: - """ - Process metadata updates. + """Process metadata updates. Args: server_config: The server configuration dictionary to update @@ -301,8 +300,7 @@ def _update_mcp_connections_table_metadata( def handle_job_intake(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Handle initial UpdateMcpServer job submission. + """Handle initial UpdateMcpServer job submission. This handler will perform the following actions: 1. Determine if any metadata (description, groups, environment, etc.) changes are required @@ -743,8 +741,7 @@ def update_ecs_service(cluster_arn: str, service_arn: str, task_definition_arn: def handle_ecs_update(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Update ECS task definition with new environment variables and update service. + """Update ECS task definition with new environment variables and update service. This handler will: 1. Retrieve current task definition from ECS @@ -809,8 +806,7 @@ def handle_ecs_update(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_poll_ecs_deployment(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Monitor ECS service deployment progress. + """Monitor ECS service deployment progress. This handler will: 1. Check if ECS service deployment is complete @@ -903,8 +899,7 @@ def handle_poll_ecs_deployment(event: dict[str, Any], context: Any) -> dict[str, def handle_poll_capacity(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Poll ECS service to confirm if the capacity is done updating. + """Poll ECS service to confirm if the capacity is done updating. This handler will: 1. Get the ECS service's current status. If it is still updating, then exit with a @@ -950,8 +945,7 @@ def handle_poll_capacity(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_finish_update(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Finalize update in DDB. + """Finalize update in DDB. 1. If the server was enabled from the Stopped state, update MCP Connections table to ACTIVE, set status to InService in DDB diff --git a/lambda/mcp_workbench/lambda_functions.py b/lambda/mcp_workbench/lambda_functions.py index 2db827b47..b5610f6b4 100644 --- a/lambda/mcp_workbench/lambda_functions.py +++ b/lambda/mcp_workbench/lambda_functions.py @@ -293,7 +293,7 @@ def validate_syntax(event: dict, context: dict) -> dict[str, Any]: "validation_timestamp": iso_string(), } - logger.info(f"Validation completed. Valid: {result.is_valid}, " f"Errors: {len(result.syntax_errors)}") + logger.info(f"Validation completed. Valid: {result.is_valid}, Errors: {len(result.syntax_errors)}") return response diff --git a/lambda/mcp_workbench/mcp_mocks.py b/lambda/mcp_workbench/mcp_mocks.py index 2f0f040b2..fd9e2df8c 100644 --- a/lambda/mcp_workbench/mcp_mocks.py +++ b/lambda/mcp_workbench/mcp_mocks.py @@ -14,10 +14,9 @@ """Mock implementations of MCP Workbench core components for validation purposes. -These mocks are used by the syntax validator to allow user code to import -and use MCP Workbench constructs without needing the full MCP Workbench -package installed. They provide just enough functionality to validate -the structure and usage of MCP tools. +These mocks are used by the syntax validator to allow user code to import and use MCP Workbench constructs without +needing the full MCP Workbench package installed. They provide just enough functionality to validate the structure and +usage of MCP tools. """ from abc import ABC, abstractmethod @@ -27,17 +26,14 @@ class BaseTool(ABC): - """ - Mock BaseTool for validation purposes. + """Mock BaseTool for validation purposes. - This provides the same interface as the real BaseTool class, - allowing validation of class-based MCP tools without requiring - the full MCP Workbench package. + This provides the same interface as the real BaseTool class, allowing validation of class-based MCP tools without + requiring the full MCP Workbench package. """ def __init__(self, name: str, description: str): - """ - Initialize the tool with required metadata. + """Initialize the tool with required metadata. Args: name: The name of the tool @@ -48,8 +44,7 @@ def __init__(self, name: str, description: str): @abstractmethod async def execute(self) -> Callable[..., Any]: - """ - Returns a function to be executed as the tool. + """Returns a function to be executed as the tool. Returns: The function to be executed @@ -58,8 +53,7 @@ async def execute(self) -> Callable[..., Any]: def mcp_tool(name: str, description: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - """ - Mock mcp_tool decorator for validation purposes. + """Mock mcp_tool decorator for validation purposes. This provides the same interface as the real mcp_tool decorator, allowing validation of function-based MCP tools without requiring diff --git a/lambda/mcp_workbench/s3_event_handler.py b/lambda/mcp_workbench/s3_event_handler.py index 4f639b63f..679c85090 100644 --- a/lambda/mcp_workbench/s3_event_handler.py +++ b/lambda/mcp_workbench/s3_event_handler.py @@ -52,8 +52,7 @@ def _management_bearer_token() -> str | None: def trigger_workbench_rescan() -> dict[str, Any]: - """ - GET the workbench rescan route (same app management key as OIDC middleware when auth is on). + """GET the workbench rescan route (same app management key as OIDC middleware when auth is on). Waits briefly so rclone --poll-interval can surface new S3 keys in the tools mount. """ @@ -99,8 +98,7 @@ def trigger_workbench_rescan() -> dict[str, Any]: def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Handle S3 events from EventBridge: call MCP Workbench HTTP rescan (in-VPC). + """Handle S3 events from EventBridge: call MCP Workbench HTTP rescan (in-VPC). Reloads tools from the rclone-mounted bucket without restarting ECS tasks. """ @@ -138,9 +136,7 @@ def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: def validate_s3_event(event: dict[str, Any]) -> bool: - """ - Validate that the event is a proper S3 event from EventBridge. - """ + """Validate that the event is a proper S3 event from EventBridge.""" try: source = event.get("source") detail_type = event.get("detail-type") diff --git a/lambda/mcp_workbench/syntax_validator.py b/lambda/mcp_workbench/syntax_validator.py index af23a6586..ae6da3c2b 100644 --- a/lambda/mcp_workbench/syntax_validator.py +++ b/lambda/mcp_workbench/syntax_validator.py @@ -60,12 +60,9 @@ def exec_module(self, module: ModuleType) -> None: class _McpWorkbenchStubFinder(importlib.abc.MetaPathFinder): """Auto-stub any ``mcpworkbench.*`` import that hasn't already been mocked. - During Lambda-based validation we only have explicit mocks for - ``mcpworkbench.core.*``. Tools may import from other subpackages - (e.g. ``mcpworkbench.aws.*``) that don't exist in the Lambda - environment. This finder intercepts those imports and returns - lightweight stub modules so validation can proceed without - ImportErrors. + During Lambda-based validation we only have explicit mocks for ``mcpworkbench.core.*``. Tools may import from other + subpackages (e.g. ``mcpworkbench.aws.*``) that don't exist in the Lambda environment. This finder intercepts those + imports and returns lightweight stub modules so validation can proceed without ImportErrors. """ _PREFIX = "mcpworkbench." @@ -100,8 +97,7 @@ def __init__(self) -> None: self.max_code_size = 100_000 # 100KB def validate_code(self, code: str) -> ValidationResult: - """ - Validate Python code for syntax and required imports. + """Validate Python code for syntax and required imports. Args: code: Python code string to validate diff --git a/lambda/metrics/batch_job_metric.py b/lambda/metrics/batch_job_metric.py index de2cd7792..c9e305d8f 100644 --- a/lambda/metrics/batch_job_metric.py +++ b/lambda/metrics/batch_job_metric.py @@ -14,10 +14,9 @@ """Lambda handler for publishing CloudWatch metrics on Batch job state changes. -Captures SUBMITTED, RUNNING, SUCCEEDED, and FAILED state transitions from -EventBridge and publishes corresponding metrics to the LISA/BatchIngestion -namespace. This provides queue-level visibility regardless of how the -ingestion job was triggered (S3 event, scheduled, or manual upload). +Captures SUBMITTED, RUNNING, SUCCEEDED, and FAILED state transitions from EventBridge and publishes corresponding +metrics to the LISA/BatchIngestion namespace. This provides queue-level visibility regardless of how the ingestion job +was triggered (S3 event, scheduled, or manual upload). """ import json diff --git a/lambda/models/clients/litellm_client.py b/lambda/models/clients/litellm_client.py index 07afea04a..90d09057f 100644 --- a/lambda/models/clients/litellm_client.py +++ b/lambda/models/clients/litellm_client.py @@ -32,8 +32,7 @@ def __init__(self, base_uri: str, headers: Headers, verify: str | bool, timeout: self._verify = verify def list_models(self) -> list[dict[str, Any]]: - """ - Retrieve all models from the database. + """Retrieve all models from the database. Note, this is a superset of the models that are visible from the OpenAI API call to /models. If multiple models are defined with the same model name, only one will show in the OpenAI API call because of the model name, but @@ -50,14 +49,13 @@ def list_models(self) -> list[dict[str, Any]]: return models_list def add_model(self, model_name: str, litellm_params: dict[str, str]) -> dict[str, Any]: - """ - Add a new model configuration to the database. + """Add a new model configuration to the database. - The parameters for this method will be used for defining how LiteLLM accesses a model between both the model - and the litellm_params dictionary, and anything that is not LiteLLM-specific can be added to the - additional_metadata dictionary. Because LiteLLM uses this ID instead of other data, it is possible to add - two models with the same name, which causes ambiguous results when using the OpenAI API for listing models as - that only shows one model per model name. + The parameters for this method will be used for defining how LiteLLM accesses a model between both the model and + the litellm_params dictionary, and anything that is not LiteLLM-specific can be added to the additional_metadata + dictionary. Because LiteLLM uses this ID instead of other data, it is possible to add two models with the same + name, which causes ambiguous results when using the OpenAI API for listing models as that only shows one model + per model name. """ resp = requests.post( self._base_uri + "/model/new", @@ -72,11 +70,10 @@ def add_model(self, model_name: str, litellm_params: dict[str, str]) -> dict[str return resp.json() # type: ignore [no-any-return] def delete_model(self, identifier: str) -> None: - """ - Delete a model from the database. + """Delete a model from the database. - The identifier is the ID that LiteLLM generates on its end when creating a model, regardless of if the model - was defined in a static configuration file or if it was added dynamically. + The identifier is the ID that LiteLLM generates on its end when creating a model, regardless of if the model was + defined in a static configuration file or if it was added dynamically. """ requests.post( self._base_uri + "/model/delete", @@ -87,8 +84,7 @@ def delete_model(self, identifier: str) -> None: ) def get_model(self, identifier: str) -> dict[str, Any]: - """ - Get model metadata from the database. + """Get model metadata from the database. Due to what appears to be a bug in LiteLLM when accessing individual models from the /model/info route, we must list all models then filter out the one we want for this method call. @@ -100,8 +96,7 @@ def get_model(self, identifier: str) -> dict[str, Any]: return filtered_models[0] def create_guardrail(self, guardrail_config: dict[str, Any]) -> dict[str, Any]: - """ - Create a new guardrail configuration in LiteLLM. + """Create a new guardrail configuration in LiteLLM. Args: guardrail_config: Dictionary containing guardrail configuration including @@ -121,8 +116,7 @@ def create_guardrail(self, guardrail_config: dict[str, Any]) -> dict[str, Any]: return resp.json() # type: ignore [no-any-return] def update_guardrail(self, guardrail_id: str, guardrail_config: dict[str, Any]) -> dict[str, Any]: - """ - Update an existing guardrail configuration in LiteLLM. + """Update an existing guardrail configuration in LiteLLM. Args: guardrail_id: The LiteLLM guardrail ID to update @@ -142,8 +136,7 @@ def update_guardrail(self, guardrail_id: str, guardrail_config: dict[str, Any]) return resp.json() # type: ignore [no-any-return] def delete_guardrail(self, guardrail_id: str) -> None: - """ - Delete a guardrail configuration from LiteLLM. + """Delete a guardrail configuration from LiteLLM. Args: guardrail_id: The LiteLLM guardrail ID to delete @@ -157,8 +150,7 @@ def delete_guardrail(self, guardrail_id: str) -> None: resp.raise_for_status() def get_guardrail_info(self, guardrail_id: str) -> dict[str, Any]: - """ - Get information about a specific guardrail. + """Get information about a specific guardrail. Args: guardrail_id: The LiteLLM guardrail ID to retrieve @@ -176,8 +168,7 @@ def get_guardrail_info(self, guardrail_id: str) -> dict[str, Any]: return resp.json() # type: ignore [no-any-return] def apply_guardrail(self, guardrail_name: str, text: str) -> dict[str, Any]: - """ - Apply a guardrail to text content for validation. + """Apply a guardrail to text content for validation. Args: guardrail_name: Name of the guardrail to apply diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 41ac8c9b4..4a6219e94 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -163,10 +163,10 @@ class LoadBalancerConfig(BaseModel): class ScheduleType(str, Enum): - """Defines supported schedule types for resource scheduling""" + """Defines supported schedule types for resource scheduling.""" def __str__(self) -> str: - """Returns string representation of the enum value""" + """Returns string representation of the enum value.""" return str(self.value) DAILY = "DAILY" @@ -174,7 +174,7 @@ def __str__(self) -> str: class DaySchedule(BaseModel): - """Defines start and stop times for a single day""" + """Defines start and stop times for a single day.""" startTime: str = Field(pattern=r"^([01]?[0-9]|2[0-3]):[0-5][0-9]$") stopTime: str = Field(pattern=r"^([01]?[0-9]|2[0-3]):[0-5][0-9]$") @@ -182,9 +182,8 @@ class DaySchedule(BaseModel): @field_validator("startTime", "stopTime") @classmethod def validate_time_format(cls, v: str) -> str: - """Validates time format is HH:MM""" + """Validates time format is HH:MM.""" try: - datetime.strptime(v, "%H:%M") except ValueError: raise ValueError("Time must be in HH:MM format") @@ -192,8 +191,7 @@ def validate_time_format(cls, v: str) -> str: @model_validator(mode="after") def validate_stop_after_start(self) -> Self: - """Validates that stop time is after start time and at least 2 hours later""" - + """Validates that stop time is after start time and at least 2 hours later.""" start_time = datetime.strptime(self.startTime, "%H:%M").time() stop_time = datetime.strptime(self.stopTime, "%H:%M").time() @@ -221,7 +219,7 @@ def validate_stop_after_start(self) -> Self: class WeeklySchedule(BaseModel): - """Defines schedule for each day of the week with one start/stop time per day""" + """Defines schedule for each day of the week with one start/stop time per day.""" monday: DaySchedule | None = None tuesday: DaySchedule | None = None @@ -233,7 +231,7 @@ class WeeklySchedule(BaseModel): @model_validator(mode="after") def validate_daily_schedules(self) -> Self: - """Validates that at least one day has a schedule configured""" + """Validates that at least one day has a schedule configured.""" days = [self.monday, self.tuesday, self.wednesday, self.thursday, self.friday, self.saturday, self.sunday] if not any(days): @@ -243,14 +241,14 @@ def validate_daily_schedules(self) -> Self: class NextScheduledAction(BaseModel): - """Defines the next scheduled action for a model""" + """Defines the next scheduled action for a model.""" action: str = Field(pattern=r"^(START|STOP)$") scheduledTime: str class ScheduleFailure(BaseModel): - """Defines schedule failure information""" + """Defines schedule failure information.""" timestamp: str error: str @@ -258,7 +256,7 @@ class ScheduleFailure(BaseModel): class BaseSchedulingConfig(BaseModel): - """Base configuration shared by all scheduling types""" + """Base configuration shared by all scheduling types.""" timezone: str = Field(default="UTC") @@ -280,7 +278,7 @@ class BaseSchedulingConfig(BaseModel): @field_validator("timezone") @classmethod def validate_timezone(cls, v: str) -> str: - """Validates timezone is a valid IANA timezone identifier""" + """Validates timezone is a valid IANA timezone identifier.""" try: ZoneInfo(v) except Exception: @@ -289,14 +287,14 @@ def validate_timezone(cls, v: str) -> str: class DailySchedulingConfig(BaseSchedulingConfig): - """Configuration for daily schedules with different times per day""" + """Configuration for daily schedules with different times per day.""" scheduleType: Literal["DAILY"] = "DAILY" dailySchedule: WeeklySchedule @model_validator(mode="after") def validate_daily_schedule_exclusivity(self) -> Self: - """Validates that only dailySchedule is present for DAILY type""" + """Validates that only dailySchedule is present for DAILY type.""" # Check if any recurring schedule data was included if hasattr(self, "recurringSchedule"): raise ValueError("recurringSchedule not allowed for DAILY schedule type") @@ -304,14 +302,14 @@ def validate_daily_schedule_exclusivity(self) -> Self: class RecurringSchedulingConfig(BaseSchedulingConfig): - """Configuration for recurring schedules with same time every day""" + """Configuration for recurring schedules with same time every day.""" scheduleType: Literal["RECURRING"] = "RECURRING" recurringSchedule: DaySchedule @model_validator(mode="after") def validate_recurring_schedule_exclusivity(self) -> Self: - """Validates that only recurringSchedule is present for RECURRING type""" + """Validates that only recurringSchedule is present for RECURRING type.""" # Check if any daily schedule data was included if hasattr(self, "dailySchedule"): raise ValueError("dailySchedule not allowed for RECURRING schedule type") @@ -739,7 +737,7 @@ def validate_overlap(self) -> Self: """Validates overlap is not more than half of chunk size.""" if self.overlap > self.size / 2: raise ValueError( - f"chunk overlap ({self.overlap}) must be less than or equal to " f"half of chunk size ({self.size / 2})" + f"chunk overlap ({self.overlap}) must be less than or equal to half of chunk size ({self.size / 2})" ) return self @@ -1067,7 +1065,6 @@ def from_query_params(query_params: dict[str, str]) -> SortParams: Raises: ValidationError: If sortBy or sortOrder values are invalid """ - sort_by = CollectionSortBy.CREATED_AT if "sortBy" in query_params: try: @@ -1365,9 +1362,9 @@ class OpenSearchExistingClusterConfig(BaseModel): class RdsInstanceConfig(BaseModel): """Configuration schema for RDS Instances needed for LiteLLM scaling or PGVector RAG operations. - The optional fields can be omitted to create a new database instance, otherwise fill in all fields - to use an existing database instance. By default, IAM authentication is used. Set iamRdsAuth - to false in config to use password-based authentication. + The optional fields can be omitted to create a new database instance, otherwise fill in all fields to use an + existing database instance. By default, IAM authentication is used. Set iamRdsAuth to false in config to use + password-based authentication. """ username: str = Field(default="postgres", description="The username used for database connection.") diff --git a/lambda/models/handler/schedule_handlers.py b/lambda/models/handler/schedule_handlers.py index fda0a115f..5d14c5ff5 100644 --- a/lambda/models/handler/schedule_handlers.py +++ b/lambda/models/handler/schedule_handlers.py @@ -28,7 +28,7 @@ class ScheduleBaseHandler(BaseApiHandler): - """Base handler for schedule operations""" + """Base handler for schedule operations.""" def __init__( self, @@ -37,12 +37,12 @@ def __init__( model_table_resource: Any, guardrails_table_resource: Any, ): - """Initialize schedule handler""" + """Initialize schedule handler.""" super().__init__(autoscaling_client, stepfunctions_client, model_table_resource, guardrails_table_resource) class UpdateScheduleHandler(ScheduleBaseHandler): - """Handler class for UpdateSchedule requests""" + """Handler class for UpdateSchedule requests.""" def __call__( self, @@ -51,7 +51,7 @@ def __call__( user_groups: list[str] | None = None, is_admin: bool = False, ) -> UpdateScheduleResponse: - """Create or update a schedule for a model""" + """Create or update a schedule for a model.""" # Validate model exists, user access, and model status model_item = get_model_and_validate_status( self._model_table, model_id, user_groups=user_groups, is_admin=is_admin @@ -83,12 +83,12 @@ def __call__( class GetScheduleHandler(ScheduleBaseHandler): - """Handler class for GetSchedule requests""" + """Handler class for GetSchedule requests.""" def __call__( self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> GetScheduleResponse: - """Get current schedule configuration for a model""" + """Get current schedule configuration for a model.""" # Validate model exists and user access get_model_and_validate_access(self._model_table, model_id, user_groups, is_admin) @@ -108,12 +108,12 @@ def __call__( class DeleteScheduleHandler(ScheduleBaseHandler): - """Handler class for DeleteSchedule requests""" + """Handler class for DeleteSchedule requests.""" def __call__( self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> DeleteScheduleResponse: - """Delete a schedule for a model""" + """Delete a schedule for a model.""" # Validate model exists, user access, and model status get_model_and_validate_status(self._model_table, model_id, user_groups=user_groups, is_admin=is_admin) @@ -129,12 +129,12 @@ def __call__( class GetScheduleStatusHandler(ScheduleBaseHandler): - """Handler class for GetScheduleStatus requests""" + """Handler class for GetScheduleStatus requests.""" def __call__( self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> GetScheduleStatusResponse: - """Get current schedule status and next scheduled action for a model""" + """Get current schedule status and next scheduled action for a model.""" # Validate model exists and user access model_item = get_model_and_validate_access(self._model_table, model_id, user_groups, is_admin) diff --git a/lambda/models/handler/utils.py b/lambda/models/handler/utils.py index e93b458ad..c2f5f2411 100644 --- a/lambda/models/handler/utils.py +++ b/lambda/models/handler/utils.py @@ -41,8 +41,7 @@ def to_lisa_model(model_dict: dict[str, Any]) -> LISAModel: def get_model_and_validate_access( model_table: Any, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> dict[str, Any]: - """ - Get model from DynamoDB and validate user access + """Get model from DynamoDB and validate user access. Args: model_table: DynamoDB table resource @@ -82,8 +81,7 @@ def get_model_and_validate_status( user_groups: list[str] | None = None, is_admin: bool = False, ) -> dict[str, Any]: - """ - Get model from DynamoDB, validate user access, and check model status + """Get model from DynamoDB, validate user access, and check model status. Args: model_table: DynamoDB table resource diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index 4ca47f6ac..2356a6d36 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -319,8 +319,8 @@ async def update_context_window( ) -> UpdateContextWindowResponse: """Override the context window for a specific model. - Useful when automatic enrichment during model creation failed, or when - the stored value is incorrect and needs to be corrected. + Useful when automatic enrichment during model creation failed, or when the stored value is incorrect and needs to be + corrected. """ handler = UpdateContextWindowHandler( autoscaling_client=autoscaling, @@ -342,7 +342,7 @@ async def update_schedule( schedule_config: SchedulingConfig, request: Request, ) -> UpdateScheduleResponse: - """Endpoint to create or update a schedule for a model""" + """Endpoint to create or update a schedule for a model.""" admin_status, user_groups = get_admin_status_and_groups(request) update_schedule_handler = UpdateScheduleHandler( @@ -361,7 +361,7 @@ async def update_schedule( async def get_schedule( model_id: Annotated[str, Path(title="The unique model ID of the model to get schedule for")], request: Request ) -> GetScheduleResponse: - """Endpoint to get current schedule configuration for a model""" + """Endpoint to get current schedule configuration for a model.""" get_schedule_handler = GetScheduleHandler( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, @@ -389,7 +389,7 @@ async def get_schedule( async def delete_schedule( model_id: Annotated[str, Path(title="The unique model ID of the model to delete schedule for")], request: Request ) -> DeleteScheduleResponse: - """Endpoint to delete a schedule for a model""" + """Endpoint to delete a schedule for a model.""" admin_status, user_groups = get_admin_status_and_groups(request) delete_schedule_handler = DeleteScheduleHandler( @@ -407,7 +407,7 @@ async def get_schedule_status( model_id: Annotated[str, Path(title="The unique model ID of the model to get schedule status for")], request: Request, ) -> GetScheduleStatusResponse: - """Endpoint to get current schedule status and next scheduled action for a model""" + """Endpoint to get current schedule status and next scheduled action for a model.""" get_schedule_status_handler = GetScheduleStatusHandler( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, diff --git a/lambda/models/litellm_model_sync.py b/lambda/models/litellm_model_sync.py index 8dfe37685..9038d4183 100644 --- a/lambda/models/litellm_model_sync.py +++ b/lambda/models/litellm_model_sync.py @@ -14,11 +14,11 @@ """Lambda handler for syncing all models from DynamoDB to LiteLLM. -This Lambda is triggered when the LiteLLM PostgreSQL database is created or updated, -ensuring all models in the Models DynamoDB table are registered in LiteLLM. +This Lambda is triggered when the LiteLLM PostgreSQL database is created or updated, ensuring all models in the Models +DynamoDB table are registered in LiteLLM. -Note: This module intentionally does NOT import from models.state_machine.create_model -to avoid requiring GUARDRAILS_TABLE_NAME at module load time. +Note: This module intentionally does NOT import from models.state_machine.create_model to avoid requiring +GUARDRAILS_TABLE_NAME at module load time. """ import json diff --git a/lambda/models/model_api_key_cleanup.py b/lambda/models/model_api_key_cleanup.py index 0555f7d6f..f0a0030e4 100644 --- a/lambda/models/model_api_key_cleanup.py +++ b/lambda/models/model_api_key_cleanup.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Model API Key Cleanup Lambda +"""Model API Key Cleanup Lambda. This Lambda function removes the api_key field from existing Bedrock models that were created with the old LiteLLM version that required api_key = "ignored". # pragma: allowlist secret @@ -145,8 +144,7 @@ def get_database_connection() -> Any: def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Lambda handler for Bedrock model API key cleanup. + """Lambda handler for Bedrock model API key cleanup. Only processes models with modelName prefixed with "bedrock/". diff --git a/lambda/models/model_context_window_backfill.py b/lambda/models/model_context_window_backfill.py index 7743affae..5643bebb7 100644 --- a/lambda/models/model_context_window_backfill.py +++ b/lambda/models/model_context_window_backfill.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Context Window Backfill Lambda +"""Context Window Backfill Lambda. This Lambda function retroactively enriches existing model DynamoDB records with a context_window value. It runs once automatically during CDK deployment @@ -81,8 +80,8 @@ def _fetch_context_window_from_litellm( ) -> int | None: """Return max_input_tokens from LiteLLM for a non-LISA-managed model. - Falls back to list_models() filtered by model_id when litellm_id is absent - (pre-existing records created before litellm_id was stored). + Falls back to list_models() filtered by model_id when litellm_id is absent (pre-existing records created before + litellm_id was stored). """ if litellm_id: try: @@ -222,12 +221,10 @@ def _run_backfill() -> dict[str, Any]: def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - CloudFormation CustomResource handler for context window backfill. + """CloudFormation CustomResource handler for context window backfill. - Runs the backfill exactly once on Create. Update and Delete are no-ops. - The static PhysicalResourceId ensures CloudFormation never re-creates - or replaces this resource across subsequent deployments. + Runs the backfill exactly once on Create. Update and Delete are no-ops. The static PhysicalResourceId ensures + CloudFormation never re-creates or replaces this resource across subsequent deployments. """ request_type = event.get("RequestType", "") logger.info(f"context-window-backfill invoked: RequestType={request_type}") diff --git a/lambda/models/scheduling/__init__.py b/lambda/models/scheduling/__init__.py index beefa8459..682a95fe6 100644 --- a/lambda/models/scheduling/__init__.py +++ b/lambda/models/scheduling/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Scheduling module for LISA resource management""" +"""Scheduling module for LISA resource management.""" diff --git a/lambda/models/scheduling/schedule_management.py b/lambda/models/scheduling/schedule_management.py index 4ecbc2aa2..5adef85de 100644 --- a/lambda/models/scheduling/schedule_management.py +++ b/lambda/models/scheduling/schedule_management.py @@ -44,7 +44,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """Main Lambda handler for schedule management operations""" + """Main Lambda handler for schedule management operations.""" try: logger.info(f"Processing schedule management request: {json.dumps(event, default=str)}") @@ -71,7 +71,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: def update_schedule(event: dict[str, Any]) -> dict[str, Any]: - """Update an existing schedule for a model""" + """Update an existing schedule for a model.""" model_id = event["modelId"] schedule_config = event.get("scheduleConfig") auto_scaling_group = event.get("autoScalingGroup") @@ -127,7 +127,7 @@ def update_schedule(event: dict[str, Any]) -> dict[str, Any]: def delete_schedule(event: dict[str, Any]) -> dict[str, Any]: - """Delete a schedule for a model""" + """Delete a schedule for a model.""" model_id = event["modelId"] try: @@ -167,7 +167,7 @@ def delete_schedule(event: dict[str, Any]) -> dict[str, Any]: def get_schedule(event: dict[str, Any]) -> dict[str, Any]: - """Get current schedule configuration for a model""" + """Get current schedule configuration for a model.""" model_id = event["modelId"] try: @@ -200,7 +200,7 @@ def get_schedule(event: dict[str, Any]) -> dict[str, Any]: def create_scheduled_actions(model_id: str, auto_scaling_group: str, schedule_config: SchedulingConfig) -> list[str]: - """Create Auto Scaling scheduled actions based on schedule configuration""" + """Create Auto Scaling scheduled actions based on schedule configuration.""" scheduled_action_arns = [] if schedule_config.scheduleType == ScheduleType.RECURRING: @@ -223,7 +223,7 @@ def create_scheduled_actions(model_id: str, auto_scaling_group: str, schedule_co def create_scheduling_config(schedule_data: dict[str, Any]) -> SchedulingConfig: - """Create the appropriate scheduling config instance based on scheduleType""" + """Create the appropriate scheduling config instance based on scheduleType.""" schedule_type = schedule_data.get("scheduleType") if schedule_type == ScheduleType.DAILY: @@ -235,7 +235,7 @@ def create_scheduling_config(schedule_data: dict[str, Any]) -> SchedulingConfig: def get_existing_asg_capacity(auto_scaling_group: str) -> dict[str, int]: - """Get the existing Auto Scaling Group's current capacity configuration""" + """Get the existing Auto Scaling Group's current capacity configuration.""" try: response = autoscaling_client.describe_auto_scaling_groups(AutoScalingGroupNames=[auto_scaling_group]) @@ -256,7 +256,7 @@ def get_existing_asg_capacity(auto_scaling_group: str) -> dict[str, int]: def get_model_baseline_capacity(model_id: str) -> dict[str, int]: - """Get the baseline capacity configuration from the model's DynamoDB record""" + """Get the baseline capacity configuration from the model's DynamoDB record.""" try: response = model_table.get_item(Key={"model_id": model_id}) @@ -293,7 +293,7 @@ def get_model_baseline_capacity(model_id: str) -> dict[str, int]: def check_daily_immediate_scaling( auto_scaling_group: str, daily_schedule: WeeklySchedule, timezone_name: str, model_id: str ) -> None: - """Check current day and time, scale ASG down to 0 if outside scheduled window for daily schedules""" + """Check current day and time, scale ASG down to 0 if outside scheduled window for daily schedules.""" try: tz = ZoneInfo(timezone_name) now = datetime.now(tz) @@ -330,7 +330,7 @@ def check_daily_immediate_scaling( def scale_immediately(auto_scaling_group: str, day_schedule: DaySchedule, timezone_name: str, model_id: str) -> None: - """Check current time and immediately scale ASG up or down based on scheduled window""" + """Check current time and immediately scale ASG up or down based on scheduled window.""" try: tz = ZoneInfo(timezone_name) now = datetime.now(tz) @@ -393,7 +393,7 @@ def scale_immediately(auto_scaling_group: str, day_schedule: DaySchedule, timezo def create_recurring_scheduled_actions( model_id: str, auto_scaling_group: str, day_schedule: DaySchedule, timezone_name: str ) -> list[str]: - """Create scheduled actions for recurring schedule""" + """Create scheduled actions for recurring schedule.""" scheduled_action_arns = [] # Get baseline capacity config from model DDB record @@ -461,7 +461,7 @@ def create_recurring_scheduled_actions( def create_daily_scheduled_actions( model_id: str, auto_scaling_group: str, daily_schedule: WeeklySchedule, timezone_name: str ) -> list[str]: - """Create scheduled actions for daily schedule (different times each day with one start/stop time per day)""" + """Create scheduled actions for daily schedule (different times each day with one start/stop time per day).""" scheduled_action_arns = [] # Get baseline capacity config from DDB record @@ -536,19 +536,19 @@ def create_daily_scheduled_actions( def time_to_cron(time_str: str) -> str: - """Convert time string (HH:MM) to cron expression""" + """Convert time string (HH:MM) to cron expression.""" hour, minute = map(int, time_str.split(":")) return f"{minute} {hour} * * *" def time_to_cron_with_day(time_str: str, day_of_week: int) -> str: - """Convert time string (HH:MM) to cron expression with day""" + """Convert time string (HH:MM) to cron expression with day.""" hour, minute = map(int, time_str.split(":")) return f"{minute} {hour} * * {day_of_week}" def construct_scheduled_action_arn(auto_scaling_group: str, action_name: str) -> str: - """Construct ARN for a scheduled action""" + """Construct ARN for a scheduled action.""" region = os.environ.get("AWS_REGION", "us-east-1") account_id = os.environ.get("AWS_ACCOUNT_ID") @@ -567,7 +567,7 @@ def construct_scheduled_action_arn(auto_scaling_group: str, action_name: str) -> def delete_scheduled_actions(scheduled_action_arns: list[str]) -> None: - """Delete Auto Scaling scheduled actions by ARN""" + """Delete Auto Scaling scheduled actions by ARN.""" for arn in scheduled_action_arns: try: # Extract action name and ASG name from ARN @@ -589,7 +589,7 @@ def delete_scheduled_actions(scheduled_action_arns: list[str]) -> None: def cleanup_scheduled_actions(scheduled_action_arns: list[str]) -> None: - """Clean up scheduled actions (used for error recovery)""" + """Clean up scheduled actions (used for error recovery).""" for arn in scheduled_action_arns: try: action_name = arn.split(":scheduledActionName/")[-1] @@ -602,7 +602,7 @@ def cleanup_scheduled_actions(scheduled_action_arns: list[str]) -> None: def cleanup_scheduled_actions_by_name_pattern(auto_scaling_group: str, model_id: str) -> None: - """Delete all scheduled actions for a model by finding them via name pattern""" + """Delete all scheduled actions for a model by finding them via name pattern.""" try: # Get all scheduled actions for the Auto Scaling Group response = autoscaling_client.describe_scheduled_actions(AutoScalingGroupName=auto_scaling_group) @@ -627,7 +627,6 @@ def cleanup_scheduled_actions_by_name_pattern(auto_scaling_group: str, model_id: or action_name.startswith(f"{model_id}-saturday-") or action_name.startswith(f"{model_id}-sunday-") ): - try: autoscaling_client.delete_scheduled_action( AutoScalingGroupName=auto_scaling_group, ScheduledActionName=action_name @@ -655,7 +654,7 @@ def cleanup_scheduled_actions_by_name_pattern(auto_scaling_group: str, model_id: def calculate_next_scheduled_action(schedule_config: SchedulingConfig, timezone_name: str) -> dict[str, str] | None: - """Calculate the next scheduled action (START or STOP) based on the schedule configuration""" + """Calculate the next scheduled action (START or STOP) based on the schedule configuration.""" try: tz = ZoneInfo(timezone_name) now = datetime.now(tz) @@ -672,7 +671,7 @@ def calculate_next_scheduled_action(schedule_config: SchedulingConfig, timezone_ def _calculate_next_recurring_action(day_schedule: DaySchedule, now: datetime, tz: ZoneInfo) -> dict[str, str]: - """Calculate next action for recurring schedule""" + """Calculate next action for recurring schedule.""" # Parse schedule times start_hour, start_minute = map(int, day_schedule.startTime.split(":")) stop_hour, stop_minute = map(int, day_schedule.stopTime.split(":")) @@ -700,7 +699,7 @@ def _calculate_next_recurring_action(day_schedule: DaySchedule, now: datetime, t def _calculate_next_daily_action(daily_schedule: WeeklySchedule, now: datetime, tz: ZoneInfo) -> dict[str, str] | None: - """Calculate next action for daily schedule""" + """Calculate next action for daily schedule.""" current_weekday = now.weekday() day_schedules = [ @@ -738,7 +737,7 @@ def _calculate_next_daily_action(daily_schedule: WeeklySchedule, now: datetime, def _get_next_action_for_today(day_schedule: DaySchedule, now: datetime, tz: ZoneInfo) -> dict[str, str] | None: - """Get next action for today's schedule only""" + """Get next action for today's schedule only.""" today = now.date() # Parse schedule times @@ -764,7 +763,7 @@ def _get_next_action_for_today(day_schedule: DaySchedule, now: datetime, tz: Zon def merge_schedule_data(model_id: str, partial_update: dict[str, Any]) -> dict[str, Any]: - """Merge partial schedule update with existing schedule data""" + """Merge partial schedule update with existing schedule data.""" # Get existing schedule data from model_config.autoScalingConfig.scheduling existing_data = {} try: @@ -803,7 +802,7 @@ def merge_schedule_data(model_id: str, partial_update: dict[str, Any]) -> dict[s def get_existing_scheduled_action_arns(model_id: str) -> list[str]: - """Get existing scheduled action ARNs for a model""" + """Get existing scheduled action ARNs for a model.""" try: response = model_table.get_item(Key={"model_id": model_id}) @@ -825,7 +824,7 @@ def get_existing_scheduled_action_arns(model_id: str) -> list[str]: def update_model_schedule_record( model_id: str, scheduling_config: SchedulingConfig | None, scheduled_action_arns: list[str], enabled: bool ) -> None: - """Update model record in DynamoDB with schedule information""" + """Update model record in DynamoDB with schedule information.""" try: # Check if model_config.autoScalingConfig exists first response = model_table.get_item(Key={"model_id": model_id}) diff --git a/lambda/models/scheduling/schedule_monitoring.py b/lambda/models/scheduling/schedule_monitoring.py index 666e0beaa..ea3ef7944 100644 --- a/lambda/models/scheduling/schedule_monitoring.py +++ b/lambda/models/scheduling/schedule_monitoring.py @@ -36,7 +36,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """Main Lambda handler for CloudWatch Events from Auto Scaling Groups""" + """Main Lambda handler for CloudWatch Events from Auto Scaling Groups.""" logger.info(f"Processing event - RequestId: {context.aws_request_id}") try: @@ -59,7 +59,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_autoscaling_event(event: dict[str, Any]) -> dict[str, Any]: - """Handle Auto Scaling Group CloudWatch events""" + """Handle Auto Scaling Group CloudWatch events.""" try: detail = event.get("detail", {}) event_type = event.get("detail-type", "") @@ -86,7 +86,7 @@ def handle_autoscaling_event(event: dict[str, Any]) -> dict[str, Any]: def handle_successful_scaling(model_id: str, auto_scaling_group: str, detail: dict[str, Any]) -> dict[str, Any]: - """Handle successful Auto Scaling actions using ASG state""" + """Handle successful Auto Scaling actions using ASG state.""" try: # Check ASG state to determine model status try: @@ -138,7 +138,7 @@ def handle_successful_scaling(model_id: str, auto_scaling_group: str, detail: di def sync_model_status(event: dict[str, Any]) -> dict[str, Any]: - """Manually sync model status using ASG state""" + """Manually sync model status using ASG state.""" model_id = event.get("modelId") if not model_id: raise ValueError("modelId is required for sync_status operation") @@ -208,7 +208,7 @@ def sync_model_status(event: dict[str, Any]) -> dict[str, Any]: def find_model_by_asg_name(asg_name: str) -> str | None: - """Find model ID by looking up which model uses the given Auto Scaling Group""" + """Find model ID by looking up which model uses the given Auto Scaling Group.""" try: response = model_table.scan( FilterExpression="auto_scaling_group = :asg_name", @@ -227,7 +227,7 @@ def find_model_by_asg_name(asg_name: str) -> str | None: def update_model_status(model_id: str, new_status: ModelStatus, reason: str) -> None: - """Update model status in DynamoDB""" + """Update model status in DynamoDB.""" try: # Convert enum to string value for DynamoDB status_str = new_status.value if hasattr(new_status, "value") else str(new_status) @@ -250,7 +250,7 @@ def update_model_status(model_id: str, new_status: ModelStatus, reason: str) -> def get_model_info(model_id: str) -> dict[str, Any] | None: - """Get model information from DynamoDB""" + """Get model information from DynamoDB.""" try: response = model_table.get_item(Key={"model_id": model_id}) diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index 0454c636e..5045f3200 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -76,8 +76,7 @@ def get_container_path(inference_container_type: InferenceContainer) -> str: - """ - Get the LISA repository path for referencing container build scripts. + """Get the LISA repository path for referencing container build scripts. Paths are relative to /lib/serve/ecs-model/ """ @@ -91,7 +90,7 @@ def get_container_path(inference_container_type: InferenceContainer) -> str: def adjust_initial_capacity_for_schedule(prepared_event: dict[str, Any]) -> None: - """Adjust Auto Scaling Group initial capacity based on schedule configuration""" + """Adjust Auto Scaling Group initial capacity based on schedule configuration.""" try: # Check if scheduling is configured auto_scaling_config = prepared_event.get("autoScalingConfig", {}) or {} @@ -497,12 +496,10 @@ def handle_poll_create_stack(event: dict[str, Any], context: Any) -> dict[str, A def handle_poll_model_ready(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Poll ASG to confirm model instances are healthy before marking as InService. + """Poll ASG to confirm model instances are healthy before marking as InService. - This handler checks that the Auto Scaling Group has healthy instances running - before proceeding to add the model to LiteLLM. This ensures the model is actually - ready to serve requests, not just that the infrastructure was created. + This handler checks that the Auto Scaling Group has healthy instances running before proceeding to add the model to + LiteLLM. This ensures the model is actually ready to serve requests, not just that the infrastructure was created. """ output_dict = deepcopy(event) model_id = event.get("modelId", "unknown") @@ -713,7 +710,7 @@ def handle_add_guardrails_to_litellm(event: dict[str, Any], context: Any) -> dic # Transform guardrail config to LiteLLM format litellm_guardrail_config = { "guardrail": { - "guardrail_name": f'{guardrail_config["guardrailName"]}-{model_id}', + "guardrail_name": f"{guardrail_config['guardrailName']}-{model_id}", "litellm_params": { "guardrail": "bedrock", "mode": str(guardrail_config.get("mode", "pre_call")), @@ -869,8 +866,8 @@ def _fetch_context_window_from_s3(model_name: Any, model_type: str) -> int | Non def handle_enrich_context_window(event: dict[str, Any], context: Any) -> dict[str, Any]: """Enrich model DDB record with context window size. - Non-blocking — failure is logged as a warning but does not raise an exception - or affect the state machine's success path. + Non-blocking — failure is logged as a warning but does not raise an exception or affect the state machine's success + path. """ output_dict = deepcopy(event) model_id = event.get("modelId") @@ -919,8 +916,7 @@ def handle_enrich_context_window(event: dict[str, Any], context: Any) -> dict[st def handle_failure(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Handle failures from state machine. + """Handle failures from state machine. Possible causes of failures would be: 1. Docker Image failed to replicate into ECR in expected amount of time diff --git a/lambda/models/state_machine/schedule_handlers.py b/lambda/models/state_machine/schedule_handlers.py index 8efa510ac..2a141b0bc 100644 --- a/lambda/models/state_machine/schedule_handlers.py +++ b/lambda/models/state_machine/schedule_handlers.py @@ -33,7 +33,7 @@ def handle_schedule_creation(event: dict[str, Any], context: Any) -> dict[str, Any]: - """Create Auto Scaling scheduled actions for the model if scheduling is configured""" + """Create Auto Scaling scheduled actions for the model if scheduling is configured.""" logger.info(f"Processing schedule creation for model: {event.get('modelId')}") output_dict = event.copy() @@ -85,7 +85,7 @@ def handle_schedule_creation(event: dict[str, Any], context: Any) -> dict[str, A def handle_schedule_update(event: dict[str, Any], context: Any) -> dict[str, Any]: - """Update Auto Scaling scheduled actions when schedule configuration changes""" + """Update Auto Scaling scheduled actions when schedule configuration changes.""" logger.info(f"Processing schedule update for model: {event.get('modelId')}") output_dict = event.copy() @@ -127,7 +127,7 @@ def handle_schedule_update(event: dict[str, Any], context: Any) -> dict[str, Any def handle_cleanup_schedule(event: dict[str, Any], context: Any) -> dict[str, Any]: - """Clean up scheduled actions before deleting the model""" + """Clean up scheduled actions before deleting the model.""" logger.info(f"Cleaning up schedule for model: {event.get('modelId')}") output_dict = event.copy() @@ -146,7 +146,7 @@ def handle_cleanup_schedule(event: dict[str, Any], context: Any) -> dict[str, An def update_schedule_failure_status(model_id: str, error_message: str) -> None: - """Update model with schedule failure status using boolean flags""" + """Update model with schedule failure status using boolean flags.""" try: failure_info = {"timestamp": iso_string(), "error": error_message, "retryCount": 0} diff --git a/lambda/models/state_machine/update_model.py b/lambda/models/state_machine/update_model.py index 39c058d63..96536b74e 100644 --- a/lambda/models/state_machine/update_model.py +++ b/lambda/models/state_machine/update_model.py @@ -156,8 +156,7 @@ def _get_metadata_update_handlers(model_config: dict[str, Any], model_id: str) - def _process_metadata_updates( model_config: dict[str, Any], update_payload: dict[str, Any], model_id: str ) -> tuple[bool, dict[str, Any]]: - """ - Process metadata updates. + """Process metadata updates. Args: model_config: The model configuration dictionary to update @@ -186,8 +185,7 @@ def _process_metadata_updates( def handle_job_intake(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Handle initial UpdateModel job submission. + """Handle initial UpdateModel job submission. This handler will perform the following actions: 1. Determine if any metadata (streaming, or modelType) changes are required @@ -376,8 +374,7 @@ def handle_job_intake(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_poll_capacity(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Poll autoscaling and target group to confirm if the capacity is done updating. + """Poll autoscaling and target group to confirm if the capacity is done updating. This handler will: 1. Get the ASG's current status. If it is still updating, then exit with a boolean to indicate for more polling @@ -407,8 +404,7 @@ def handle_poll_capacity(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_finish_update(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Finalize update in DDB. + """Finalize update in DDB. 1. If the model was enabled from the Stopped state, add the model back to LiteLLM, set status to InService in DDB 2. If the model was disabled from the InService state, set status to Stopped @@ -506,8 +502,7 @@ def handle_finish_update(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_update_guardrails(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Update guardrails for a model in LiteLLM and DynamoDB. + """Update guardrails for a model in LiteLLM and DynamoDB. This handler will: 1. Process guardrails configuration updates from the event @@ -606,7 +601,7 @@ def handle_update_guardrails(event: dict[str, Any], context: Any) -> dict[str, A # Transform guardrail config to LiteLLM format for update litellm_guardrail_config = { "guardrail": { - "guardrail_name": f'{guardrail_config["guardrailName"]}-{model_id}', + "guardrail_name": f"{guardrail_config['guardrailName']}-{model_id}", "litellm_params": { "guardrail": "bedrock", "mode": str(guardrail_config.get("mode", "pre_call")), @@ -652,11 +647,10 @@ def handle_update_guardrails(event: dict[str, Any], context: Any) -> dict[str, A logger.info(f"Successfully updated guardrail: {guardrail_name}") else: - # Transform guardrail config to LiteLLM format litellm_guardrail_config = { "guardrail": { - "guardrail_name": f'{guardrail_config["guardrailName"]}-{model_id}', + "guardrail_name": f"{guardrail_config['guardrailName']}-{model_id}", "litellm_params": { "guardrail": "bedrock", "mode": str(guardrail_config.get("mode", "pre_call")), @@ -926,8 +920,7 @@ def update_ecs_service(cluster_arn: str, service_arn: str, task_definition_arn: def handle_ecs_update(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Update ECS task definition with new environment variables and update service. + """Update ECS task definition with new environment variables and update service. This handler will: 1. Retrieve current task definition from ECS @@ -989,8 +982,7 @@ def handle_ecs_update(event: dict[str, Any], context: Any) -> dict[str, Any]: def handle_poll_ecs_deployment(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Monitor ECS service deployment progress. + """Monitor ECS service deployment progress. This handler will: 1. Check if ECS service deployment is complete diff --git a/lambda/prompt_templates/models.py b/lambda/prompt_templates/models.py index ef0811581..e6f554b85 100644 --- a/lambda/prompt_templates/models.py +++ b/lambda/prompt_templates/models.py @@ -28,8 +28,8 @@ class PromptTemplateType(StrEnum): class PromptTemplateModel(BaseModel): - """ - A Pydantic model representing a template for prompts. + """A Pydantic model representing a template for prompts. + Contains metadata and functionality to create new revisions. """ @@ -60,8 +60,7 @@ class PromptTemplateModel(BaseModel): body: str def new_revision(self, update: dict[str, Any]) -> "PromptTemplateModel": - """ - Create a new revision of the current prompt template. + """Create a new revision of the current prompt template. Args: update (Dict[str, Any]): A dictionary containing fields to update in the new revision. diff --git a/lambda/repository/collection_repo.py b/lambda/repository/collection_repo.py index f46cf9a18..5753fba96 100644 --- a/lambda/repository/collection_repo.py +++ b/lambda/repository/collection_repo.py @@ -39,8 +39,7 @@ class CollectionRepository: """Collection repository for DynamoDB operations.""" def __init__(self, table_name: str | None = None) -> None: - """ - Initialize the Collection Repository. + """Initialize the Collection Repository. Args: table_name: Optional table name override for testing @@ -51,8 +50,7 @@ def __init__(self, table_name: str | None = None) -> None: logger.info(f"Initialized CollectionRepository with table: {table_name}") def create(self, collection: RagCollectionConfig) -> RagCollectionConfig: - """ - Create a new collection in DynamoDB. + """Create a new collection in DynamoDB. Args: collection: The collection configuration to create @@ -97,8 +95,7 @@ def create(self, collection: RagCollectionConfig) -> RagCollectionConfig: raise CollectionRepositoryError(f"Unexpected error creating collection: {str(e)}") def find_by_id(self, collection_id: str, repository_id: str) -> RagCollectionConfig | None: - """ - Find a collection by its ID and repository ID. + """Find a collection by its ID and repository ID. Args: collection_id: The collection ID @@ -136,8 +133,7 @@ def update( updates: dict[str, Any], expected_version: str | None = None, ) -> RagCollectionConfig: - """ - Update a collection with optimistic locking. + """Update a collection with optimistic locking. Args: collection_id: The collection ID @@ -213,8 +209,7 @@ def update( raise CollectionRepositoryError(f"Unexpected error updating collection: {str(e)}") def delete(self, collection_id: str, repository_id: str) -> bool: - """ - Delete a collection from DynamoDB. + """Delete a collection from DynamoDB. Args: collection_id: The collection ID @@ -256,8 +251,7 @@ def list_by_repository( sort_by: CollectionSortBy = CollectionSortBy.CREATED_AT, sort_order: SortOrder = SortOrder.DESC, ) -> tuple[list[RagCollectionConfig], dict[str, str] | None]: - """ - List collections for a repository with pagination, filtering, and sorting. + """List collections for a repository with pagination, filtering, and sorting. Args: repository_id: The repository ID @@ -333,8 +327,7 @@ def list_by_repository( raise CollectionRepositoryError(f"Failed to list collections: {str(e)}") def count_by_repository(self, repository_id: str, status: CollectionStatus | None = None) -> int: - """ - Count collections in a repository. + """Count collections in a repository. Args: repository_id: The repository ID @@ -400,8 +393,8 @@ def find_by_name(self, repository_id: str, collection_name: str) -> RagCollectio raise CollectionRepositoryError(f"Failed to find collection by name: {str(e)}") def find_collections_using_model(self, model_id: str) -> list[dict[str, str]]: - """ - Find all collections that use a specific embedding model. + """Find all collections that use a specific embedding model. + Excludes collections with status indicating they are deleted or archived. Args: diff --git a/lambda/repository/collection_service.py b/lambda/repository/collection_service.py index e4f001988..11f0d8371 100644 --- a/lambda/repository/collection_service.py +++ b/lambda/repository/collection_service.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Collection service for business logic.""" import heapq @@ -107,7 +106,6 @@ def create_collection( Raises: ValidationError: If collection name already exists in repository """ - # Check if collection name already exists in this repository existing = self.collection_repo.find_by_name(collection.repositoryId, collection.name) # type: ignore[arg-type] if existing: @@ -155,8 +153,8 @@ def list_collections( ) -> tuple[list[RagCollectionConfig], dict[str, str] | None]: """List collections with access control. - For Bedrock KB repositories, default collections are persisted to the database - and will be included in the query results automatically. + For Bedrock KB repositories, default collections are persisted to the database and will be included in the query + results automatically. For other repository types, a virtual default collection is generated if needed. """ @@ -195,7 +193,7 @@ def update_collection( Args: collection_id: Collection ID to update repository_id: Repository ID - request: RagCollectionConfig with fields to update + collection_data: RagCollectionConfig with fields to update username: Username for access control user_groups: User groups for access control is_admin: Whether user is admin @@ -349,7 +347,7 @@ def delete_collection( ingestion_job_repo.save(deletion_job) ingestion_service.create_delete_job(deletion_job) - logger.info(f"Submitted {deletion_type} deletion job {deletion_job.id} " f"for repository {repository_id}") + logger.info(f"Submitted {deletion_type} deletion job {deletion_job.id} for repository {repository_id}") response = { "jobId": deletion_job.id, @@ -456,8 +454,7 @@ def list_all_user_collections( filter_text: str | None = None, sort_params: SortParams | None = None, ) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: - """ - List all collections user has access to across all repositories. + """List all collections user has access to across all repositories. This method orchestrates the complete workflow: 1. Get accessible repositories @@ -469,6 +466,7 @@ def list_all_user_collections( username: Username for access control user_groups: User groups for access control is_admin: Whether user is admin + is_rag_admin: Whether user is RAG admin page_size: Number of items per page pagination_token: Pagination token from previous request filter_text: Optional text filter for name/description @@ -532,8 +530,7 @@ def list_all_user_collections( def _get_accessible_repositories( self, username: str, user_groups: list[str], is_admin: bool ) -> list[dict[str, Any]]: - """ - Get all repositories user has access to. + """Get all repositories user has access to. Args: username: Username for access control @@ -554,8 +551,7 @@ def _get_accessible_repositories( return accessible def _has_repository_access(self, user_groups: list[str], repository: dict[str, Any]) -> bool: - """ - Check if user has access to repository based on groups. + """Check if user has access to repository based on groups. Args: user_groups: User groups for access control @@ -581,8 +577,7 @@ def _has_repository_access(self, user_groups: list[str], repository: dict[str, A def _enrich_with_repository_metadata( self, collections: list[RagCollectionConfig], repositories: list[dict[str, Any]] ) -> list[dict[str, Any]]: - """ - Enrich collections with repository metadata. + """Enrich collections with repository metadata. Args: collections: List of collection configurations @@ -613,8 +608,7 @@ def _enrich_with_repository_metadata( return enriched def _estimate_total_collections(self, repositories: list[dict[str, Any]]) -> int: - """ - Estimate total number of collections across repositories. + """Estimate total number of collections across repositories. Args: repositories: List of repository configurations @@ -644,8 +638,7 @@ def _paginate_collections( filter_text: str | None, sort_params: SortParams, ) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: - """ - Simple pagination strategy for small-to-medium deployments. + """Simple pagination strategy for small-to-medium deployments. Aggregates all collections in memory, applies filtering and sorting, then returns requested page. @@ -741,8 +734,7 @@ def _paginate_collections( return enriched, next_token def _matches_filter(self, collection: RagCollectionConfig, filter_text: str) -> bool: - """ - Check if collection matches text filter. + """Check if collection matches text filter. Args: collection: Collection to check @@ -766,8 +758,7 @@ def _matches_filter(self, collection: RagCollectionConfig, filter_text: str) -> def _sort_collections( self, collections: list[RagCollectionConfig], sort_params: SortParams ) -> list[RagCollectionConfig]: - """ - Sort collections by specified field and order. + """Sort collections by specified field and order. Args: collections: List of collections to sort @@ -796,8 +787,7 @@ def _paginate_large_collections( filter_text: str | None, sort_params: SortParams, ) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: - """ - Scalable pagination strategy for large deployments. + """Scalable pagination strategy for large deployments. Uses incremental merge with per-repository cursors to handle 1000+ collections efficiently without loading all into memory. @@ -944,8 +934,7 @@ def _paginate_large_collections( def _merge_sorted_batches( self, batches: list[dict[str, Any]], sort_by: str, sort_order: str ) -> list[RagCollectionConfig]: - """ - Merge pre-sorted batches from multiple repositories using min-heap. + """Merge pre-sorted batches from multiple repositories using min-heap. Time Complexity: O(N log K) where N = total collections, K = number of repositories Space Complexity: O(N) for merged result @@ -1008,8 +997,7 @@ def _merge_sorted_batches( return merged def _get_sort_key(self, collection: RagCollectionConfig, sort_by: str) -> Any: - """ - Extract sort key from collection. + """Extract sort key from collection. Args: collection: Collection to extract key from diff --git a/lambda/repository/embeddings.py b/lambda/repository/embeddings.py index cdc440304..a299ac1eb 100644 --- a/lambda/repository/embeddings.py +++ b/lambda/repository/embeddings.py @@ -66,9 +66,7 @@ def _get_http_session() -> requests.Session: class RagEmbeddings(BaseModel): - """ - Handles document embeddings through LiteLLM using management credentials. - """ + """Handles document embeddings through LiteLLM using management credentials.""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -106,9 +104,10 @@ def __init__(self, model_name: str, id_token: str | None = None, **data: Any) -> raise def embed_documents(self, texts: list[str]) -> list[list[float]]: - """ - Generate embeddings for a list of documents, automatically batching - to stay within the embedding server's max batch size. + """Generate embeddings for a list of documents. + + Automatically batching to stay within the embedding server's max + batch size. Uses input_type="passage" so litellm applies the correct model-specific prefix for document indexing (e.g. "passage: " for E5 models). @@ -158,7 +157,7 @@ def _embed_batch_with_retry(self, texts: list[str], input_type: str | None = Non if attempt < MAX_RETRIES: backoff = INITIAL_BACKOFF_SECONDS * (2 ** (attempt - 1)) logger.warning( - f"Embedding attempt {attempt}/{MAX_RETRIES} failed: {e}. " f"Retrying in {backoff:.1f}s..." + f"Embedding attempt {attempt}/{MAX_RETRIES} failed: {e}. Retrying in {backoff:.1f}s..." ) time.sleep(backoff) else: diff --git a/lambda/repository/ingestion_service.py b/lambda/repository/ingestion_service.py index f544038c9..433ac78eb 100644 --- a/lambda/repository/ingestion_service.py +++ b/lambda/repository/ingestion_service.py @@ -116,8 +116,7 @@ def _merge_metadata_for_ingestion( collection: dict | None, document_metadata: dict[str, Any] | None = None, ) -> dict[str, Any] | None: - """ - Merge metadata from repository, collection, and document sources for ingestion jobs. + """Merge metadata from repository, collection, and document sources for ingestion jobs. This ensures the ingestion job contains the complete merged metadata that will be applied to documents during ingestion, following the hierarchy: diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index e422271a4..03127996d 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -100,8 +100,7 @@ @api_wrapper def list_all(event: dict, context: dict) -> list[dict[str, Any]]: - """ - List all available repositories that the user has access to. + """List all available repositories that the user has access to. Args: event: Lambda event containing user authentication @@ -122,8 +121,7 @@ def list_all(event: dict, context: dict) -> list[dict[str, Any]]: @api_wrapper @admin_only def list_status(event: dict, context: dict) -> dict[str, Any]: - """ - Get all repository status. + """Get all repository status. Returns: List of repository status @@ -299,8 +297,8 @@ def get_repository(event: dict[str, Any], repository_id: str) -> dict[str, Any]: def create_bedrock_collection(event: dict, context: dict) -> dict[str, Any]: - """ - Create collections for a Bedrock Knowledge Base repository based on pipeline configurations. + """Create collections for a Bedrock Knowledge Base repository based on pipeline configurations. + This is called by the state machine during repository creation. Each pipeline configuration represents a data source and should have a corresponding collection. @@ -435,6 +433,7 @@ def create_bedrock_collection(event: dict, context: dict) -> dict[str, Any]: def create_default_collection(event: dict, context: dict) -> dict[str, Any]: """Persist the default collection for a non-Bedrock repository after stack creation. + Called by the state machine for OpenSearch/PGVector repositories. """ try: @@ -481,8 +480,7 @@ def create_default_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper @rag_admin_or_admin def create_collection(event: dict, context: dict) -> dict[str, Any]: - """ - Create a new collection within a vector store. + """Create a new collection within a vector store. Args: event (dict): The Lambda event object containing: @@ -549,8 +547,7 @@ def create_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper def get_collection(event: dict, context: dict) -> dict[str, Any]: - """ - Get a collection by ID within a vector store. + """Get a collection by ID within a vector store. Args: event (dict): The Lambda event object containing: @@ -614,8 +611,7 @@ def get_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper @rag_admin_or_admin def update_collection(event: dict, context: dict) -> dict[str, Any]: - """ - Update a collection within a vector store. + """Update a collection within a vector store. Args: event (dict): The Lambda event object containing: @@ -677,8 +673,7 @@ def update_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper @rag_admin_or_admin def delete_collection(event: dict, context: dict) -> dict[str, Any]: - """ - Delete a collection (regular or default) within a vector store. + """Delete a collection (regular or default) within a vector store. Path: /repository/{repositoryId}/collection/{collectionId} @@ -734,8 +729,7 @@ def delete_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper def list_collections(event: dict, context: dict) -> dict[str, Any]: - """ - List collections in a repository with pagination, filtering, and sorting. + """List collections in a repository with pagination, filtering, and sorting. Args: event (dict): The Lambda event object containing: @@ -848,8 +842,7 @@ def list_collections(event: dict, context: dict) -> dict[str, Any]: @api_wrapper def list_user_collections(event: dict, context: dict) -> dict[str, Any]: - """ - List all collections user has access to across all repositories. + """List all collections user has access to across all repositories. Args: event (dict): The Lambda event object containing: @@ -944,7 +937,10 @@ def list_user_collections(event: dict, context: dict) -> dict[str, Any]: def _ensure_document_ownership(event: dict[str, Any], docs: list[RagDocument]) -> None: - """Verify ownership of documents. Admins and RAG admins can delete any document.""" + """Verify ownership of documents. + + Admins and RAG admins can delete any document. + """ username = get_username(event) if not is_admin(event) and not is_rag_admin(event): for doc in docs: @@ -954,8 +950,10 @@ def _ensure_document_ownership(event: dict[str, Any], docs: list[RagDocument]) - @api_wrapper def delete_documents(event: dict, context: dict) -> dict[str, Any]: - """Purge all records related to the specified document from the RAG repository. If a documentId is supplied, a - single document will be removed. If a documentName is supplied, all documents with that name will be removed + """Purge all records related to the specified document from the RAG repository. + + If a documentId is supplied, a single document will be removed. If a documentName is supplied, all + documents with that name will be removed Args: event (dict): The Lambda event object containing: @@ -1078,9 +1076,7 @@ def handle_deprecated_chunking_strategy(request: IngestDocumentRequest, query_pa # Create chunkingStrategy from legacy parameters request.chunkingStrategy = {"type": "fixed", "size": chunk_size, "overlap": chunk_overlap} - logger.info( - f"Migrated legacy parameters to chunkingStrategy: " f"size={chunk_size}, overlap={chunk_overlap}" - ) + logger.info(f"Migrated legacy parameters to chunkingStrategy: size={chunk_size}, overlap={chunk_overlap}") if "collectionId" in query_params: request.collectionId = query_params.get("collectionId") @@ -1161,6 +1157,7 @@ def get_document(event: dict, context: dict) -> dict[str, Any]: Args: event (dict): The Lambda event object containing: + context: Lambda context path_params: repositoryId - the repository documentId - the document @@ -1185,8 +1182,10 @@ def get_document(event: dict, context: dict) -> dict[str, Any]: @api_wrapper def download_document(event: dict, context: dict) -> str: """Generate a pre-signed S3 URL for downloading a file from the RAG ingested files. + Args: event (dict): The Lambda event object containing: + context: Lambda context path_params: repositoryId - the repository documentId - the document @@ -1276,7 +1275,6 @@ def list_docs(event: dict, context: dict) -> dict[str, Any]: Raises: KeyError: If collectionId is not provided in queryStringParameters """ - path_params = event.get("pathParameters", {}) or {} repository_id = path_params.get("repositoryId") @@ -1379,25 +1377,25 @@ def list_jobs(event: dict[str, Any], context: dict) -> dict[str, Any]: @api_wrapper @admin_only def create(event: dict, context: dict) -> Any: - """ - Create a new process execution using AWS Step Functions. This function is only accessible by administrators. + """Create a new process execution using AWS Step Functions. - For Bedrock Knowledge Base repositories, automatically adds a default pipeline configuration - if none is provided, using the datasource S3 bucket for event-driven ingestion. + This function is only accessible by administrators. + For Bedrock Knowledge Base repositories, automatically adds a default pipeline configuration + if none is provided, using the datasource S3 bucket for event-driven ingestion. Args: - event (dict): The Lambda event object containing: - - body: A JSON string with the process creation details containing VectorStoreConfig. - context (dict): The Lambda context object. + event (dict): The Lambda event object containing: + - body: A JSON string with the process creation details containing VectorStoreConfig. + context (dict): The Lambda context object. Returns: - Dict[str, str]: A dictionary containing: - - status: Success status message. - - executionArn: The ARN of the step function execution. + Dict[str, str]: A dictionary containing: + - status: Success status message. + - executionArn: The ARN of the step function execution. Raises: - ValueError: If the user is not an administrator. - ValidationError: If the request body is invalid. + ValueError: If the user is not an administrator. + ValidationError: If the request body is invalid. """ # Fetch the Step Function ARN from SSM Parameter Store parameter_name = os.environ["LISA_RAG_CREATE_STATE_MACHINE_ARN_PARAMETER"] @@ -1458,8 +1456,7 @@ def create(event: dict, context: dict) -> Any: @api_wrapper def get_repository_by_id(event: dict, context: dict) -> dict[str, Any]: - """ - Get a vector store configuration by ID. + """Get a vector store configuration by ID. Args: event (dict): The Lambda event object containing: @@ -1495,8 +1492,7 @@ def _get_pipeline_key(pipeline: dict) -> str: def _validate_immutable_pipeline_fields(current_pipelines: list, new_pipelines: list) -> None: - """ - Validate that immutable pipeline fields haven't changed for existing pipelines. + """Validate that immutable pipeline fields haven't changed for existing pipelines. Immutable fields: autoRemove, collectionId, s3Bucket, s3Prefix, trigger @@ -1552,27 +1548,27 @@ def _validate_immutable_pipeline_fields(current_pipelines: list, new_pipelines: @api_wrapper @rag_admin_or_admin def update_repository(event: dict, context: dict) -> dict[str, Any]: - """ - Update a vector store configuration. Accessible by administrators and RAG admins (with scoped access). + """Update a vector store configuration. - Admins can update all fields. RAG admins with group access can only update pipeline-related fields. - RAG admins cannot change allowedGroups or other repository-level settings. + Accessible by administrators and RAG admins (with scoped access). + Admins can update all fields. RAG admins with group access can only update pipeline-related fields. + RAG admins cannot change allowedGroups or other repository-level settings. - If the pipeline configuration has changed, this will trigger an infrastructure deployment - using the state machine, similar to repository creation. + If the pipeline configuration has changed, this will trigger an infrastructure deployment + using the state machine, similar to repository creation. Args: - event (dict): The Lambda event object containing: - - pathParameters.repositoryId: The repository ID to update - - body: JSON with fields to update (UpdateVectorStoreRequest) - context (dict): The Lambda context object + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The repository ID to update + - body: JSON with fields to update (UpdateVectorStoreRequest) + context (dict): The Lambda context object Returns: - Dict[str, Any]: The updated repository configuration with executionArn if deployment triggered + Dict[str, Any]: The updated repository configuration with executionArn if deployment triggered Raises: - ValidationError: If validation fails - HTTPException: If repository not found + ValidationError: If validation fails + HTTPException: If repository not found """ # Extract path parameters path_params = event.get("pathParameters", {}) @@ -1678,7 +1674,7 @@ def update_repository(event: dict, context: dict) -> dict[str, Any]: # If metadata provided but missing tags, preserve existing tags elif "tags" not in current_meta and "tags" in existing_meta: pipeline["metadata"]["tags"] = existing_meta["tags"] - logger.info(f"Preserved tags for collection {collection_id}: " f"{existing_meta['tags']}") + logger.info(f"Preserved tags for collection {collection_id}: {existing_meta['tags']}") # Check if pipeline configuration has changed # Use the converted pipelines from updates if available, otherwise use request.pipelines @@ -1764,9 +1760,9 @@ def update_repository(event: dict, context: dict) -> dict[str, Any]: @api_wrapper @admin_only def delete(event: dict, context: dict) -> Any: - """ - Delete a vector store process using AWS Step Functions. This function ensures - that the user is an administrator or owns the vector store being deleted. + """Delete a vector store process using AWS Step Functions. + + This function ensures that the user is an administrator or owns the vector store being deleted. Also deletes all associated collections and their documents. Args: @@ -1855,8 +1851,7 @@ def _remove_legacy(repository_id: str) -> None: @api_wrapper def list_bedrock_knowledge_bases(event: dict, context: dict) -> dict[str, Any]: - """ - List all ACTIVE Bedrock Knowledge Bases in the AWS account. + """List all ACTIVE Bedrock Knowledge Bases in the AWS account. Marks KBs as unavailable if they're already associated with a repository. @@ -1903,8 +1898,7 @@ def list_bedrock_knowledge_bases(event: dict, context: dict) -> dict[str, Any]: kb_list.append(kb_dict) logger.info( - f"Found {len(active_kbs)} ACTIVE Knowledge Bases out of {len(all_kbs)} total, " - f"{len(used_kb_ids)} already in use" + f"Found {len(active_kbs)} ACTIVE Knowledge Bases out of {len(all_kbs)} total, {len(used_kb_ids)} already in use" ) return {"knowledgeBases": kb_list, "totalKnowledgeBases": len(kb_list)} @@ -1912,8 +1906,7 @@ def list_bedrock_knowledge_bases(event: dict, context: dict) -> dict[str, Any]: @api_wrapper def list_bedrock_data_sources(event: dict, context: dict) -> dict[str, Any]: - """ - List data sources for a specific Bedrock Knowledge Base. + """List data sources for a specific Bedrock Knowledge Base. Args: event: Lambda event containing: diff --git a/lambda/repository/metadata_generator.py b/lambda/repository/metadata_generator.py index fe3f149f2..855cc066d 100644 --- a/lambda/repository/metadata_generator.py +++ b/lambda/repository/metadata_generator.py @@ -63,8 +63,7 @@ def merge_metadata( document_metadata: dict[str, Any] | None = None, for_bedrock_kb: bool = False, ) -> dict[str, Any]: - """ - Merge metadata from repository, collection, and document sources. + """Merge metadata from repository, collection, and document sources. This is the core metadata merging logic used by both ingestion jobs and Bedrock KB. Follows the hierarchy: repository → collection → document (document has highest precedence). diff --git a/lambda/repository/pipeline_delete_documents.py b/lambda/repository/pipeline_delete_documents.py index ff4ec3b32..2fdcf5a7a 100644 --- a/lambda/repository/pipeline_delete_documents.py +++ b/lambda/repository/pipeline_delete_documents.py @@ -84,8 +84,7 @@ def drop_pgvector_collection(repository_id: str, collection_id: str) -> None: def pipeline_delete_collection(job: IngestionJob) -> None: - """ - Delete all documents in a collection. + """Delete all documents in a collection. Steps: 1. Drop vector store index for collection (if supported) @@ -135,8 +134,7 @@ def pipeline_delete_collection(job: IngestionJob) -> None: user_managed = [doc for doc in documents if doc.get("ingestion_type") == "existing"] logger.info( - f"Collection {job.collection_id}: " - f"lisa_managed={len(lisa_managed)}, user_managed={len(user_managed)}" + f"Collection {job.collection_id}: lisa_managed={len(lisa_managed)}, user_managed={len(user_managed)}" ) # Extract S3 paths for LISA-managed documents only @@ -194,8 +192,7 @@ def pipeline_delete_collection(job: IngestionJob) -> None: def pipeline_delete(job: IngestionJob) -> None: - """ - Route deletion job to appropriate handler based on job type. + """Route deletion job to appropriate handler based on job type. Args: job: Ingestion job with deletion details @@ -214,8 +211,7 @@ def pipeline_delete(job: IngestionJob) -> None: def pipeline_delete_document(job: IngestionJob) -> None: - """ - Delete a single document. + """Delete a single document. Args: job: Ingestion job with document deletion details @@ -257,8 +253,7 @@ def pipeline_delete_document(job: IngestionJob) -> None: def pipeline_delete_documents(job: IngestionJob) -> None: - """ - Delete multiple documents in batch (up to 100 at a time). + """Delete multiple documents in batch (up to 100 at a time). Processes documents from document_ids field containing list of document IDs. diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 10b9a51d3..42716d43e 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -56,8 +56,7 @@ def pipeline_ingest(job: IngestionJob) -> None: - """ - Ingest a single document or batch of documents. + """Ingest a single document or batch of documents. Routes to appropriate handler based on job type. """ @@ -77,11 +76,13 @@ def pipeline_ingest_document(job: IngestionJob) -> None: repository = vs_repo.find_repository_by_id(job.repository_id) if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): # Bedrock KB path: Copy document to KB bucket and track + if not job.collection_id: + raise ValueError("collection_id is required for Bedrock KB ingestion") # Get KB bucket for this collection (supports multiple config formats) try: kb_bucket = get_datasource_bucket_for_collection( repository=repository, - collection_id=job.collection_id, # type: ignore[arg-type] + collection_id=job.collection_id, ) except ValueError as e: error_msg = str(e) @@ -101,13 +102,16 @@ def pipeline_ingest_document(job: IngestionJob) -> None: if needs_copy: # Document uploaded to LISA bucket, needs to be copied to KB bucket logger.info( - f"Document {job.s3_path} uploaded to LISA bucket. " f"Copying to KB data source bucket {kb_bucket}" + f"Document {job.s3_path} uploaded to LISA bucket. Copying to KB data source bucket {kb_bucket}" ) # Check if document already exists (idempotent operation) existing_docs = list( rag_document_repository.find_by_source( - job.repository_id, job.collection_id, kb_s3_path, join_docs=False # type: ignore[arg-type] + job.repository_id, + job.collection_id, + kb_s3_path, + join_docs=False, ) ) @@ -159,7 +163,7 @@ def pipeline_ingest_document(job: IngestionJob) -> None: collection = None try: collection = collection_service.get_collection( - collection_id=job.collection_id, # type: ignore[arg-type] + collection_id=job.collection_id, repository_id=job.repository_id, username="system", user_groups=[], @@ -260,11 +264,10 @@ def pipeline_ingest_document(job: IngestionJob) -> None: def pipeline_ingest_documents(job: IngestionJob) -> None: - """ - Ingest multiple documents in batch (up to 100 at a time). + """Ingest multiple documents in batch (up to 100 at a time). - Processes documents from s3_paths field containing list of S3 paths. - If s3_paths is empty, triggers S3 bucket scan to discover existing documents. + Processes documents from s3_paths field containing list of S3 paths. If s3_paths is empty, triggers S3 bucket scan + to discover existing documents. """ try: logger.info(f"Starting batch ingestion for job {job.id}") @@ -341,8 +344,7 @@ def pipeline_ingest_documents(job: IngestionJob) -> None: def _handle_s3_discovery_scan(job: IngestionJob) -> None: - """ - Handle S3 bucket scanning for existing documents. + """Handle S3 bucket scanning for existing documents. Delegates to S3DocumentDiscoveryService for the actual work. diff --git a/lambda/repository/rag_document_repo.py b/lambda/repository/rag_document_repo.py index f5b418d45..cdee3f273 100644 --- a/lambda/repository/rag_document_repo.py +++ b/lambda/repository/rag_document_repo.py @@ -28,7 +28,7 @@ class RagDocumentRepository: - """RAG Document repository for DynamoDB""" + """RAG Document repository for DynamoDB.""" def __init__(self, document_table_name: str, sub_document_table_name: str): dynamodb = boto3.resource("dynamodb") @@ -137,6 +137,7 @@ def find_by_name( document_name (str): The name of the documents to retrieve repository_id (str): The repository id to list documents for collection_id (str): The collection id to list documents for + join_docs: Join document entries together if record is chunked Returns: list[RagDocument]: A list of document objects matching the specified name @@ -173,6 +174,8 @@ def find_by_source( Args: document_source (str): The name of the documents to retrieve repository_id (str): The repository id to list documents for + collection_id (str): The collection id to list documents for + join_docs: Join document entries together if record is chunked Returns: list[RagDocument]: A list of document objects matching the specified name @@ -243,10 +246,11 @@ def list_all( Args: repository_id: Repository ID - collection_id?: Collection ID + collection_id: Collection ID last_evaluated_key: last key for pagination limit: maximum returned items join_docs: whether to include subdoc ids with parent doc + Returns: List of documents """ @@ -288,9 +292,11 @@ def list_all( def count_documents(self, repository_id: str, collection_id: str | None = None) -> int: """Count total documents in a repository/collection. + Args: repository_id: Repository ID - collection_id?: Collection ID + collection_id: Collection ID + Returns: Total number of documents """ @@ -343,7 +349,7 @@ def _get_subdoc_ids(self, entries: list[RagSubDocument]) -> list[str]: """Map subdocuments from a document object. Args: - document: The document object containing subdocuments + entries: The document object containing subdocuments Returns: List of subdocument dictionaries diff --git a/lambda/repository/services/bedrock_kb_repository_service.py b/lambda/repository/services/bedrock_kb_repository_service.py index cea08b2b8..f17e115a2 100644 --- a/lambda/repository/services/bedrock_kb_repository_service.py +++ b/lambda/repository/services/bedrock_kb_repository_service.py @@ -43,8 +43,8 @@ class BedrockKBRepositoryService(RepositoryService): """Service for Bedrock Knowledge Base repository operations. - Bedrock KB manages its own ingestion, chunking, and embedding pipeline. - LISA only tracks documents and delegates actual operations to Bedrock. + Bedrock KB manages its own ingestion, chunking, and embedding pipeline. LISA only tracks documents and delegates + actual operations to Bedrock. """ def supports_custom_collections(self) -> bool: @@ -58,8 +58,8 @@ def should_create_default_collection(self) -> bool: def get_collection_id_from_config(self, pipeline_config: dict[str, Any]) -> str: """For Bedrock KB, collection ID is the data source ID. - Extracts the data source ID from the pipeline config's collectionId field, - which should match one of the data sources in bedrockKnowledgeBaseConfig. + Extracts the data source ID from the pipeline config's collectionId field, which should match one of the data + sources in bedrockKnowledgeBaseConfig. """ # The pipeline config should have a collectionId that matches a data source ID collection_id: str | None = pipeline_config.get("collectionId") @@ -143,7 +143,7 @@ def ingest_document( ) rag_document_repository.save(rag_document) - logger.info(f"Tracked document {kb_s3_path} for Bedrock KB. " f"KB will handle ingestion automatically.") + logger.info(f"Tracked document {kb_s3_path} for Bedrock KB. KB will handle ingestion automatically.") return rag_document def delete_document( @@ -180,8 +180,8 @@ def delete_collection( ) -> None: """Delete all LISA-managed documents from Bedrock KB collection. - Only deletes documents with ingestion_type MANUAL or AUTO. - Preserves user-managed documents (ingestion_type EXISTING). + Only deletes documents with ingestion_type MANUAL or AUTO. Preserves user-managed documents (ingestion_type + EXISTING). """ if not bedrock_agent_client: raise ValueError("Bedrock agent client required for KB operations") @@ -210,9 +210,7 @@ def delete_collection( ] user_managed = [doc for doc in documents if doc.get("ingestion_type") == IngestionType.EXISTING] - logger.info( - f"Collection {collection_id}: " f"lisa_managed={len(lisa_managed)}, user_managed={len(user_managed)}" - ) + logger.info(f"Collection {collection_id}: lisa_managed={len(lisa_managed)}, user_managed={len(user_managed)}") # Extract S3 paths for LISA-managed documents s3_paths = [doc.get("source", "") for doc in lisa_managed if doc.get("source")] @@ -305,8 +303,7 @@ def retrieve_documents( if error_code == "ValidationException" and "auto-paused" in error_message.lower(): logger.warning(f"Aurora DB is resuming from auto-pause for KB {kb_id}") raise ServiceUnavailableException( - "The knowledge base database is currently starting up. " - "Please retry your request in a few moments." + "The knowledge base database is currently starting up. Please retry your request in a few moments." ) logger.error(f"Bedrock retrieve failed for KB {kb_id}: {error_message}") @@ -481,9 +478,7 @@ def _validate_and_normalize_path(self, s3_path: str, expected_bucket: str) -> st source_bucket = s3_path.split("/")[2] if s3_path.startswith("s3://") else None if source_bucket != expected_bucket: - logger.warning( - f"Document {s3_path} not from KB bucket {expected_bucket}. " f"Normalizing to KB bucket path." - ) + logger.warning(f"Document {s3_path} not from KB bucket {expected_bucket}. Normalizing to KB bucket path.") # Normalize to KB bucket path return f"s3://{expected_bucket}/{os.path.basename(s3_path)}" diff --git a/lambda/repository/services/opensearch_repository_service.py b/lambda/repository/services/opensearch_repository_service.py index 4e9dabc1d..a5d89756b 100644 --- a/lambda/repository/services/opensearch_repository_service.py +++ b/lambda/repository/services/opensearch_repository_service.py @@ -39,8 +39,8 @@ class OpenSearchRepositoryService(VectorStoreRepositoryService): """Service for OpenSearch repository operations. - Inherits common vector store behavior from VectorStoreRepositoryService. - Only implements OpenSearch-specific index management. + Inherits common vector store behavior from VectorStoreRepositoryService. Only implements OpenSearch-specific index + management. """ def retrieve_documents( @@ -105,9 +105,7 @@ def retrieve_documents( if include_score and results: max_score = max(self._normalize_similarity_score(score) for _, score in results) if max_score < 0.3: - logger.warning( - f"All similarity scores < 0.3 for query '{query}' - " "possible embedding model mismatch" - ) + logger.warning(f"All similarity scores < 0.3 for query '{query}' - possible embedding model mismatch") return documents diff --git a/lambda/repository/services/pgvector_repository_service.py b/lambda/repository/services/pgvector_repository_service.py index aa65d46a1..64ec25f85 100644 --- a/lambda/repository/services/pgvector_repository_service.py +++ b/lambda/repository/services/pgvector_repository_service.py @@ -37,8 +37,8 @@ class PGVectorRepositoryService(VectorStoreRepositoryService): """Service for PGVector repository operations. - Inherits common vector store behavior from VectorStoreRepositoryService. - Only implements PGVector-specific collection management and score normalization. + Inherits common vector store behavior from VectorStoreRepositoryService. Only implements PGVector-specific + collection management and score normalization. """ def _drop_collection_index(self, collection_id: str) -> None: diff --git a/lambda/repository/services/repository_service.py b/lambda/repository/services/repository_service.py index 36bc21eca..bde3555f8 100644 --- a/lambda/repository/services/repository_service.py +++ b/lambda/repository/services/repository_service.py @@ -23,8 +23,8 @@ class RepositoryService(ABC): """Abstract base class defining repository-specific operations. - Each repository type (OpenSearch, PGVector, Bedrock KB) implements this - interface to provide type-specific behavior for document management. + Each repository type (OpenSearch, PGVector, Bedrock KB) implements this interface to provide type-specific behavior + for document management. """ def __init__(self, repository: dict[str, Any]): diff --git a/lambda/repository/services/repository_service_factory.py b/lambda/repository/services/repository_service_factory.py index a1f0e1828..86aa0c6c5 100644 --- a/lambda/repository/services/repository_service_factory.py +++ b/lambda/repository/services/repository_service_factory.py @@ -27,8 +27,7 @@ class RepositoryServiceFactory: """Factory for creating repository-specific service instances. - Encapsulates repository-specific behavior, eliminating the need for - conditional logic throughout the codebase. + Encapsulates repository-specific behavior, eliminating the need for conditional logic throughout the codebase. """ # Registry mapping repository types to service classes @@ -55,9 +54,7 @@ def create_service(cls, repository: dict[str, Any]) -> RepositoryService: service_class = cls._services.get(repo_type) if not service_class: - raise ValueError( - f"Unsupported repository type: {repo_type}. " f"Supported types: {list(cls._services.keys())}" - ) + raise ValueError(f"Unsupported repository type: {repo_type}. Supported types: {list(cls._services.keys())}") return service_class(repository) diff --git a/lambda/repository/services/vector_store_repository_service.py b/lambda/repository/services/vector_store_repository_service.py index 1d2998513..29c8b8100 100644 --- a/lambda/repository/services/vector_store_repository_service.py +++ b/lambda/repository/services/vector_store_repository_service.py @@ -14,8 +14,8 @@ """Base implementation for vector store-based repository services (OpenSearch, PGVector). -This class provides common functionality for repositories that use traditional -vector stores with chunking and embedding pipelines. +This class provides common functionality for repositories that use traditional vector stores with chunking and embedding +pipelines. """ import logging @@ -48,11 +48,11 @@ class VectorStoreRepositoryService(RepositoryService): """Base implementation for vector store-based repository services. - Provides common functionality for OpenSearch and PGVector repositories - that share similar ingestion, deletion, and retrieval patterns. + Provides common functionality for OpenSearch and PGVector repositories that share similar ingestion, deletion, and + retrieval patterns. - Subclasses only need to implement repository-specific operations like - index/collection dropping and score normalization. + Subclasses only need to implement repository-specific operations like index/collection dropping and score + normalization. """ def supports_custom_collections(self) -> bool: @@ -138,8 +138,7 @@ def delete_collection( ) -> None: """Delete collection from vector store. - Delegates to subclass-specific implementation for dropping - indexes/collections. + Delegates to subclass-specific implementation for dropping indexes/collections. """ self._drop_collection_index(collection_id) @@ -197,9 +196,7 @@ def retrieve_documents( if include_score and results: max_score = max(self._normalize_similarity_score(score) for _, score in results) if max_score < 0.3: - logger.warning( - f"All similarity scores < 0.3 for query '{query}' - " "possible embedding model mismatch" - ) + logger.warning(f"All similarity scores < 0.3 for query '{query}' - possible embedding model mismatch") return documents diff --git a/lambda/repository/state_machine/cleanup_repo_docs.py b/lambda/repository/state_machine/cleanup_repo_docs.py index 4edcbbebb..6239615e1 100644 --- a/lambda/repository/state_machine/cleanup_repo_docs.py +++ b/lambda/repository/state_machine/cleanup_repo_docs.py @@ -25,8 +25,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any] | Any: - """ - Remove LISA-managed documents from repository. + """Remove LISA-managed documents from repository. Only deletes documents with ingestion_type of MANUAL or AUTO. Preserves EXISTING documents (user-managed). diff --git a/lambda/repository/state_machine/list_modified_objects.py b/lambda/repository/state_machine/list_modified_objects.py index 87748c5f5..45c0a4216 100644 --- a/lambda/repository/state_machine/list_modified_objects.py +++ b/lambda/repository/state_machine/list_modified_objects.py @@ -27,8 +27,7 @@ def normalize_prefix(prefix: str) -> str: - """ - Normalize the S3 prefix by handling trailing slashes. + """Normalize the S3 prefix by handling trailing slashes. Args: prefix: S3 prefix to normalize @@ -50,8 +49,7 @@ def normalize_prefix(prefix: str) -> str: def validate_bucket_prefix(bucket: str, prefix: str) -> bool: - """ - Validate bucket and prefix parameters. + """Validate bucket and prefix parameters. Args: bucket: S3 bucket name @@ -77,8 +75,7 @@ def validate_bucket_prefix(bucket: str, prefix: str) -> bool: def handle_list_modified_objects(event: dict[str, Any], context: Any) -> dict[str, Any] | Any: - """ - Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. + """Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. Args: event: Event data containing bucket and prefix information diff --git a/lambda/repository/state_machine/wait_for_collection_deletions.py b/lambda/repository/state_machine/wait_for_collection_deletions.py index 0c4e96f51..950222f7d 100644 --- a/lambda/repository/state_machine/wait_for_collection_deletions.py +++ b/lambda/repository/state_machine/wait_for_collection_deletions.py @@ -23,8 +23,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """ - Check if all collection deletion jobs for a repository are complete. + """Check if all collection deletion jobs for a repository are complete. Args: event: Event data containing repositoryId and stackName @@ -47,7 +46,7 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: all_complete = pending_count == 0 logger.info( - f"Repository {repository_id}: " f"pending_collection_deletions={pending_count}, " f"all_complete={all_complete}" + f"Repository {repository_id}: pending_collection_deletions={pending_count}, all_complete={all_complete}" ) return { diff --git a/lambda/repository/vector_store_repo.py b/lambda/repository/vector_store_repo.py index 37366642b..2acb5c637 100644 --- a/lambda/repository/vector_store_repo.py +++ b/lambda/repository/vector_store_repo.py @@ -26,7 +26,7 @@ class VectorStoreRepository: - """Vector Store repository for DynamoDB""" + """Vector Store repository for DynamoDB.""" def __init__(self, table_name: str | None = None) -> None: dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -61,7 +61,7 @@ def get_registered_repositories(self) -> list[dict]: return registered_repositories def get_repository_status(self) -> dict[str, str]: - """Get a list the status of all repositories""" + """Get a list the status of all repositories.""" status: dict[str, str] = {} response = self.table.scan( @@ -80,8 +80,7 @@ def get_repository_status(self) -> dict[str, str]: return status def find_repository_by_id(self, repository_id: str, raw_config: bool = False) -> dict[str, Any]: - """ - Find a repository by its ID. + """Find a repository by its ID. Args: repository_id: The ID of the repository to find. @@ -119,8 +118,7 @@ def find_repository_by_id(self, repository_id: str, raw_config: bool = False) -> return config def update(self, repository_id: str, updates: dict[str, Any], status: str | None = None) -> dict[str, Any]: - """ - Update a repository configuration. + """Update a repository configuration. Args: repository_id: The ID of the repository to update. @@ -163,8 +161,7 @@ def update(self, repository_id: str, updates: dict[str, Any], status: str | None raise ValueError(f"Failed to update repository: {repository_id}", e) def delete(self, repository_id: str) -> bool: - """ - Delete a repository by its ID. + """Delete a repository by its ID. Args: repository_id: The ID of the repository to delete. @@ -182,8 +179,8 @@ def delete(self, repository_id: str) -> bool: raise ValueError(f"Failed to delete repository: {repository_id}", e) def find_repositories_using_model(self, model_id: str) -> list[dict]: - """ - Find all repositories that use a specific model. + """Find all repositories that use a specific model. + Excludes repositories with status indicating they are deleted or archived. Args: diff --git a/lambda/session/lambda_functions.py b/lambda/session/lambda_functions.py index c2baa0ee6..d81ac005f 100644 --- a/lambda/session/lambda_functions.py +++ b/lambda/session/lambda_functions.py @@ -74,13 +74,12 @@ def _is_session_encryption_enabled() -> bool: """Check if session encryption is enabled via global configuration. - Returns + Returns: ------- bool True if session encryption is enabled, False otherwise. Defaults to False if configuration is not found or accessible. """ - try: logger.debug("Querying global configuration for session encryption setting") # Query the global configuration entry @@ -119,7 +118,7 @@ def _get_current_model_config(model_id: str) -> Any: model_id : str The model ID to fetch configuration for. - Returns + Returns: ------- Dict[str, Any] The current model configuration, or empty dict if not found. @@ -146,7 +145,7 @@ def _update_session_with_current_model_config( session_config : SessionConfigurationModel The session configuration containing model information. - Returns + Returns: ------- SessionConfigurationModel Updated configuration with current model settings. diff --git a/lambda/session/repository.py b/lambda/session/repository.py index f4a42eae6..c03e20ceb 100644 --- a/lambda/session/repository.py +++ b/lambda/session/repository.py @@ -14,9 +14,8 @@ """Session data-access helpers shared across Lambda packages. -No module-level AWS resource instantiation — all clients/resources are -passed in by the caller so this module is safe to import from any Lambda -regardless of which environment variables are present. +No module-level AWS resource instantiation — all clients/resources are passed in by the caller so this module is safe to +import from any Lambda regardless of which environment variables are present. """ import logging diff --git a/lambda/user_preferences/models.py b/lambda/user_preferences/models.py index a45f4f458..24390ffde 100644 --- a/lambda/user_preferences/models.py +++ b/lambda/user_preferences/models.py @@ -16,8 +16,8 @@ class UserPreferencesModel(BaseModel): - """ - A Pydantic model representing a template for prompts. + """A Pydantic model representing a template for prompts. + Contains metadata and functionality to create new revisions. """ diff --git a/lambda/utilities/audit_logging_utils.py b/lambda/utilities/audit_logging_utils.py index 9046cb730..e23a9aa52 100644 --- a/lambda/utilities/audit_logging_utils.py +++ b/lambda/utilities/audit_logging_utils.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Shared helpers for API Gateway audit logging. +"""Shared helpers for API Gateway audit logging. Strict opt-in behavior: - When disabled, callers must not emit audit logs. @@ -72,20 +71,18 @@ def audit_all() -> bool: def audit_include_json_body() -> bool: - """ - When false (default), callers must not emit AUDIT_API_GATEWAY_REQUEST_BODY. + """When false (default), callers must not emit AUDIT_API_GATEWAY_REQUEST_BODY. + CDK sets this only when audit logging is enabled and includeJsonBody is true. """ return _env_bool("LISA_AUDIT_INCLUDE_JSON_BODY") def log_audit_event(logger: logging.Logger, event_type: str, payload: dict[str, Any]) -> None: - """ - Emit audit data so it appears in CloudWatch log streams. + """Emit audit data so it appears in CloudWatch log streams. - Lambda's configured formatter (see ``setup_root_logging``) only prints the log - ``message`` string, not ``logging`` ``extra`` fields. This helper appends a - compact JSON object after the event type for search and Logs Insights (e.g. + Lambda's configured formatter (see ``setup_root_logging``) only prints the log ``message`` string, not ``logging`` + ``extra`` fields. This helper appends a compact JSON object after the event type for search and Logs Insights (e.g. parse the substring after the first space as JSON). """ record: dict[str, Any] = {"event_type": event_type, **payload} @@ -154,8 +151,7 @@ def _path_starts_with_prefix(path: str, prefix: str) -> bool: def get_matched_audit_prefix(path: str) -> str | None: - """ - Return the matched prefix (e.g. "/session") when auditing should apply. + """Return the matched prefix (e.g. "/session") when auditing should apply. Returns: - "ALL" when auditAll is enabled @@ -184,8 +180,7 @@ def should_audit_path(path: str) -> bool: def get_method_and_path_from_method_arn(method_arn: str) -> tuple[str, str]: - """ - Parse execute-api methodArn into (http_method, request_path). + """Parse execute-api methodArn into (http_method, request_path). Example: arn:aws:execute-api:us-east-1:123:abc123/prod/GET/repository/foo @@ -206,8 +201,7 @@ def get_method_and_path_from_method_arn(method_arn: str) -> tuple[str, str]: def sanitize_json_for_audit(value: Any) -> Any: - """ - Recursively redact sensitive keys from JSON values. + """Recursively redact sensitive keys from JSON values. This is intentionally permissive: it redacts by key name anywhere in the structure. """ @@ -226,8 +220,7 @@ def sanitize_json_for_audit(value: Any) -> Any: def sanitize_json_body_for_audit(body: Any) -> str: - """ - Convert body into a sanitized JSON string suitable for audit logging. + """Convert body into a sanitized JSON string suitable for audit logging. Returns placeholder strings for non-JSON or oversized bodies. """ diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py index a6e32182d..e0647b005 100644 --- a/lambda/utilities/auth.py +++ b/lambda/utilities/auth.py @@ -77,8 +77,8 @@ def get_user_context(event: dict[str, Any]) -> tuple[str, bool, list[str]]: def get_authorizer(event: Any) -> dict[str, Any]: """Return the API Gateway Lambda authorizer context dict. - This is a small shared helper so other parts of the codebase don't need to - re-implement the same defensive extraction logic. + This is a small shared helper so other parts of the codebase don't need to re-implement the same defensive + extraction logic. """ if not isinstance(event, dict): return {} @@ -86,8 +86,7 @@ def get_authorizer(event: Any) -> dict[str, Any]: def user_has_group_access(user_groups: list[str], allowed_groups: list[str]) -> bool: - """ - Check if user has access based on group membership. + """Check if user has access based on group membership. Args: user_groups: List of groups the user belongs to @@ -137,12 +136,12 @@ def get_management_key() -> str: # API token utility functions def generate_token() -> str: - """Generate cryptographically secure random token (64 bytes = 128 hex chars)""" + """Generate cryptographically secure random token (64 bytes = 128 hex chars).""" return secrets.token_hex(64) def hash_token(token: str) -> str: - """Create SHA-256 hash of token""" + """Create SHA-256 hash of token.""" return hashlib.sha256(token.encode()).hexdigest() diff --git a/lambda/utilities/auth_provider.py b/lambda/utilities/auth_provider.py index d48968378..b4888c82c 100644 --- a/lambda/utilities/auth_provider.py +++ b/lambda/utilities/auth_provider.py @@ -24,8 +24,8 @@ class AuthorizationProvider(ABC): """Abstract base class for authorization providers. - This abstraction allows swapping between different authorization backends - (e.g., OIDC group-based, BRASS bindle lock) without changing the consuming code. + This abstraction allows swapping between different authorization backends (e.g., OIDC group-based, BRASS bindle + lock) without changing the consuming code. """ @abstractmethod @@ -39,7 +39,7 @@ def check_admin_access(self, username: str, groups: list[str] | None = None) -> groups : list[str] | None Optional list of groups the user belongs to (used by group-based providers) - Returns + Returns: ------- bool True if user has admin access, False otherwise @@ -57,7 +57,7 @@ def check_rag_admin_access(self, username: str, groups: list[str] | None = None) groups : list[str] | None Optional list of groups the user belongs to (used by group-based providers) - Returns + Returns: ------- bool True if user has RAG admin access, False otherwise @@ -75,7 +75,7 @@ def check_app_access(self, username: str, groups: list[str] | None = None) -> bo groups : list[str] | None Optional list of groups the user belongs to (used by group-based providers) - Returns + Returns: ------- bool True if user has app access, False otherwise @@ -132,7 +132,7 @@ def check_admin_access(self, username: str, groups: list[str] | None = None) -> groups : list[str] | None List of groups the user belongs to - Returns + Returns: ------- bool True if user is in admin group, False otherwise @@ -168,7 +168,7 @@ def check_app_access(self, username: str, groups: list[str] | None = None) -> bo groups : list[str] | None List of groups the user belongs to - Returns + Returns: ------- bool True if user is in user group (or no user group configured), False otherwise @@ -195,7 +195,7 @@ def check_app_access(self, username: str, groups: list[str] | None = None) -> bo def get_authorization_provider() -> AuthorizationProvider: """Get the configured authorization provider instance. - Returns + Returns: ------- AuthorizationProvider The authorization provider instance (OIDC-based for LISA) diff --git a/lambda/utilities/aws_helpers.py b/lambda/utilities/aws_helpers.py index 1ae06e207..253dd58cc 100644 --- a/lambda/utilities/aws_helpers.py +++ b/lambda/utilities/aws_helpers.py @@ -42,8 +42,7 @@ @cache def get_cert_path(iam_client: Any) -> str | bool: - """ - Get certificate path for SSL validation against LISA Serve endpoint. + """Get certificate path for SSL validation against LISA Serve endpoint. This function retrieves IAM server certificates for SSL verification. For ACM certificates or when no certificate is specified, it returns @@ -54,12 +53,12 @@ def get_cert_path(iam_client: Any) -> str | bool: iam_client : Any Boto3 IAM client instance. - Returns + Returns: ------- Union[str, bool] Path to certificate file, or True to use default verification. - Example + Example: ------- >>> iam = boto3.client("iam") >>> cert_path = get_cert_path(iam) @@ -114,15 +113,14 @@ def get_cert_path(iam_client: Any) -> str | bool: @cache def get_rest_api_container_endpoint() -> str: - """ - Get REST API container base URI from SSM Parameter Store. + """Get REST API container base URI from SSM Parameter Store. - Returns + Returns: ------- str The REST API container endpoint URL. - Example + Example: ------- >>> endpoint = get_rest_api_container_endpoint() >>> endpoint @@ -134,15 +132,14 @@ def get_rest_api_container_endpoint() -> str: def _get_lambda_role_arn() -> str: - """ - Get the ARN of the Lambda execution role. + """Get the ARN of the Lambda execution role. - Returns + Returns: ------- str The full ARN of the Lambda execution role. - Example + Example: ------- >>> _get_lambda_role_arn() 'arn:aws:sts::123456789012:assumed-role/MyLambdaRole/MyFunction' @@ -153,15 +150,14 @@ def _get_lambda_role_arn() -> str: def get_lambda_role_name() -> str: - """ - Extract the role name from the Lambda execution role ARN. + """Extract the role name from the Lambda execution role ARN. - Returns + Returns: ------- str The name of the Lambda execution role without the full ARN. - Example + Example: ------- >>> get_lambda_role_name() 'MyLambdaRole' @@ -172,15 +168,14 @@ def get_lambda_role_name() -> str: def get_account_and_partition() -> tuple[str, str]: - """ - Get AWS account ID and partition from environment or ECR repository ARN. + """Get AWS account ID and partition from environment or ECR repository ARN. - Returns + Returns: ------- tuple[str, str] Tuple of (account_id, partition). - Example + Example: ------- >>> account_id, partition = get_account_and_partition() >>> account_id diff --git a/lambda/utilities/bedrock_agent_discovery.py b/lambda/utilities/bedrock_agent_discovery.py index 7eb4d71e0..950f3086c 100644 --- a/lambda/utilities/bedrock_agent_discovery.py +++ b/lambda/utilities/bedrock_agent_discovery.py @@ -203,8 +203,7 @@ def list_agent_aliases(agent_id: str, bedrock_agent_client: Any) -> list[Bedrock def discover_bedrock_agents(bedrock_agent_client: Any | None = None) -> list[BedrockAgentDiscoveryItem]: - """ - List agents in the account and attach alias hints for invocation. + """List agents in the account and attach alias hints for invocation. Only agents in PREPARED state are returned (consistent with ready-to-invoke agents). """ diff --git a/lambda/utilities/bedrock_kb.py b/lambda/utilities/bedrock_kb.py index 344fef29f..4ac630157 100644 --- a/lambda/utilities/bedrock_kb.py +++ b/lambda/utilities/bedrock_kb.py @@ -96,8 +96,7 @@ def discover_and_ingest_documents( s3_prefix: str = "", ingestion_type: IngestionType = IngestionType.EXISTING, ) -> S3DocumentDiscoveryResult: - """ - Discover and ingest existing documents from S3 bucket. + """Discover and ingest existing documents from S3 bucket. Scans S3 bucket, creates metadata.json files, creates RagDocument entries, and triggers Bedrock KB sync. @@ -195,8 +194,7 @@ def discover_and_ingest_documents( raise def _scan_s3_bucket(self, s3_bucket: str, s3_prefix: str) -> tuple[list[str], int]: - """ - Scan S3 bucket and return list of document keys. + """Scan S3 bucket and return list of document keys. Args: s3_bucket: S3 bucket name @@ -325,8 +323,7 @@ def get_datasource_bucket_for_collection( repository: dict[str, Any], collection_id: str, ) -> str: - """ - Get the S3 bucket for a specific collection/data source. + """Get the S3 bucket for a specific collection/data source. Supports multiple configuration formats: - Legacy: bedrockKnowledgeDatasourceS3Bucket (single bucket) @@ -403,9 +400,9 @@ def ingest_document_to_kb( job: IngestionJob, repository: dict[str, Any], ) -> None: - """ - Copy the source object into the KB datasource bucket and trigger ingestion. S3 will - kick off another IngestionJob to store the document in the collection DB + """Copy the source object into the KB datasource bucket and trigger ingestion. + + S3 will kick off another IngestionJob to store the document in the collection DB """ bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) @@ -558,8 +555,7 @@ def ingest_bedrock_s3_documents( batch_size: int = 100, metadata: dict[str, Any] | None = None, ) -> tuple[int, int]: - """ - Discover and create ingestion jobs for existing documents in S3 bucket. + """Discover and create ingestion jobs for existing documents in S3 bucket. Scans S3 bucket for documents and creates batch ingestion jobs. Skips metadata files and directories. @@ -574,6 +570,7 @@ def ingest_bedrock_s3_documents( embedding_model: Embedding model identifier s3_prefix: Optional S3 prefix to scan within bucket batch_size: Number of documents per batch job (default: 100) + metadata: Optional pre-merged metadata to include in the job Returns: Tuple of (discovered_count, skipped_count) @@ -651,8 +648,7 @@ def create_s3_scan_job( s3_prefix: str = "", metadata: dict[str, Any] | None = None, ) -> str: - """ - Create a batch ingestion job to scan and ingest existing S3 documents. + """Create a batch ingestion job to scan and ingest existing S3 documents. This creates a batch job with empty s3_paths that will be processed by pipeline_ingest_documents. The empty s3_paths signals that the S3 bucket diff --git a/lambda/utilities/bedrock_kb_discovery.py b/lambda/utilities/bedrock_kb_discovery.py index 39b286293..d82011965 100644 --- a/lambda/utilities/bedrock_kb_discovery.py +++ b/lambda/utilities/bedrock_kb_discovery.py @@ -14,9 +14,8 @@ """Discovery service for Bedrock Knowledge Base data sources. -This module provides functionality to discover and list Knowledge Bases and their -data sources from AWS Bedrock Agent APIs. It supports caching and pagination for -efficient resource discovery. +This module provides functionality to discover and list Knowledge Bases and their data sources from AWS Bedrock Agent +APIs. It supports caching and pagination for efficient resource discovery. """ import logging @@ -38,8 +37,7 @@ def list_knowledge_bases( bedrock_agent_client: Any | None = None, ) -> list[KnowledgeBaseMetadata]: - """ - List all Knowledge Bases accessible in the AWS account. + """List all Knowledge Bases accessible in the AWS account. Args: bedrock_agent_client: Optional boto3 bedrock-agent client @@ -80,10 +78,10 @@ def list_knowledge_bases( error_code = e.response.get("Error", {}).get("Code", "") if error_code == "AccessDeniedException": raise ValidationError( - "Access denied to list Knowledge Bases. " "Please check IAM permissions for bedrock:ListKnowledgeBases." + "Access denied to list Knowledge Bases. Please check IAM permissions for bedrock:ListKnowledgeBases." ) elif error_code == "ThrottlingException": - raise ValidationError("Rate limit exceeded while listing Knowledge Bases. " "Please try again later.") + raise ValidationError("Rate limit exceeded while listing Knowledge Bases. Please try again later.") else: raise ValidationError(f"Failed to list Knowledge Bases: {str(e)}") except Exception as e: @@ -95,8 +93,7 @@ def discover_kb_data_sources( kb_id: str, bedrock_agent_client: Any | None = None, ) -> list[DataSourceMetadata]: - """ - Discover all data sources in a Bedrock Knowledge Base. + """Discover all data sources in a Bedrock Knowledge Base. Args: kb_id: Knowledge Base ID @@ -156,7 +153,7 @@ def discover_kb_data_sources( error_code = e.response.get("Error", {}).get("Code", "") if error_code == "ResourceNotFoundException": raise ValidationError( - f"Knowledge Base '{kb_id}' not found. " f"Please verify the KB ID in the AWS Bedrock console." + f"Knowledge Base '{kb_id}' not found. Please verify the KB ID in the AWS Bedrock console." ) elif error_code == "AccessDeniedException": raise ValidationError( @@ -165,7 +162,7 @@ def discover_kb_data_sources( ) elif error_code == "ThrottlingException": raise ValidationError( - f"Rate limit exceeded while discovering data sources for KB '{kb_id}'. " f"Please try again later." + f"Rate limit exceeded while discovering data sources for KB '{kb_id}'. Please try again later." ) else: raise ValidationError(f"Failed to discover data sources: {str(e)}") @@ -211,7 +208,6 @@ def build_pipeline_configs_from_kb_config( Raises: ValidationError: If duplicate data source IDs or S3 URIs found """ - pipeline_configs = [] data_source_ids = set() s3_uris = set() @@ -279,8 +275,7 @@ def get_available_data_sources( repository_id: str | None = None, bedrock_agent_client: Any | None = None, ) -> list[DataSourceMetadata]: - """ - Get all data sources for a Knowledge Base. + """Get all data sources for a Knowledge Base. Args: kb_id: Knowledge Base ID diff --git a/lambda/utilities/bedrock_kb_validation.py b/lambda/utilities/bedrock_kb_validation.py index 7b529bcb5..bac1ab2da 100644 --- a/lambda/utilities/bedrock_kb_validation.py +++ b/lambda/utilities/bedrock_kb_validation.py @@ -25,8 +25,7 @@ def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Any | None = None) -> dict[str, Any]: - """ - Validate that a Bedrock Knowledge Base exists and is accessible. + """Validate that a Bedrock Knowledge Base exists and is accessible. Args: kb_id: Knowledge Base ID to validate @@ -53,12 +52,11 @@ def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Any | None = No if error_code == "ResourceNotFoundException": raise ValidationError( - f"Knowledge Base '{kb_id}' not found. " f"Please verify the KB ID in the AWS Bedrock console." + f"Knowledge Base '{kb_id}' not found. Please verify the KB ID in the AWS Bedrock console." ) elif error_code == "AccessDeniedException": raise ValidationError( - f"Access denied to Knowledge Base '{kb_id}'. " - f"Please check IAM permissions for bedrock:GetKnowledgeBase." + f"Access denied to Knowledge Base '{kb_id}'. Please check IAM permissions for bedrock:GetKnowledgeBase." ) else: raise ValidationError(f"Failed to validate Knowledge Base '{kb_id}': {str(e)}") @@ -69,8 +67,7 @@ def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Any | None = No def validate_data_source_exists( kb_id: str, data_source_id: str, bedrock_agent_client: Any | None = None ) -> dict[str, Any]: - """ - Validate that a data source exists in a Bedrock Knowledge Base. + """Validate that a data source exists in a Bedrock Knowledge Base. Args: kb_id: Knowledge Base ID @@ -90,7 +87,7 @@ def validate_data_source_exists( response = bedrock_agent_client.get_data_source(knowledgeBaseId=kb_id, dataSourceId=data_source_id) data_source_config = response.get("dataSource", {}) - logger.info(f"Validated Data Source {data_source_id} in KB {kb_id}: " f"{data_source_config.get('name')}") + logger.info(f"Validated Data Source {data_source_id} in KB {kb_id}: {data_source_config.get('name')}") return data_source_config # type: ignore[no-any-return] except ClientError as e: @@ -115,8 +112,7 @@ def validate_data_source_exists( def validate_bedrock_kb_repository( kb_id: str, data_source_id: str, bedrock_agent_client: Any | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Validate both Knowledge Base and Data Source exist. + """Validate both Knowledge Base and Data Source exist. Args: kb_id: Knowledge Base ID diff --git a/lambda/utilities/chunking_strategy_factory.py b/lambda/utilities/chunking_strategy_factory.py index 04fe681d7..93e3d3261 100644 --- a/lambda/utilities/chunking_strategy_factory.py +++ b/lambda/utilities/chunking_strategy_factory.py @@ -36,8 +36,7 @@ class ChunkingStrategyHandler(ABC): @abstractmethod def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy) -> list[Document]: - """ - Chunk documents according to the strategy. + """Chunk documents according to the strategy. Parameters ---------- @@ -46,7 +45,7 @@ def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy) -> l strategy : ChunkingStrategy The chunking strategy configuration - Returns + Returns: ------- list[Document] List of chunked documents @@ -58,8 +57,7 @@ class FixedSizeChunkingHandler(ChunkingStrategyHandler): """Handler for fixed-size chunking strategy.""" def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> list[Document]: - """ - Chunk documents using fixed-size strategy with RecursiveCharacterTextSplitter. + """Chunk documents using fixed-size strategy with RecursiveCharacterTextSplitter. Parameters ---------- @@ -68,7 +66,7 @@ def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy = DEF strategy : ChunkingStrategy The chunking strategy configuration (FixedChunkingStrategy) - Returns + Returns: ------- list[Document] List of chunked documents @@ -109,8 +107,7 @@ class NoneChunkingHandler(ChunkingStrategyHandler): """Handler for no-chunking strategy - returns documents as-is.""" def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy) -> list[Document]: - """ - Return documents without chunking. + """Return documents without chunking. Parameters ---------- @@ -119,7 +116,7 @@ def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy) -> l strategy : ChunkingStrategy The chunking strategy configuration (NoneChunkingStrategy) - Returns + Returns: ------- list[Document] Original list of documents unmodified @@ -138,8 +135,7 @@ class ChunkingStrategyFactory: @classmethod def chunk_documents(cls, docs: list[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> list[Document]: - """ - Chunk documents using the appropriate strategy handler. + """Chunk documents using the appropriate strategy handler. Parameters ---------- @@ -148,12 +144,12 @@ def chunk_documents(cls, docs: list[Document], strategy: ChunkingStrategy = DEFA strategy : ChunkingStrategy The chunking strategy configuration - Returns + Returns: ------- list[Document] List of chunked documents - Raises + Raises: ------ ValueError If the chunking strategy type is not supported @@ -172,8 +168,7 @@ def chunk_documents(cls, docs: list[Document], strategy: ChunkingStrategy = DEFA @classmethod def register_handler(cls, strategy_type: ChunkingStrategyType, handler: ChunkingStrategyHandler) -> None: - """ - Register a new chunking strategy handler. + """Register a new chunking strategy handler. This allows for extending the factory with additional chunking strategies. @@ -189,10 +184,9 @@ def register_handler(cls, strategy_type: ChunkingStrategyType, handler: Chunking @classmethod def get_supported_strategies(cls) -> list[ChunkingStrategyType]: - """ - Get list of supported chunking strategy types. + """Get list of supported chunking strategy types. - Returns + Returns: ------- list[ChunkingStrategyType] List of supported strategy types diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index 566b8c970..03d5b544f 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Common helper functions for RAG Lambdas. +"""Common helper functions for RAG Lambdas. DEPRECATED: This module is maintained for backward compatibility. New code should import from the specific utility modules: @@ -61,7 +60,7 @@ def filter(self, record: logging.LogRecord) -> bool: record : logging.LogRecord The log record. - Returns + Returns: ------- bool A boolean. diff --git a/lambda/utilities/constants.py b/lambda/utilities/constants.py index d1172ef62..4a2eed20f 100644 --- a/lambda/utilities/constants.py +++ b/lambda/utilities/constants.py @@ -38,8 +38,8 @@ XML_FILE, LOG_FILE, ] +"""Constants for pagination and time limits.""" -"""Constants for pagination and time limits""" DEFAULT_TIME_LIMIT_HOURS = 720 # 30 days DEFAULT_PAGE_SIZE = 10 MAX_PAGE_SIZE = 100 diff --git a/lambda/utilities/db_setup_iam_auth.py b/lambda/utilities/db_setup_iam_auth.py index 46c30c2ba..e7e2c3c05 100644 --- a/lambda/utilities/db_setup_iam_auth.py +++ b/lambda/utilities/db_setup_iam_auth.py @@ -199,9 +199,8 @@ def create_db_user(db_host: str, db_port: str, db_name: str, db_user: str, secre def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """Lambda handler for IAM database user setup. - Creates an IAM-authenticated PostgreSQL user. The bootstrap secret is kept - for CloudFormation compatibility (not deleted) even though it won't be used - for authentication after IAM auth is enabled. + Creates an IAM-authenticated PostgreSQL user. The bootstrap secret is kept for CloudFormation compatibility (not + deleted) even though it won't be used for authentication after IAM auth is enabled. """ logger.info(f"IAM auth setup Lambda invoked with event: {json.dumps(event)}") diff --git a/lambda/utilities/dict_helpers.py b/lambda/utilities/dict_helpers.py index 3db830ed8..8ea92d009 100644 --- a/lambda/utilities/dict_helpers.py +++ b/lambda/utilities/dict_helpers.py @@ -18,8 +18,7 @@ def merge_fields(source: dict, target: dict, fields: list[str]) -> dict: - """ - Merge specified fields from source dictionary to target dictionary. + """Merge specified fields from source dictionary to target dictionary. Supports both top-level and nested fields using dot notation. @@ -32,12 +31,12 @@ def merge_fields(source: dict, target: dict, fields: list[str]) -> dict: fields : list[str] List of field names, can use dot notation for nested fields. - Returns + Returns: ------- dict Updated target dictionary. - Example + Example: ------- >>> source = {"user": {"name": "John", "age": 30}, "status": "active"} >>> target = {"id": "123"} @@ -82,8 +81,7 @@ def set_nested_value(obj: dict, path: list[str], value: Any) -> None: def get_property_path(data: dict[str, Any], property_path: str) -> Any | None: - """ - Get value from nested dictionary using dot-notation path. + """Get value from nested dictionary using dot-notation path. Parameters ---------- @@ -92,12 +90,12 @@ def get_property_path(data: dict[str, Any], property_path: str) -> Any | None: property_path : str Dot-notation path to the property (e.g., "user.address.city"). - Returns + Returns: ------- Optional[Any] The value at the specified path, or None if path doesn't exist. - Example + Example: ------- >>> data = {"user": {"address": {"city": "Seattle"}}} >>> get_property_path(data, "user.address.city") @@ -117,20 +115,19 @@ def get_property_path(data: dict[str, Any], property_path: str) -> Any | None: def get_item(response: Any) -> Any: - """ - Extract first item from DynamoDB query/scan response. + """Extract first item from DynamoDB query/scan response. Parameters ---------- response : Any DynamoDB query or scan response. - Returns + Returns: ------- Any First item from the response, or None if no items. - Example + Example: ------- >>> response = {"Items": [{"id": "123", "name": "John"}]} >>> get_item(response) diff --git a/lambda/utilities/event_parser.py b/lambda/utilities/event_parser.py index 741ad3b28..183c11a27 100644 --- a/lambda/utilities/event_parser.py +++ b/lambda/utilities/event_parser.py @@ -22,8 +22,7 @@ def sanitize_event_for_logging(event: dict[str, Any]) -> str: - """ - Sanitize Lambda event before logging. + """Sanitize Lambda event before logging. This function sanitizes the event by: 1. Redacting authorization headers @@ -35,17 +34,14 @@ def sanitize_event_for_logging(event: dict[str, Any]) -> str: event : Dict[str, Any] The Lambda event from API Gateway. - Returns + Returns: ------- str The sanitized event as a JSON-formatted string. - Example + Example: ------- - >>> event = { - ... "headers": {"Authorization": "Bearer token123"}, - ... "path": "/users/123" - ... } + >>> event = {"headers": {"Authorization": "Bearer token123"}, "path": "/users/123"} >>> sanitized = sanitize_event_for_logging(event) >>> "token123" in sanitized False @@ -103,20 +99,19 @@ def sanitize_event_for_logging(event: dict[str, Any]) -> str: def get_session_id(event: dict) -> str: - """ - Extract session ID from Lambda event path parameters. + """Extract session ID from Lambda event path parameters. Parameters ---------- event : dict Lambda event from API Gateway. - Returns + Returns: ------- str The session ID from path parameters. - Example + Example: ------- >>> event = {"pathParameters": {"sessionId": "sess-123"}} >>> get_session_id(event) @@ -127,26 +122,21 @@ def get_session_id(event: dict) -> str: def get_principal_id(event: dict) -> str: - """ - Extract principal ID from Lambda event authorizer context. + """Extract principal ID from Lambda event authorizer context. Parameters ---------- event : dict Lambda event from API Gateway. - Returns + Returns: ------- str The principal ID from authorizer context. - Example + Example: ------- - >>> event = { - ... "requestContext": { - ... "authorizer": {"principal": "user-123"} - ... } - ... } + >>> event = {"requestContext": {"authorizer": {"principal": "user-123"}}} >>> get_principal_id(event) 'user-123' """ @@ -155,20 +145,19 @@ def get_principal_id(event: dict) -> str: def get_bearer_token(event: dict) -> str | None: - """ - Extract Bearer token from Authorization header in Lambda event. + """Extract Bearer token from Authorization header in Lambda event. Parameters ---------- event : dict Lambda event from API Gateway. - Returns + Returns: ------- Optional[str] The token string if present and properly formatted, else None. - Example + Example: ------- >>> event = {"headers": {"Authorization": "Bearer abc123"}} >>> get_bearer_token(event) @@ -189,8 +178,7 @@ def get_bearer_token(event: dict) -> str | None: def get_id_token(event: dict) -> str: - """ - Extract ID token from Authorization header in Lambda event. + """Extract ID token from Authorization header in Lambda event. This function extracts the bearer token from the authorization header, removing the "Bearer" prefix if present. @@ -200,17 +188,17 @@ def get_id_token(event: dict) -> str: event : dict Lambda event from API Gateway. - Returns + Returns: ------- str The ID token without the "Bearer" prefix. - Raises + Raises: ------ ValueError If authorization header is missing. - Example + Example: ------- >>> event = {"headers": {"Authorization": "Bearer token123"}} >>> get_id_token(event) diff --git a/lambda/utilities/exceptions.py b/lambda/utilities/exceptions.py index 5f96af16d..ade752cc4 100644 --- a/lambda/utilities/exceptions.py +++ b/lambda/utilities/exceptions.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Exceptions from handling RAG documents.""" diff --git a/lambda/utilities/fastapi_factory.py b/lambda/utilities/fastapi_factory.py index 12358536f..cddf2da8b 100644 --- a/lambda/utilities/fastapi_factory.py +++ b/lambda/utilities/fastapi_factory.py @@ -28,8 +28,7 @@ def create_fastapi_app() -> FastAPI: - """ - Create a FastAPI application with standard LISA configuration. + """Create a FastAPI application with standard LISA configuration. This factory function creates a FastAPI app with: - Standard FastAPI settings (redirect_slashes, lifespan, docs) diff --git a/lambda/utilities/fastapi_middleware/auth_decorators.py b/lambda/utilities/fastapi_middleware/auth_decorators.py index b760abc16..6e733940e 100644 --- a/lambda/utilities/fastapi_middleware/auth_decorators.py +++ b/lambda/utilities/fastapi_middleware/auth_decorators.py @@ -23,8 +23,7 @@ def require_admin(message: str = "User does not have permission to perform this action") -> Callable: - """ - Decorator for FastAPI route handlers that require admin access. + """Decorator for FastAPI route handlers that require admin access. Works with async FastAPI handlers that have a `request: Request` parameter. The decorator extracts the AWS event from the request scope and checks admin status. diff --git a/lambda/utilities/fastapi_middleware/aws_api_gateway_middleware.py b/lambda/utilities/fastapi_middleware/aws_api_gateway_middleware.py index 42782c8c8..4b75e9e94 100644 --- a/lambda/utilities/fastapi_middleware/aws_api_gateway_middleware.py +++ b/lambda/utilities/fastapi_middleware/aws_api_gateway_middleware.py @@ -18,11 +18,9 @@ class AWSAPIGatewayMiddleware(BaseHTTPMiddleware): - """ - Handles the FastAPI path and root_path dynamically from the ASGI request data. + """Handles the FastAPI path and root_path dynamically from the ASGI request data. - Mangum injects the AWS event data which we can use to dynamically set the path - and root_path. + Mangum injects the AWS event data which we can use to dynamically set the path and root_path. https://github.com/jordaneremieff/mangum/issues/147 """ diff --git a/lambda/utilities/fastapi_middleware/exception_handlers.py b/lambda/utilities/fastapi_middleware/exception_handlers.py index 6acbdab0a..e586f1126 100644 --- a/lambda/utilities/fastapi_middleware/exception_handlers.py +++ b/lambda/utilities/fastapi_middleware/exception_handlers.py @@ -24,8 +24,7 @@ async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse: - """ - Handle all unhandled exceptions. + """Handle all unhandled exceptions. This handler catches any exceptions not handled by more specific handlers. It logs detailed error information internally but returns a generic message diff --git a/lambda/utilities/fastapi_middleware/input_validation_middleware.py b/lambda/utilities/fastapi_middleware/input_validation_middleware.py index 9a5008d8f..e0f05a2c2 100644 --- a/lambda/utilities/fastapi_middleware/input_validation_middleware.py +++ b/lambda/utilities/fastapi_middleware/input_validation_middleware.py @@ -29,8 +29,7 @@ def sanitize_input(data: str) -> str: - """ - Sanitize string input by removing or escaping dangerous characters. + """Sanitize string input by removing or escaping dangerous characters. This function: - Escapes HTML/XML special characters to prevent XSS @@ -57,8 +56,7 @@ def sanitize_input(data: str) -> str: class InputValidationMiddleware(BaseHTTPMiddleware): - """ - Middleware that validates and sanitizes all incoming requests. + """Middleware that validates and sanitizes all incoming requests. This middleware provides security protections against: - Null byte injection attacks @@ -70,8 +68,7 @@ class InputValidationMiddleware(BaseHTTPMiddleware): """ def __init__(self, app: ASGIApp, max_request_size: int = DEFAULT_MAX_REQUEST_SIZE) -> None: - """ - Initialize the input validation middleware. + """Initialize the input validation middleware. Args: app: The ASGI application @@ -82,8 +79,7 @@ def __init__(self, app: ASGIApp, max_request_size: int = DEFAULT_MAX_REQUEST_SIZ self.max_request_size = max_request_size def contains_null_bytes(self, data: str) -> bool: - """ - Check if a string contains null bytes. + r"""Check if a string contains null bytes. Null bytes (\\x00) can be used to bypass input validation or cause unexpected behavior in string processing. @@ -97,8 +93,7 @@ def contains_null_bytes(self, data: str) -> bool: return "\x00" in data async def check_request_size(self, request: Request) -> JSONResponse | None: - """ - Validate that the request body size does not exceed the configured limit. + """Validate that the request body size does not exceed the configured limit. Args: request: The incoming HTTP request @@ -125,7 +120,7 @@ async def check_request_size(self, request: Request) -> JSONResponse | None: content={ "error": "Payload Too Large", "message": ( - f"Request body size exceeds maximum allowed size " f"of {self.max_request_size} bytes" + f"Request body size exceeds maximum allowed size of {self.max_request_size} bytes" ), }, ) @@ -136,8 +131,7 @@ async def check_request_size(self, request: Request) -> JSONResponse | None: return None async def validate_query_params(self, request: Request) -> JSONResponse | None: - """ - Validate query parameters for null bytes. + """Validate query parameters for null bytes. Args: request: The incoming HTTP request @@ -165,8 +159,7 @@ async def validate_query_params(self, request: Request) -> JSONResponse | None: return None async def validate_path_params(self, request: Request) -> JSONResponse | None: - """ - Validate path parameters for null bytes. + """Validate path parameters for null bytes. Args: request: The incoming HTTP request @@ -193,8 +186,7 @@ async def validate_path_params(self, request: Request) -> JSONResponse | None: return None async def validate_request_body(self, request: Request) -> JSONResponse | None: - """ - Validate request body for null bytes. + """Validate request body for null bytes. This reads the request body and checks for null bytes. If found, returns an error response. Otherwise, the body is consumed and needs @@ -235,8 +227,7 @@ async def validate_request_body(self, request: Request) -> JSONResponse | None: return None async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - """ - Process the request through validation checks before passing to handlers. + """Process the request through validation checks before passing to handlers. Validation order: 1. HTTP method validation (returns 405 if invalid) diff --git a/lambda/utilities/fastapi_middleware/request_logging_middleware.py b/lambda/utilities/fastapi_middleware/request_logging_middleware.py index 9e96f009b..1b0ea105c 100644 --- a/lambda/utilities/fastapi_middleware/request_logging_middleware.py +++ b/lambda/utilities/fastapi_middleware/request_logging_middleware.py @@ -34,8 +34,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): - """ - Middleware that logs all incoming requests with sanitized data. + """Middleware that logs all incoming requests with sanitized data. This middleware provides: - Automatic logging of all requests (method, path, headers, params) @@ -51,8 +50,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): """ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - """ - Process the request, log details, and pass to next handler. + """Process the request, log details, and pass to next handler. Args: request: The incoming HTTP request @@ -128,8 +126,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - return response def _build_log_data(self, request: Request, event: dict[str, Any]) -> dict[str, Any]: - """ - Build sanitized log data from request and AWS event. + """Build sanitized log data from request and AWS event. Args: request: The FastAPI request object @@ -170,8 +167,7 @@ def _build_log_data(self, request: Request, event: dict[str, Any]) -> dict[str, return log_data def _sanitize_body(self, body: bytes) -> str: - """ - Sanitize request body for logging. + """Sanitize request body for logging. Attempts to parse as JSON and redact sensitive fields. If parsing fails, returns a placeholder. diff --git a/lambda/utilities/fastapi_middleware/security_headers_middleware.py b/lambda/utilities/fastapi_middleware/security_headers_middleware.py index 6351da5e1..732fc26ab 100644 --- a/lambda/utilities/fastapi_middleware/security_headers_middleware.py +++ b/lambda/utilities/fastapi_middleware/security_headers_middleware.py @@ -18,8 +18,7 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): - """ - Middleware that adds security headers to all HTTP responses. + """Middleware that adds security headers to all HTTP responses. Security headers included: - Strict-Transport-Security: Forces HTTPS connections @@ -34,8 +33,7 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): """ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - """ - Process the request and add security headers to the response. + """Process the request and add security headers to the response. Args: request: The incoming HTTP request diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index 681e36a81..32441f474 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -38,8 +38,7 @@ def _get_metadata(s3_uri: str, name: str, metadata: dict | None = None) -> dict: - """ - Create metadata dictionary for a document. + """Create metadata dictionary for a document. Args: s3_uri: S3 URI of the document @@ -83,7 +82,7 @@ def _extract_pdf_content(s3_object: dict) -> str: ---------- s3_object (dict): an S3 object containing a PDF file body - Returns + Returns: ------- str: The extracted text from the PDF file. """ @@ -110,7 +109,7 @@ def _extract_docx_content(s3_object: dict) -> str: ---------- s3_object (dict): an S3 object containing a docx file body - Returns + Returns: ------- str: The extracted text from the docx file. """ @@ -125,9 +124,9 @@ def _extract_docx_content(s3_object: dict) -> str: def _extract_text_content(s3_object: dict) -> str: - """ - Extracts text content from an S3 object. Decode as - utf-8 to properly read special characters + """Extracts text content from an S3 object. + + Decode as utf-8 to properly read special characters Parameters ---------- @@ -144,12 +143,12 @@ def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: ingestion_job : IngestionJob Ingestion job containing file information and chunking strategy - Returns + Returns: ------- list[Document] List of document chunks for the processed file - Raises + Raises: ------ RagUploadException If S3 path is invalid or file processing fails diff --git a/lambda/utilities/header_sanitizer.py b/lambda/utilities/header_sanitizer.py index 588c32d04..029d214ed 100644 --- a/lambda/utilities/header_sanitizer.py +++ b/lambda/utilities/header_sanitizer.py @@ -44,8 +44,7 @@ def sanitize_headers(headers: dict[str, Any], event: dict[str, Any]) -> dict[str, Any]: - """ - Sanitize HTTP headers using a allowlist approach. + """Sanitize HTTP headers using a allowlist approach. Only headers in the ALLOWED_HEADERS set are logged. This prevents log injection attacks by rejecting any unexpected or potentially malicious headers. @@ -64,7 +63,7 @@ def sanitize_headers(headers: dict[str, Any], event: dict[str, Any]) -> dict[str >>> headers = { ... "accept": "application/json", ... "x-amzn-actiontrace": "injected-value", - ... "x-forwarded-for": "1.2.3.4" + ... "x-forwarded-for": "1.2.3.4", ... } >>> event = {"requestContext": {"identity": {"sourceIp": "9.10.11.12"}}} >>> sanitized = sanitize_headers(headers, event) @@ -101,8 +100,7 @@ def sanitize_headers(headers: dict[str, Any], event: dict[str, Any]) -> dict[str def get_sanitized_headers_for_logging(event: dict[str, Any]) -> dict[str, Any]: - """ - Extract and sanitize headers from Lambda event for safe logging. + """Extract and sanitize headers from Lambda event for safe logging. This is a convenience function that extracts headers from the event and sanitizes them in one step. @@ -114,10 +112,7 @@ def get_sanitized_headers_for_logging(event: dict[str, Any]) -> dict[str, Any]: Dictionary of sanitized headers safe for logging Example: - >>> event = { - ... "headers": {"x-forwarded-for": "1.2.3.4"}, - ... "requestContext": {"identity": {"sourceIp": "5.6.7.8"}} - ... } + >>> event = {"headers": {"x-forwarded-for": "1.2.3.4"}, "requestContext": {"identity": {"sourceIp": "5.6.7.8"}}} >>> headers = get_sanitized_headers_for_logging(event) >>> headers["x-forwarded-for"] "5.6.7.8" diff --git a/lambda/utilities/healthcheck_validator.py b/lambda/utilities/healthcheck_validator.py index f342de9fd..db851cf9b 100644 --- a/lambda/utilities/healthcheck_validator.py +++ b/lambda/utilities/healthcheck_validator.py @@ -16,8 +16,7 @@ def validate_healthcheck_command(command: str | list[str]) -> None: - """ - Validate ECS healthcheck command format. + """Validate ECS healthcheck command format. This validation ensures the command format is compatible with ECS requirements to prevent deployment failures. It does NOT restrict command content - admins diff --git a/lambda/utilities/input_validation.py b/lambda/utilities/input_validation.py index 24951bec5..b25ed9393 100644 --- a/lambda/utilities/input_validation.py +++ b/lambda/utilities/input_validation.py @@ -32,8 +32,7 @@ def contains_null_bytes(data: str) -> bool: - """ - Check if a string contains null bytes. + r"""Check if a string contains null bytes. Null bytes (\\x00) can be used to bypass input validation or cause unexpected behavior in string processing. @@ -48,8 +47,7 @@ def contains_null_bytes(data: str) -> bool: def validate_input(max_request_size: int = DEFAULT_MAX_REQUEST_SIZE) -> Callable[[F], F]: - """ - Decorator to validate Lambda event input before processing. + """Decorator to validate Lambda event input before processing. This decorator provides security protections against: - Null byte injection attacks @@ -66,8 +64,7 @@ def validate_input(max_request_size: int = DEFAULT_MAX_REQUEST_SIZE) -> Callable def decorator(f: F) -> F: @functools.wraps(f) def wrapper(event: dict, context: dict) -> dict[str, str | int | dict[str, str]]: - """ - Validate Lambda event input. + """Validate Lambda event input. Validation order: 1. HTTP method validation (returns 405 if invalid) diff --git a/lambda/utilities/lambda_decorators.py b/lambda/utilities/lambda_decorators.py index ad79c4dfd..9893dd9fb 100644 --- a/lambda/utilities/lambda_decorators.py +++ b/lambda/utilities/lambda_decorators.py @@ -40,9 +40,7 @@ @overload -def api_wrapper(_func: LambdaHandler) -> LambdaHandler: - """Overload for decorator without parentheses.""" - ... +def api_wrapper(_func: LambdaHandler) -> LambdaHandler: ... @overload @@ -50,9 +48,7 @@ def api_wrapper( _func: None = None, *, max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, -) -> Callable[[LambdaHandler], LambdaHandler]: - """Overload for decorator with parameters.""" - ... +) -> Callable[[LambdaHandler], LambdaHandler]: ... def api_wrapper( @@ -60,8 +56,7 @@ def api_wrapper( *, max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, ) -> LambdaHandler | Callable[[LambdaHandler], LambdaHandler]: - """ - Wrap Lambda function with comprehensive API Gateway integration. + """Wrap Lambda function with comprehensive API Gateway integration. This decorator provides: - Input validation (null bytes, request size, HTTP methods) @@ -81,12 +76,12 @@ def api_wrapper( max_request_size : int Maximum allowed request body size in bytes (default: 1MB). - Returns + Returns: ------- LambdaHandler | Callable[[LambdaHandler], LambdaHandler] The wrapped function with API Gateway integration. - Example + Example: ------- >>> @api_wrapper ... def get_user(event: dict, context: dict) -> dict: @@ -156,8 +151,7 @@ def wrapper(event: dict[Any, Any], context: Any) -> dict[Any, Any]: def authorization_wrapper(f: LambdaHandler) -> LambdaHandler: - """ - Wrap Lambda authorizer function. + """Wrap Lambda authorizer function. This decorator sets up the Lambda context for authorizer functions without adding API Gateway response formatting. @@ -167,12 +161,12 @@ def authorization_wrapper(f: LambdaHandler) -> LambdaHandler: f : LambdaHandler The Lambda authorizer function to wrap. - Returns + Returns: ------- LambdaHandler The wrapped authorizer function. - Example + Example: ------- >>> @authorization_wrapper ... def authorizer(event: dict, context: dict) -> dict: @@ -190,15 +184,14 @@ def wrapper(event: dict[Any, Any], context: Any) -> Any: def get_lambda_context() -> Any: - """ - Get the current Lambda context from context variable. + """Get the current Lambda context from context variable. - Returns + Returns: ------- Any The Lambda context object. - Raises + Raises: ------ LookupError If called outside of a Lambda execution context. diff --git a/lambda/utilities/response_builder.py b/lambda/utilities/response_builder.py index dffd918e4..b4db7497a 100644 --- a/lambda/utilities/response_builder.py +++ b/lambda/utilities/response_builder.py @@ -27,15 +27,14 @@ class DecimalEncoder(json.JSONEncoder): """JSON encoder that handles Decimal, datetime, and Pydantic objects.""" def default(self, obj: Any) -> Any: - """ - Encode special types to JSON-serializable formats. + """Encode special types to JSON-serializable formats. Parameters ---------- obj : Any Object to encode. - Returns + Returns: ------- Any JSON-serializable representation. @@ -51,15 +50,14 @@ def default(self, obj: Any) -> Any: def _serialize_pydantic(obj: Any) -> Any: - """ - Recursively serialize Pydantic models to dictionaries. + """Recursively serialize Pydantic models to dictionaries. Parameters ---------- obj : Any Object to serialize. - Returns + Returns: ------- Any Serialized object. @@ -74,8 +72,7 @@ def _serialize_pydantic(obj: Any) -> Any: def generate_html_response(status_code: int, response_body: Any) -> dict[str, str | int | dict[str, str]]: - """ - Generate API Gateway response with security headers. + """Generate API Gateway response with security headers. This function creates a properly formatted API Gateway response with: - JSON-encoded body @@ -90,12 +87,12 @@ def generate_html_response(status_code: int, response_body: Any) -> dict[str, st response_body : Any Response body to be JSON-encoded. Can be dict, list, Pydantic model, or list of Pydantic models. - Returns + Returns: ------- Dict[str, Union[str, int, Dict[str, str]]] API Gateway response object. - Example + Example: ------- >>> generate_html_response(200, {"userId": "123", "name": "John"}) { @@ -123,8 +120,7 @@ def generate_html_response(status_code: int, response_body: Any) -> dict[str, st def generate_exception_response(e: Exception) -> dict[str, str | int | dict[str, str]]: - """ - Generate API Gateway error response from exception. + """Generate API Gateway error response from exception. This function maps exceptions to appropriate HTTP status codes and generates user-friendly error messages while logging detailed errors @@ -142,12 +138,12 @@ def generate_exception_response(e: Exception) -> dict[str, str | int | dict[str, e : Exception Exception that was caught. - Returns + Returns: ------- Dict[str, Union[str, int, Dict[str, str]]] API Gateway error response. - Example + Example: ------- >>> try: ... raise ValueError("Invalid user ID") diff --git a/lambda/utilities/session_encryption.py b/lambda/utilities/session_encryption.py index 1ec78ab64..6ac1d1127 100644 --- a/lambda/utilities/session_encryption.py +++ b/lambda/utilities/session_encryption.py @@ -65,8 +65,7 @@ def _get_kms_key_arn() -> str: def _generate_data_key(key_arn: str, encryption_context: dict[str, str] | None = None) -> tuple[bytes, bytes]: - """ - Generate a data key from KMS. + """Generate a data key from KMS. Args: key_arn: KMS key ARN @@ -86,8 +85,7 @@ def _generate_data_key(key_arn: str, encryption_context: dict[str, str] | None = def _decrypt_data_key(encrypted_data_key: bytes, encryption_context: dict[str, str] | None = None) -> bytes: - """ - Decrypt a data key using KMS. + """Decrypt a data key using KMS. Args: encrypted_data_key: Encrypted data key @@ -105,8 +103,7 @@ def _decrypt_data_key(encrypted_data_key: bytes, encryption_context: dict[str, s def _create_encryption_context(user_id: str, session_id: str) -> dict[str, str]: - """ - Create encryption context for KMS operations. + """Create encryption context for KMS operations. Args: user_id: User ID @@ -119,8 +116,7 @@ def _create_encryption_context(user_id: str, session_id: str) -> dict[str, str]: def encrypt_session_data(data: Any, user_id: str, session_id: str) -> str: - """ - Encrypt session data using KMS envelope encryption. + """Encrypt session data using KMS envelope encryption. Args: data: Data to encrypt (will be JSON serialized) @@ -162,8 +158,7 @@ def encrypt_session_data(data: Any, user_id: str, session_id: str) -> str: def decrypt_session_data(encrypted_data: str, user_id: str, session_id: str) -> Any: - """ - Decrypt session data using KMS envelope encryption. + """Decrypt session data using KMS envelope encryption. Args: encrypted_data: Base64 encoded encrypted data @@ -201,8 +196,7 @@ def decrypt_session_data(encrypted_data: str, user_id: str, session_id: str) -> def is_encrypted_data(data: str) -> bool: - """ - Check if a string appears to be encrypted session data. + """Check if a string appears to be encrypted session data. Args: data: String to check @@ -227,8 +221,7 @@ def is_encrypted_data(data: str) -> bool: def migrate_session_to_encrypted(session_data: dict[str, Any], user_id: str, session_id: str) -> dict[str, Any]: - """ - Migrate a session from unencrypted to encrypted format. + """Migrate a session from unencrypted to encrypted format. Args: session_data: Session data dictionary @@ -265,8 +258,7 @@ def migrate_session_to_encrypted(session_data: dict[str, Any], user_id: str, ses def decrypt_session_fields(session_data: dict[str, Any], user_id: str, session_id: str) -> dict[str, Any]: - """ - Decrypt encrypted fields in session data. + """Decrypt encrypted fields in session data. Args: session_data: Session data dictionary diff --git a/lib/docs/admin/api-overview.md b/lib/docs/admin/api-overview.md index 3fe9f7503..429c954ef 100644 --- a/lib/docs/admin/api-overview.md +++ b/lib/docs/admin/api-overview.md @@ -36,13 +36,13 @@ Serve ALB. The `/v2/serve/models` endpoint on the LISA Serve ALB allows users to list all models available for inference in the LISA system. -#### Request Example: +#### Request Example ```bash curl -s -H 'Authorization: Bearer ' -X GET https:///v2/serve/models ``` -#### Response Example: +#### Response Example ```json { @@ -70,7 +70,7 @@ curl -s -H 'Authorization: Bearer ' -X GET https:/// } ``` -#### Explanation of Response Fields: +#### Explanation of Response Fields These fields are all defined by the OpenAI API specification, which is documented [here](https://platform.openai.com/docs/api-reference/models/list). @@ -87,6 +87,7 @@ LISA provides RESTful API endpoints for programmatic access to user metrics data ### Base URL Structure All metrics endpoints are accessed through LISA's main API Gateway with the following structure: + ``` https://{API-GATEWAY-DOMAIN}/{STAGE}/metrics/users/ ``` @@ -104,9 +105,11 @@ All API endpoints require proper authentication through LISA's configured author **Description**: Retrieves comprehensive metrics data for a specific user, including session-level details and usage history. **Path Parameters**: + - `userId` (string, required): The unique identifier for the user whose metrics you want to retrieve **Example Request**: + ```bash curl -X GET \ 'https://your-api-gateway-domain/metrics/users/john.doe@company.com' \ @@ -114,6 +117,7 @@ curl -X GET \ ``` **Response Format**: + ```json { "statusCode": 200, @@ -148,6 +152,7 @@ curl -X GET \ ``` **Response Fields**: + - `totalPrompts`: Total number of prompts submitted by the user - `ragUsageCount`: Number of times the user utilized RAG features - `mcpToolCallsCount`: Total MCP tool calls made by the user @@ -158,6 +163,7 @@ curl -X GET \ - `lastSeen`: Timestamp of user's most recent interaction **Error Responses**: + - `400 Bad Request`: Missing or invalid userId parameter - `404 Not Found`: User not found in metrics database - `500 Internal Server Error`: Database or system error @@ -169,6 +175,7 @@ curl -X GET \ **Description**: Retrieves aggregated metrics across all users in the system, providing system-wide analytics and usage statistics. **Example Request**: + ```bash curl -X GET \ 'https://your-api-gateway-domain/metrics/users/all' \ @@ -176,6 +183,7 @@ curl -X GET \ ``` **Response Format**: + ```json { "statusCode": 200, @@ -200,6 +208,7 @@ curl -X GET \ ``` **Response Fields**: + - `totalUniqueUsers`: Count of unique users who have interacted with LISA - `totalPrompts`: Aggregate count of all prompts across users - `totalRagUsage`: Total number of RAG feature uses @@ -210,8 +219,8 @@ curl -X GET \ - `userGroups`: Distribution of users across organizational groups **Error Responses**: -- `500 Internal Server Error`: Database scan error or system failure +- `500 Internal Server Error`: Database scan error or system failure # Error Handling for API Requests @@ -223,9 +232,9 @@ Below is a list of common errors that can occur in the system, along with the HT ### ModelNotFoundError -* **Description**: Raised when a model that is requested for retrieval or deletion is not found in the system. -* **HTTP Status Code**: `404 Not Found` -* **Response Body**: +- **Description**: Raised when a model that is requested for retrieval or deletion is not found in the system. +- **HTTP Status Code**: `404 Not Found` +- **Response Body**: ```json { @@ -234,13 +243,13 @@ Below is a list of common errors that can occur in the system, along with the HT } ``` -* **Example Scenario**: When a client attempts to fetch details of a model that does not exist in the database, the `ModelNotFoundError` is raised. +- **Example Scenario**: When a client attempts to fetch details of a model that does not exist in the database, the `ModelNotFoundError` is raised. ### ModelAlreadyExistsError -* **Description:** Raised when a request to create a model is made, but the model already exists in the system. -* **HTTP Status Code**: `400` -* **Response Body**: +- **Description:** Raised when a request to create a model is made, but the model already exists in the system. +- **HTTP Status Code**: `400` +- **Response Body**: ```json { @@ -249,13 +258,13 @@ Below is a list of common errors that can occur in the system, along with the HT } ``` -* **Example Scenario:** A client attempts to create a model with an ID or name that already exists in the database. The system detects the conflict and raises the `ModelAlreadyExistsError`. +- **Example Scenario:** A client attempts to create a model with an ID or name that already exists in the database. The system detects the conflict and raises the `ModelAlreadyExistsError`. ### InvalidInputError (Hypothetical Example) -* **Description**: Raised when the input provided by the client for creating or updating a model is invalid or does not conform to expected formats. -* **HTTP Status Code**: `400 Bad Request` -* **Response Body**: +- **Description**: Raised when the input provided by the client for creating or updating a model is invalid or does not conform to expected formats. +- **HTTP Status Code**: `400 Bad Request` +- **Response Body**: ```json { @@ -264,16 +273,16 @@ Below is a list of common errors that can occur in the system, along with the HT } ``` -* **Example Scenario**: The client submits a malformed JSON body or omits required fields in a model creation request, triggering an `InvalidInputError`. +- **Example Scenario**: The client submits a malformed JSON body or omits required fields in a model creation request, triggering an `InvalidInputError`. ## Handling Validation Errors Validation errors are handled across the API via utility functions and model transformation logic. These errors typically occur when user inputs fail validation checks or when required data is missing from a request. -### Example Response for Validation Error: +### Example Response for Validation Error -* **HTTP Status Code**: `422 Unprocessable Entity` -* **Response Body**: +- **HTTP Status Code**: `422 Unprocessable Entity` +- **Response Body**: ```json { diff --git a/lib/docs/admin/architecture.md b/lib/docs/admin/architecture.md index c34efe3e7..ca4d97c93 100644 --- a/lib/docs/admin/architecture.md +++ b/lib/docs/admin/architecture.md @@ -14,8 +14,8 @@ LISA Serve and LISA MCP are standalone, core solutions with APIs for customers n * Model management API supports deploying, updating, and deleting third party and internally hosted models. * MCP API supports deploying, updating, deleting, and calling internally hosted MCP tools. - ## LISA Serve + ![LISA Serve Architecture](../assets/LisaServe.png) LISA Serve provides model self-hosting and integration with compatible external model providers. Serve supports text generation, image generation, video generation, and embedding models. Serve’s components are designed for scale and reliability. Serve can be accessed via LISA’s REST APIs, or through LISA’s chat @@ -34,6 +34,7 @@ Self-hosted model traffic is directed to model specific ALBs, which enable autos * LISA supports OpenAI's API spec, which means LISA can be easily configured with the Continue plugin for use with Jetbrains or VS Code integrated development environments (IDE). ### Model Management + ![LISA Model Management Architecture](../assets/LisaModelManagement.png) Use Model Management for managing the entire lifecycle of models configured or hosted with LISA. This includes creation, updating, @@ -68,8 +69,8 @@ security, networking, and infrastructure components are automatically deployed a * ECS Cluster: ECS cluster and task definitions are located in `ecs_model_deployer/src/lib/ecsCluster.ts`, with model containers specified in `ecs_model_deployer/src/lib/ecs-model.ts`. - ## LISA MCP + ![LISA MCP Architecture](../assets/LisaMcp.png) LISA MCP is a standalone product that provides scalable infrastructure for deploying and hosting Model Context Protocol (MCP) servers. It allows customers to self-host MCP servers for enterprise use. LISA MCP can be deployed independently of LISA Serve or configured to work seamlessly with LISA Serve and the Chat UI. @@ -77,11 +78,13 @@ LISA MCP is a standalone product that provides scalable infrastructure for deplo Each MCP server deployed via LISA MCP is provisioned on AWS Fargate via Amazon ECS, fronted by Application Load Balancers (ALBs) and Network Load Balancers (NLBs), and published through the existing API Gateway. This architecture allows chat sessions to securely invoke MCP tools without leaving your VPC. All routes remain protected by the same API Gateway Lambda authorizer patterns that guards the rest of LISA, ensuring API Keys, IDP lockdown, and JWT group enforcement continue to apply automatically. **Server Types:** LISA MCP supports all MCP server types: + * **STDIO servers:** Automatically wrapped with `mcp-proxy` and exposed over HTTP on port 8080 * **HTTP servers:** Direct HTTP endpoints using the configured port (default 8000) * **SSE servers:** Server-Sent Events endpoints for streaming responses **Networking Architecture:** The networking follows a layered approach: + * **API Gateway** receives MCP traffic on `/mcp/{serverId}` routes * **Network Load Balancer (NLB)** terminates the API Gateway VPC Link and forwards to the Application Load Balancer * **Application Load Balancer (ALB)** provides HTTP features including health checks, routing, and load balancing @@ -90,6 +93,7 @@ Each MCP server deployed via LISA MCP is provisioned on AWS Fargate via Amazon E **Lifecycle Management:** AWS Step Functions orchestrate the complete lifecycle of MCP servers, handling creation, update, deletion, start, and stop workflows. Each workflow provisions the required resources using CloudFormation templates, which manage infrastructure components like ECS Fargate services, load balancers, VPC Links, and auto-scaling configurations. **Key Features:** + * Turn-key hosting for STDIO, HTTP, or SSE MCP servers with a single API/UI workflow * Dynamic container builds from pre-built images or S3 artifacts synced at deploy time * Auto-scaling with configurable Fargate min/max capacity, custom metrics, and scaling targets per server @@ -105,8 +109,8 @@ Each MCP server deployed via LISA MCP is provisioned on AWS Fargate via Amazon E * Authentication: API Gateway enforces the same Lambda authorizer used across LISA (JWT validation + optional API key checks). The `{LISA_BEARER_TOKEN}` placeholder in connection details is automatically replaced with the user's bearer token at connection time. * Data Storage: Server metadata is stored in the `MCP_SERVERS_TABLE` DynamoDB table. When `DEPLOYMENT_PREFIX` is configured, completed servers are published to `McpConnectionsTable` so the chat application can surface them alongside externally hosted connections. - ## Chat UI + ![LISA Chatbot Architecture](../assets/LisaChat.png) LISA provides a configurable chat user interface (UI). The UI is hosted as a static website in Amazon S3, and is fronted diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index 5b9c1dac6..a9258d4fb 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -23,7 +23,6 @@ ## Deployment Steps - LISA uses npm scripts for build and deployment. Key commands: | Task | Command | @@ -37,7 +36,6 @@ LISA uses npm scripts for build and deployment. Key commands: The `npm run deploy` script runs the full pipeline: install dependencies, Docker checks, ECR login, model verification, build, and CDK deploy. Use `STACK= npm run deploy` to deploy specific stacks. - ### Step 1: Clone the Repository Ensure you're working with the latest stable release of LISA: @@ -70,8 +68,7 @@ export CDK_DOCKER=finch # Optional, only required if not using docker as contain ### Step 3: Set Up Python and TypeScript Environments - -- ***NOTE** The code block below has two tabs for Debian & EL/AL2* +* ***NOTE** The code block below has two tabs for Debian & EL/AL2* Install system dependencies and set up both Python and TypeScript environments using the project's npm scripts: * ***NOTE** The code block below has two tabs for Debian & EL/AL2* @@ -117,7 +114,6 @@ npm run install:python npm install ``` - == MacOS ```bash @@ -156,7 +152,6 @@ npm run install:python npm install ``` - ::: ### Step 4: Configure LISA @@ -198,6 +193,7 @@ litellmConfig: > To include prompt/response content in LiteLLM logs (published by the `LISA Serve` ECS task to CloudWatch via `litellm.log`), enable LiteLLM logging callbacks and message logging in `config-custom.yaml`. > > 1. Add the following to `litellmConfig`: +> > ```yaml > litellmConfig: > litellm_settings: @@ -212,13 +208,14 @@ litellmConfig: > > 2. Ensure you are aware of the privacy/compliance implications: this causes request/response content to be logged. > -> LiteLLM Proxy logging reference: https://docs.litellm.ai/docs/proxy/logging +> LiteLLM Proxy logging reference: > [!IMPORTANT] > API Gateway audit logging (strict opt-in): > LISA can emit audit logs for API Gateway requests (who initiated the request, what action was taken, and a sanitized JSON body) only when enabled via `auditLoggingConfig` in `config-custom.yaml`. > > Example (opt-in to specific API prefixes): +> > ```yaml > auditLoggingConfig: > enabled: true @@ -227,6 +224,7 @@ litellmConfig: > ``` > > Example (`auditAll`): +> > ```yaml > auditLoggingConfig: > enabled: true @@ -300,7 +298,6 @@ This command verifies if the model's weights are already present in your S3 buck > dictated which models were deployed. > **NOTE** - > For air-gapped systems, before running `npm run model:check` you should manually download model artifacts and place them in a `models` directory at the project root, using the structure: `models/`. > **NOTE** @@ -341,9 +338,8 @@ This approach builds all necessary components in a commercial region with full i This generates: - * Lambda function zip files in `./dist/layers/*.zip` (from `build:archive`) - * Docker images exported as `./dist/images/*.tar` files (from `build-assets --include-images`) + * Docker images exported as `./dist/images/*.tar` files (from `build-assets --include-images`) #### Step 2: Transfer to ADC Region diff --git a/lib/docs/admin/getting-started.md b/lib/docs/admin/getting-started.md index 9b16eb369..1dcb8d16c 100644 --- a/lib/docs/admin/getting-started.md +++ b/lib/docs/admin/getting-started.md @@ -5,7 +5,6 @@ is an open-source, infrastructure-as-code product. Customers deploy it directly account. While LISA is specially designed for ADC regions that support government customers' most sensitive workloads, it is also compatible in any region. LISA is scalable and ready to support production use cases. - LISA accelerates GenAI adoption by offering built-in configurability with [Amazon Bedrock](https://aws.amazon.com/bedrock/) models, Knowledge Bases, and Guardrails. LISA also offers advanced capabilities like an optional enterprise-ready chat user interface (UI) with configurable features, authentication, resource access control, centralized model orchestration via LiteLLM, model self-hosting via Amazon ECS, retrieval augmented generation (RAG), APIs, and broad model context protocol (MCP) support and features. LISA is also compatible with OpenAI’s API specification making it easily configurable with supporting solutions. For example, the Continue plugin for VSCode and JetBrains integrated development environments (IDE). LISA's roadmap is customer-driven, with new capabilities launching monthly. Reach out to the product team to ask questions, provide feedback, and send feature requests via the "Contact Us" button above. @@ -87,7 +86,6 @@ built in model orchestration, added model flexibility, and model context protoco flexibility for different use cases. * Leverages AWS services that are FedRAMP High compliant. - *The below screenshot showcases LISA’s optional chat assistant user interface. On the left is the user’s Chat History. In the center, the user can start a new chat session and prompt a model. Up top, the user can select from four libraries: Model, Document, Prompt, and MCP Connections. As an Administrator, this user also can access the Administration menu. Here they configure application features and manage available models. See the next screenshot for more details.* *See the next screenshot for more details.* ![LISA UI](../assets/LISA_UI.png) @@ -105,36 +103,36 @@ LISA Roles and Enterprise Groups control access to features and resources. `AdminGroup` and `UserGroup` properties in the configuration are used to control tiers of application access, not resource access. -- **AdminGroup**: The IDP group that distinguishes which users have access to create and manage restricted resource configuration within the UI, including: - - Activating application features - - Configuring models via Model Management - - Creating and deleting repositories, and configuring group access via RAG management - - Managing all collections across all repositories - - MCP server management - - MCP Workbench code editor +* **AdminGroup**: The IDP group that distinguishes which users have access to create and manage restricted resource configuration within the UI, including: + * Activating application features + * Configuring models via Model Management + * Creating and deleting repositories, and configuring group access via RAG management + * Managing all collections across all repositories + * MCP server management + * MCP Workbench code editor -- **RagAdminGroup** (optional): The IDP group for users who need to manage RAG content without full Admin privileges. This is especially useful in multi-tenant environments. RAG Admins can: - - Access the RAG Management page - - Create, update, and delete collections on repositories they have group access to - - Update ingestion pipelines on repositories they have group access to - - Delete documents in accessible repositories - - RAG Admins **cannot** create or delete repositories, change repository `allowedGroups`, or access any other Admin-only pages (Model Management, Configuration, MCP, API Tokens) +* **RagAdminGroup** (optional): The IDP group for users who need to manage RAG content without full Admin privileges. This is especially useful in multi-tenant environments. RAG Admins can: + * Access the RAG Management page + * Create, update, and delete collections on repositories they have group access to + * Update ingestion pipelines on repositories they have group access to + * Delete documents in accessible repositories + * RAG Admins **cannot** create or delete repositories, change repository `allowedGroups`, or access any other Admin-only pages (Model Management, Configuration, MCP, API Tokens) -- **UserGroup** (optional): If provided, this is required when the IDP is used for multiple systems and you want to control which users in the IDP have access to LISA. +* **UserGroup** (optional): If provided, this is required when the IDP is used for multiple systems and you want to control which users in the IDP have access to LISA. -- **API Management** (v6.1+): A new role that allows users to manage their API tokens within LISA, but does not grant full Admin privileges. +* **API Management** (v6.1+): A new role that allows users to manage their API tokens within LISA, but does not grant full Admin privileges. ## Groups Access to resources can be constrained by Enterprise Groups, including: -- LISA models -- Prompt templates -- RAG repos -- RAG collections -- MCP Connections -- LISA MCP servers -- API tokens +* LISA models +* Prompt templates +* RAG repos +* RAG collections +* MCP Connections +* LISA MCP servers +* API tokens You can create or bring any number of Enterprise Groups in your IDP, which can then be used in LISA to lock down resources at creation/update. When you create/update any resource, you can assign 0, 1, or many Groups to that resource. @@ -143,15 +141,17 @@ You can create or bring any number of Enterprise Groups in your IDP, which can t For example, let's say your IDP has the following groups: **Team Red**, **Team White**, and **Team Blue**. Below shows how you can use Groups to lock down access to Models, and then RAG repos and their Collections: **Models:** -- Model 1: Teams Red and White -- Model 2: none (Global) -- Model 3: Team Blue + +* Model 1: Teams Red and White +* Model 2: none (Global) +* Model 3: Team Blue **RAG Repositories and Collections:** -- RAG Repo 1: Teams Red, White, Blue - - Collection A: Team Red - - Collection B: Team White - - Collection C: Teams White and Blue -- RAG Repo 2: none (Global) - - Collection X: Team Blue - - Collection Y: none (Global) + +* RAG Repo 1: Teams Red, White, Blue + * Collection A: Team Red + * Collection B: Team White + * Collection C: Teams White and Blue +* RAG Repo 2: none (Global) + * Collection X: Team Blue + * Collection Y: none (Global) diff --git a/lib/docs/assets/LISA_Cognito_Example.png b/lib/docs/assets/LISA_Cognito_Example.png index 67f32aac8..466bd3c97 100644 Binary files a/lib/docs/assets/LISA_Cognito_Example.png and b/lib/docs/assets/LISA_Cognito_Example.png differ diff --git a/lib/docs/assets/LISA_Config.png b/lib/docs/assets/LISA_Config.png index 59680aaab..3c972f7ae 100644 Binary files a/lib/docs/assets/LISA_Config.png and b/lib/docs/assets/LISA_Config.png differ diff --git a/lib/docs/assets/LISA_Model_Mgmt.png b/lib/docs/assets/LISA_Model_Mgmt.png index dae2c5bcd..ce75799d6 100644 Binary files a/lib/docs/assets/LISA_Model_Mgmt.png and b/lib/docs/assets/LISA_Model_Mgmt.png differ diff --git a/lib/docs/assets/LISA_UI.png b/lib/docs/assets/LISA_UI.png index 76add5af5..25097ce59 100644 Binary files a/lib/docs/assets/LISA_UI.png and b/lib/docs/assets/LISA_UI.png differ diff --git a/lib/docs/assets/LisaChat.png b/lib/docs/assets/LisaChat.png index 0aacd722f..909010899 100644 Binary files a/lib/docs/assets/LisaChat.png and b/lib/docs/assets/LisaChat.png differ diff --git a/lib/docs/assets/LisaMcp.png b/lib/docs/assets/LisaMcp.png index 52cbc168f..64f9f6e58 100644 Binary files a/lib/docs/assets/LisaMcp.png and b/lib/docs/assets/LisaMcp.png differ diff --git a/lib/docs/assets/LisaModelManagement.png b/lib/docs/assets/LisaModelManagement.png index 61d47c444..c9d905dd6 100644 Binary files a/lib/docs/assets/LisaModelManagement.png and b/lib/docs/assets/LisaModelManagement.png differ diff --git a/lib/docs/assets/LisaServe.png b/lib/docs/assets/LisaServe.png index e58b07e20..953b76cee 100644 Binary files a/lib/docs/assets/LisaServe.png and b/lib/docs/assets/LisaServe.png differ diff --git a/lib/docs/config/api-tokens.md b/lib/docs/config/api-tokens.md index 52eca6382..586455824 100644 --- a/lib/docs/config/api-tokens.md +++ b/lib/docs/config/api-tokens.md @@ -21,6 +21,7 @@ LISA's API Token Management system provides secure, programmatic access to LISA API tokens are stored in the DynamoDB `APITokenTable` with enhanced security and metadata: **Table Schema:** + - **Partition Key**: `token` (SHA-256 hash of the actual token) - **Attributes**: - `tokenUUID`: Unique identifier for the token @@ -71,6 +72,7 @@ authConfig: If using LISA's UI, the `Allow user managed API tokens` UI configuration must be enabled for users to manage tokens via the UI: Navigate to **Configuration** → **User Components** and enable: + - `Allow user managed API tokens` ### Role Configuration @@ -110,13 +112,14 @@ Administrators have full visibility and control over all tokens in the system. > [!WARNING] > Token is displayed **ONLY ONCE**. Copy the token immediately or download it. - - Copy the token immediately or download it - - Check "I have securely saved this token" - - Click **Close** +- Copy the token immediately or download it +- Check "I have securely saved this token" +- Click **Close** #### Viewing All Tokens The token table displays: + - `Token Name`: Descriptive name - `Username`: User token belongs to - `Created By`: Who created the token @@ -128,6 +131,7 @@ The token table displays: - `Token UUID`: Unique identifier **Features:** + - **Search**: Filter tokens by name, username, creator, or groups - **Pagination**: Navigate through pages of tokens - **Sort**: Click column headers to sort @@ -160,10 +164,12 @@ This opens a view which shows the user their API token. > [!NOTE] > Token will automatically inherit your current groups. + 3. Click **Create Token** 4. **Token Display Modal** (see Admin section above) **Limitations:** + - Users can create only ONE token - Users cannot create system tokens - User tokens inherit the user's group memberships at point of creation (admins can create tokens with custom groups on behalf of users) @@ -172,6 +178,7 @@ This opens a view which shows the user their API token. #### Viewing Your Token The interface shows: + - Your token details - Current status (Active/Expired) - Expiration date @@ -397,12 +404,14 @@ curl https:///v2/serve/chat/completions \ ### User Tokens **Characteristics:** + - One token per user - Automatically inherits user's groups (unless created by admin with custom groups) - Created by user (self-service) or admin - Cannot be created if user already has a token **Use Cases:** + - Personal development - Individual tool integration - User-specific automation @@ -410,11 +419,13 @@ curl https:///v2/serve/chat/completions \ ### System Tokens **Characteristics:** + - Admin-only creation - Customizable group assignments - Ideal for service accounts **Use Cases:** + - Multiple services needing separate tokens - Shared service accounts - Production automations @@ -459,6 +470,7 @@ Legacy tokens are no longer supported and will need to be recreated. ### Identifying Legacy Tokens Legacy tokens have: + - Token values stored as-is (not hashed) - No metadata (name, groups, etc.) - "Legacy" badge in UI diff --git a/lib/docs/config/claude-code-setup.md b/lib/docs/config/claude-code-setup.md index c850c68c4..6bb2041b2 100644 --- a/lib/docs/config/claude-code-setup.md +++ b/lib/docs/config/claude-code-setup.md @@ -3,10 +3,12 @@ This guide explains how to configure Claude Code with LISA Serve ### References + - [Claude Code Documentation](https://code.claude.com/docs) - [Claude Cote LLM Gateway Configuration](https://code.claude.com/docs/en/llm-gateway) ### Prerequisites + - LISA instance deployed and accessible - LISA serve endpoint URL - LISA API key (See API Key Management) @@ -15,6 +17,7 @@ This guide explains how to configure Claude Code with LISA Serve ### Setup Steps 1. **Configure Claude Code Environment Variables**: + ```bash # Set the base URL to your LISA endpoint # Find it on cloudformation in the LISA-lisa-serve- stack in the outputs tab @@ -38,7 +41,6 @@ This guide explains how to configure Claude Code with LISA Serve export MAX_THINKING_TOKENS = 8192 # set to 0 to disable thinking ``` - ## Verification After configuration, verify your setup: @@ -49,13 +51,16 @@ claude "hello world" ``` ### Testing in the VSCode extension + You have two options to test if the configuration is working + 1. Open new claude code tab (This should reload the environment variables depending on your configuration) 2. Reload the vscode window ## Troubleshooting ### LISA Endpoint Issues + - Verify endpoint is accessible: `curl https://your-lisa-endpoint.com/health` - Check API key is valid - Confirm model names match LISA configuration diff --git a/lib/docs/config/cloudwatch.md b/lib/docs/config/cloudwatch.md index 7760d5cfe..3e5e6b5e6 100644 --- a/lib/docs/config/cloudwatch.md +++ b/lib/docs/config/cloudwatch.md @@ -3,6 +3,7 @@ LISA offers Administrators insights into user engagement, feature adoption, and system utilization. Through Amazon CloudWatch, LISA automatically tracks user interactions, RAG usage, MCP tool calls, and group-level analytics, presenting this data through an integrated dashboard and API endpoints. ## Overview + LISA usage data enables administrators to understand user behavior patterns, monitor feature adoption, and make data-driven decisions about system optimization. LISA tracks three primary categories: - **User Engagement**: Total prompts and users @@ -22,70 +23,70 @@ The LISA Metrics Dashboard is automatically created during deployment. Administr 3. Locate and click on **LISA-Metrics** dashboard. Click on it to open 4. The dashboard displays a 7-day view by default - ### Dashboard Widgets The dashboard contains 12 widgets organized to provide comprehensive visibility into system usage: #### Total Prompts -Displays the count of all user prompts over a given time period. This time-series graph shows overall system engagement and helps identify usage trends, peak hours, and growth patterns. It updates hourly with sum statistics. +Displays the count of all user prompts over a given time period. This time-series graph shows overall system engagement and helps identify usage trends, peak hours, and growth patterns. It updates hourly with sum statistics. #### Total RAG Usage -Tracks Retrieval Augmented Generation (RAG) feature adoption across all users. This indicates how frequently users leverage vector stores. Updates hourly with sum statistics. +Tracks Retrieval Augmented Generation (RAG) feature adoption across all users. This indicates how frequently users leverage vector stores. Updates hourly with sum statistics. #### Total MCP Tool Calls -Monitors Model Context Protocol (MCP) tool utilization system-wide. This graph reveals which external integrations are most valuable to users and helps identify opportunities for expanding MCP tool offerings. Aggregates tool calls across all available MCP servers. +Monitors Model Context Protocol (MCP) tool utilization system-wide. This graph reveals which external integrations are most valuable to users and helps identify opportunities for expanding MCP tool offerings. Aggregates tool calls across all available MCP servers. #### Prompts by User -Breaks down the prompt activity by individual users, enabling identification of power users and usage distribution patterns. +Breaks down the prompt activity by individual users, enabling identification of power users and usage distribution patterns. #### RAG Usage by User -Provides user-level RAG adoption usage, showing which users are actively leveraging vector stores. +Provides user-level RAG adoption usage, showing which users are actively leveraging vector stores. #### MCP Tool Calls by User -Displays MCP tool usage at the individual user level, helping administrators identify which users are most engaged with external integrations. +Displays MCP tool usage at the individual user level, helping administrators identify which users are most engaged with external integrations. #### MCP Tool Calls by Tool -Shows utilization breakdown by specific MCP tools, enabling administrators to understand which integrations provide the most value and which might need promotion or improvement. +Shows utilization breakdown by specific MCP tools, enabling administrators to understand which integrations provide the most value and which might need promotion or improvement. #### Total User Count -Displays the count of users who have interacted with the system. This single-value widget updates daily and provides a high-level view of user base size and growth. +Displays the count of users who have interacted with the system. This single-value widget updates daily and provides a high-level view of user base size and growth. #### Groups by Membership Count -Pie chart visualization showing user distribution across organizational groups. Helps administrators understand team sizes, group engagement levels, and organizational adoption patterns. Updates daily. +Pie chart visualization showing user distribution across organizational groups. Helps administrators understand team sizes, group engagement levels, and organizational adoption patterns. Updates daily. #### Group Prompt Counts -Tracks prompt activity aggregated by organizational groups, enabling administrators to understand which teams or departments are most engaged with LISA. +Tracks prompt activity aggregated by organizational groups, enabling administrators to understand which teams or departments are most engaged with LISA. #### Group RAG Usage -Monitors RAG feature adoption at the group level. +Monitors RAG feature adoption at the group level. #### Group MCP Usage -Displays MCP tool utilization by organizational groups. +Displays MCP tool utilization by organizational groups. ## Data Storage - Usage data is stored in multiple locations: ### DynamoDB Storage + - **Usage Metrics Table**: Stores aggregate usage metrics including total prompts, RAG usage counts, MCP tool usage, and group memberships - **Session-Level Tracking**: Detailed per-session metrics to support accurate counting and prevent data duplication - **Real-Time Updates**: Session data updates immediately as users interact with LISA ### CloudWatch Metrics + - **Namespace**: All metrics are published under LISA/UsageMetrics - **Dimensions**: Supports filtering by UserId, GroupName, and ToolName @@ -94,65 +95,74 @@ Usage data is stored in multiple locations: CloudWatch provides flexible time range customization options for analyzing metrics across different periods: ### Quick Time Range Selection + 1. **Dashboard Level**: Use the time range selector at the top-right of the dashboard 2. **Available Options**: + - Last hour, 3 hours, 12 hours, last day, 3 days, week - Custom relative and absolute ranges from the range selector dropdown ### Widget-Level Customization + 1. Click the **three dots** menu on any individual widget 2. Select **Edit** to access advanced options 3. **Period Settings**: Adjust data point intervals (1 day to 1 minute) 4. **Statistic Options**: Choose between Sum, Average, Maximum, Minimum, etc. . . 5. **Time Range Override**: Set widget-specific time ranges independent of dashboard settings - ## Data Availability and Timing Understanding when meaningful data becomes available is crucial for effective analysis: ### Update Frequencies + - **Daily Metrics**: User counts and group membership statistics refresh once daily - **Session Metrics**: Individual session data is updated in real-time, but may not immediately be visible on the CloudWatch widgets unless the period of the widget is changed or sufficient time passes for it to become visible (given default period and time range). ## Dashboard Management ### Customization and Overrides + CloudWatch allows extensive dashboard customization to meet specific organizational needs: 1. **Widget Modification**: Click any widget's menu to edit titles, colors, and display options 2. **Layout Changes**: Drag and drop widgets to reorganize dashboard layout ### Saving Changes + - **Manual Save**: A change to the dashboard can be saved using the **Save** button in the top right to preserve customizations. A change to a widget can be saved using using the **Update widget** button on the bottom of the edit widget modal. - **Auto-Save**: Auto-save can be configure on the dashboard by clicking the **Autosave** button which appears above the refresh button on the dashboard screen, and toggling Autosave to be on. ### Deployment Considerations + **Important**: Custom dashboard modifications are overwritten during LISA redeployments. To preserve customizations: 1. **Document Changes**: Keep records of custom configurations 2. **Copy Source**: From the CloudWatch dashboard select **Actions** and **View/edit source**. This will show you the source code of the dashboard which can be copied and reused. 3. **Create Copy**: CloudWatch dashboard select **Actions** and **Save dashboard as**. Provide a unique name and a copy of the `LISA-Metrics` dashboard will be made. - ## Daily Metrics Management ### Automated Daily Processing + The system includes a dedicated Lambda function that runs daily to update metrics that change infrequently: - Unique User Counts - Group Membership Counts ### Manual Invocation + Administrators can manually trigger daily metrics updates when needed: #### AWS Console Method + 1. Navigate to **Lambda** in the AWS Management Console 2. Locate the **DailyMetricsLambda** function (prefixed with your deployment name) 3. Use the **Test** functionality to invoke the function (can be invoked with empty event JSON) 4. Monitor execution through CloudWatch Logs #### AWS CLI Method + ```bash aws —region {REGION} lambda invoke \ —function-name -DailyMetricsLambda \ @@ -161,7 +171,9 @@ aws —region {REGION} lambda invoke \ ``` ### Troubleshooting + If metrics appear inconsistent or incomplete: + 1. Check CloudWatch Logs for the metrics processing Lambda functions 2. Verify SQS queue processing for any backlogs 3. Manually invoke the daily metrics Lambda if aggregate numbers seem outdated diff --git a/lib/docs/config/collection-management-api.md b/lib/docs/config/collection-management-api.md index 63a3bc482..afe9977ee 100644 --- a/lib/docs/config/collection-management-api.md +++ b/lib/docs/config/collection-management-api.md @@ -5,6 +5,7 @@ The Collection Management API provides endpoints for creating, reading, updating ## Base URL Structure All collection endpoints are accessed through LISA's main API Gateway with the following structure: + ``` https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/{repositoryId}/collection ``` @@ -22,6 +23,7 @@ Create a new collection within a vector store. **Endpoint:** `POST /repository/{repositoryId}/collection` **Path Parameters:** + - `repositoryId` (string, required): The parent vector store repository ID **Request Body:** @@ -67,6 +69,7 @@ Create a new collection within a vector store. **Chunking Strategy Types:** 1. **FIXED**: Fixed-size chunks with overlap + ```json { "type": "fixed", @@ -78,6 +81,7 @@ Create a new collection within a vector store. ``` 2. **SEMANTIC**: Semantic-based chunking + ```json { "type": "SEMANTIC", @@ -88,6 +92,7 @@ Create a new collection within a vector store. ``` 3. **RECURSIVE**: Recursive text splitting with custom separators + ```json { "type": "RECURSIVE", @@ -253,6 +258,7 @@ Retrieve a collection by ID within a vector store. **Endpoint:** `GET /repository/{repositoryId}/collection/{collectionId}` **Path Parameters:** + - `repositoryId` (string, required): The parent vector store repository ID - `collectionId` (string, required): The collection ID (UUID) @@ -358,6 +364,7 @@ Update a collection's configuration within a vector store. Supports partial upda **Endpoint:** `PUT /repository/{repositoryId}/collection/{collectionId}` **Path Parameters:** + - `repositoryId` (string, required): The parent vector store repository ID - `collectionId` (string, required): The collection ID (UUID) @@ -406,6 +413,7 @@ All fields are optional. Only include fields you want to update. **Immutable Fields:** The following fields cannot be modified after creation and will be ignored if included in the request: + - `collectionId` - `repositoryId` - `embeddingModel` @@ -587,6 +595,7 @@ When updating the chunking strategy on a collection that already has documents, ``` This warning indicates that: + 1. New documents uploaded after the change will use the new chunking strategy 2. Existing documents will keep their original chunking 3. You may want to re-ingest existing documents to apply the new strategy @@ -598,6 +607,7 @@ Delete a collection within a vector store. This operation requires admin access **Endpoint:** `DELETE /repository/{repositoryId}/collection/{collectionId}` **Path Parameters:** + - `repositoryId` (string, required): The parent vector store repository ID - `collectionId` (string, required): The collection ID (UUID) @@ -795,6 +805,7 @@ List collections in a repository with pagination, filtering, and sorting. **Endpoint:** `GET /repository/{repositoryId}/collections` **Path Parameters:** + - `repositoryId` (string, required): The parent vector store repository ID **Query Parameters:** @@ -1134,36 +1145,45 @@ getAllCollections('repo-123') **Filtering Examples:** 1. **Filter by name/description:** + ``` GET /repository/repo-123/collections?filter=legal ``` + Returns collections with "legal" in name or description 2. **Filter by status:** + ``` GET /repository/repo-123/collections?status=ACTIVE ``` + Returns only active collections 3. **Combined filters:** + ``` GET /repository/repo-123/collections?filter=legal&status=ACTIVE ``` + Returns active collections with "legal" in name or description **Sorting Examples:** 1. **Sort by name (ascending):** + ``` GET /repository/repo-123/collections?sortBy=name&sortOrder=asc ``` 2. **Sort by creation date (newest first):** + ``` GET /repository/repo-123/collections?sortBy=createdAt&sortOrder=desc ``` 3. **Sort by last update (oldest first):** + ``` GET /repository/repo-123/collections?sortBy=updatedAt&sortOrder=asc ``` @@ -1187,21 +1207,25 @@ Collections inherit configuration from their parent vector store: ## Validation Rules ### Collection Name + - Required for creation - Maximum 100 characters - Must be unique within repository - Allowed characters: alphanumeric, spaces, hyphens, underscores ### Allowed Groups + - Must be subset of parent repository's allowed groups - Empty array inherits from parent ### Chunking Strategy Parameters + - `chunkSize`: 100-10000 - `chunkOverlap`: 0 to chunkSize/2 - `separators`: non-empty array for RECURSIVE strategy ### Metadata Tags + - Maximum 50 tags per collection - Each tag maximum 50 characters - Allowed characters: alphanumeric, hyphens, underscores @@ -1209,11 +1233,13 @@ Collections inherit configuration from their parent vector store: ## Access Control ### Permission Levels + - **Read**: View collection configuration, query documents - **Write**: Upload documents, update collection metadata - **Admin**: Delete collection, modify access control ### Access Rules + 1. Admin users have full access to all collections across all repositories 2. RAG Admin users can create, update, and delete collections on repositories they have group access to; they cannot modify `allowedGroups` or repository-level settings 3. Non-admin users must have group membership intersection with collection's allowed groups @@ -1237,20 +1263,24 @@ Collections inherit configuration from their parent vector store: ### Common Errors **"Collection name must be unique within repository"** + - Solution: Choose a different name or check existing collections **"User does not have write access to repository"** + - Solution: Ensure user is in one of the repository's allowed groups or is an admin **"Allowed groups must be subset of parent repository groups"** + - Solution: Only specify groups that exist in the parent repository's allowed groups **"Chunk size must be between 100 and 10000"** + - Solution: Adjust chunk size to be within the valid range **"Cannot create collection: allowUserCollections is false"** -- Solution: Contact an administrator to enable user collections or have an admin create the collection +- Solution: Contact an administrator to enable user collections or have an admin create the collection ### List User Collections (Cross-Repository) @@ -1373,6 +1403,7 @@ The endpoint automatically selects an appropriate pagination strategy based on d The pagination token is a JSON string with two possible formats: **V1 Token (Simple Strategy):** + ```json { "version": "v1", @@ -1386,6 +1417,7 @@ The pagination token is a JSON string with two possible formats: ``` **V2 Token (Scalable Strategy):** + ```json { "version": "v2", diff --git a/lib/docs/config/guardrails.md b/lib/docs/config/guardrails.md index 74246bb06..589d80020 100644 --- a/lib/docs/config/guardrails.md +++ b/lib/docs/config/guardrails.md @@ -124,6 +124,7 @@ POST /{deploymentStage}/models ``` **Notes:** + - The guardrail key (e.g., "guardrail-1") is an internal identifier - `guardrailIdentifier` must match an existing AWS Bedrock Guardrail - Empty `allowedGroups` means the guardrail applies to all users @@ -188,6 +189,7 @@ PUT /{deploymentStage}/models/{modelId} ``` This operation: + 1. Removes the guardrail from LiteLLM 2. Deletes all associated guardrails from DynamoDB 3. Removes guardrail configurations from LiteLLM @@ -212,7 +214,7 @@ This operation: - **Allowed Groups** (optional): Add group names that should have this guardrail 7. Click **Add** to add groups, or press Enter after typing a group name 8. Repeat steps 6-8 to add multiple guardrails -9. Finish remaining configuration steps +9. Finish remaining configuration steps 10. Click **Create Model** to finalize ### Viewing Guardrails @@ -263,6 +265,7 @@ This operation: **Symptom**: Requests are not being filtered as expected **Possible Causes**: + 1. Guardrail doesn't exist in AWS Bedrock 2. Guardrail identifier is incorrect 3. Guardrail is not attached to model @@ -271,6 +274,7 @@ This operation: 5. AWS Bedrock Guardrail is not accessible from LISA VPC **Resolution**: + 1. Verify guardrail exists in AWS Bedrock Console 2. Check guardrail identifier configured in LISA matches AWS Bedrock Console 3. Verify user group memberships @@ -283,11 +287,13 @@ This operation: **Symptom**: Updated guardrail configuration not being applied **Possible Causes**: + 1. Model update did not complete successfully 2. Guardrail changes made in AWS Bedrock but version not updated in LISA 3. Cache issues with model configuration **Resolution**: + 1. Check model status (should be "In Service") 2. Verify `guardrailVersion` in LISA matches the version in AWS Bedrock 3. Check state machine execution logs @@ -298,11 +304,13 @@ This operation: **Symptom**: Error during model creation or update mentioning invalid guardrail **Possible Causes**: + 1. Guardrail doesn't exist in AWS Bedrock 2. Incorrect guardrail ID or ARN 3. Guardrail in different AWS region **Resolution**: + 1. Verify guardrail exists in AWS Bedrock Console in the correct region 2. Copy guardrail identifier directly from AWS Console @@ -311,10 +319,12 @@ This operation: **Symptom**: Requests take significantly longer with guardrails enabled **Possible Causes**: + 1. Too many guardrails configured 2. Complex guardrail rules in AWS Bedrock **Resolution**: + 1. Reduce number of guardrails where possible 2. Optimize guardrail rules in AWS Bedrock Console 3. Consider using only critical guardrails for performance-sensitive applications diff --git a/lib/docs/config/hosted-mcp.md b/lib/docs/config/hosted-mcp.md index d5003843a..bf8b9053f 100644 --- a/lib/docs/config/hosted-mcp.md +++ b/lib/docs/config/hosted-mcp.md @@ -223,31 +223,37 @@ Response (truncated): ## Troubleshooting ### Create API returns *“CREATE_MCP_SERVER_SFN_ARN not configured”* + - **Cause:** Environment variables were not set when the MCP API Lambda was deployed. - **Resolution:** Re-run `deploylisa` or manually set `CREATE_MCP_SERVER_SFN_ARN`, `DELETE_MCP_SERVER_SFN_ARN`, and `UPDATE_MCP_SERVER_SFN_ARN` on the MCP API Lambda, then retry. ### Error *“Server name conflicts with existing server”* + - **Cause:** Another record normalizes to the same alphanumeric identifier (e.g., `Docs-MCP` vs `docs_mcp`). - **Resolution:** Choose a different name or delete the prior server before re-creating it. ### Stack stuck in `CREATING` + - **Cause:** CloudFormation deployment failed (missing IAM roles, invalid container image, unreachable S3 path). - **Resolution:** Inspect the `CreateMcpServer` Step Functions execution, then open the CloudFormation stack events to identify the failing resource. Fix the underlying issue and re-run the create workflow. ### Hosted server is `IN_SERVICE` but unreachable + - **Cause:** Incorrect `port`, health check, or security group settings. - **Resolution:** Verify the ALB target group health, container logs, and that the application is listening on the expected port. For STDIO servers, ensure the `startCommand` launches an MCP-compatible process that `mcp-proxy` can wrap. ### Bearer token placeholder not replaced + - **Cause:** Custom headers still show `{LISA_BEARER_TOKEN}`. - **Resolution:** The placeholder is replaced at connection time. Make sure the consuming application sends an `Authorization` header when invoking the MCP connection. The API automatically replaces the placeholder right before returning connection details. ### Update API rejects payload + - **Cause:** The `UpdateHostedMcpServerRequest` validator requires at least one field; it also blocks simultaneous enable/disable and auto scaling changes. - **Resolution:** Split enable/disable operations from scaling updates, and include only the fields you intend to change. diff --git a/lib/docs/config/langfuse-tracing.md b/lib/docs/config/langfuse-tracing.md index 31cde6a5e..aa77b37c3 100644 --- a/lib/docs/config/langfuse-tracing.md +++ b/lib/docs/config/langfuse-tracing.md @@ -117,6 +117,7 @@ Access the Langfuse tracing interface to view collected traces. Non-streamed responses generate traces with the following structure: **Input:** + ```json { "args": [], @@ -128,6 +129,7 @@ Non-streamed responses generate traces with the following structure: ``` **Output:** + ```json { "status_code": 200, @@ -176,6 +178,7 @@ Non-streamed responses generate traces with the following structure: Streamed responses maintain identical input structure to non-streamed responses but the output differs. The default trace output for streamed responses are: + ```xml ``` @@ -207,10 +210,12 @@ For enhanced troubleshooting and integration support, Langfuse provides: ### Reference Documentation **LiteLLM Integration:** + - [Langfuse Logging with LiteLLM](https://docs.litellm.ai/docs/proxy/logging#langfuse) - [OpenTelemetry Integration with LiteLLM Proxy](https://litellm.vercel.app/docs/observability/langfuse_otel_integration#with-litellm-proxy) **Langfuse Documentation:** + - [LiteLLM Proxy with @observe Decorator](https://langfuse.com/guides/cookbook/integration_litellm_proxy) - [LiteLLM SDK Integration Guide](https://langfuse.com/integrations/frameworks/litellm-sdk) - [Python SDK Documentation](https://python.reference.langfuse.com/langfuse) diff --git a/lib/docs/config/mcp.md b/lib/docs/config/mcp.md index 7b811053c..08d6dd3ae 100644 --- a/lib/docs/config/mcp.md +++ b/lib/docs/config/mcp.md @@ -13,23 +13,27 @@ tools and perform the necessary steps to complete the task. 2. Optionally, enable **AWS Sessions** to allow users to connect AWS credentials per chat session for use by MCP tools that support it. See [AWS Sessions](#aws-sessions) below. 3. An Administrator must add an LLM to support LISA’s MCP tools capability. During model creation in Model Management, toggle **Tool Calls** to be active. - **Add MCP Connections** Administrators and non-admins can configure MCP server connects with LISA. Non-admins can add MCP server connections for personal use only. However, Administrator can create or update MCP server connections to be available for all LISA users. The steps below demonstrate how to configure LISA with an externally hosted MCP server on the [Pipedream](https://mcp.pipedream.com/) platform. To use Pipedream MCP servers, you must create an account. This will give you access to MCP server information to use in the below steps. 1. As an Administrator, under the **Libraries** menu select **MCP Connections**. 2. Add at least one MCP server to LISA by clicking on **Create** **MCP Connection** button. + > **TIP:** > > Pipedream’s GitHub MCP server is straightforward to implement with LISA and does not require sharing GitHub credentials with Pipedream. With the GitHub MCP server, you can prompt the LLM about LISA or MLSpace GitHub repos hosted by AWS Labs. The server will then use its tools to find the repo in GitHub, and gather information. This information will be used by the LLM to answer your prompts. + 3. Fill out the required fields to create the MCP server connection: **Name**, **Description**, **URL**. 4. Confirm that this new MCP server connection is **Active**. To activate, toggle `Active` option to be active. This is an Admin only setting that determines if an MCP server connection is generally available for users to opt into using. + > **TIP:** > > LISA’s MCP server connections can either be private to a single user, or available to all users. Private MCP server connections created by non-admins are automatically set to `Active` and only available to that user. Admins can choose to deactivate a user’s private MCP server connection. + 5. Click **Create Connection.** 6. Click on the new MCP server connection to view all the tools available on that server. + > **TIP 1:** > > MCP server tools are opt out by default. Meaning that an MCP server’s tools are all activated for use in LISA when a new server connection is created. Each user can then choose to personally deactivate specific tools by toggling **Use Tool** to be inactive. This is a user specific setting. @@ -37,13 +41,15 @@ Administrators and non-admins can configure MCP server connects with LISA. Non-a > **TIP 2:** > > While individual tools may be active, users cannot access them unless they have opted in to use the parent MCP server connection. This is a user specific setting. See Step 8 below. + 7. Use the breadcrumbs at the top of the page to return to the **MCP Connection** list. 8. To use an MCP sever, each user must opt into the MCP server connection in **MCP Connections**. Non-admins can only opt into `Active` MCP server connections. They will not see `Inactive` MCP servers. To opt into using an MCP server connection, set the **Use Server** toggle to active in the **MCP Connections** list. + > **TIP:** > > MCP server connections are opt in by default. Each user must individually opt in. This is a user specific setting. -9. To confirm configuration, start a new chat session. At the bottom of the page the count of active MCP servers and tools available to this user will display. +9. To confirm configuration, start a new chat session. At the bottom of the page the count of active MCP servers and tools available to this user will display. **Using MCP tools with LISA** @@ -52,7 +58,6 @@ Administrators and non-admins can configure MCP server connects with LISA. Non-a 3. Prompt the model to do something supported by the MCP server’s tools configured above. 4. By default, you will be asked to confirm the execution of each tool’s planned external action. See LISA’s Autopilot Mode below for more information. You can also choose to opt out of future confirmation prompts each tool takes. - **MCP Settings: Autopilot Mode, Edit, and Delete** *Autopilot Mode* diff --git a/lib/docs/config/model-management-api.md b/lib/docs/config/model-management-api.md index 26e75142a..c90517ea1 100644 --- a/lib/docs/config/model-management-api.md +++ b/lib/docs/config/model-management-api.md @@ -6,13 +6,13 @@ This API is only accessible by administrators via the API Gateway and is used to The `/models` route allows admins to list all models managed by the system. This includes models that are either creating, deleting, already active, or in a failed state. Models can be deployed via ECS or managed externally through a LiteLLM configuration. -### Request Example: +### Request Example ```bash curl -s -H "Authorization: Bearer " -X GET https:///models ``` -### Response Example: +### Response Example ```json { @@ -85,7 +85,7 @@ curl -s -H "Authorization: Bearer " -X GET https:// } ``` -### Explanation of Response Fields: +### Explanation of Response Fields - `modelId`: A unique identifier for the model. - `modelName`: The name of the model, typically referencing the underlying service (Bedrock, SageMaker, etc.). @@ -100,13 +100,13 @@ LISA provides the `/models` endpoint for creating LISA-hosted ECS models and ext This API accepts the same model definition parameters that were accepted in the V2 model definitions within the config.yaml file with one notable difference: the `containerConfig.image.path` field is now omitted because it corresponded with the `inferenceContainer` selection. As a convenience, this path is no longer required. -### Request Example: +### Request Example ``` POST https:///models ``` -### Example Payload for ECS Model: +### Example Payload for ECS Model ```json { @@ -159,7 +159,7 @@ POST https:///models } ``` -### Creating a LiteLLM-Only Model: +### Creating a LiteLLM-Only Model ```json { @@ -170,7 +170,7 @@ POST https:///models } ``` -### Creating a Customer Internal Hosted Model: +### Creating a Customer Internal Hosted Model ```json { @@ -183,16 +183,16 @@ POST https:///models } ``` -### Explanation of Key Fields for Creation Payload: +### Explanation of Key Fields for Creation Payload - `modelId`: The unique identifier for the model. This is any name you would like it to be. - `modelName`: The name of the model as it appears in the system. For LISA self-hosted models, this must be the S3 Key to your model artifacts, otherwise this is the LiteLLM-compatible reference to a SageMaker Endpoint or Bedrock Foundation Model. Note: Bedrock and SageMaker resources must exist in the same region as your LISA deployment. If your LISA installation is in us-east-1, then all SageMaker and Bedrock calls will also happen in us-east-1. Configuration examples: - - LISA hosting: If your model artifacts are in `s3://${lisa_models_bucket}/path/to/model/weights`, then the `modelName` value here should be `path/to/model/weights` - - LiteLLM-only, Bedrock: If you want to use `amazon.titan-text-lite-v1`, your `modelName` value should be `bedrock/amazon.titan-text-lite-v1` - - LiteLLM-only, SageMaker: If you want to use a SageMaker Endpoint named `my-sm-endpoint`, then the `modelName` value should be `sagemaker/my-sm-endpoint`. + - LISA hosting: If your model artifacts are in `s3://${lisa_models_bucket}/path/to/model/weights`, then the `modelName` value here should be `path/to/model/weights` + - LiteLLM-only, Bedrock: If you want to use `amazon.titan-text-lite-v1`, your `modelName` value should be `bedrock/amazon.titan-text-lite-v1` + - LiteLLM-only, SageMaker: If you want to use a SageMaker Endpoint named `my-sm-endpoint`, then the `modelName` value should be `sagemaker/my-sm-endpoint`. - `modelType`: The type of model, such as text generation (textgen). - `streaming`: Whether the model supports streaming inference. - `hostingType`: Optional hosting selector. Use `INTERNAL_HOSTED` for customer internal load balancer endpoints. @@ -206,13 +206,13 @@ POST https:///models Admins can delete a model using the following endpoint. Deleting a model removes the infrastructure (ECS) or disconnects from LiteLLM. -### Request Example: +### Request Example ``` DELETE https:///models/{modelId} ``` -### Response Example: +### Response Example ```json { @@ -328,13 +328,13 @@ LISA supports two scheduling types. One type may be applied to each self-hosted Create or update a schedule for a specific model. This endpoint accepts the same payload for both creating new schedules and updating existing ones. -##### Request Example: +##### Request Example ``` PUT https:///models/{modelId}/schedule ``` -##### Example Payload for Daily Schedule: +##### Example Payload for Daily Schedule ```json { @@ -361,7 +361,7 @@ PUT https:///models/{modelId}/schedule } ``` -##### Example Payload for Recurring Schedule: +##### Example Payload for Recurring Schedule ```json { @@ -374,7 +374,7 @@ PUT https:///models/{modelId}/schedule } ``` -##### Response Example: +##### Response Example ```json { @@ -384,7 +384,7 @@ PUT https:///models/{modelId}/schedule } ``` -##### Key Fields for Schedule Configuration: +##### Key Fields for Schedule Configuration - `scheduleType`: Either "DAILY" or "RECURRING" - `timezone`: IANA timezone identifier (e.g., "UTC", "America/New_York", "Europe/London") @@ -400,13 +400,13 @@ PUT https:///models/{modelId}/schedule Retrieve the current schedule configuration for a model. -##### Request Example: +##### Request Example ``` GET https:///models/{modelId}/schedule ``` -##### Response Example: +##### Response Example ```json { @@ -439,13 +439,13 @@ GET https:///models/{modelId}/schedule Get detailed status information about a model's scheduling configuration and current state. -##### Request Example: +##### Request Example ``` GET https:///models/{modelId}/schedule/status ``` -##### Response Example: +##### Response Example ```json { @@ -465,7 +465,7 @@ GET https:///models/{modelId}/schedule/status } ``` -##### Schedule Status Fields: +##### Schedule Status Fields - `scheduleEnabled`: Whether scheduling is currently active for the model - `scheduleConfigured`: Whether a schedule has been configured for the model @@ -481,13 +481,13 @@ GET https:///models/{modelId}/schedule/status Remove the schedule configuration for a model, disabling automatic start/stop functionality. -##### Request Example: +##### Request Example ``` DELETE https:///models/{modelId}/schedule ``` -##### Response Example: +##### Response Example ```json { diff --git a/lib/docs/config/model-management-ui.md b/lib/docs/config/model-management-ui.md index ba5888834..f450d1daa 100644 --- a/lib/docs/config/model-management-ui.md +++ b/lib/docs/config/model-management-ui.md @@ -3,13 +3,13 @@ ## Configuring Models LISA's Model Management UI allows Administrators to configure models for use with LISA. LISA supports: + - third-party models hosted externally to LISA that are compatible with LiteLLM, - customer internal hosted models exposed by an internal AWS load balancer URL, and - self-hosted models running on LISA-managed Amazon ECS infrastructure. LISA's Model Management wizard walks Administrators through configuration steps. - ## Scaling Models ### Overview @@ -60,23 +60,27 @@ LISA supports two ALB-based scaling metrics. The metric determines what signal t ### Scaling Recommendations #### Text Generation Models + - Use `TargetResponseTime` as the scaling metric - Set `targetValue` to your acceptable p90 latency (e.g., 10 seconds) - These models are compute-intensive with longer inference times, so latency-based scaling reacts to actual user impact - Consider a higher `defaultInstanceWarmup` (300–3600s) since large models take time to load #### Embedding Models + - Use `RequestCountPerTarget` as the scaling metric - Embedding requests are typically fast and uniform, so request volume is a reliable scaling signal - A lower `targetValue` (e.g., 20–50) keeps latency consistent under load #### Cost Optimization + - Set `minCapacity` to 0 or 1 depending on whether you need always-on availability - Use [Model Scheduling](#model-scheduling) to automatically stop models during off-hours - Monitor actual utilization through CloudWatch to right-size `maxCapacity` - Increase `cooldown` to avoid unnecessary scaling churn during bursty traffic #### High Availability + - Set `minCapacity` to at least 2 for production workloads to survive a single instance failure - Ensure `maxCapacity` provides enough headroom for peak traffic - Keep `defaultInstanceWarmup` accurate to avoid premature traffic routing to cold instances @@ -140,33 +144,40 @@ Updated models automatically become available in the Chat UI once updates comple ### Update Process Flow #### 1. Validation Phase + The system validates update requests against current model state: + - Ensures model is in `InService` or `Stopped` state - Validates configuration conflicts - Checks capacity constraints against existing auto-scaling groups - Verifies container configuration compatibility #### 2. State Machine Orchestration + Updates are processed through a multi-step state machine: **Step 1 - Job Intake**: + - Processes update payload - Determines required update types - Sets model status to `Updating` - Prepares infrastructure changes **Step 2 - ECS Updates** (if needed): + - Creates new task definition with updated container config - Updates ECS service - Monitors deployment progress - Handles rollback on failures **Step 3 - Capacity Updates** (if needed): + - Updates auto-scaling group parameters - Monitors instance health and availability - Waits for capacity stabilization **Step 4 - Finalization**: + - Updates model metadata in database - Restores model to `InService` status - Registers model with inference endpoint if needed @@ -174,15 +185,18 @@ Updates are processed through a multi-step state machine: #### 3. Safety Mechanisms **State Validation**: + - Models cannot be updated during transitional states - Updates requiring a container restart require explicit acknowledgment **Rollback Protection**: + - Failed deployments automatically scale down to prevent resource waste - ECS updates include deployment monitoring with timeout protection - Database state is preserved during failures **Resource Limits**: + - Polling timeouts prevent infinite waiting - Capacity changes validate against AWS account limits - Container updates respect ECS service constraints @@ -190,6 +204,7 @@ Updates are processed through a multi-step state machine: ### Performing Model Updates #### Prerequisites + - Administrator access to LISA Model Management - Target model is in `InService` or `Stopped` state - Understanding of update impact (restart requirements) @@ -219,16 +234,19 @@ Updates are processed through a multi-step state machine: #### Common Update Failures **Validation Errors**: + - Model in wrong state for updates - Configuration conflicts (e.g., min > max capacity) - Invalid container configurations **Deployment Issues**: + - ECS deployment timeouts - Health check failures - Resource constraints **Capacity Problems**: + - Auto-scaling group update failures - Instance launch issues - Load balancer target group problems @@ -256,6 +274,7 @@ LISA supports two scheduling types. One type may be applied to each self-hosted ### Schedule Configuration #### Prerequisites + - Administrator access to LISA Model Management - Target model must be a LISA-hosted model - Model must be in `InService` or `Stopped` state @@ -299,11 +318,13 @@ LISA supports two scheduling types. One type may be applied to each self-hosted #### Schedule Configuration Examples **Daily Schedule Example**: + - Monday-Friday: 09:00 to 17:00 (business hours) - Saturday: 10:00 to 14:00 (reduced hours) - Sunday: No schedule (model is in `Stopped` state) **Recurring Schedule Example**: + - Every day: 08:00 to 20:00 - Applies consistently across all days of the week @@ -355,18 +376,21 @@ Model Management displays schedule information: ### Schedule Behavior and Rules #### Time Format Requirements + - All times must be in 24-hour format (HH:MM) - Valid range: 00:00 to 23:59 - Start time must be before stop time within the same day - Stop time must be at least 2 hours after start time #### Schedule Execution + - **Automatic Actions**: Models are automatically started and stopped according to configured schedules - **Immediate Effect**: Schedule updates take effect immediately and recalculate next actions - **Manual Override**: Manual start/stop operations work independently and don't affect scheduling - **State Preservation**: Unscheduled days maintain the model's current state #### Timezone Handling + - Schedules respect the configured timezone for all calculations - Supports all IANA timezone identifiers (e.g., "UTC", "America/New_York", "Europe/London") - Automatically handles daylight saving time transitions @@ -385,11 +409,13 @@ The UI provides several indicators for schedule health: #### Common Scheduling Issues **Schedule Configuration Errors**: + - Invalid timezone selection - Stop time less than 2 hours after start time - No days configured for daily schedules **Schedule Execution Failures**: + - Model in invalid state during scheduled action - AWS service limits preventing scaling operations - Network connectivity issues affecting schedule execution @@ -405,18 +431,21 @@ The UI provides several indicators for schedule health: ### Best Practices #### Cost Optimization + - Configure schedules to match actual usage patterns - Use daily scheduling for models with varying weekday/weekend usage - Consider time zone alignment with primary user base - Monitor actual vs. scheduled usage to refine schedules #### Operational Considerations + - Allow sufficient warmup time after scheduled starts before peak usage - Coordinate scheduled actions with maintenance windows - Test schedule configurations in non-production environments first - Document schedule configurations for operational handoff #### Schedule Design + - Use meaningful time buffers like a minimum of 2-hours between starting and stopping within a single day - For operations spanning midnight, split schedules across consecutive days (e.g., Monday 21:00-23:59, Tuesday 00:00-03:00) - Plan for holiday and special event schedule modifications diff --git a/lib/docs/config/projects.md b/lib/docs/config/projects.md index d08e18b4f..d0353be1c 100644 --- a/lib/docs/config/projects.md +++ b/lib/docs/config/projects.md @@ -27,6 +27,7 @@ The `maxProjectsPerUser` limit can be adjusted based on your organization's need ### Switching Between Views When the Projects feature is activated, a toggle appears in the chat sidebar with two options: + - **History** — The default chronological view of all sessions - **Projects** — A folder-based view showing Projects and their assigned sessions diff --git a/lib/docs/config/rag-evaluation.md b/lib/docs/config/rag-evaluation.md index e508615c7..7b4ea59eb 100644 --- a/lib/docs/config/rag-evaluation.md +++ b/lib/docs/config/rag-evaluation.md @@ -21,6 +21,7 @@ The RAG evaluation suite measures how well your RAG system retrieves relevant do - DynamoDB (for token registration) - Bedrock (if evaluating Bedrock KB) 3. **Python Environment** with LISA SDK installed: + ```bash source .venv/bin/activate # Activate LISA venv ``` @@ -28,6 +29,7 @@ The RAG evaluation suite measures how well your RAG system retrieves relevant do ### Setup 1. **Create your config file:** + ```bash cd test/integration/rag/eval_datasets cp eval_config.example.yaml eval_config.yaml @@ -41,6 +43,7 @@ The RAG evaluation suite measures how well your RAG system retrieves relevant do - Repository and collection IDs 3. **Create your golden dataset:** + ```bash cp golden-dataset.example.jsonl golden-dataset.jsonl ``` @@ -57,6 +60,7 @@ python -m lisapy.evaluation \ ``` **With verbose logging:** + ```bash python -m lisapy.evaluation \ --config test/integration/rag/eval_datasets/eval_config.yaml \ @@ -118,6 +122,7 @@ backends: You can evaluate just one backend by configuring only that section: **OpenSearch Only:** + ```yaml backends: lisa_api: @@ -126,6 +131,7 @@ backends: ``` **Bedrock KB Only:** + ```yaml backends: bedrock_kb: @@ -217,6 +223,7 @@ Authentication uses **AWS Secrets Manager** for management keys: 3. Authenticated requests use both `Api-Key` and `Authorization` headers **Required IAM Permissions:** + ```json { "Version": "2012-10-17", @@ -244,6 +251,7 @@ Authentication uses **AWS Secrets Manager** for management keys: Uses standard AWS SDK authentication (boto3 default credential chain). **Required IAM Permissions:** + ```json { "Effect": "Allow", @@ -259,18 +267,21 @@ Uses standard AWS SDK authentication (boto3 default credential chain). ### Metrics Explained **Precision@k:** + - Measures: What fraction of retrieved documents are relevant? - Formula: (Relevant Retrieved) / k - Range: 0.0 to 1.0 (higher is better) - Example: If k=5 and 3 retrieved docs are relevant → Precision@5 = 0.6 **Recall@k:** + - Measures: What fraction of relevant documents were retrieved? - Formula: (Relevant Retrieved) / (Total Relevant) - Range: 0.0 to 1.0 (higher is better) - Example: If 3 relevant docs exist and 2 were retrieved → Recall = 0.67 **NDCG@k (Normalized Discounted Cumulative Gain):** + - Measures: Ranking quality (relevant docs should rank higher) - Penalizes relevant documents that appear lower in results - Range: 0.0 to 1.0 (higher is better) @@ -328,31 +339,40 @@ When evaluating multiple backends: ### Config Errors **Error:** `FileNotFoundError` + - **Fix:** Use absolute paths or run from repo root **Error:** `ValidationError: region field required` + - **Fix:** Ensure `region:` is set in your config file **Error:** `ValidationError: documents field required` + - **Fix:** Must define at least one document in `documents:` section ### Runtime Errors **Error:** `Repository not found` + - **Fix:** Verify `repo_id` matches an existing repository. List repos: + ```bash curl -H "Authorization: YOUR_TOKEN" \ https://YOUR-API-URL/repository ``` **Error:** `Bedrock knowledge base not found` + - **Fix:** Verify `knowledge_base_id` is correct. List KBs: + ```bash aws bedrock-agent list-knowledge-bases ``` **Error:** `S3 object not found` + - **Fix:** Documents must exist at `{s3_bucket}/{filename}`. Verify with: + ```bash aws s3 ls s3://your-bucket/ --recursive ``` diff --git a/lib/docs/config/repositories.md b/lib/docs/config/repositories.md index fd8c2b832..0066c36fc 100644 --- a/lib/docs/config/repositories.md +++ b/lib/docs/config/repositories.md @@ -254,6 +254,7 @@ curl -s -H 'Authorization: Bearer ' \ Administrators and RAG Admins access repository management through the Administration menu. The capabilities available depend on the user's role: **Administrators** have full access, including: + - Create, update, and delete repositories - Configure vector store implementation (OpenSearch, PGVector, Bedrock Knowledge Base) - Set default embedding models and chunking strategies @@ -262,6 +263,7 @@ Administrators and RAG Admins access repository management through the Administr - Enable or disable user-created collections **RAG Admins** have scoped access on repositories they belong to via group membership: + - Create, update, and delete collections - Update ingestion pipelines - Cannot create or delete repositories, or modify `allowedGroups` diff --git a/lib/docs/config/security-group-overrides.md b/lib/docs/config/security-group-overrides.md index dab4d327f..6b270bab5 100644 --- a/lib/docs/config/security-group-overrides.md +++ b/lib/docs/config/security-group-overrides.md @@ -1,9 +1,11 @@ # Security Group Overrides User Guide ## Overview + This guide explains how to configure security group overrides in your environment. When you enable security group overrides, you'll need to override all security groups that are deployed across your stacks. ## Prerequisites + Before configuring security groups, you'll need the following values ready: - Private VPC ID `[PRIVATE_VPC_ID]` diff --git a/lib/docs/config/vllm_variables.md b/lib/docs/config/vllm_variables.md index a7fc016d2..d1445b0f7 100644 --- a/lib/docs/config/vllm_variables.md +++ b/lib/docs/config/vllm_variables.md @@ -1,6 +1,7 @@ # vLLM Environment Variables LISA Serve supports configuring vLLM model serving through environment variables. These variables allow you to control performance, memory usage, parallelization, and advanced features when deploying models with vLLM. + - **NOTE:** Standard vLLM environment variables are supported and passed directly into the VLLM container. [See vLLM's documentation](https://docs.vllm.ai/en/latest/configuration/env_vars/) - Review your ECS instance type's specifications to determine if the model you want LISA Serve to host has the proper VRAM/RAM capacity. Instances that have multiple GPUs may require the VLLM_TENSOR_PARALLEL_SIZE environment variable set to utilize all GPUs. diff --git a/lib/docs/user/chat.md b/lib/docs/user/chat.md index 6b329e9ec..7447a5bae 100644 --- a/lib/docs/user/chat.md +++ b/lib/docs/user/chat.md @@ -13,12 +13,14 @@ When enabled by an administrator, the Chat UI shows a **Chat Assistants** sectio The Document Summarization feature enables efficient document processing through LISA's non-RAG context functionality. Users can streamline their workflow via an intuitive modal interface that facilitates document upload, LLM selection, and customized summarization template configuration. The system generates comprehensive document summaries tailored to specific requirements. #### Core Components + - Document upload interface - Environment-specific LLM integration - Configurable summarization templates with customizable parameters - Context-preserving file processing #### Operational Workflow + 1. Initiate summarization from active chat session 2. Upload target document for processing 3. Select appropriate LLM based on requirements @@ -28,12 +30,14 @@ The Document Summarization feature enables efficient document processing through 7. Review generated summary in chat interface #### Key Benefits + - Efficient information extraction and processing - Flexible summarization parameters for diverse use cases - Intuitive user interface optimized for accessibility - Enhanced contextual accuracy through preserved document integrity #### Administrative Configuration + LLM availability within the summarization modal requires summarization flagging and proper model configuration during initial setup. Selected LLMs must meet minimum requirements for: - Context window capacity diff --git a/lib/docs/user/model-library.md b/lib/docs/user/model-library.md index f70d47ff8..8c630692e 100644 --- a/lib/docs/user/model-library.md +++ b/lib/docs/user/model-library.md @@ -9,6 +9,7 @@ LISA's Model Library allows non-Administrators to view details about the models ### Activating / Deactivating the Model Library Administrators can activate / deactivate the Model library. + 1. Select the `Administration` dropdown from the top navigation bar 2. Select `Configuration` 3. Under `Library Components` toggle `Show Model Library` diff --git a/lib/serve/ecs-model/embedding/instructor/src/entrypoint.sh b/lib/serve/ecs-model/embedding/instructor/src/entrypoint.sh old mode 100644 new mode 100755 diff --git a/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh b/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh old mode 100644 new mode 100755 diff --git a/lib/serve/ecs-model/metrics_publisher.py b/lib/serve/ecs-model/metrics_publisher.py old mode 100644 new mode 100755 index b08e52a8a..732d5b759 --- a/lib/serve/ecs-model/metrics_publisher.py +++ b/lib/serve/ecs-model/metrics_publisher.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -LISA Inference Metrics Publisher +"""LISA Inference Metrics Publisher. Background daemon that scrapes Prometheus metrics from inference engine endpoints (vLLM, TGI, TEI) and publishes them to CloudWatch. diff --git a/lib/serve/ecs-model/textgen/tgi/src/entrypoint.sh b/lib/serve/ecs-model/textgen/tgi/src/entrypoint.sh old mode 100644 new mode 100755 diff --git a/lib/serve/ecs-model/vllm/src/entrypoint.sh b/lib/serve/ecs-model/vllm/src/entrypoint.sh old mode 100644 new mode 100755 diff --git a/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py b/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py index 4158d4753..2c03a8588 100644 --- a/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py +++ b/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py @@ -14,9 +14,8 @@ """Generic AWS API access via boto3 using the MCP workbench AWS session. -This sample exposes one tool that can call any boto3 client method (service + -operation + parameters). That matches IAM permissions of the connected -credentials. For production, consider restricting allowed services or operations. +This sample exposes one tool that can call any boto3 client method (service + operation + parameters). That matches IAM +permissions of the connected credentials. For production, consider restricting allowed services or operations. """ from __future__ import annotations diff --git a/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py b/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py index 7780c7f5c..28cb4057b 100644 --- a/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py +++ b/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -MCP Tool Creation Tutorial +"""MCP Tool Creation Tutorial. + ========================== This file demonstrates how to create MCP (Model Context Protocol) tools using two different approaches: @@ -36,8 +36,7 @@ class CalculatorTool(BaseTool): - """ - A simple calculator tool that performs basic arithmetic operations. + """A simple calculator tool that performs basic arithmetic operations. This class demonstrates the class-based approach to creating MCP tools: 1. Inherit from BaseTool @@ -47,8 +46,7 @@ class CalculatorTool(BaseTool): """ def __init__(self) -> None: - """ - Initialize the tool with metadata. + """Initialize the tool with metadata. The BaseTool constructor requires: - name: A unique identifier for the tool @@ -59,11 +57,10 @@ def __init__(self) -> None: ) async def execute(self) -> Callable: - """ - Return the callable function that implements the tool's functionality. + """Return the callable function that implements the tool's functionality. - This method is called by the MCP framework to get the actual function - that will be executed when the tool is invoked. + This method is called by the MCP framework to get the actual function that will be executed when the tool is + invoked. """ return self.calculate @@ -73,8 +70,7 @@ async def calculate( left_operand: Annotated[float, "The first number"], right_operand: Annotated[float, "The second number"], ) -> dict[str, float | str]: - """ - Execute the calculator operation. + """Execute the calculator operation. Parameter Type Annotations with Context: ======================================= @@ -115,100 +111,97 @@ async def calculate( # ============================================================================= # This is a simpler approach for straightforward tools that don't need # complex initialization or state management. - -""" -Here's how you would implement the same calculator using the @mcp_tool decorator: - -from mcpworkbench.core.decorators import mcp_tool -from typing import Annotated - -@mcp_tool( - name="simple_calculator", - description="A simple calculator using the decorator approach" -) -async def simple_calculator( - operator: Annotated[str, "The arithmetic operation: add, subtract, multiply, or divide"], - left_operand: Annotated[float, "The first number in the operation"], - right_operand: Annotated[float, "The second number in the operation"] -) -> dict: - ''' - Perform basic arithmetic operations using the decorator approach. - - The @mcp_tool decorator automatically: - 1. Registers the function as an MCP tool - 2. Extracts parameter information from type annotations - 3. Uses the Annotated descriptions for parameter documentation - 4. Handles the MCP protocol communication - - This approach is ideal for: - - Simple, stateless operations - - Quick prototyping - - Tools that don't need complex initialization - ''' - - if operator == "add": - result = left_operand + right_operand - elif operator == "subtract": - result = left_operand - right_operand - elif operator == "multiply": - result = left_operand * right_operand - elif operator == "divide": - if right_operand == 0: - raise ValueError("Cannot divide by zero") - result = left_operand / right_operand - else: - raise ValueError(f"Unknown operator: {operator}") - - return { - "operator": operator, - "left_operand": left_operand, - "right_operand": right_operand, - "result": result - } - +# +# Here's how you would implement the same calculator using the @mcp_tool decorator: +# +# from mcpworkbench.core.decorators import mcp_tool +# from typing import Annotated +# +# @mcp_tool( +# name="simple_calculator", +# description="A simple calculator using the decorator approach" +# ) +# async def simple_calculator( +# operator: Annotated[str, "The arithmetic operation: add, subtract, multiply, or divide"], +# left_operand: Annotated[float, "The first number in the operation"], +# right_operand: Annotated[float, "The second number in the operation"] +# ) -> dict: +# ''' +# Perform basic arithmetic operations using the decorator approach. +# +# The @mcp_tool decorator automatically: +# 1. Registers the function as an MCP tool +# 2. Extracts parameter information from type annotations +# 3. Uses the Annotated descriptions for parameter documentation +# 4. Handles the MCP protocol communication +# +# This approach is ideal for: +# - Simple, stateless operations +# - Quick prototyping +# - Tools that don't need complex initialization +# ''' +# +# if operator == "add": +# result = left_operand + right_operand +# elif operator == "subtract": +# result = left_operand - right_operand +# elif operator == "multiply": +# result = left_operand * right_operand +# elif operator == "divide": +# if right_operand == 0: +# raise ValueError("Cannot divide by zero") +# result = left_operand / right_operand +# else: +# raise ValueError(f"Unknown operator: {operator}") +# +# return { +# "operator": operator, +# "left_operand": left_operand, +# "right_operand": right_operand, +# "result": result +# } +# # Additional examples of Annotated usage for different parameter types: - -@mcp_tool(name="file_processor", description="Process files with various options") -async def process_files( - file_paths: Annotated[list[str], "List of file paths to process"], - max_size: Annotated[int, "Maximum file size in bytes (default: 1MB)"] = 1024*1024, - format: Annotated[str, "Output format: 'json', 'csv', or 'txt'"] = "json", - recursive: Annotated[bool, "Whether to process subdirectories recursively"] = False -): - ''' - Example showing different parameter types with Annotated descriptions. - - Key points about Annotated: - - Works with any Python type: str, int, float, bool, list, dict, etc. - - The description should be clear and specific - - Can include examples, constraints, or default behavior - - Helps AI models understand how to use your tool correctly - ''' - pass -""" +# +# @mcp_tool(name="file_processor", description="Process files with various options") +# async def process_files( +# file_paths: Annotated[list[str], "List of file paths to process"], +# max_size: Annotated[int, "Maximum file size in bytes (default: 1MB)"] = 1024*1024, +# format: Annotated[str, "Output format: 'json', 'csv', or 'txt'"] = "json", +# recursive: Annotated[bool, "Whether to process subdirectories recursively"] = False +# ): +# ''' +# Example showing different parameter types with Annotated descriptions. +# +# Key points about Annotated: +# - Works with any Python type: str, int, float, bool, list, dict, etc. +# - The description should be clear and specific +# - Can include examples, constraints, or default behavior +# - Helps AI models understand how to use your tool correctly +# ''' +# pass # ============================================================================= # CHOOSING BETWEEN THE TWO APPROACHES # ============================================================================= -""" -When to use Class-based approach: -- Complex tools with multiple related functions -- Tools that need initialization or configuration -- Tools that maintain state between calls -- Tools that need to share resources or connections -- When you want to group related functionality together - -When to use @mcp_tool decorator: -- Simple, stateless operations -- Quick prototyping and testing -- Single-purpose tools -- When you want minimal boilerplate code -- For functional programming style - -Both approaches support: -- Async/await operations -- Type annotations with Annotated for parameter descriptions -- Error handling and validation -- Return value serialization -- Integration with the MCP protocol -""" +# +# When to use Class-based approach: +# - Complex tools with multiple related functions +# - Tools that need initialization or configuration +# - Tools that maintain state between calls +# - Tools that need to share resources or connections +# - When you want to group related functionality together +# +# When to use @mcp_tool decorator: +# - Simple, stateless operations +# - Quick prototyping and testing +# - Single-purpose tools +# - When you want minimal boilerplate code +# - For functional programming style +# +# Both approaches support: +# - Async/await operations +# - Type annotations with Annotated for parameter descriptions +# - Error handling and validation +# - Return value serialization +# - Integration with the MCP protocol diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/__init__.py b/lib/serve/mcp-workbench/src/mcpworkbench/__init__.py index 95f90f251..d3bd7e1bf 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/__init__.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -MCP Workbench - A dynamic host for Python files used as MCP tools. +"""MCP Workbench - A dynamic host for Python files used as MCP tools. Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0. diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py b/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py index 72c8be790..4fc338159 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py @@ -102,8 +102,7 @@ async def execute(self, arguments: dict[str, Any]) -> Any: def create_adapter(tool_info: ToolInfo) -> ToolAdapter: - """ - Create the appropriate adapter for a tool. + """Create the appropriate adapter for a tool. Args: tool_info: Information about the tool to create an adapter for diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py index 98eb262f8..937f95c18 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -AWS session management package for MCP Workbench. +"""AWS session management package for MCP Workbench. -This package contains helper types and utilities for managing short-lived -AWS session credentials on a per-(user, session) basis. +This package contains helper types and utilities for managing short-lived AWS session credentials on a per-(user, +session) basis. """ from .identity import CallerIdentity as CallerIdentity diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py index 81a6ce904..79e363361 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py @@ -32,12 +32,10 @@ def _get_identity_from_request(request: Request) -> tuple[str, str]: - """ - Extract (user_id, session_id) from the authenticated request. + """Extract (user_id, session_id) from the authenticated request. - user_id is derived from the JWT ``sub`` claim in the Authorization - header (already verified by OIDCHTTPBearer middleware). - session_id comes from the ``X-Session-Id`` header sent by the frontend. + user_id is derived from the JWT ``sub`` claim in the Authorization header (already verified by OIDCHTTPBearer + middleware). session_id comes from the ``X-Session-Id`` header sent by the frontend. """ # request.headers is case-insensitive; avoid converting to a plain dict hdrs = request.headers @@ -70,8 +68,7 @@ def _get_identity_from_request(request: Request) -> tuple[str, str]: @router.post("/connect", status_code=status.HTTP_200_OK) async def connect_aws(request: Request) -> dict[str, Any]: - """ - Accept AWS static credentials, validate them, and create a short-lived STS session. + """Accept AWS static credentials, validate them, and create a short-lived STS session. Request body: - accessKeyId: str diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py index e8ba5b81d..d1984c8ef 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py @@ -14,10 +14,9 @@ """Helpers for extracting caller identity inside MCP tool functions. -Tool functions call :func:`get_caller_identity` to obtain the current -:class:`CallerIdentity`. On first access within a request, the function -reads HTTP headers from the underlying MCP request context and caches the -result in a ``ContextVar``. +Tool functions call :func:`get_caller_identity` to obtain the current :class:`CallerIdentity`. On first access within a +request, the function reads HTTP headers from the underlying MCP request context and caches the result in a +``ContextVar``. No FastMCP middleware is required — identity is resolved lazily on demand. """ @@ -51,8 +50,7 @@ class CallerIdentityError(Exception): def decode_jwt_payload(token: str) -> dict: """Extract claims from a JWT payload via base64 decode (no signature check). - The OIDCHTTPBearer middleware already verified the signature, so this - is purely for reading claims. + The OIDCHTTPBearer middleware already verified the signature, so this is purely for reading claims. """ parts = token.split(".") if len(parts) < 2: @@ -68,8 +66,7 @@ def decode_jwt_payload(token: str) -> dict: def _extract_identity_from_headers(headers: dict[str, str]) -> CallerIdentity | None: """Try to build a :class:`CallerIdentity` from raw HTTP headers. - Returns ``None`` when either ``user_id`` or ``session_id`` cannot be - determined. + Returns ``None`` when either ``user_id`` or ``session_id`` cannot be determined. """ user_id: str | None = headers.get("x-user-id") if not user_id: @@ -90,8 +87,8 @@ def _extract_identity_from_headers(headers: dict[str, str]) -> CallerIdentity | def _get_headers_from_request_ctx() -> dict[str, str]: """Read HTTP headers directly from the MCP low-level request context. - Falls back to FastMCP's ``get_http_headers()`` if the direct approach - fails. Returns an empty dict if neither method succeeds. + Falls back to FastMCP's ``get_http_headers()`` if the direct approach fails. Returns an empty dict if neither + method succeeds. """ # Approach 1: read directly from the MCP request_ctx ContextVar try: @@ -157,7 +154,7 @@ def _populate_identity_from_http() -> CallerIdentity | None: has_auth = "authorization" in headers has_session = "x-session-id" in headers logger.warning( - "identity: extraction failed — authorization present=%s, " "x-session-id present=%s, header keys=%s", + "identity: extraction failed — authorization present=%s, x-session-id present=%s, header keys=%s", has_auth, has_session, sorted(headers.keys()), @@ -168,12 +165,11 @@ def _populate_identity_from_http() -> CallerIdentity | None: def get_caller_identity() -> CallerIdentity: """Return the caller identity for the current MCP tool invocation. - On first call within a request, lazily reads HTTP headers from the - MCP request context and caches the result. Subsequent calls in the - same context return the cached value. + On first call within a request, lazily reads HTTP headers from the MCP request context and caches the result. + Subsequent calls in the same context return the cached value. - Raises :class:`CallerIdentityError` when identity cannot be determined - (required headers absent or not in an MCP request context). + Raises :class:`CallerIdentityError` when identity cannot be determined (required headers absent or not in an MCP + request context). """ identity = _current_identity.get() if identity is not None: @@ -186,7 +182,6 @@ def get_caller_identity() -> CallerIdentity: if identity is None: raise CallerIdentityError( - "Cannot determine caller identity. " - "Ensure the MCP connection sends Authorization and X-Session-Id headers." + "Cannot determine caller identity. Ensure the MCP connection sends Authorization and X-Session-Id headers." ) return identity diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py index c3c4e4b31..ad4f3f7ca 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py @@ -20,11 +20,9 @@ @dataclass class AwsSessionRecord: - """ - In-memory representation of a short-lived AWS session for a user/session. + """In-memory representation of a short-lived AWS session for a user/session. - The fields mirror the design in LISA_Auth.md, with expires_at stored as - an aware UTC datetime. + The fields mirror the design in LISA_Auth.md, with expires_at stored as an aware UTC datetime. """ user_id: str diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py index b926ecb63..37ff65322 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py @@ -21,11 +21,10 @@ @dataclass class InMemoryAwsSessionStore: - """ - Simple in-process implementation of an AWS session store. + """Simple in-process implementation of an AWS session store. - This is suitable for a single MCP Workbench process. For multi-instance - deployments, a distributed store such as Redis should be used instead. + This is suitable for a single MCP Workbench process. For multi-instance deployments, a distributed store such as + Redis should be used instead. """ safety_margin_seconds: int = 0 @@ -38,9 +37,7 @@ def set_session(self, record: AwsSessionRecord) -> None: self._sessions[key] = record def get_session(self, user_id: str, session_id: str) -> AwsSessionRecord | None: - """ - Retrieve the session for a given user/session, or None if missing/expired. - """ + """Retrieve the session for a given user/session, or None if missing/expired.""" key = (user_id, session_id) record = self._sessions.get(key) if record is None: diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py index abc69fee8..fc58f8e57 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py @@ -29,10 +29,7 @@ class InvalidAwsCredentialsError(Exception): @dataclass class AwsStsClient: - """ - Thin wrapper around boto3 STS client for validating credentials and - creating short-lived session credentials. - """ + """Thin wrapper around boto3 STS client for validating credentials and creating short-lived session credentials.""" def _create_sts_client( self, @@ -61,8 +58,7 @@ def validate_static_credentials( session_token: str | None, region: str, ) -> tuple[str, str]: - """ - Validate credentials via GetCallerIdentity. + """Validate credentials via GetCallerIdentity. Returns (account_id, arn) on success, raises InvalidAwsCredentialsError on failure. """ @@ -90,8 +86,7 @@ def create_session_credentials( duration_seconds: int = 3600, safety_margin_seconds: int = 60, ) -> AwsSessionRecord: - """ - Produce an AwsSessionRecord from the provided credentials. + """Produce an AwsSessionRecord from the provided credentials. * **Long-term (IAM user) credentials** (no session_token): calls ``GetSessionToken`` to mint short-lived temporary credentials diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/cli.py b/lib/serve/mcp-workbench/src/mcpworkbench/cli.py index 99eab6476..3dad5a399 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/cli.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/cli.py @@ -89,7 +89,6 @@ def main( debug: bool, ) -> None: """MCP Workbench - A dynamic host for Python files used as MCP tools.""" - # Set logging level if debug: logging.getLogger().setLevel(logging.DEBUG) diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py index d10dcc85c..7d4e796b2 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py @@ -22,8 +22,7 @@ def mcp_tool(name: str, description: str) -> Callable[[F], F]: - """ - Decorator to mark a function as an MCP tool. + """Decorator to mark a function as an MCP tool. Args: name: The name of the tool diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py index 0e087d028..3bca1abc5 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py @@ -63,8 +63,7 @@ class BaseTool(ABC): """Abstract base class for MCP tools.""" def __init__(self, name: str, description: str): - """ - Initialize the tool with required metadata. + """Initialize the tool with required metadata. Args: name: The name of the tool @@ -75,8 +74,7 @@ def __init__(self, name: str, description: str): @abstractmethod async def execute(self) -> Callable[..., Any]: - """ - Returns an function to be executed as the tool. + """Returns an function to be executed as the tool. Returns: The function to be executed diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py index 40013c341..9af41b7de 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py @@ -44,8 +44,7 @@ class ToolDiscovery: """Discovers and loads tools from Python files.""" def __init__(self, tools_directory: str): - """ - Initialize the tool discovery. + """Initialize the tool discovery. Args: tools_directory: Path to directory containing tool files @@ -61,8 +60,7 @@ def __init__(self, tools_directory: str): raise ValueError(f"Tools directory is not a directory: {tools_directory}") def discover_tools(self) -> list[ToolInfo]: - """ - Discover all tools in the tools directory. + """Discover all tools in the tools directory. Returns: List of discovered tool information @@ -86,8 +84,7 @@ def discover_tools(self) -> list[ToolInfo]: return tools def rescan_tools(self) -> RescanResult: - """ - Rescan the tools directory and return changes. + """Rescan the tools directory and return changes. Returns: RescanResult with information about changes @@ -150,8 +147,7 @@ def _reload_modules(self) -> None: pass def _discover_tools_in_file(self, file_path: Path) -> list[ToolInfo]: - """ - Discover tools in a single Python file. + """Discover tools in a single Python file. Args: file_path: Path to the Python file @@ -203,7 +199,6 @@ def _find_class_based_tools(self, module: Any, file_path: Path, module_name: str # Check if it's a subclass of BaseTool (but not BaseTool itself) if issubclass(obj, BaseTool) and obj != BaseTool and not inspect.isabstract(obj): - try: # Try to instantiate the tool to get its metadata # We need to handle different constructor signatures diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py index b04a8aa56..a2cefa3b0 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py @@ -31,8 +31,7 @@ def __init__(self) -> None: self._lock = threading.RLock() def register_tool(self, tool_info: ToolInfo) -> None: - """ - Register a tool in the registry. + """Register a tool in the registry. Args: tool_info: Information about the tool to register @@ -42,8 +41,7 @@ def register_tool(self, tool_info: ToolInfo) -> None: logger.info(f"Registered tool: {tool_info.name}") def register_tools(self, tools: list[ToolInfo]) -> None: - """ - Register multiple tools in the registry. + """Register multiple tools in the registry. Args: tools: List of tools to register @@ -54,8 +52,7 @@ def register_tools(self, tools: list[ToolInfo]) -> None: logger.info(f"Registered {len(tools)} tools") def unregister_tool(self, tool_name: str) -> bool: - """ - Unregister a tool from the registry. + """Unregister a tool from the registry. Args: tool_name: Name of the tool to unregister @@ -71,8 +68,7 @@ def unregister_tool(self, tool_name: str) -> bool: return False def get_tool(self, tool_name: str) -> ToolInfo | None: - """ - Get a tool by name. + """Get a tool by name. Args: tool_name: Name of the tool to retrieve @@ -84,8 +80,7 @@ def get_tool(self, tool_name: str) -> ToolInfo | None: return self._tools.get(tool_name) def list_tools(self) -> list[ToolInfo]: - """ - Get a list of all registered tools. + """Get a list of all registered tools. Returns: List of all registered tools @@ -94,8 +89,7 @@ def list_tools(self) -> list[ToolInfo]: return list(self._tools.values()) def list_tool_names(self) -> list[str]: - """ - Get a list of all registered tool names. + """Get a list of all registered tool names. Returns: List of all registered tool names @@ -110,12 +104,12 @@ def clear(self) -> None: logger.info("Cleared all tools from registry") def update_registry(self, new_tools: list[ToolInfo]) -> None: - """ - Update the registry with a new set of tools. + """Update the registry with a new set of tools. + This replaces all existing tools. Args: - new_tools: New list of tools to register + new_tools: New list of tools to register """ with self._lock: self._tools.clear() @@ -124,8 +118,7 @@ def update_registry(self, new_tools: list[ToolInfo]) -> None: logger.info(f"Updated registry with {len(new_tools)} tools") def get_tool_count(self) -> int: - """ - Get the number of registered tools. + """Get the number of registered tools. Returns: Number of registered tools @@ -134,8 +127,7 @@ def get_tool_count(self) -> int: return len(self._tools) def has_tool(self, tool_name: str) -> bool: - """ - Check if a tool is registered. + """Check if a tool is registered. Args: tool_name: Name of the tool to check diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py index 854bded9b..88181b859 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py @@ -158,7 +158,10 @@ def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> No self._jwks_client = get_jwks_client() async def dispatch(self, request: Request, call_next: Any) -> Response: - """Verify the provided bearer token or API Key. API Key will take precedence over the bearer token.""" + """Verify the provided bearer token or API Key. + + API Key will take precedence over the bearer token. + """ if request.method == "OPTIONS": return await call_next(request) @@ -198,9 +201,8 @@ async def dispatch(self, request: Request, call_next: Any) -> Response: class ApiTokenAuthorizer: """Class for checking API tokens against a DynamoDB table of API Tokens. - For the Token database, only a string value in the "token" field is required. Optionally, - customers may put a UNIX timestamp (in seconds) in a "tokenExpiration" field so that the - API key becomes invalid after a specified time. + For the Token database, only a string value in the "token" field is required. Optionally, customers may put a UNIX + timestamp (in seconds) in a "tokenExpiration" field so that the API key becomes invalid after a specified time. """ def __init__(self) -> None: diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py index 7a76b91fa..9dbd0446b 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py @@ -49,8 +49,7 @@ class MCPWorkbenchServer: """MCP Workbench server using pure FastMCP 2.0.""" def __init__(self, config: ServerConfig, tool_discovery: ToolDiscovery, tool_registry: ToolRegistry): - """ - Initialize the MCP Workbench server. + """Initialize the MCP Workbench server. Args: config: Server configuration @@ -159,7 +158,6 @@ async def rescan_endpoint(request: Request) -> JSONResponse: def _create_starlette_app(self) -> Starlette: """Create Starlette application with MCP and HTTP routes.""" - mcp_app = self.app.http_app(path="/", transport="streamable-http", stateless_http=True) async def health_check(request: Request) -> JSONResponse: diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py index 0be13a801..5a8014ae0 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py @@ -91,9 +91,8 @@ def _merge_vary_origin(headers: MutableHeaders) -> None: def wrap_asgi_with_cors_headers(app: ASGIApp, cors_config: CORSConfig) -> ASGIApp: """Outer ASGI wrapper: ensure CORS headers on every HTTP response when missing. - Starlette's outer ``ServerErrorMiddleware`` can emit error responses that bypass inner - ``CORSMiddleware``'s ``send`` wrapper, so browsers see 500 without - ``Access-Control-Allow-Origin`` and block the response body. + Starlette's outer ``ServerErrorMiddleware`` can emit error responses that bypass inner ``CORSMiddleware``'s ``send`` + wrapper, so browsers see 500 without ``Access-Control-Allow-Origin`` and block the response body. """ async def asgi(scope: Scope, receive: Receive, send: Send) -> None: diff --git a/lib/serve/mcp-workbench/test_install.py b/lib/serve/mcp-workbench/test_install.py old mode 100644 new mode 100755 index 77dac6ad6..e030a170e --- a/lib/serve/mcp-workbench/test_install.py +++ b/lib/serve/mcp-workbench/test_install.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Simple test to verify MCP Workbench installation. +"""Simple test to verify MCP Workbench installation. + Run this after installing to verify everything works. """ @@ -28,7 +28,6 @@ def test_cli_available() -> bool: """Test that the CLI command is available.""" - try: result = subprocess.run( [sys.executable, "-m", "mcpworkbench.cli", "--help"], capture_output=True, text=True, timeout=10 diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 6251a67d7..0f2b62f06 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -83,8 +83,7 @@ def _generate_presigned_video_url(key: str, content_type: str = "video/mp4") -> async def apply_guardrails_to_request(params: dict, model_id: str, jwt_data: dict) -> None: - """ - Apply guardrails to a chat completion request. + """Apply guardrails to a chat completion request. This function modifies the params dict in-place, adding applicable guardrails based on the user's group membership and the model's guardrail configuration. @@ -123,8 +122,7 @@ async def apply_guardrails_to_request(params: dict, model_id: str, jwt_data: dic def handle_guardrail_violation_response( response: Response, model_id: str, params: dict, is_streaming: bool ) -> Response | None: - """ - Handle guardrail violation errors in LiteLLM responses. + """Handle guardrail violation errors in LiteLLM responses. Checks if a 400 error response contains a guardrail violation and converts it into an appropriate format (streaming or non-streaming). @@ -171,8 +169,7 @@ def handle_guardrail_violation_response( def invalidate_model_cache(model_id: str | None = None) -> None: - """ - Manually invalidate model info cache. + """Manually invalidate model info cache. Note: This function is available for manual/programmatic cache clearing but is not automatically triggered. The cache relies on TTL expiration for normal operation. @@ -189,8 +186,7 @@ def invalidate_model_cache(model_id: str | None = None) -> None: def get_model_info(model_id: str, use_cache: bool = True) -> dict | None: - """ - Get model information from LiteLLM for a given model ID. + """Get model information from LiteLLM for a given model ID. Uses a TTL-based cache to reduce API calls while ensuring deleted/recreated models are eventually refreshed. @@ -264,8 +260,7 @@ def generate_response_with_guardrail_handling( request: Request, params: dict, ) -> Iterator[str]: - """ - Generate streaming responses with guardrail violation error handling and token usage capture. + """Generate streaming responses with guardrail violation error handling and token usage capture. In addition to guardrail handling, this generator watches for the SSE usage chunk (the chunk containing ``"usage": {...}``) emitted at the end of a streaming response. @@ -352,14 +347,12 @@ def generate_response_with_guardrail_handling( @router.api_route("/{api_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) async def litellm_passthrough(request: Request, api_path: str) -> Response: - """ - Pass requests directly to LiteLLM. LiteLLM and deployed models will respond here directly. + """Pass requests directly to LiteLLM. - Authentication is handled by auth_middleware. This function checks authorization - based on whether the route requires admin access. + LiteLLM and deployed models will respond here directly. Authentication is handled by auth_middleware. This + function checks authorization based on whether the route requires admin access. - Results are only streamed if the OpenAI-compatible request specifies streaming as part of the - input payload. + Results are only streamed if the OpenAI-compatible request specifies streaming as part of the input payload. """ litellm_path = f"{LITELLM_URL}/{api_path}" headers = dict(request.headers.items()) @@ -541,7 +534,6 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: # Check for anthropic specific headers and reset the max token parameter to None # so LiteLLM handles the max_token value. Only if it's not an Anthropic model if model_id and "anthropic-beta" in headers and "anthropic-version" in headers: - # Only nullify max_tokens if the model is NOT an Anthropic model if model_name and ".anthropic" not in model_name: if "max_tokens" in params: diff --git a/lib/serve/rest-api/src/auth.py b/lib/serve/rest-api/src/auth.py index 9bdf5d37e..c0228d49b 100644 --- a/lib/serve/rest-api/src/auth.py +++ b/lib/serve/rest-api/src/auth.py @@ -136,8 +136,7 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: def extract_user_groups_from_jwt(jwt_data: dict[str, Any] | None) -> list[str]: - """ - Extract user groups from JWT data using the JWT_GROUPS_PROP environment variable. + """Extract user groups from JWT data using the JWT_GROUPS_PROP environment variable. This follows the same property path traversal logic as is_user_in_group() function. @@ -146,7 +145,7 @@ def extract_user_groups_from_jwt(jwt_data: dict[str, Any] | None) -> list[str]: jwt_data : Optional[Dict[str, Any]] JWT data from authentication. None if user authenticated via API token. - Returns + Returns: ------- list[str] List of groups the user belongs to. Empty list if no JWT data or groups not found. @@ -226,9 +225,8 @@ async def id_token_is_valid(self, request: Request) -> dict[str, Any] | None: class ApiTokenAuthorizer: """Class for checking API tokens against a DynamoDB table of API Tokens. - For the Token database, only a string value in the "token" field is required. Optionally, - customers may put a UNIX timestamp (in seconds) in a "tokenExpiration" field so that the - API key becomes invalid after a specified time. + For the Token database, only a string value in the "token" field is required. Optionally, customers may put a UNIX + timestamp (in seconds) in a "tokenExpiration" field so that the API key becomes invalid after a specified time. """ def __init__(self) -> None: @@ -242,7 +240,6 @@ def _get_token_info(self, token_hash: str) -> Any: async def is_valid_api_token(self, headers: dict[str, str]) -> dict[str, Any] | None: """Return token info if API Token from request headers is valid, else None.""" - for header_name in AuthHeaders.values(): token = get_authorization_token(headers, header_name) @@ -348,8 +345,10 @@ async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | Non return jwt_data async def authenticate_request(self, request: Request) -> dict[str, Any] | None: - """Authenticate request and return JWT data if valid, else None. Invalid requests throw an exception""" + """Authenticate request and return JWT data if valid, else None. + Invalid requests throw an exception + """ logger.trace(f"Authenticating request: {request.method} {request.url.path}") # First try API tokens @@ -453,8 +452,7 @@ def _set_token_context(self, request: Request, token_info: dict[str, Any]) -> No def is_api_user(request: Request) -> bool: - """ - Check if the user authenticated with an API token. + """Check if the user authenticated with an API token. Args: request: The FastAPI request object @@ -466,8 +464,7 @@ def is_api_user(request: Request) -> bool: def get_user_context(request: Request) -> tuple[str, list[str]]: - """ - Get user information from the request. + """Get user information from the request. Works with both API token and JWT authentication. diff --git a/lib/serve/rest-api/src/auth_provider.py b/lib/serve/rest-api/src/auth_provider.py index 6117740f9..cde82c937 100644 --- a/lib/serve/rest-api/src/auth_provider.py +++ b/lib/serve/rest-api/src/auth_provider.py @@ -24,8 +24,8 @@ class AuthorizationProvider(ABC): """Abstract base class for authorization providers. - This abstraction allows swapping between different authorization backends - (e.g., OIDC group-based, BRASS bindle lock) without changing the consuming code. + This abstraction allows swapping between different authorization backends (e.g., OIDC group-based, BRASS bindle + lock) without changing the consuming code. """ @abstractmethod @@ -39,7 +39,7 @@ def check_admin_access(self, username: str, groups: list[str] | None = None) -> groups : list[str] | None Optional list of groups the user belongs to (used by group-based providers) - Returns + Returns: ------- bool True if user has admin access, False otherwise @@ -57,7 +57,7 @@ def check_app_access(self, username: str, groups: list[str] | None = None) -> bo groups : list[str] | None Optional list of groups the user belongs to (used by group-based providers) - Returns + Returns: ------- bool True if user has app access, False otherwise @@ -75,7 +75,7 @@ def check_admin_access_jwt(self, jwt_data: dict[str, Any], jwt_groups_property: jwt_groups_property : str The property path to extract groups from JWT - Returns + Returns: ------- bool True if user has admin access, False otherwise @@ -93,7 +93,7 @@ def check_app_access_jwt(self, jwt_data: dict[str, Any], jwt_groups_property: st jwt_groups_property : str The property path to extract groups from JWT - Returns + Returns: ------- bool True if user has app access, False otherwise @@ -185,7 +185,7 @@ def check_app_access_jwt(self, jwt_data: dict[str, Any], jwt_groups_property: st def get_authorization_provider() -> AuthorizationProvider: """Get the configured authorization provider instance. - Returns + Returns: ------- AuthorizationProvider The authorization provider instance (OIDC-based for LISA) diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh old mode 100644 new mode 100755 diff --git a/lib/serve/rest-api/src/main.py b/lib/serve/rest-api/src/main.py index 8c24d2ee9..75e06b8ba 100644 --- a/lib/serve/rest-api/src/main.py +++ b/lib/serve/rest-api/src/main.py @@ -88,8 +88,7 @@ async def lifespan(app: FastAPI): # type: ignore async def rate_limit(request, call_next): # type: ignore """Per-user rate limiting middleware. - Runs after authentication (user identity is available) to enforce - per-API-key / per-user request rate limits. + Runs after authentication (user identity is available) to enforce per-API-key / per-user request rate limits. """ return await rate_limit_middleware(request, call_next) diff --git a/lib/serve/rest-api/src/middleware/auth_middleware.py b/lib/serve/rest-api/src/middleware/auth_middleware.py index 9b52412c3..1eb030f32 100644 --- a/lib/serve/rest-api/src/middleware/auth_middleware.py +++ b/lib/serve/rest-api/src/middleware/auth_middleware.py @@ -14,8 +14,8 @@ """Authentication middleware for the serve API. -This middleware handles authentication at the request level, validating tokens -and setting user context on request.state for downstream handlers. +This middleware handles authentication at the request level, validating tokens and setting user context on request.state +for downstream handlers. """ import os diff --git a/lib/serve/rest-api/src/middleware/exception_handlers.py b/lib/serve/rest-api/src/middleware/exception_handlers.py index 9f0f7ec3a..c1f10c31f 100644 --- a/lib/serve/rest-api/src/middleware/exception_handlers.py +++ b/lib/serve/rest-api/src/middleware/exception_handlers.py @@ -14,8 +14,8 @@ """Exception handlers for the serve API. -This module provides exception handlers that return proper HTTP status codes -with generic error messages, preventing internal details from being exposed. +This module provides exception handlers that return proper HTTP status codes with generic error messages, preventing +internal details from being exposed. """ import traceback diff --git a/lib/serve/rest-api/src/middleware/input_validation.py b/lib/serve/rest-api/src/middleware/input_validation.py index f535359b1..5d3835a30 100644 --- a/lib/serve/rest-api/src/middleware/input_validation.py +++ b/lib/serve/rest-api/src/middleware/input_validation.py @@ -33,8 +33,7 @@ def contains_null_bytes(data: str) -> bool: - """ - Check if a string contains null bytes. + r"""Check if a string contains null bytes. Null bytes (\\x00) can be used to bypass input validation or cause unexpected behavior in string processing. @@ -51,8 +50,7 @@ def contains_null_bytes(data: str) -> bool: async def validate_input_middleware( request: Request, call_next: Callable[[Request], Any], max_request_size: int = DEFAULT_MAX_REQUEST_SIZE ) -> Response: - """ - Middleware to validate request input before processing. + """Middleware to validate request input before processing. This middleware provides security protections against: - Null byte injection attacks diff --git a/lib/serve/rest-api/src/middleware/rate_limit_middleware.py b/lib/serve/rest-api/src/middleware/rate_limit_middleware.py index 7b2d18b46..5e7c011dc 100644 --- a/lib/serve/rest-api/src/middleware/rate_limit_middleware.py +++ b/lib/serve/rest-api/src/middleware/rate_limit_middleware.py @@ -160,7 +160,10 @@ def _get_user_limits(user_key: str) -> tuple[float, float, float]: def _prune_stale_buckets() -> None: - """Remove buckets that haven't been touched recently. Must hold ``_lock``.""" + """Remove buckets that haven't been touched recently. + + Must hold ``_lock``. + """ now = time.monotonic() stale_keys = [k for k, b in _buckets.items() if (now - b.last_refill) > _STALE_SECONDS] for k in stale_keys: @@ -170,8 +173,7 @@ def _prune_stale_buckets() -> None: async def _check_rate_limit(user_key: str) -> tuple[bool, float]: """Check whether *user_key* is within its rate limit. - Returns ``(allowed, retry_after_seconds)``. - Uses per-user overrides from ``RATE_LIMIT_OVERRIDES`` when available. + Returns ``(allowed, retry_after_seconds)``. Uses per-user overrides from ``RATE_LIMIT_OVERRIDES`` when available. """ max_tokens, refill_rate, _ = _get_user_limits(user_key) @@ -195,8 +197,7 @@ async def _check_rate_limit(user_key: str) -> tuple[bool, float]: def _get_user_key(request: Request) -> str | None: """Derive a rate-limit key from the authenticated request. - Returns ``None`` for requests that should bypass rate limiting - (management tokens, unauthenticated/public paths). + Returns ``None`` for requests that should bypass rate limiting (management tokens, unauthenticated/public paths). """ if not getattr(request.state, "authenticated", False): return None @@ -240,8 +241,7 @@ def _get_user_key(request: Request) -> str | None: async def rate_limit_middleware(request: Request, call_next: Callable[[Request], Response]) -> Response: """Per-user rate limiting middleware. - Must run **after** authentication middleware so that ``request.state`` - contains the caller identity. + Must run **after** authentication middleware so that ``request.state`` contains the caller identity. """ if not RATE_LIMIT_ENABLED: return await call_next(request) diff --git a/lib/serve/rest-api/src/middleware/request_middleware.py b/lib/serve/rest-api/src/middleware/request_middleware.py index 81fbf81d8..40a0588d4 100644 --- a/lib/serve/rest-api/src/middleware/request_middleware.py +++ b/lib/serve/rest-api/src/middleware/request_middleware.py @@ -36,7 +36,7 @@ async def process_request_middleware(request: Request, call_next: Callable[[Requ call_next : Callable The next middleware or route handler - Returns + Returns: ------- Response The response with added request ID header diff --git a/lib/serve/rest-api/src/services/text_processing.py b/lib/serve/rest-api/src/services/text_processing.py index ff670b864..f243ba218 100644 --- a/lib/serve/rest-api/src/services/text_processing.py +++ b/lib/serve/rest-api/src/services/text_processing.py @@ -25,7 +25,7 @@ def render_context_from_messages(messages_list: list[dict[str, str]]) -> str: messages_list : List[Dict[str, str]] List of messages with 'content' field - Returns + Returns: ------- str Concatenated message content @@ -43,12 +43,12 @@ def parse_model_provider_from_string(model_string: str) -> tuple[str, str]: model_string : str Combined model string in format "model_name (provider_name)" - Returns + Returns: ------- Tuple[str, str] Model name and provider name - Raises + Raises: ------ ValueError If string format is invalid @@ -79,7 +79,7 @@ def map_openai_params_to_lisa(request_data: dict) -> dict: request_data : dict OpenAI-format request data - Returns + Returns: ------- dict Mapped parameters for LISA diff --git a/lib/serve/rest-api/src/utils/generate_litellm_config.py b/lib/serve/rest-api/src/utils/generate_litellm_config.py index 4349cd774..bb069a61b 100644 --- a/lib/serve/rest-api/src/utils/generate_litellm_config.py +++ b/lib/serve/rest-api/src/utils/generate_litellm_config.py @@ -16,11 +16,10 @@ Models are managed dynamically via LiteLLM's DB (store_model_in_db=True). -Uses LiteLLM's ConfigYAML pydantic model for top-level structure validation. -Note: general_settings is kept as a dict because LiteLLM's ConfigGeneralSettings -model does not include all runtime-accepted fields (e.g., proxy_batch_write_at, -disable_error_logs, allow_requests_on_db_unavailable). These fields are consumed -by LiteLLM at runtime via dict access but are not in the pydantic schema. +Uses LiteLLM's ConfigYAML pydantic model for top-level structure validation. Note: general_settings is kept as a dict +because LiteLLM's ConfigGeneralSettings model does not include all runtime-accepted fields (e.g., proxy_batch_write_at, +disable_error_logs, allow_requests_on_db_unavailable). These fields are consumed by LiteLLM at runtime via dict access +but are not in the pydantic schema. """ import json @@ -66,7 +65,7 @@ def _build_general_settings(db_key: str, db_params: dict[str, str], use_iam_auth if not use_iam_auth: username, password = _get_database_credentials(db_params) settings["database_url"] = ( - f"postgresql://{username}:{password}" f"@{db_params['dbHost']}:{db_params['dbPort']}/{db_params['dbName']}" + f"postgresql://{username}:{password}@{db_params['dbHost']}:{db_params['dbPort']}/{db_params['dbName']}" ) return settings diff --git a/lib/serve/rest-api/src/utils/guardrails.py b/lib/serve/rest-api/src/utils/guardrails.py index e6587e127..def552a48 100644 --- a/lib/serve/rest-api/src/utils/guardrails.py +++ b/lib/serve/rest-api/src/utils/guardrails.py @@ -27,15 +27,14 @@ async def get_model_guardrails(model_id: str) -> list[dict[str, Any]]: - """ - Query the guardrails DynamoDB table for guardrails associated with a model. + """Query the guardrails DynamoDB table for guardrails associated with a model. Parameters ---------- model_id : str The model ID to query guardrails for. - Returns + Returns: ------- List[Dict[str, Any]] List of guardrail configurations for the model. Returns empty list if no guardrails found. @@ -61,8 +60,7 @@ async def get_model_guardrails(model_id: str) -> list[dict[str, Any]]: def get_applicable_guardrails(user_groups: list[str], guardrails: list[dict[str, Any]], model_id: str) -> list[str]: - """ - Determine which guardrails apply to a user based on group membership. + """Determine which guardrails apply to a user based on group membership. A guardrail applies if: - It has no allowed_groups (public guardrail, applies to everyone) @@ -79,7 +77,7 @@ def get_applicable_guardrails(user_groups: list[str], guardrails: list[dict[str, model_id : str The model ID being invoked. Used to construct the full LiteLLM guardrail name. - Returns + Returns: ------- List[str] List of LiteLLM guardrail names (format: {guardrail_name}-{model_id}) that should be applied to the request. @@ -116,15 +114,14 @@ def get_applicable_guardrails(user_groups: list[str], guardrails: list[dict[str, def is_guardrail_violation(error_msg: str) -> bool: - """ - Check if an error message indicates a guardrail policy violation. + """Check if an error message indicates a guardrail policy violation. Parameters ---------- error_msg : str The error message to check. - Returns + Returns: ------- bool True if the error message indicates a guardrail violation, False otherwise. @@ -133,15 +130,14 @@ def is_guardrail_violation(error_msg: str) -> bool: def extract_guardrail_response(error_msg: str) -> str | None: - """ - Extract the bedrock_guardrail_response from an error message. + """Extract the bedrock_guardrail_response from an error message. Parameters ---------- error_msg : str The error message containing the guardrail response. - Returns + Returns: ------- Optional[str] The extracted guardrail response text, or None if not found. @@ -151,8 +147,7 @@ def extract_guardrail_response(error_msg: str) -> str | None: def create_guardrail_streaming_response(guardrail_response: str, model_id: str, created: int = 0) -> Iterator[str]: - """ - Generate streaming response chunks for a guardrail violation. + """Generate streaming response chunks for a guardrail violation. Parameters ---------- @@ -163,7 +158,7 @@ def create_guardrail_streaming_response(guardrail_response: str, model_id: str, created : int, optional The creation timestamp, by default 0. - Yields + Yields: ------ str Properly formatted SSE chunks for the guardrail response. @@ -205,8 +200,7 @@ def create_guardrail_streaming_response(guardrail_response: str, model_id: str, def create_guardrail_json_response(guardrail_response: str, model_id: str, created: int = 0) -> JSONResponse: - """ - Create a JSON response for a guardrail violation. + """Create a JSON response for a guardrail violation. Parameters ---------- @@ -217,7 +211,7 @@ def create_guardrail_json_response(guardrail_response: str, model_id: str, creat created : int, optional The creation timestamp, by default 0. - Returns + Returns: ------- JSONResponse A properly formatted JSON response for the guardrail violation. diff --git a/lib/serve/rest-api/src/utils/header_sanitizer.py b/lib/serve/rest-api/src/utils/header_sanitizer.py index f9ae5ad94..5da384108 100644 --- a/lib/serve/rest-api/src/utils/header_sanitizer.py +++ b/lib/serve/rest-api/src/utils/header_sanitizer.py @@ -14,8 +14,8 @@ """Utility for sanitizing HTTP headers before logging to prevent log injection attacks. -This module is adapted for the serve API (ECS context) where we don't have -API Gateway event context. Instead, we extract real client IP from ECS/ALB headers. +This module is adapted for the serve API (ECS context) where we don't have API Gateway event context. Instead, we +extract real client IP from ECS/ALB headers. """ from fastapi import Request @@ -33,8 +33,7 @@ def get_real_client_ip(request: Request) -> str: - """ - Extract the real client IP address from the request. + """Extract the real client IP address from the request. In ECS behind ALB, the real client IP is typically in the last entry of x-forwarded-for added by the ALB, or we can use the client host. @@ -62,8 +61,7 @@ def sanitize_headers_for_logging( headers: dict[str, str], real_client_ip: str | None = None, ) -> dict[str, str]: - """ - Sanitize HTTP headers by replacing user-controlled values with server-controlled values. + """Sanitize HTTP headers by replacing user-controlled values with server-controlled values. This prevents attackers from manipulating security-critical headers in logs, which could be used to hide their true source IP or manipulate audit trails. @@ -132,8 +130,7 @@ def sanitize_headers_for_logging( def get_sanitized_headers_from_request(request: Request) -> dict[str, str]: - """ - Extract and sanitize headers from a FastAPI request for safe logging. + """Extract and sanitize headers from a FastAPI request for safe logging. This is a convenience function that extracts headers from the request and sanitizes them in one step. diff --git a/lib/serve/rest-api/src/utils/metrics.py b/lib/serve/rest-api/src/utils/metrics.py index 15de724f7..eb6fbfed3 100644 --- a/lib/serve/rest-api/src/utils/metrics.py +++ b/lib/serve/rest-api/src/utils/metrics.py @@ -30,8 +30,7 @@ def extract_messages_for_metrics(params: dict) -> list[dict]: - """ - Extract messages from chat completion request parameters. + """Extract messages from chat completion request parameters. Args: params: The request parameters containing messages @@ -93,8 +92,7 @@ def extract_messages_for_metrics(params: dict) -> list[dict]: def extract_token_usage(response_body: dict | None) -> tuple[int | None, int | None]: - """ - Extract token usage from a LLM response body (non-streaming or SSE chunk). + """Extract token usage from a LLM response body (non-streaming or SSE chunk). The usage structure is identical in both cases — LiteLLM normalises it: {"usage": {"prompt_tokens": N, "completion_tokens": N, ...}, ...} @@ -123,8 +121,7 @@ def publish_metrics_event( prompt_tokens: int | None = None, completion_tokens: int | None = None, ) -> None: - """ - Publish metrics event to SQS queue for API users. + """Publish metrics event to SQS queue for API users. Includes both message-level metrics (for prompt/RAG/MCP counting) and token-level metrics (prompt_tokens, completion_tokens) if available. diff --git a/lib/serve/rest-api/src/utils/metrics_models.py b/lib/serve/rest-api/src/utils/metrics_models.py index 90b92c22a..eb9011a2a 100644 --- a/lib/serve/rest-api/src/utils/metrics_models.py +++ b/lib/serve/rest-api/src/utils/metrics_models.py @@ -14,9 +14,8 @@ """Pydantic models for metrics events published to SQS. -This module is intentionally kept in sync with lambda/metrics/models.py. -The two files live in separate deployment contexts (FastAPI container vs Lambda) -and cannot share code directly, so any schema changes must be applied to both. +This module is intentionally kept in sync with lambda/metrics/models.py. The two files live in separate deployment +contexts (FastAPI container vs Lambda) and cannot share code directly, so any schema changes must be applied to both. """ from typing import Any diff --git a/lib/serve/rest-api/src/utils/rds_auth.py b/lib/serve/rest-api/src/utils/rds_auth.py index e3adb07ab..5ce38f12c 100644 --- a/lib/serve/rest-api/src/utils/rds_auth.py +++ b/lib/serve/rest-api/src/utils/rds_auth.py @@ -23,7 +23,7 @@ def _get_lambda_role_arn() -> str: """Get the ARN of the Lambda execution role. - Returns + Returns: ------- str The full ARN of the Lambda execution role @@ -36,7 +36,7 @@ def _get_lambda_role_arn() -> str: def get_lambda_role_name() -> str: """Extract the role name from the Lambda execution role ARN. - Returns + Returns: ------- str The name of the Lambda execution role without the full ARN diff --git a/lib/serve/rest-api/src/utils/request_utils.py b/lib/serve/rest-api/src/utils/request_utils.py index 852e415ea..661e5c761 100644 --- a/lib/serve/rest-api/src/utils/request_utils.py +++ b/lib/serve/rest-api/src/utils/request_utils.py @@ -61,13 +61,13 @@ def handle_stream_exceptions( The streaming function to wrap. This function is expected to be an asynchronous generator yielding strings. - Returns + Returns: ------- wrapper : Callable[..., AsyncGenerator[str, None]] The wrapped function, which handles exceptions by yielding them as formatted error messages in the stream. - Yields + Yields: ------ str The items yielded by the original function, or a JSON-formatted error message in case of an exception. @@ -98,8 +98,7 @@ def get_lisa_end_user_id( jwt_data: dict[str, Any] | None, state_username: str | None, ) -> str | None: - """ - Derive a human-readable end-user id for logs/spend attribution. + """Derive a human-readable end-user id for logs/spend attribution. LiteLLM uses the provided end-user identifier for spend/budget/logging. We prefer the same claims used by the authorizer/session to make the diff --git a/lib/serve/rest-api/src/utils/setup_prisma_db.py b/lib/serve/rest-api/src/utils/setup_prisma_db.py index cf9268908..a31b78d06 100644 --- a/lib/serve/rest-api/src/utils/setup_prisma_db.py +++ b/lib/serve/rest-api/src/utils/setup_prisma_db.py @@ -50,8 +50,8 @@ def _get_prisma_schema_dir() -> str: def _generate_prisma_client(schema_dir: str) -> None: """Generate the Prisma Python client from LiteLLM's schema. - This writes the generated client to site-packages/prisma/ once, so that - Gunicorn workers find it already present and skip generation. + This writes the generated client to site-packages/prisma/ once, so that Gunicorn workers find it already present and + skip generation. """ schema_path = os.path.join(schema_dir, "schema.prisma") print(f"Generating Prisma client from {schema_path}") diff --git a/lib/user-interface/react/README.md b/lib/user-interface/react/README.md index dfb9cf763..95f6abd22 100644 --- a/lib/user-interface/react/README.md +++ b/lib/user-interface/react/README.md @@ -22,7 +22,6 @@ If you are developing a production application, we recommend updating the config }, ``` - - Replace `plugin:@typescript-eslint/recommended` to `plugin:@typescript-eslint/recommended-type-checked` or `plugin:@typescript-eslint/strict-type-checked` - Optionally add `plugin:@typescript-eslint/stylistic-type-checked` - Install [eslint-plugin-react](https://github.com/jsx-eslint/eslint-plugin-react) and add `plugin:react/recommended` & `plugin:react/jsx-runtime` to the `extends` list diff --git a/lib/user-interface/react/public/branding/base/login.png b/lib/user-interface/react/public/branding/base/login.png index ff3a3e0d3..9f4ceb14e 100644 Binary files a/lib/user-interface/react/public/branding/base/login.png and b/lib/user-interface/react/public/branding/base/login.png differ diff --git a/lib/user-interface/react/scripts/set-revision-info.mjs b/lib/user-interface/react/scripts/set-revision-info.mjs old mode 100755 new mode 100644 diff --git a/lisa-sdk/lisapy/authentication.py b/lisa-sdk/lisapy/authentication.py index f798471e6..ab92edb49 100644 --- a/lisa-sdk/lisapy/authentication.py +++ b/lisa-sdk/lisapy/authentication.py @@ -45,7 +45,7 @@ def get_cognito_token(client_id: str, username: str, region: str = "us-east-1") region : str, default="us-east-1" AWS region. - Returns + Returns: ------- Dict[str, Any] Token response from cognito. @@ -80,17 +80,16 @@ def get_management_key( deployment_stage : str | None Deployment stage. When provided an additional pattern is tried first. - Returns + Returns: ------- str The management API key. - Raises + Raises: ------ RuntimeError If none of the secret-name patterns resolve. """ - secrets_client = boto3.client("secretsmanager", region_name=region) if region else boto3.client("secretsmanager") patterns: list[str] = [] @@ -142,12 +141,12 @@ def create_api_token( ttl_seconds : int Time-to-live in seconds (default 1 hour). - Returns + Returns: ------- str The registered API token (same as *api_key*). - Raises + Raises: ------ Exception If the DynamoDB put_item call fails. @@ -181,12 +180,12 @@ def setup_authentication( deployment_stage : str | None Deployment stage (optional). - Returns + Returns: ------- dict[str, str] Authentication headers (``Api-Key`` and ``Authorization``). - Raises + Raises: ------ RuntimeError If the management key cannot be retrieved. diff --git a/lisa-sdk/lisapy/errors.py b/lisa-sdk/lisapy/errors.py index e14d38a50..9e439ce21 100644 --- a/lisa-sdk/lisapy/errors.py +++ b/lisa-sdk/lisapy/errors.py @@ -72,7 +72,7 @@ def parse_error(status_code: int, response: ErrorResponse = None) -> Exception: response : ErrorResponse, optional API response object (requests.Response) or pre-extracted error message. - Returns + Returns: ------- Exception Parsed exception. diff --git a/lisa-sdk/lisapy/evaluation/__init__.py b/lisa-sdk/lisapy/evaluation/__init__.py index f20c93fc8..225a55c4e 100644 --- a/lisa-sdk/lisapy/evaluation/__init__.py +++ b/lisa-sdk/lisapy/evaluation/__init__.py @@ -14,8 +14,8 @@ """RAG evaluation module for measuring retrieval quality. -Provides metric functions, golden dataset loading, and evaluator classes -for Bedrock Knowledge Bases and LISA API backends (OpenSearch, PGVector). +Provides metric functions, golden dataset loading, and evaluator classes for Bedrock Knowledge Bases and LISA API +backends (OpenSearch, PGVector). """ from .base import BaseEvaluator diff --git a/lisa-sdk/lisapy/evaluation/base.py b/lisa-sdk/lisapy/evaluation/base.py index 8d7adb831..607aa37b5 100644 --- a/lisa-sdk/lisapy/evaluation/base.py +++ b/lisa-sdk/lisapy/evaluation/base.py @@ -64,7 +64,7 @@ def evaluate(self, golden: list[GoldenDatasetEntry]) -> EvalResult: for doc in entry.expected: if doc not in self.source_map: raise ValueError( - f"Golden dataset references unknown document '{doc}'. " f"Available: {sorted(self.source_map)}" + f"Golden dataset references unknown document '{doc}'. Available: {sorted(self.source_map)}" ) expected = {self.source_map[doc] for doc in entry.expected} rel_map = {self.source_map[doc]: entry.relevance[doc] for doc in entry.expected} diff --git a/lisa-sdk/lisapy/evaluation/lisa_api.py b/lisa-sdk/lisapy/evaluation/lisa_api.py index 0deb94a50..dcdb87088 100644 --- a/lisa-sdk/lisapy/evaluation/lisa_api.py +++ b/lisa-sdk/lisapy/evaluation/lisa_api.py @@ -14,8 +14,7 @@ """Evaluator for LISA API backends (OpenSearch, PGVector). -A single parameterized class covering any vector store backend -accessible via the LISA API's similarity_search endpoint. +A single parameterized class covering any vector store backend accessible via the LISA API's similarity_search endpoint. """ from lisapy.api import LisaApi diff --git a/lisa-sdk/lisapy/evaluation/metrics.py b/lisa-sdk/lisapy/evaluation/metrics.py index f7ba3017c..e2bb51a09 100644 --- a/lisa-sdk/lisapy/evaluation/metrics.py +++ b/lisa-sdk/lisapy/evaluation/metrics.py @@ -23,9 +23,8 @@ def deduplicate_sources(sources: list[str]) -> list[str]: """Deduplicate source paths preserving first-occurrence rank order. - RAG retrieval returns multiple chunks per document. For document-level - evaluation, we only care about the rank of the first chunk from each - unique source document. + RAG retrieval returns multiple chunks per document. For document-level evaluation, we only care about the rank of + the first chunk from each unique source document. """ seen: set[str] = set() deduped: list[str] = [] diff --git a/lisa-sdk/lisapy/langchain.py b/lisa-sdk/lisapy/langchain.py index a60261a37..b21270e47 100644 --- a/lisa-sdk/lisapy/langchain.py +++ b/lisa-sdk/lisapy/langchain.py @@ -32,15 +32,14 @@ class LisaTextgen(LLM): """Lisa text generation adapter. - To use, you should have the `lisapy` python package installed and - a Lisa API available. + To use, you should have the `lisapy` python package installed and a Lisa API available. """ provider: str """Provider of the LISA serve model e.g., ecs.textgen.tgi.""" model_name: str - """Name of LISA serve model e.g. Mixtral-8x7B-Instruct-v0.1""" + """Name of LISA serve model e.g. Mixtral-8x7B-Instruct-v0.1.""" client: LisaLlm """An instance of the Lisa Llm client.""" @@ -112,10 +111,13 @@ class LisaOpenAIEmbeddings(BaseModel, Embeddings): """Model name for Embeddings API.""" api_token: str - """API Token for communicating with LISA Serve. This can be a custom API token or the IdP Bearer token.""" + """API Token for communicating with LISA Serve. + + This can be a custom API token or the IdP Bearer token. + """ verify: bool | str - """Cert path or option for verifying SSL""" + """Cert path or option for verifying SSL.""" _embedding_model: OpenAIEmbeddings = PrivateAttr(default_factory=None) """OpenAI-compliant client for making requests against embedding model.""" @@ -153,8 +155,7 @@ async def aembed_query(self, text: str) -> list[float]: class LisaEmbeddings(BaseModel, Embeddings): """Lisa text embedding adapter. - To use, you should have the `lisapy` python package installed and - a Lisa API available. + To use, you should have the `lisapy` python package installed and a Lisa API available. """ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) @@ -163,7 +164,7 @@ class LisaEmbeddings(BaseModel, Embeddings): """Provider of the LISA serve model e.g., ecs.textgen.tgi.""" model_name: str - """Name of LISA serve model e.g. Mistral-8x7B-Instruct-v0.1""" + """Name of LISA serve model e.g. Mistral-8x7B-Instruct-v0.1.""" client: LisaLlm """An instance of the Lisa client.""" @@ -182,7 +183,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: texts : List[str] The list of texts to embed. - Returns + Returns: ------- List[List[float]] List of embeddings, one for each text. @@ -197,7 +198,7 @@ def embed_query(self, text: str) -> list[float]: text : str The text to embed. - Returns + Returns: ------- List[float] Embedding for the text. @@ -212,7 +213,7 @@ async def aembed_query(self, text: str) -> list[float]: text : str The text to embed. - Returns + Returns: ------- List[float] Embedding for the text. @@ -227,7 +228,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: texts : List[str] The list of texts to embed. - Returns + Returns: ------- List[List[float]] List of embeddings, one for each text. diff --git a/lisa-sdk/lisapy/main.py b/lisa-sdk/lisapy/main.py index 42f0a9051..f79cd7509 100644 --- a/lisa-sdk/lisapy/main.py +++ b/lisa-sdk/lisapy/main.py @@ -78,7 +78,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def list_models(self) -> list[dict[str, Any]]: """List all models from the LiteLLM proxy. - Returns + Returns: ------- list[dict[str, Any]] List of model dicts in OpenAI format (id, object, created, owned_by). @@ -94,7 +94,7 @@ def list_models(self) -> list[dict[str, Any]]: def health(self) -> dict[str, Any]: """Check health of the LiteLLM proxy. - Returns + Returns: ------- dict[str, Any] Health status response from the proxy. @@ -109,7 +109,7 @@ def health(self) -> dict[str, Any]: def health_readiness(self) -> dict[str, Any]: """Check readiness of the LiteLLM proxy. - Returns + Returns: ------- dict[str, Any] Readiness status response from the proxy. @@ -128,7 +128,7 @@ def health_liveliness(self) -> dict[str, Any]: normalizes it to ``{"status": "I'm alive!"}`` for a consistent dict return type across all health methods. - Returns + Returns: ------- dict[str, Any] Liveliness status response from the proxy. @@ -149,7 +149,7 @@ def get_model_info(self) -> list[ModelInfoEntry]: Returns the full LiteLLM model database including litellm_params, provider details, and model configuration. - Returns + Returns: ------- list[ModelInfoEntry] List of model info entries with name, params, and metadata. @@ -244,7 +244,7 @@ def complete(self, prompt: str, model: str, **kwargs: Any) -> CompletionResponse Additional OpenAI completions parameters (temperature, max_tokens, etc.). Unknown parameters are filtered out. - Returns + Returns: ------- CompletionResponse Legacy completion response with id, choices, and usage. @@ -274,7 +274,7 @@ async def acomplete(self, prompt: str, model: str, **kwargs: Any) -> CompletionR Additional OpenAI completions parameters (temperature, max_tokens, etc.). Unknown parameters are filtered out. - Returns + Returns: ------- CompletionResponse Legacy completion response with id, choices, and usage. @@ -312,7 +312,7 @@ def generate_image(self, prompt: str, model: str, **kwargs: Any) -> ImageRespons **kwargs : Any Additional parameters (n, size, quality, response_format, style). - Returns + Returns: ------- ImageResponse Image generation response with created timestamp and image data. @@ -341,7 +341,7 @@ async def agenerate_image(self, prompt: str, model: str, **kwargs: Any) -> Image **kwargs : Any Additional parameters (n, size, quality, response_format, style). - Returns + Returns: ------- ImageResponse Image generation response with created timestamp and image data. @@ -382,7 +382,7 @@ def text_to_speech(self, text: str, model: str, voice: str = "alloy", **kwargs: **kwargs : Any Additional parameters (response_format, speed). - Returns + Returns: ------- bytes Raw audio content. @@ -414,7 +414,7 @@ async def atext_to_speech(self, text: str, model: str, voice: str = "alloy", **k **kwargs : Any Additional parameters (response_format, speed). - Returns + Returns: ------- bytes Raw audio content. @@ -456,12 +456,11 @@ def transcribe(self, file: str | bytes, model: str, filename: str = "audio.mp3", **kwargs : Any Additional parameters (language, prompt, response_format, temperature). - Returns + Returns: ------- dict[str, Any] Transcription response with text and metadata. """ - if isinstance(file, str): file_path = Path(file) if not file_path.is_file(): @@ -501,12 +500,11 @@ async def atranscribe( **kwargs : Any Additional parameters (language, prompt, response_format, temperature). - Returns + Returns: ------- dict[str, Any] Transcription response with text and metadata. """ - if isinstance(file, str): file_path = Path(file) if not file_path.is_file(): @@ -549,7 +547,7 @@ def generate(self, prompt: str, model: FoundationModel) -> Response: model : FoundationModel Foundation model for text generation. - Returns + Returns: ------- Response Text generation response. @@ -579,7 +577,7 @@ async def agenerate(self, prompt: str, model: FoundationModel) -> Response: model : FoundationModel Foundation model for text generation. - Returns + Returns: ------- Response Text generation response. @@ -618,7 +616,7 @@ def generate_stream(self, prompt: str, model: FoundationModel) -> Generator[Stre model : FoundationModel Foundation model for text generation. - Returns + Returns: ------- Generator[StreamingResponse, None, None] Text generation streaming response. @@ -662,7 +660,7 @@ async def agenerate_stream(self, prompt: str, model: FoundationModel) -> AsyncGe model : FoundationModel Foundation model for text generation. - Returns + Returns: ------- AsyncGenerator[StreamingResponse, None] Text generation streaming response. @@ -712,7 +710,7 @@ def embed(self, texts: str | list[str], model: FoundationModel) -> list[list[flo model : FoundationModel Foundation model for text embeddings. - Returns + Returns: ------- List[List[float]] Text embeddings as a batched response. @@ -739,7 +737,7 @@ async def aembed(self, texts: str | list[str], model: FoundationModel) -> list[l model : FoundationModel Foundation model for text embeddings. - Returns + Returns: ------- List[List[float]] Text embeddings as a batched response. diff --git a/mcp_server_deployer/src/lib/scripts/http-s3-entrypoint.sh b/mcp_server_deployer/src/lib/scripts/http-s3-entrypoint.sh old mode 100644 new mode 100755 diff --git a/mcp_server_deployer/src/lib/scripts/stdio-prebuilt-s3-entrypoint.sh b/mcp_server_deployer/src/lib/scripts/stdio-prebuilt-s3-entrypoint.sh old mode 100644 new mode 100755 diff --git a/mcp_server_deployer/src/lib/scripts/stdio-s3-entrypoint.sh b/mcp_server_deployer/src/lib/scripts/stdio-s3-entrypoint.sh old mode 100644 new mode 100755 diff --git a/pyproject.toml b/pyproject.toml index bece6911b..102e835ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,14 +52,20 @@ exclude = [ line-length = 120 [tool.ruff.lint] -ignore = ["D401"] -select = ["E", "F", "W", "PLC0415"] # Enable pycodestyle errors, pyflakes, warnings, and import-outside-toplevel +ignore = ["D401", "D100", "D104", "D103", "D101", "D102", "D105", "D107"] +select = ["E", "F", "W", "PLC0415", "D"] # Enable pycodestyle, pyflakes, warnings, import-outside-toplevel, pydocstyle + +[tool.ruff.lint.pydocstyle] +convention = "google" [tool.ruff.lint.per-file-ignores] -# Allow imports in functions for test files +# Allow imports in functions for test files, relax docstring requirements in tests "test/lambda/test_session_lambda.py" = ["E402", "PLC0415"] -"test/**/*.py" = ["E402", "PLC0415"] -"**/test_*.py" = ["E402", "PLC0415"] +"test/**/*.py" = ["E402", "PLC0415", "D"] +"**/test_*.py" = ["E402", "PLC0415", "D"] + +[tool.ruff.format] +docstring-code-format = true [tool.pytest.ini_options] addopts = "--strict-markers -vv -x" diff --git a/scripts/audit_dependencies.py b/scripts/audit_dependencies.py index 8749809fd..704db6c59 100755 --- a/scripts/audit_dependencies.py +++ b/scripts/audit_dependencies.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Audit dependency versions across all LISA components. +"""Audit dependency versions across all LISA components. This script finds version inconsistencies across: - requirements.txt files @@ -48,8 +47,7 @@ def find_files(self, pattern: str, exclude_dirs: list[str] | None = None) -> lis return sorted(files) def parse_requirement_line(self, line: str) -> tuple[str, str] | None: - """ - Parse a requirement line into (package, version_spec). + """Parse a requirement line into (package, version_spec). Returns None if line is a comment, empty, or not a simple requirement. """ @@ -196,8 +194,7 @@ def normalize_version_spec(self, spec: str) -> str: return spec def are_versions_compatible(self, specs: set[str]) -> bool: - """ - Check if version specs are compatible. + """Check if version specs are compatible. This is a simplified check - it considers specs compatible if: - They're identical @@ -232,8 +229,7 @@ def are_versions_compatible(self, specs: set[str]) -> bool: return len(exact_versions) <= 1 def generate_report(self) -> tuple[dict[str, dict[str, set[str]]], int]: - """ - Generate inconsistency report. + """Generate inconsistency report. Returns (inconsistencies, total_packages_checked) """ @@ -252,8 +248,7 @@ def generate_report(self) -> tuple[dict[str, dict[str, set[str]]], int]: return inconsistencies, len(self.package_versions) def run_audit(self) -> int: - """ - Run full audit and print report. + """Run full audit and print report. Returns exit code (0 if no issues, 1 if inconsistencies found). """ diff --git a/scripts/bootstrap.mjs b/scripts/bootstrap.mjs old mode 100644 new mode 100755 diff --git a/scripts/check-for-models.mjs b/scripts/check-for-models.mjs old mode 100644 new mode 100755 diff --git a/scripts/config.mjs b/scripts/config.mjs old mode 100644 new mode 100755 diff --git a/scripts/convert-to-safetensors.py b/scripts/convert-to-safetensors.py index 6be27bb30..2bf3f0bc6 100644 --- a/scripts/convert-to-safetensors.py +++ b/scripts/convert-to-safetensors.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Script that will convert standard pytorch weights to safe tensor weights. +"""Script that will convert standard pytorch weights to safe tensor weights. Parameters ---------- diff --git a/scripts/copy-deps.mjs b/scripts/copy-deps.mjs old mode 100644 new mode 100755 diff --git a/scripts/deploy.mjs b/scripts/deploy.mjs old mode 100644 new mode 100755 diff --git a/scripts/destroy.mjs b/scripts/destroy.mjs old mode 100644 new mode 100755 diff --git a/scripts/docker-login.mjs b/scripts/docker-login.mjs old mode 100644 new mode 100755 diff --git a/scripts/generate-baseline.mjs b/scripts/generate-baseline.mjs old mode 100644 new mode 100755 diff --git a/scripts/install-python.sh b/scripts/install-python.sh old mode 100644 new mode 100755 diff --git a/scripts/integration-env.mjs b/scripts/integration-env.mjs old mode 100644 new mode 100755 diff --git a/scripts/model-check.mjs b/scripts/model-check.mjs old mode 100644 new mode 100755 diff --git a/scripts/run-integration-tests.mjs b/scripts/run-integration-tests.mjs old mode 100644 new mode 100755 diff --git a/scripts/run-pytest.sh b/scripts/run-pytest.sh old mode 100644 new mode 100755 diff --git a/scripts/verify-config.mjs b/scripts/verify-config.mjs old mode 100644 new mode 100755 diff --git a/test/README.md b/test/README.md index 31cf60e0e..85a62306d 100644 --- a/test/README.md +++ b/test/README.md @@ -17,11 +17,13 @@ test/ ## Running Tests ### Run all tests + ```bash pytest ``` ### Run tests for a specific module + ```bash # MCP Workbench tests pytest test/mcp-workbench/ @@ -34,6 +36,7 @@ pytest test/lambda/ ``` ### Run a specific test file + ```bash pytest test/mcp-workbench/test_core.py ``` @@ -41,21 +44,26 @@ pytest test/mcp-workbench/test_core.py ## Module Test Organization ### MCP Workbench (`test/mcp-workbench/`) + Tests for the MCP Workbench module located in `lib/serve/mcp-workbench/src/mcpworkbench/` ### LISA SDK (`test/lisa-sdk/`) + Tests for the LISA Python SDK located in `lisa-sdk/lisapy/` ### Lambda (`test/lambda/`) + Tests for Lambda functions located in `lambda/` ## Configuration Tests are configured in: + - `pytest.ini` - Main pytest configuration with PYTHONPATH settings - `pyproject.toml` - Additional pytest and mypy configuration The PYTHONPATH is configured to include: + - `lambda/` - `lisa-sdk/` - `lib/serve/rest-api/src/` @@ -66,6 +74,7 @@ This allows tests to import modules from their respective source directories. ## Type Checking Run mypy type checking: + ```bash mypy --config-file=pyproject.toml lisa-sdk/ lib/serve/mcp-workbench/src/ ``` diff --git a/test/integration/config_loader.py b/test/integration/config_loader.py index 8b77da62e..40ef0b78a 100644 --- a/test/integration/config_loader.py +++ b/test/integration/config_loader.py @@ -14,8 +14,8 @@ """Load integration test config from config-custom.yaml (and config-base.yaml). -Values are used as defaults when CLI options are not provided. Mirrors the -behavior of scripts/config.mjs and scripts/integration-env.mjs. +Values are used as defaults when CLI options are not provided. Mirrors the behavior of scripts/config.mjs and +scripts/integration-env.mjs. """ from __future__ import annotations @@ -66,8 +66,8 @@ def _deep_merge(base: dict, override: dict) -> None: def get_config_values() -> dict[str, str]: - """ - Extract deployment-related values from config. + """Extract deployment-related values from config. + Supports both flat config and env-based config (env: X, X: { deploymentName, ... }). """ config = load_config() @@ -95,8 +95,8 @@ def get(key: str, default: str = "") -> str: def fetch_url_from_aws(kind: str) -> str: - """ - Fetch API or ALB URL from AWS via integration-env.mjs. + """Fetch API or ALB URL from AWS via integration-env.mjs. + kind: "api" -> API Gateway URL, "alb" -> REST/ALB URL. Returns empty string on failure. """ diff --git a/test/integration/rag/test_rag_collections_integration.py b/test/integration/rag/test_rag_collections_integration.py old mode 100644 new mode 100755 index 767c86844..ae20092b8 --- a/test/integration/rag/test_rag_collections_integration.py +++ b/test/integration/rag/test_rag_collections_integration.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Integration tests for RAG Collections. +"""Integration tests for RAG Collections. This test suite validates end-to-end functionality of RAG collections including: - Collection creation and management @@ -60,7 +59,10 @@ class RagIntegrationFixtures: - """Shared fixtures for RAG integration tests. Contains no test methods.""" + """Shared fixtures for RAG integration tests. + + Contains no test methods. + """ created_collections = [] @@ -573,18 +575,16 @@ def test_06_delete_collection_with_documents( class TestPipelineRagCollectionIntegration(RagIntegrationFixtures): """Integration tests for RAG collections using the S3 pipeline EventBridge trigger. - Tests end-to-end pipeline ingestion and deletion by dropping/removing files - directly in the pipeline S3 bucket and verifying the event-driven handler - processes them correctly. + Tests end-to-end pipeline ingestion and deletion by dropping/removing files directly in the pipeline S3 bucket and + verifying the event-driven handler processes them correctly. """ @pytest.fixture(scope="class") def pipeline_info(self, lisa_client: LisaApi, test_repository_id: str, test_embedding_model: str): - """Discover the pipeline S3 bucket/prefix from TEST_REPOSITORY_ID and create - a collection whose name is 'default' to test pipeline name-to-UUID resolution. + """Discover the pipeline S3 bucket/prefix from TEST_REPOSITORY_ID and create a collection whose name is + 'default' to test pipeline name-to-UUID resolution. - Yields the info dict needed by test_01 and test_02. - Cleans up the collection on teardown. + Yields the info dict needed by test_01 and test_02. Cleans up the collection on teardown. """ # Discover pipeline S3 bucket from the existing repository repo = lisa_client.get_repository(test_repository_id) @@ -624,9 +624,8 @@ def test_01_pipeline_ingest_resolves_collection_name_to_uuid( test_embedding_model: str, pipeline_info: dict, ): - """Drop a file into the pipeline S3 bucket and verify it ingests under the - UUID collection, confirming the pipeline resolves collection name to UUID. - """ + """Drop a file into the pipeline S3 bucket and verify it ingests under the UUID collection, confirming the + pipeline resolves collection name to UUID.""" import boto3 region = os.getenv("AWS_DEFAULT_REGION", "us-east-1") @@ -649,11 +648,11 @@ def test_01_pipeline_ingest_resolves_collection_name_to_uuid( try: documents = lisa_client.list_documents(test_repository_id, collection_uuid) if documents: - logger.info(f"✓ Document ingested after {int(time.time()-start)}s under UUID {collection_uuid}") + logger.info(f"✓ Document ingested after {int(time.time() - start)}s under UUID {collection_uuid}") break except Exception as e: logger.debug(f"Waiting for pipeline ingest: {e}") - logger.info(f"Polling for ingested document... ({int(time.time()-start)}s elapsed)") + logger.info(f"Polling for ingested document... ({int(time.time() - start)}s elapsed)") time.sleep(15) collection_name = pipeline_info["collection_name"] @@ -675,9 +674,8 @@ def test_02_pipeline_delete_resolves_collection_name_to_uuid( test_repository_id: str, pipeline_info: dict, ): - """Delete the S3 file and verify the document is removed, confirming the - pipeline delete handler resolves collection name to UUID. - """ + """Delete the S3 file and verify the document is removed, confirming the pipeline delete handler resolves + collection name to UUID.""" import boto3 region = os.getenv("AWS_DEFAULT_REGION", "us-east-1") @@ -696,9 +694,9 @@ def test_02_pipeline_delete_resolves_collection_name_to_uuid( while time.time() - start < max_wait: remaining = lisa_client.list_documents(test_repository_id, collection_uuid) if not any(d["document_id"] == document_id for d in remaining): - logger.info(f"✓ Document deleted after {int(time.time()-start)}s via pipeline deletion event") + logger.info(f"✓ Document deleted after {int(time.time() - start)}s via pipeline deletion event") break - logger.info(f"Polling for document deletion... ({int(time.time()-start)}s elapsed)") + logger.info(f"Polling for document deletion... ({int(time.time() - start)}s elapsed)") time.sleep(15) else: assert False, f"Pipeline deletion timed out after {max_wait}s. Delete handler did not process the event." @@ -707,16 +705,16 @@ def test_02_pipeline_delete_resolves_collection_name_to_uuid( class TestDefaultCollectionPath(RagIntegrationFixtures): """Tests that ingest and delete work when no collectionId is specified (default path). - The repository's embeddingModelId acts as the implicit collection. After create_default_collection - runs in the state machine, a real UUID-backed collection exists in DDB for this path. + The repository's embeddingModelId acts as the implicit collection. After create_default_collection runs in the state + machine, a real UUID-backed collection exists in DDB for this path. """ @pytest.fixture(scope="class") def default_collection_id(self, lisa_client: LisaApi, test_repository_id: str) -> str: """Resolve the default collection UUID for the repository. - Calls list_collections and returns the one marked default, or falls back to - the repository's embeddingModelId if no default collection exists yet. + Calls list_collections and returns the one marked default, or falls back to the repository's embeddingModelId if + no default collection exists yet. """ collections_resp = lisa_client.list_collections(test_repository_id) collections = collections_resp.get("collections", []) @@ -789,7 +787,7 @@ def test_02_ingest_to_default_collection( try: documents = lisa_client.list_documents(test_repository_id, default_collection_id) if documents: - logger.info(f"✓ Document ingested after {int(time.time()-start)}s") + logger.info(f"✓ Document ingested after {int(time.time() - start)}s") break except Exception as e: logger.debug(f"Waiting: {e}") @@ -824,7 +822,7 @@ def test_03_delete_from_default_collection( while time.time() - start < max_wait: remaining = lisa_client.list_documents(test_repository_id, default_collection_id) if not any(d.get("document_id") == document_id for d in remaining): - logger.info(f"✓ Document deleted after {int(time.time()-start)}s") + logger.info(f"✓ Document deleted after {int(time.time() - start)}s") break time.sleep(10) else: diff --git a/test/integration/sdk/README.md b/test/integration/sdk/README.md index 0debe8606..8d06071e1 100644 --- a/test/integration/sdk/README.md +++ b/test/integration/sdk/README.md @@ -85,6 +85,7 @@ When adding new integration tests: ### Authentication Errors If you see authentication errors: + - Verify AWS credentials are configured correctly - Check that the deployment name matches your LISA deployment - Ensure the management key exists in Secrets Manager @@ -92,6 +93,7 @@ If you see authentication errors: ### Connection Errors If you see connection errors: + - Verify the API URL is correct and accessible - Check SSL verification settings (`--verify false` for self-signed certs) - Ensure network connectivity to the LISA deployment @@ -99,6 +101,7 @@ If you see connection errors: ### Skipped Tests Many tests are skipped by default because they require: + - Specific models to be deployed (TGI, instructor embeddings, etc.) - Specific configurations (API Gateway vs REST URL) - Management tokens (not all deployments support this) diff --git a/test/integration/sdk/conftest.py b/test/integration/sdk/conftest.py index 4226437bd..a2295edcd 100644 --- a/test/integration/sdk/conftest.py +++ b/test/integration/sdk/conftest.py @@ -14,11 +14,11 @@ """Sets the input parameters for lisa-sdk tests. -Note: pytest_addoption for --api, --url, etc. is in the root conftest.py because -pytest parses command-line options before loading subdirectory conftests. +Note: pytest_addoption for --api, --url, etc. is in the root conftest.py because pytest parses command-line options +before loading subdirectory conftests. -When --api/--url are not provided, values are loaded from config-custom.yaml or -fetched from AWS via scripts/integration-env.mjs (same as run-integration-tests.sh). +When --api/--url are not provided, values are loaded from config-custom.yaml or fetched from AWS via +scripts/integration-env.mjs (same as run-integration-tests.sh). """ import logging @@ -55,7 +55,10 @@ def _resolve_url_option(pytestconfig: pytest.Config, kind: str) -> str: @pytest.fixture(scope="session") def url(pytestconfig: pytest.Config) -> str: - """Get the REST url (ALB). From --url, or config-custom.yaml + AWS.""" + """Get the REST url (ALB). + + From --url, or config-custom.yaml + AWS. + """ val = _resolve_url_option(pytestconfig, "url") if not val: pytest.skip( @@ -67,7 +70,10 @@ def url(pytestconfig: pytest.Config) -> str: @pytest.fixture(scope="session") def api(pytestconfig: pytest.Config) -> str: - """Get the API Gateway url. From --api, or config-custom.yaml + AWS.""" + """Get the API Gateway url. + + From --api, or config-custom.yaml + AWS. + """ val = _resolve_url_option(pytestconfig, "api") if not val: pytest.skip( @@ -95,7 +101,10 @@ def verify(pytestconfig: pytest.Config) -> bool | Any: @pytest.fixture(scope="session") def api_key(pytestconfig: pytest.Config) -> str: - """Get management key from Secrets Manager. Uses same multi-pattern lookup as RAG tests.""" + """Get management key from Secrets Manager. + + Uses same multi-pattern lookup as RAG tests. + """ profile = _resolve_option(pytestconfig, "profile", "profile") or "default" deployment_name = _resolve_option(pytestconfig, "deployment", "deployment") or "app" stage = _resolve_option(pytestconfig, "stage", "stage") or "prod" @@ -116,9 +125,7 @@ def api_key(pytestconfig: pytest.Config) -> str: @pytest.fixture(scope="session") def api_token(pytestconfig: pytest.Config, api_key: str) -> Generator: - """ - Create a token item in DynamoDB with expiration if none is provided - """ + """Create a token item in DynamoDB with expiration if none is provided.""" auth_token = pytestconfig.getoption("auth_token") if auth_token is not None: return diff --git a/test/integration/sdk/test_integration_sdk_rag.py b/test/integration/sdk/test_integration_sdk_rag.py index f3478e454..93eb8a641 100644 --- a/test/integration/sdk/test_integration_sdk_rag.py +++ b/test/integration/sdk/test_integration_sdk_rag.py @@ -14,8 +14,8 @@ """Integration tests for RAG SDK document operations. -Tests document ingestion, listing, and deletion via the LISA SDK against a deployed environment. -Requires: deployed LISA with at least one repository and one embedding model. +Tests document ingestion, listing, and deletion via the LISA SDK against a deployed environment. Requires: deployed LISA +with at least one repository and one embedding model. """ import logging @@ -37,7 +37,6 @@ class TestLisaRag: - @pytest.fixture(autouse=True, scope="class") def setup_class(self, lisa_api: LisaApi, request: pytest.FixtureRequest) -> None: # Find the specific test repository @@ -84,7 +83,10 @@ def cleanup_ingested_documents(self, lisa_api: LisaApi, setup_class: None) -> No logger.exception(f"CLEANUP: Failed to delete document {doc_id} — ignoring") def _upload_and_ingest(self, lisa_api: LisaApi, content: str, prefix: str) -> str: - """Upload and ingest a single temp file. Returns the s3Path from the ingestion job.""" + """Upload and ingest a single temp file. + + Returns the s3Path from the ingestion job. + """ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", prefix=prefix, delete=False) as f: f.write(content) temp_path = f.name @@ -117,8 +119,9 @@ def _extract_doc_id(doc: dict) -> str | None: return doc.get("document_id") or doc.get("documentId") or doc.get("id") def _wait_for_document(self, lisa_api: LisaApi, s3_path: str) -> str: - """Poll until a document with matching source appears in list_documents. Returns its document_id. + """Poll until a document with matching source appears in list_documents. + Returns its document_id. The ingest API returns s3Path (e.g. s3://bucket/key) which matches RagDocument.source exactly. """ start = time.time() diff --git a/test/integration/sdk/test_llm_proxy.py b/test/integration/sdk/test_llm_proxy.py index fadfb44a5..7674ca794 100644 --- a/test/integration/sdk/test_llm_proxy.py +++ b/test/integration/sdk/test_llm_proxy.py @@ -14,8 +14,8 @@ """Integration tests for the LLM proxy (LiteLLM REST API). -Tests the LiteLLM proxy endpoints via the LisaLlm SDK against a deployed environment. -Requires: deployed LISA with at least one textgen model. +Tests the LiteLLM proxy endpoints via the LisaLlm SDK against a deployed environment. Requires: deployed LISA with at +least one textgen model. """ import logging @@ -34,9 +34,8 @@ def _get_textgen_model(lisa_llm: LisaLlm) -> FoundationModel: """Discover the first available textgen model using the model/info endpoint. - Uses get_model_info() which returns the full LiteLLM model database including - model_info.mode (e.g. "chat", "embedding") for reliable model type detection. - Falls back to name-based heuristics if mode is not available. + Uses get_model_info() which returns the full LiteLLM model database including model_info.mode (e.g. "chat", + "embedding") for reliable model type detection. Falls back to name-based heuristics if mode is not available. """ try: entries = lisa_llm.get_model_info() @@ -221,7 +220,6 @@ def test_health_liveliness(lisa_llm: LisaLlm) -> None: def test_get_model_info(lisa_llm: LisaLlm) -> None: """Model info endpoint should return a list of ModelInfoEntry objects.""" - result = lisa_llm.get_model_info() assert isinstance(result, list), f"Expected list, got {type(result)}" if not result: @@ -240,7 +238,6 @@ def test_get_model_info(lisa_llm: LisaLlm) -> None: def test_complete(lisa_llm: LisaLlm) -> None: """Legacy completions endpoint should return a CompletionResponse.""" - model = _get_textgen_model(lisa_llm) try: result = lisa_llm.complete( diff --git a/test/integration/test_repository_update_metadata_preservation.py b/test/integration/test_repository_update_metadata_preservation.py index ac4a3fa19..8e04e702f 100644 --- a/test/integration/test_repository_update_metadata_preservation.py +++ b/test/integration/test_repository_update_metadata_preservation.py @@ -142,12 +142,12 @@ def test_bedrock_kb_update_preserves_existing_metadata( } # Act - with patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), patch( - "repository.lambda_functions.build_pipeline_configs_from_kb_config" - ) as mock_build_pipelines, patch("utilities.auth.is_admin", return_value=True), patch( - "utilities.auth.user_has_group_access", return_value=True + with ( + patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), + patch("repository.lambda_functions.build_pipeline_configs_from_kb_config") as mock_build_pipelines, + patch("utilities.auth.is_admin", return_value=True), + patch("utilities.auth.user_has_group_access", return_value=True), ): - # Mock the pipeline builder to return pipelines without metadata mock_build_pipelines.return_value = [ { @@ -215,9 +215,11 @@ def test_direct_pipeline_update_preserves_metadata_when_missing( } # Act - with patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), patch( - "utilities.auth.is_admin", return_value=True - ), patch("utilities.auth.user_has_group_access", return_value=True): + with ( + patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), + patch("utilities.auth.is_admin", return_value=True), + patch("utilities.auth.user_has_group_access", return_value=True), + ): _result = update_repository(event, lambda_context) # Assert @@ -273,9 +275,11 @@ def test_partial_metadata_update_preserves_missing_tags( } # Act - with patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), patch( - "utilities.auth.is_admin", return_value=True - ), patch("utilities.auth.user_has_group_access", return_value=True): + with ( + patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), + patch("utilities.auth.is_admin", return_value=True), + patch("utilities.auth.user_has_group_access", return_value=True), + ): _result = update_repository(event, lambda_context) # Assert @@ -326,9 +330,11 @@ def test_complete_metadata_replacement_when_tags_provided( } # Act - with patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), patch( - "utilities.auth.is_admin", return_value=True - ), patch("utilities.auth.user_has_group_access", return_value=True): + with ( + patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), + patch("utilities.auth.is_admin", return_value=True), + patch("utilities.auth.user_has_group_access", return_value=True), + ): _result = update_repository(event, lambda_context) # Assert @@ -361,7 +367,9 @@ def test_no_metadata_preservation_for_new_collections( knowledgeBaseId="kb-123", dataSources=[ BedrockDataSource( - id="datasource-2", name="New Data Source", s3Uri="s3://new-docs/" # New data source ID + id="datasource-2", + name="New Data Source", + s3Uri="s3://new-docs/", # New data source ID ) ], ) @@ -376,12 +384,12 @@ def test_no_metadata_preservation_for_new_collections( } # Act - with patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), patch( - "repository.lambda_functions.build_pipeline_configs_from_kb_config" - ) as mock_build_pipelines, patch("utilities.auth.is_admin", return_value=True), patch( - "utilities.auth.user_has_group_access", return_value=True + with ( + patch("repository.lambda_functions.vs_repo", mock_vector_store_repo), + patch("repository.lambda_functions.build_pipeline_configs_from_kb_config") as mock_build_pipelines, + patch("utilities.auth.is_admin", return_value=True), + patch("utilities.auth.user_has_group_access", return_value=True), ): - mock_build_pipelines.return_value = [ { "s3Bucket": "new-docs", diff --git a/test/lambda/test_audit_logging.py b/test/lambda/test_audit_logging.py index 8e26b6ca6..681940a18 100644 --- a/test/lambda/test_audit_logging.py +++ b/test/lambda/test_audit_logging.py @@ -83,14 +83,13 @@ async def test_create_model_logs_all_required_fields(self, caplog): ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -200,14 +199,13 @@ async def test_create_model_logs_container_details_for_lisa_hosted(self, caplog) ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -270,14 +268,13 @@ async def test_create_model_logs_without_container_config(self, caplog): ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -337,14 +334,13 @@ async def test_create_model_does_not_log_sensitive_data(self, caplog): ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -407,14 +403,13 @@ async def test_create_model_logs_for_successful_request(self, caplog): ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks for successful creation mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -462,14 +457,13 @@ async def test_create_model_logs_for_failed_request(self, caplog): ) # Mock the handler to raise ModelAlreadyExistsError - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -535,14 +529,13 @@ async def test_create_model_extracts_real_ip_from_api_gateway_context(self, capl ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -587,14 +580,12 @@ async def test_create_model_handles_missing_event_context(self, caplog): ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.auth.is_admin" - ) as mock_is_admin, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_decorator_is_admin, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin, + caplog.at_level(logging.INFO), ): - # Setup mocks - simulate admin with no event context mock_is_admin.return_value = True mock_decorator_is_admin.return_value = True @@ -707,14 +698,13 @@ async def test_create_model_extracts_registry_domain_from_various_formats(self, ) # Mock the handler and auth functions - with patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, patch( - "utilities.fastapi_middleware.auth_decorators.is_admin" - ) as mock_is_admin, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, caplog.at_level( - logging.INFO + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_handler_class, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + caplog.at_level(logging.INFO), ): - # Setup mocks mock_is_admin.return_value = True mock_get_groups.return_value = ["admin-group"] @@ -736,6 +726,6 @@ async def test_create_model_extracts_registry_domain_from_various_formats(self, assert log_record is not None, f"Log not found for image: {image_url}" actual_domain = log_record.container["registry_domain"] - assert actual_domain == expected_domain, ( - f"Expected domain '{expected_domain}' for image '{image_url}', " f"got '{actual_domain}'" - ) + assert ( + actual_domain == expected_domain + ), f"Expected domain '{expected_domain}' for image '{image_url}', got '{actual_domain}'" diff --git a/test/lambda/test_authorizer_lambda.py b/test/lambda/test_authorizer_lambda.py index d0d1530d5..5705e6432 100644 --- a/test/lambda/test_authorizer_lambda.py +++ b/test/lambda/test_authorizer_lambda.py @@ -760,8 +760,8 @@ def test_lambda_handler_rag_admin_only_user_gets_allow( ): """Test lambda_handler allows a user who is only in the rag_admin group. - A user only in the rag_admin group (not admin, not user group) should get - an Allow policy from the authorizer, not a Deny. + A user only in the rag_admin group (not admin, not user group) should get an Allow policy from the authorizer, not a + Deny. """ mock_get_management_tokens.return_value = [] mock_is_valid_api_token.return_value = False diff --git a/test/lambda/test_chat_assistant_stacks_lambda.py b/test/lambda/test_chat_assistant_stacks_lambda.py index dee9ff127..6ad2d5963 100644 --- a/test/lambda/test_chat_assistant_stacks_lambda.py +++ b/test/lambda/test_chat_assistant_stacks_lambda.py @@ -72,14 +72,17 @@ def wrapper(event, context): @pytest.fixture(scope="function") def chat_stacks_handlers(patch_is_admin_for_chat_stacks): - """Patch retry_config and api_wrapper only for this module, then import handlers. No global mocks. - Depends on patch_is_admin_for_chat_stacks so handlers are imported after admin_only is restored - (test_repository_lambda patches it at module load). Clear cache to force fresh import with current admin_only.""" + """Patch retry_config and api_wrapper only for this module, then import handlers. + + No global mocks. Depends on patch_is_admin_for_chat_stacks so handlers are imported after admin_only is restored + (test_repository_lambda patches it at module load). Clear cache to force fresh import with current admin_only. + """ for mod in list(sys.modules.keys()): if mod == "chat_assistant_stacks" or mod.startswith("chat_assistant_stacks."): del sys.modules[mod] - with patch("utilities.common_functions.retry_config", retry_config), patch( - "utilities.common_functions.api_wrapper", mock_api_wrapper + with ( + patch("utilities.common_functions.retry_config", retry_config), + patch("utilities.common_functions.api_wrapper", mock_api_wrapper), ): from chat_assistant_stacks.lambda_functions import ( create, diff --git a/test/lambda/test_chunking_strategy.py b/test/lambda/test_chunking_strategy.py old mode 100644 new mode 100755 diff --git a/test/lambda/test_collection_api_integration.py b/test/lambda/test_collection_api_integration.py index 13875f1b8..2a828e50f 100644 --- a/test/lambda/test_collection_api_integration.py +++ b/test/lambda/test_collection_api_integration.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Integration tests for cross-repository collection API. +"""Integration tests for cross-repository collection API. -These tests verify end-to-end functionality with real repository implementations -(using mocked DynamoDB tables). +These tests verify end-to-end functionality with real repository implementations (using mocked DynamoDB tables). """ import os @@ -168,8 +166,7 @@ def mock_count_by_repo(repository_id): def test_cross_repository_query_integration(integration_collection_service, mock_dynamodb_tables): - """ - Full flow: multiple repos in DB → query → aggregated results. + """Full flow: multiple repos in DB → query → aggregated results. Integration test verifying: 1. Service queries multiple repositories @@ -203,8 +200,7 @@ def test_cross_repository_query_integration(integration_collection_service, mock def test_permission_enforcement_integration(integration_collection_service, mock_dynamodb_tables): - """ - Full flow: repos with different permissions → filtered results. + """Full flow: repos with different permissions → filtered results. Integration test verifying: 1. Repository-level permissions are enforced @@ -237,8 +233,7 @@ def test_permission_enforcement_integration(integration_collection_service, mock def test_pagination_with_large_dataset_integration(integration_collection_service, mock_dynamodb_tables): - """ - Full flow: 1000+ collections → paginated results. + """Full flow: 1000+ collections → paginated results. Integration test verifying: 1. Large datasets trigger appropriate pagination strategy @@ -315,8 +310,7 @@ def mock_list_by_repo(repository_id, **kwargs): def test_scalable_pagination_activation_integration(integration_collection_service, mock_dynamodb_tables): - """ - Full flow: large dataset triggers scalable strategy. + """Full flow: large dataset triggers scalable strategy. Integration test verifying: 1. Service estimates collection count @@ -345,8 +339,7 @@ def test_scalable_pagination_activation_integration(integration_collection_servi def test_repository_metadata_enrichment_integration(integration_collection_service, mock_dynamodb_tables): - """ - Full flow: collections enriched with repo names. + """Full flow: collections enriched with repo names. Integration test verifying: 1. Collections are queried from repositories diff --git a/test/lambda/test_collection_id_resolution.py b/test/lambda/test_collection_id_resolution.py index 8ca52e77b..014069350 100644 --- a/test/lambda/test_collection_id_resolution.py +++ b/test/lambda/test_collection_id_resolution.py @@ -14,11 +14,9 @@ """Tests for pipeline collectionId resolution in ingest and delete handlers. -Pipeline configs may store a collection name (e.g. "default") as the collectionId -rather than the auto-generated UUID used as the collections table primary key. -These tests verify that find_by_id_or_name correctly resolves a name to its -corresponding UUID at read time, ensuring both ingest and delete handlers -operate on the correct collection. +Pipeline configs may store a collection name (e.g. "default") as the collectionId rather than the auto-generated UUID +used as the collections table primary key. These tests verify that find_by_id_or_name correctly resolves a name to its +corresponding UUID at read time, ensuring both ingest and delete handlers operate on the correct collection. """ import os @@ -125,9 +123,8 @@ def _build_real_collection_service(mock_dynamodb_table): def test_ingest_name_resolves_to_uuid_via_real_service(setup_env, mock_dynamodb_table): """Ingest handler resolves a collection name to its UUID before saving the job. - When a pipeline config references a collection by name, the ingest handler - must resolve that name to the collection's UUID so the job is persisted - with the correct identifier. + When a pipeline config references a collection by name, the ingest handler must resolve that name to the + collection's UUID so the job is persisted with the correct identifier. """ # DynamoDB: get_item (UUID lookup) misses; query (name lookup) returns the collection mock_dynamodb_table.get_item.return_value = {} @@ -147,12 +144,12 @@ def test_ingest_name_resolves_to_uuid_via_real_service(setup_env, mock_dynamodb_ } } - with patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_handlers.collection_service", real_svc - ), patch("repository.pipeline_ingest_handlers.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingest_handlers.ingestion_service" + with ( + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_handlers.collection_service", real_svc), + patch("repository.pipeline_ingest_handlers.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_handlers.ingestion_service"), ): - mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} from repository.pipeline_ingest_handlers import handle_pipeline_ingest_event @@ -169,9 +166,8 @@ def test_ingest_name_resolves_to_uuid_via_real_service(setup_env, mock_dynamodb_ def test_delete_name_resolves_to_uuid_via_real_service(setup_env, mock_dynamodb_table): """Delete handler resolves a collection name to its UUID before removing documents. - When a pipeline config references a collection by name, the delete handler - must resolve that name to the collection's UUID and proceed with document - deletion rather than skipping the operation. + When a pipeline config references a collection by name, the delete handler must resolve that name to the + collection's UUID and proceed with document deletion rather than skipping the operation. """ mock_dynamodb_table.get_item.return_value = {} mock_dynamodb_table.query.return_value = {"Items": [_make_collection_item("uuid-abc", "default")]} @@ -187,14 +183,13 @@ def test_delete_name_resolves_to_uuid_via_real_service(setup_env, mock_dynamodb_ } } - with patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_handlers.collection_service", real_svc - ), patch("repository.pipeline_ingest_handlers.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_ingest_handlers.ingestion_job_repository" - ), patch( - "repository.pipeline_ingest_handlers.ingestion_service" + with ( + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_handlers.collection_service", real_svc), + patch("repository.pipeline_ingest_handlers.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_ingest_handlers.ingestion_job_repository"), + patch("repository.pipeline_ingest_handlers.ingestion_service"), ): - mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_doc_repo.find_by_source.return_value = [] diff --git a/test/lambda/test_collection_service.py b/test/lambda/test_collection_service.py index 8f64e914b..d451dfbb7 100644 --- a/test/lambda/test_collection_service.py +++ b/test/lambda/test_collection_service.py @@ -38,7 +38,7 @@ def setup_env(monkeypatch): def test_create_collection(): - """Test collection creation""" + """Test collection creation.""" from repository.collection_service import CollectionService mock_repo = Mock() @@ -66,7 +66,7 @@ def test_create_collection(): def test_get_collection(): - """Test get collection""" + """Test get collection.""" from repository.collection_service import CollectionService mock_repo = Mock() @@ -93,8 +93,7 @@ def test_get_collection(): def test_list_collections(): - """Test list collections""" - + """Test list collections.""" from repository.collection_service import CollectionService mock_repo = Mock() @@ -163,10 +162,10 @@ def test_delete_collection(): mock_ingestion_job_repo = Mock() mock_ingestion_service = Mock() - with patch("repository.collection_service.IngestionJobRepository", return_value=mock_ingestion_job_repo), patch( - "repository.collection_service.DocumentIngestionService", return_value=mock_ingestion_service + with ( + patch("repository.collection_service.IngestionJobRepository", return_value=mock_ingestion_job_repo), + patch("repository.collection_service.DocumentIngestionService", return_value=mock_ingestion_service), ): - result = service.delete_collection( repository_id="test-repo", collection_id="test-coll", @@ -203,10 +202,10 @@ def test_delete_default_collection(): mock_ingestion_job_repo = Mock() mock_ingestion_service = Mock() - with patch("repository.collection_service.IngestionJobRepository", return_value=mock_ingestion_job_repo), patch( - "repository.collection_service.DocumentIngestionService", return_value=mock_ingestion_service + with ( + patch("repository.collection_service.IngestionJobRepository", return_value=mock_ingestion_job_repo), + patch("repository.collection_service.DocumentIngestionService", return_value=mock_ingestion_service), ): - result = service.delete_collection( repository_id="test-repo", collection_id=None, @@ -237,7 +236,7 @@ def test_delete_default_collection(): def test_create_collection_lambda_with_embedding_model(): - """Test create_collection lambda with embedding model specified""" + """Test create_collection lambda with embedding model specified.""" import json from unittest.mock import Mock, patch @@ -281,10 +280,11 @@ def test_create_collection_lambda_with_embedding_model(): status=CollectionStatus.ACTIVE, ) - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_service, patch("utilities.auth.is_admin") as mock_is_admin: - + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_service, + patch("utilities.auth.is_admin") as mock_is_admin, + ): mock_get_repo.return_value = mock_repository mock_service.create_collection.return_value = mock_collection mock_is_admin.return_value = True # Mock admin check to pass @@ -302,7 +302,7 @@ def test_create_collection_lambda_with_embedding_model(): def test_create_collection_lambda_without_embedding_model_with_repository_default(): - """Test create_collection lambda without embedding model but repository has default""" + """Test create_collection lambda without embedding model but repository has default.""" import json from unittest.mock import Mock, patch @@ -346,10 +346,11 @@ def test_create_collection_lambda_without_embedding_model_with_repository_defaul status=CollectionStatus.ACTIVE, ) - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_service, patch("utilities.auth.is_admin") as mock_is_admin: - + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_service, + patch("utilities.auth.is_admin") as mock_is_admin, + ): mock_get_repo.return_value = mock_repository mock_service.create_collection.return_value = mock_collection mock_is_admin.return_value = True # Mock admin check to pass @@ -366,7 +367,7 @@ def test_create_collection_lambda_without_embedding_model_with_repository_defaul def test_create_collection_lambda_without_embedding_model_no_repository_default(): - """Test create_collection lambda fails when no embedding model and no repository default""" + """Test create_collection lambda fails when no embedding model and no repository default.""" import json from unittest.mock import Mock, patch @@ -401,10 +402,11 @@ def test_create_collection_lambda_without_embedding_model_no_repository_default( "embeddingModelId": None, # No default embedding model } - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_service, patch("utilities.auth.is_admin") as mock_is_admin: - + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_service, + patch("utilities.auth.is_admin") as mock_is_admin, + ): mock_get_repo.return_value = mock_repository mock_is_admin.return_value = True # Mock admin check to pass @@ -423,7 +425,7 @@ def test_create_collection_lambda_without_embedding_model_no_repository_default( def test_create_collection_lambda_original_payload(): - """Test create_collection lambda with the original failing payload""" + """Test create_collection lambda with the original failing payload.""" import json from unittest.mock import Mock, patch @@ -463,10 +465,11 @@ def test_create_collection_lambda_original_payload(): status=CollectionStatus.ACTIVE, ) - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_service, patch("utilities.auth.is_admin") as mock_is_admin: - + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_service, + patch("utilities.auth.is_admin") as mock_is_admin, + ): mock_get_repo.return_value = mock_repository mock_service.create_collection.return_value = mock_collection mock_is_admin.return_value = True # Mock admin check to pass diff --git a/test/lambda/test_collection_service_cross_repo.py b/test/lambda/test_collection_service_cross_repo.py index 0fc14c061..238a11d55 100644 --- a/test/lambda/test_collection_service_cross_repo.py +++ b/test/lambda/test_collection_service_cross_repo.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Unit tests for cross-repository collection queries. +"""Unit tests for cross-repository collection queries. These tests follow API-level testing principles: - Test complete workflows, not individual lines @@ -159,8 +158,7 @@ def sample_collections(): def test_list_all_user_collections_admin_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories, sample_collections ): - """ - Complete workflow: Admin requests collections → service queries all repos → returns all collections. + """Complete workflow: Admin requests collections → service queries all repos → returns all collections. Workflow: 1. Admin user requests all collections @@ -201,8 +199,7 @@ def mock_list_by_repo(repository_id, **kwargs): def test_list_all_user_collections_group_access_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories, sample_collections ): - """ - Complete workflow: User with groups → filtered by repo permissions → returns accessible collections. + """Complete workflow: User with groups → filtered by repo permissions → returns accessible collections. Workflow: 1. User with group1 requests collections @@ -241,8 +238,7 @@ def mock_list_by_repo(repository_id, **kwargs): def test_list_all_user_collections_no_access_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories ): - """ - Complete workflow: User with no access → empty list returned. + """Complete workflow: User with no access → empty list returned. Workflow: 1. User with no matching groups requests collections @@ -271,8 +267,7 @@ def test_list_all_user_collections_no_access_workflow( def test_list_all_user_collections_private_collections_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories, sample_collections ): - """ - Complete workflow: User sees collections based on ownership and group access. + """Complete workflow: User sees collections based on ownership and group access. Workflow: 1. User2 (owner of coll-2) requests collections @@ -326,8 +321,7 @@ def mock_list_by_repo(repository_id, **kwargs): def test_pagination_strategy_selection_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories ): - """ - Complete workflow: Service estimates count → selects correct strategy. + """Complete workflow: Service estimates count → selects correct strategy. Workflow: 1. User requests collections @@ -359,8 +353,7 @@ def test_pagination_strategy_selection_workflow( def test_paginate_collections_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories, sample_collections ): - """ - Complete workflow: Request with filter/sort → paginated results. + """Complete workflow: Request with filter/sort → paginated results. Workflow: 1. User requests collections with filter and sort @@ -398,9 +391,8 @@ def mock_list_by_repo(repository_id, **kwargs): def test_rag_admin_sees_admin_restricted_collection_in_accessible_repo( collection_service, mock_vector_store_repo, mock_collection_repo ): - """ - RAG admin sees all collections in repos they have group access to, - including collections restricted to AdminGroup. + """RAG admin sees all collections in repos they have group access to, including collections restricted to + AdminGroup. Workflow: 1. RAG admin (in rag-team, NOT in AdminGroup) requests collections @@ -461,9 +453,9 @@ def test_rag_admin_sees_admin_restricted_collection_in_accessible_repo( def test_rag_admin_cannot_see_collections_in_admin_restricted_repo( collection_service, mock_vector_store_repo, mock_collection_repo ): - """ - RAG admin is blocked from repos where they don't have group access, - even with is_rag_admin=True. Repo-level filtering is unchanged. + """RAG admin is blocked from repos where they don't have group access, even with is_rag_admin=True. + + Repo-level filtering is unchanged. Workflow: 1. RAG admin (in rag-team) requests collections @@ -511,8 +503,8 @@ def test_rag_admin_cannot_see_collections_in_admin_restricted_repo( def test_regular_user_still_filtered_by_collection_allowed_groups( collection_service, mock_vector_store_repo, mock_collection_repo ): - """ - Regular users (is_rag_admin=False) are still filtered by collection allowedGroups. + """Regular users (is_rag_admin=False) are still filtered by collection allowedGroups. + Adding is_rag_admin parameter must not affect existing user behavior. """ now = datetime.now(timezone.utc) @@ -563,8 +555,8 @@ def test_regular_user_still_filtered_by_collection_allowed_groups( def test_full_admin_sees_all_collections_regardless_of_allowed_groups( collection_service, mock_vector_store_repo, mock_collection_repo ): - """ - Full admin (is_admin=True) still bypasses all collection-level filtering. + """Full admin (is_admin=True) still bypasses all collection-level filtering. + Regression: is_rag_admin parameter must not affect admin behavior. """ now = datetime.now(timezone.utc) @@ -602,8 +594,7 @@ def test_full_admin_sees_all_collections_regardless_of_allowed_groups( def test_repository_metadata_enrichment_workflow( collection_service, mock_vector_store_repo, mock_collection_repo, sample_repositories, sample_collections ): - """ - Complete workflow: Collections queried → enriched with repo names. + """Complete workflow: Collections queried → enriched with repo names. Workflow: 1. User requests collections diff --git a/test/lambda/test_create_mcp_server_state_machine.py b/test/lambda/test_create_mcp_server_state_machine.py index e15ff6c26..98d654e4a 100644 --- a/test/lambda/test_create_mcp_server_state_machine.py +++ b/test/lambda/test_create_mcp_server_state_machine.py @@ -145,9 +145,10 @@ def test_handle_deploy_server_success(mcp_servers_table, sample_mcp_server_event """Test successful deployment of server.""" from mcp_server.state_machine.create_mcp_server import handle_deploy_server - with patch("mcp_server.state_machine.create_mcp_server.lambdaClient") as mock_lambda, patch( - "mcp_server.state_machine.create_mcp_server.cfnClient" - ) as mock_cfn: + with ( + patch("mcp_server.state_machine.create_mcp_server.lambdaClient") as mock_lambda, + patch("mcp_server.state_machine.create_mcp_server.cfnClient") as mock_cfn, + ): mock_response = MagicMock() mock_response.read.return_value = json.dumps({"stackName": "test-stack-name"}).encode() mock_lambda.invoke.return_value = {"Payload": mock_response} @@ -181,9 +182,11 @@ def test_handle_deploy_server_missing_stack_name(sample_mcp_server_event): """Test deployment failure when stack name is missing.""" from mcp_server.state_machine.create_mcp_server import handle_deploy_server - with patch("mcp_server.state_machine.create_mcp_server.lambdaClient") as mock_lambda, patch( - "mcp_server.state_machine.create_mcp_server.cfnClient" - ), patch("mcp_server.state_machine.create_mcp_server.mcp_servers_table"): + with ( + patch("mcp_server.state_machine.create_mcp_server.lambdaClient") as mock_lambda, + patch("mcp_server.state_machine.create_mcp_server.cfnClient"), + patch("mcp_server.state_machine.create_mcp_server.mcp_servers_table"), + ): mock_response = MagicMock() mock_response.read.return_value = json.dumps({}).encode() # Missing stackName mock_lambda.invoke.return_value = {"Payload": mock_response} @@ -227,9 +230,10 @@ def test_handle_deploy_server_with_optional_fields(mcp_servers_table, sample_mcp "retries": 3, } - with patch("mcp_server.state_machine.create_mcp_server.lambdaClient") as mock_lambda, patch( - "mcp_server.state_machine.create_mcp_server.cfnClient" - ) as mock_cfn: + with ( + patch("mcp_server.state_machine.create_mcp_server.lambdaClient") as mock_lambda, + patch("mcp_server.state_machine.create_mcp_server.cfnClient") as mock_cfn, + ): mock_response = MagicMock() mock_response.read.return_value = json.dumps({"stackName": "test-stack"}).encode() mock_lambda.invoke.return_value = {"Payload": mock_response} @@ -374,8 +378,9 @@ def test_handle_add_server_to_active_with_connections_table(mcp_servers_table, s event = sample_mcp_server_event.copy() event["stack_name"] = "test-stack-name" - with patch("mcp_server.state_machine.create_mcp_server.ssmClient") as mock_ssm, patch.dict( - os.environ, {"DEPLOYMENT_PREFIX": "/test/lisa"} + with ( + patch("mcp_server.state_machine.create_mcp_server.ssmClient") as mock_ssm, + patch.dict(os.environ, {"DEPLOYMENT_PREFIX": "/test/lisa"}), ): # Mock SSM to return table name and API URL mock_ssm.get_parameter.side_effect = [ diff --git a/test/lambda/test_create_model_state_machine.py b/test/lambda/test_create_model_state_machine.py index 35e576f45..afe0f8801 100644 --- a/test/lambda/test_create_model_state_machine.py +++ b/test/lambda/test_create_model_state_machine.py @@ -814,8 +814,9 @@ def test_handle_add_guardrails_to_litellm_with_guardrails(model_table, guardrail }, } - with patch("models.state_machine.create_model.model_table", model_table), patch( - "models.state_machine.create_model.guardrails_table", guardrails_table + with ( + patch("models.state_machine.create_model.model_table", model_table), + patch("models.state_machine.create_model.guardrails_table", guardrails_table), ): result = handle_add_guardrails_to_litellm(event, lambda_context) @@ -860,8 +861,9 @@ def test_fetch_context_window_from_litellm_no_max_input_tokens(): "model_info": {"id": "test-litellm-id"}, } - with patch("models.state_machine.create_model.litellm_client", mock_litellm_client), patch( - "models.state_machine.create_model.time.sleep" + with ( + patch("models.state_machine.create_model.litellm_client", mock_litellm_client), + patch("models.state_machine.create_model.time.sleep"), ): result = _fetch_context_window_from_litellm("test-litellm-id") assert result is None @@ -871,8 +873,9 @@ def test_fetch_context_window_from_litellm_exception(): """Test fetching context window from LiteLLM when get_model raises an exception.""" mock_litellm_client.get_model.side_effect = Exception("Connection error") - with patch("models.state_machine.create_model.litellm_client", mock_litellm_client), patch( - "models.state_machine.create_model.time.sleep" + with ( + patch("models.state_machine.create_model.litellm_client", mock_litellm_client), + patch("models.state_machine.create_model.time.sleep"), ): result = _fetch_context_window_from_litellm("test-litellm-id") assert result is None @@ -895,8 +898,9 @@ class NoSuchKey(Exception): mock_s3.exceptions = MockS3Exceptions() - with patch("models.state_machine.create_model.s3_client", mock_s3), patch.dict( - os.environ, {"MODELS_BUCKET_NAME": "test-bucket"} + with ( + patch("models.state_machine.create_model.s3_client", mock_s3), + patch.dict(os.environ, {"MODELS_BUCKET_NAME": "test-bucket"}), ): result = _fetch_context_window_from_s3("mistralai/Mistral-7B-Instruct-v0.3", "textgen") assert result == 32768 @@ -928,8 +932,9 @@ class NoSuchKey(Exception): {"Body": MagicMock(read=lambda: json.dumps({"max_position_embeddings": 77}).encode())}, ] - with patch("models.state_machine.create_model.s3_client", mock_s3), patch.dict( - os.environ, {"MODELS_BUCKET_NAME": "test-bucket"} + with ( + patch("models.state_machine.create_model.s3_client", mock_s3), + patch.dict(os.environ, {"MODELS_BUCKET_NAME": "test-bucket"}), ): result = _fetch_context_window_from_s3("sd-model/stable-diffusion-v1", "imagegen") assert result == 77 @@ -952,8 +957,9 @@ class NoSuchKey(Exception): mock_s3.exceptions = MockS3Exceptions() mock_s3.get_object.side_effect = MockS3Exceptions.NoSuchKey("not found") - with patch("models.state_machine.create_model.s3_client", mock_s3), patch.dict( - os.environ, {"MODELS_BUCKET_NAME": "test-bucket"} + with ( + patch("models.state_machine.create_model.s3_client", mock_s3), + patch.dict(os.environ, {"MODELS_BUCKET_NAME": "test-bucket"}), ): result = _fetch_context_window_from_s3("nonexistent-model", "textgen") assert result is None @@ -973,8 +979,9 @@ def test_handle_enrich_context_window_non_lisa_managed(model_table, lambda_conte "model_info": {"id": "test-litellm-id", "max_input_tokens": 100000}, } - with patch("models.state_machine.create_model.model_table", model_table), patch( - "models.state_machine.create_model.litellm_client", mock_litellm_client + with ( + patch("models.state_machine.create_model.model_table", model_table), + patch("models.state_machine.create_model.litellm_client", mock_litellm_client), ): result = handle_enrich_context_window(event, lambda_context) @@ -1007,9 +1014,11 @@ class NoSuchKey(Exception): "Body": MagicMock(read=lambda: json.dumps({"max_position_embeddings": 32768}).encode()) } - with patch("models.state_machine.create_model.model_table", model_table), patch( - "models.state_machine.create_model.s3_client", mock_s3 - ), patch.dict(os.environ, {"MODELS_BUCKET_NAME": "test-bucket"}): + with ( + patch("models.state_machine.create_model.model_table", model_table), + patch("models.state_machine.create_model.s3_client", mock_s3), + patch.dict(os.environ, {"MODELS_BUCKET_NAME": "test-bucket"}), + ): result = handle_enrich_context_window(event, lambda_context) assert result["modelId"] == "lisa-model" @@ -1052,9 +1061,11 @@ def test_handle_enrich_context_window_non_blocking_on_failure(model_table, lambd mock_litellm_client.get_model.side_effect = Exception("LiteLLM is down") - with patch("models.state_machine.create_model.model_table", model_table), patch( - "models.state_machine.create_model.litellm_client", mock_litellm_client - ), patch("models.state_machine.create_model.time.sleep"): + with ( + patch("models.state_machine.create_model.model_table", model_table), + patch("models.state_machine.create_model.litellm_client", mock_litellm_client), + patch("models.state_machine.create_model.time.sleep"), + ): # Should NOT raise result = handle_enrich_context_window(event, lambda_context) assert result["modelId"] == "fail-model" diff --git a/test/lambda/test_encoders.py b/test/lambda/test_encoders.py index bca61ec1f..2677fc9e4 100644 --- a/test/lambda/test_encoders.py +++ b/test/lambda/test_encoders.py @@ -31,7 +31,7 @@ def test_convert_decimal_with_decimal(): - """Test convert_decimal with Decimal values""" + """Test convert_decimal with Decimal values.""" # Test single Decimal result = convert_decimal(Decimal("123.45")) assert result == 123.45 @@ -49,7 +49,7 @@ def test_convert_decimal_with_decimal(): def test_convert_decimal_with_dict(): - """Test convert_decimal with dictionary containing Decimals""" + """Test convert_decimal with dictionary containing Decimals.""" input_dict = { "price": Decimal("99.99"), "quantity": Decimal("5"), @@ -73,7 +73,7 @@ def test_convert_decimal_with_dict(): def test_convert_decimal_with_list(): - """Test convert_decimal with list containing Decimals""" + """Test convert_decimal with list containing Decimals.""" input_list = [Decimal("10.5"), "string_value", 42, [Decimal("3.14"), "nested"], {"amount": Decimal("100.00")}] result = convert_decimal(input_list) @@ -113,7 +113,7 @@ def test_convert_decimal_with_non_decimal_types(): def test_convert_decimal_with_empty_collections(): - """Test convert_decimal with empty dict and list""" + """Test convert_decimal with empty dict and list.""" # Test empty dict result = convert_decimal({}) assert result == {} @@ -124,7 +124,7 @@ def test_convert_decimal_with_empty_collections(): def test_convert_decimal_with_complex_nested_structure(): - """Test convert_decimal with deeply nested structure""" + """Test convert_decimal with deeply nested structure.""" complex_data = { "users": [ { @@ -175,7 +175,7 @@ def test_convert_decimal_with_complex_nested_structure(): def test_convert_decimal_with_mixed_types_in_list(): - """Test convert_decimal with list containing mixed types including nested structures""" + """Test convert_decimal with list containing mixed types including nested structures.""" mixed_list = [ Decimal("123.45"), {"price": Decimal("99.99"), "name": "Item"}, @@ -205,7 +205,7 @@ def test_convert_decimal_with_mixed_types_in_list(): def test_convert_decimal_preserves_original_structure(): - """Test that convert_decimal preserves the original data structure""" + """Test that convert_decimal preserves the original data structure.""" original = {"level1": {"level2": {"level3": [{"value": Decimal("42.0")}, {"value": Decimal("84.0")}]}}} result = convert_decimal(original) diff --git a/test/lambda/test_file_processing.py b/test/lambda/test_file_processing.py index 53cd68394..8d723fdf0 100644 --- a/test/lambda/test_file_processing.py +++ b/test/lambda/test_file_processing.py @@ -76,10 +76,10 @@ def test_generate_chunks_success_with_valid_path(sample_ingestion_job): """Test generate_chunks with valid S3 path.""" # Use a supported file extension for the test sample_ingestion_job.s3_path = "s3://test-bucket/test-key.txt" - with patch("utilities.file_processing.boto3.client") as mock_client, patch( - "utilities.file_processing.s3" - ) as mock_s3_global: - + with ( + patch("utilities.file_processing.boto3.client") as mock_client, + patch("utilities.file_processing.s3") as mock_s3_global, + ): # Setup mocks mock_s3 = MagicMock() mock_client.return_value = mock_s3 diff --git a/test/lambda/test_lambda_auth.py b/test/lambda/test_lambda_auth.py index 47a87cb28..228d7576e 100644 --- a/test/lambda/test_lambda_auth.py +++ b/test/lambda/test_lambda_auth.py @@ -530,7 +530,7 @@ def test_get_username_default(setup_env): def test_user_has_group(): - """Test user_has_group_access helper function""" + """Test user_has_group_access helper function.""" from utilities.auth import user_has_group_access # Test user has group diff --git a/test/lambda/test_lambda_decorators.py b/test/lambda/test_lambda_decorators.py index 0099d8e54..12fd8c4e4 100644 --- a/test/lambda/test_lambda_decorators.py +++ b/test/lambda/test_lambda_decorators.py @@ -176,7 +176,6 @@ class TestGetLambdaContext: def test_get_context_when_set(self): """Test get_lambda_context returns context when set.""" - mock_context = SimpleNamespace(function_name="get-context-test", aws_request_id="req-999") ctx_context.set(mock_context) diff --git a/test/lambda/test_mcp_server_lambda.py b/test/lambda/test_mcp_server_lambda.py index 5eb333982..35847a5b4 100644 --- a/test/lambda/test_mcp_server_lambda.py +++ b/test/lambda/test_mcp_server_lambda.py @@ -165,8 +165,8 @@ def get_error_message(body): def setup_mcp_patches(request, mock_auth): """Set up per-test patches for MCP server lambda functions. - This fixture runs after conftest's setup_auth_patches and ensures - api_wrapper is properly mocked and adds additional patches needed. + This fixture runs after conftest's setup_auth_patches and ensures api_wrapper is properly mocked and adds additional + patches needed. """ # Skip patching for test_lambda_auth.py since it tests the auth module itself if "test_lambda_auth" in request.node.nodeid: diff --git a/test/lambda/test_mcp_workbench_lambda.py b/test/lambda/test_mcp_workbench_lambda.py index bf3d602bb..c13a0718f 100644 --- a/test/lambda/test_mcp_workbench_lambda.py +++ b/test/lambda/test_mcp_workbench_lambda.py @@ -114,7 +114,10 @@ def hello_world(): @pytest.fixture def s3_setup(): - """Set up S3 with moto and create bucket. Uses complete isolation to avoid test interference.""" + """Set up S3 with moto and create bucket. + + Uses complete isolation to avoid test interference. + """ # More aggressive approach: Temporarily replace boto3.client entirely import importlib @@ -212,8 +215,9 @@ def test_get_tool_from_s3(s3_setup): from mcp_workbench.lambda_functions import _get_tool_from_s3 # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), ): tool = _get_tool_from_s3(SAMPLE_TOOL_ID) @@ -228,8 +232,9 @@ def test_get_tool_from_s3_not_found(s3_setup): from mcp_workbench.lambda_functions import _get_tool_from_s3 # Test retrieving non-existent tool with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), ): with pytest.raises(Exception) as excinfo: _get_tool_from_s3("non_existent_tool.py") @@ -250,8 +255,9 @@ def test_get_tool_from_s3_adds_py_extension(s3_setup): from mcp_workbench.lambda_functions import _get_tool_from_s3 # Use the actual function with moto S3, but request without .py extension - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), ): tool = _get_tool_from_s3("test_tool") @@ -279,9 +285,11 @@ def test_read_success(s3_setup, lambda_context): from mcp_workbench.lambda_functions import read # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = read(event, lambda_context) assert response["statusCode"] == 200 @@ -302,9 +310,11 @@ def test_read_not_admin(s3_setup, lambda_context): from mcp_workbench.lambda_functions import read # Use the actual function with moto S3 and patched is_admin - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "utilities.auth.get_username", return_value="regular-user" - ), patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("utilities.auth.get_username", return_value="regular-user"), + patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper), + ): response = read(event, lambda_context) assert response["statusCode"] == 403 @@ -326,9 +336,11 @@ def test_read_not_found(s3_setup, lambda_context): from mcp_workbench.lambda_functions import read # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = read(event, lambda_context) assert response["statusCode"] == 404 @@ -350,8 +362,9 @@ def test_read_missing_tool_id(s3_setup, lambda_context): from mcp_workbench.lambda_functions import read # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), ): response = read(event, lambda_context) @@ -382,9 +395,11 @@ def test_list_success(s3_setup, lambda_context): from mcp_workbench.lambda_functions import list as list_tools # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = list_tools(event, lambda_context) assert response["statusCode"] == 200 @@ -408,9 +423,11 @@ def test_list_not_admin(s3_setup, lambda_context): from mcp_workbench.lambda_functions import list as list_tools # Use the actual function with moto S3 and patched is_admin - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "utilities.auth.get_username", return_value="regular-user" - ), patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("utilities.auth.get_username", return_value="regular-user"), + patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper), + ): response = list_tools(event, lambda_context) assert response["statusCode"] == 403 @@ -429,9 +446,11 @@ def test_list_empty_bucket(s3_setup, lambda_context): from mcp_workbench.lambda_functions import list as list_tools # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = list_tools(event, lambda_context) assert response["statusCode"] == 200 @@ -452,9 +471,11 @@ def test_create_success(s3_setup, lambda_context): from mcp_workbench.lambda_functions import create # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = create(event, lambda_context) assert response["statusCode"] == 200 @@ -480,9 +501,11 @@ def test_create_without_py_extension(s3_setup, lambda_context): from mcp_workbench.lambda_functions import create # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = create(event, lambda_context) assert response["statusCode"] == 200 @@ -507,9 +530,11 @@ def test_create_not_admin(s3_setup, lambda_context): from mcp_workbench.lambda_functions import create # Use the actual function with moto S3 and patched is_admin - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "utilities.auth.get_username", return_value="regular-user" - ), patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("utilities.auth.get_username", return_value="regular-user"), + patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper), + ): response = create(event, lambda_context) assert response["statusCode"] == 403 @@ -531,8 +556,9 @@ def test_create_missing_fields(s3_setup, lambda_context): from mcp_workbench.lambda_functions import create # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), ): response = create(event, lambda_context) @@ -570,9 +596,11 @@ def updated_function(): from mcp_workbench.lambda_functions import update # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = update(event, lambda_context) assert response["statusCode"] == 200 @@ -599,9 +627,11 @@ def test_update_not_admin(s3_setup, lambda_context): from mcp_workbench.lambda_functions import update # Use the actual function with moto S3 and patched is_admin - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "utilities.auth.get_username", return_value="regular-user" - ), patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("utilities.auth.get_username", return_value="regular-user"), + patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper), + ): response = update(event, lambda_context) assert response["statusCode"] == 403 @@ -624,9 +654,11 @@ def test_update_not_found(s3_setup, lambda_context): from mcp_workbench.lambda_functions import update # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = update(event, lambda_context) assert response["statusCode"] == 404 @@ -649,8 +681,9 @@ def test_update_missing_tool_id(s3_setup, lambda_context): from mcp_workbench.lambda_functions import update # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), ): response = update(event, lambda_context) @@ -682,8 +715,9 @@ def test_update_missing_contents(s3_setup, lambda_context): from mcp_workbench.lambda_functions import update # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), ): response = update(event, lambda_context) @@ -715,9 +749,11 @@ def test_delete_success(s3_setup, lambda_context): from mcp_workbench.lambda_functions import delete # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = delete(event, lambda_context) assert response["statusCode"] == 200 @@ -746,9 +782,11 @@ def test_delete_not_admin(s3_setup, lambda_context): from mcp_workbench.lambda_functions import delete # Use the actual function with moto S3 and patched is_admin - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "utilities.auth.get_username", return_value="regular-user" - ), patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("utilities.auth.get_username", return_value="regular-user"), + patch("mcp_workbench.lambda_functions.api_wrapper", mock_api_wrapper), + ): response = delete(event, lambda_context) assert response["statusCode"] == 403 @@ -770,9 +808,11 @@ def test_delete_not_found(s3_setup, lambda_context): from mcp_workbench.lambda_functions import delete # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True - ), patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET): + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), + patch("mcp_workbench.lambda_functions.WORKBENCH_BUCKET", WORKBENCH_BUCKET), + ): response = delete(event, lambda_context) assert response["statusCode"] == 404 @@ -794,8 +834,9 @@ def test_delete_missing_tool_id(s3_setup, lambda_context): from mcp_workbench.lambda_functions import delete # Use the actual function with moto S3 - with patch("mcp_workbench.lambda_functions.s3_client", s3_setup), patch( - "mcp_workbench.lambda_functions.is_admin", return_value=True + with ( + patch("mcp_workbench.lambda_functions.s3_client", s3_setup), + patch("mcp_workbench.lambda_functions.is_admin", return_value=True), ): response = delete(event, lambda_context) diff --git a/test/lambda/test_metrics_lambda.py b/test/lambda/test_metrics_lambda.py index be63a90ac..0c86462be 100644 --- a/test/lambda/test_metrics_lambda.py +++ b/test/lambda/test_metrics_lambda.py @@ -418,10 +418,10 @@ def test_daily_metrics_handler(self, dynamodb_table, multiple_usage_metrics, lam Expected: Should call both metrics functions and return the combined results. """ - with patch("metrics.lambda_functions.count_unique_users_and_publish_metric") as mock_count_users, patch( - "metrics.lambda_functions.count_users_by_group_and_publish_metric" - ) as mock_count_by_group: - + with ( + patch("metrics.lambda_functions.count_unique_users_and_publish_metric") as mock_count_users, + patch("metrics.lambda_functions.count_users_by_group_and_publish_metric") as mock_count_by_group, + ): mock_count_users.return_value = 3 mock_count_by_group.return_value = {"group1": 2, "group2": 2, "group3": 2} @@ -815,10 +815,10 @@ def test_update_user_metrics_by_session_exception_handling(self, dynamodb_table) "mcpToolUsage": {}, } - with patch("metrics.lambda_functions.usage_metrics_table.get_item") as mock_get_item, patch( - "metrics.lambda_functions.logger.error" - ) as mock_logger: - + with ( + patch("metrics.lambda_functions.usage_metrics_table.get_item") as mock_get_item, + patch("metrics.lambda_functions.logger.error") as mock_logger, + ): mock_get_item.side_effect = ClientError( error_response={"Error": {"Code": "ResourceNotFoundException", "Message": "Table not found"}}, operation_name="GetItem", @@ -1047,10 +1047,10 @@ def test_publish_metric_deltas_exception_handling(self): Expected: Should log error and not raise exception when CloudWatch fails. """ - with patch("metrics.lambda_functions.cloudwatch.put_metric_data") as mock_put_metric, patch( - "metrics.lambda_functions.logger.error" - ) as mock_logger: - + with ( + patch("metrics.lambda_functions.cloudwatch.put_metric_data") as mock_put_metric, + patch("metrics.lambda_functions.logger.error") as mock_logger, + ): mock_put_metric.side_effect = Exception("CloudWatch error") # Should not raise exception @@ -1097,8 +1097,8 @@ def test_publish_metric_deltas_token_changes_only(self): assert "TotalMCPToolCalls" not in metric_names def test_update_user_metrics_token_only_new_user(self, dynamodb_table): - """token_only event for a brand-new user creates a DynamoDB record with token totals - but no sessionMetrics entry. + """token_only event for a brand-new user creates a DynamoDB record with token totals but no sessionMetrics + entry. Expected: record exists with totalPromptTokens/totalCompletionTokens set and sessionMetrics is empty dict. @@ -1130,8 +1130,7 @@ def test_update_user_metrics_token_only_new_user(self, dynamodb_table): assert item["sessionMetrics"] == {} def test_update_user_metrics_token_only_existing_user(self, dynamodb_table): - """token_only event for an existing user accumulates token totals without - creating sessionMetrics entries. + """token_only event for an existing user accumulates token totals without creating sessionMetrics entries. Expected: totalPromptTokens increases by delta on each call. """ diff --git a/test/lambda/test_model_context_window_backfill.py b/test/lambda/test_model_context_window_backfill.py index b792cec67..1ed42d6d5 100644 --- a/test/lambda/test_model_context_window_backfill.py +++ b/test/lambda/test_model_context_window_backfill.py @@ -266,9 +266,11 @@ def test_run_backfill_enriches_bedrock_model(model_table): "model_info": {"id": "litellm-abc", "max_input_tokens": 200000}, } - with patch("models.model_context_window_backfill._get_litellm_client", return_value=mock_litellm), patch( - "models.model_context_window_backfill.boto3" - ) as mock_boto, patch("models.model_context_window_backfill.now", return_value=123456): + with ( + patch("models.model_context_window_backfill._get_litellm_client", return_value=mock_litellm), + patch("models.model_context_window_backfill.boto3") as mock_boto, + patch("models.model_context_window_backfill.now", return_value=123456), + ): mock_boto.resource.return_value.Table.return_value = model_table mock_boto.client.return_value = MagicMock() result = _run_backfill() @@ -294,9 +296,10 @@ def test_run_backfill_skips_already_enriched(model_table): mock_litellm = MagicMock() - with patch("models.model_context_window_backfill._get_litellm_client", return_value=mock_litellm), patch( - "models.model_context_window_backfill.boto3" - ) as mock_boto: + with ( + patch("models.model_context_window_backfill._get_litellm_client", return_value=mock_litellm), + patch("models.model_context_window_backfill.boto3") as mock_boto, + ): mock_boto.resource.return_value.Table.return_value = model_table mock_boto.client.return_value = MagicMock() result = _run_backfill() @@ -324,9 +327,11 @@ def test_run_backfill_defaults_to_zero_when_not_found(model_table): mock_litellm.get_model.side_effect = Exception("Not found") mock_litellm.list_models.return_value = [] - with patch("models.model_context_window_backfill._get_litellm_client", return_value=mock_litellm), patch( - "models.model_context_window_backfill.boto3" - ) as mock_boto, patch("models.model_context_window_backfill.now", return_value=123456): + with ( + patch("models.model_context_window_backfill._get_litellm_client", return_value=mock_litellm), + patch("models.model_context_window_backfill.boto3") as mock_boto, + patch("models.model_context_window_backfill.now", return_value=123456), + ): mock_boto.resource.return_value.Table.return_value = model_table mock_boto.client.return_value = MagicMock() result = _run_backfill() @@ -347,10 +352,14 @@ def test_run_backfill_counts_failures(model_table): } ) - with patch("models.model_context_window_backfill._get_litellm_client", return_value=MagicMock()), patch( - "models.model_context_window_backfill._fetch_context_window_from_litellm", - side_effect=RuntimeError("Unexpected crash"), - ), patch("models.model_context_window_backfill.boto3") as mock_boto: + with ( + patch("models.model_context_window_backfill._get_litellm_client", return_value=MagicMock()), + patch( + "models.model_context_window_backfill._fetch_context_window_from_litellm", + side_effect=RuntimeError("Unexpected crash"), + ), + patch("models.model_context_window_backfill.boto3") as mock_boto, + ): mock_boto.resource.return_value.Table.return_value = model_table mock_boto.client.return_value = MagicMock() result = _run_backfill() @@ -361,9 +370,10 @@ def test_run_backfill_counts_failures(model_table): def test_run_backfill_empty_table(model_table): """Test backfill on an empty table returns all zeros.""" - with patch("models.model_context_window_backfill._get_litellm_client", return_value=MagicMock()), patch( - "models.model_context_window_backfill.boto3" - ) as mock_boto: + with ( + patch("models.model_context_window_backfill._get_litellm_client", return_value=MagicMock()), + patch("models.model_context_window_backfill.boto3") as mock_boto, + ): mock_boto.resource.return_value.Table.return_value = model_table mock_boto.client.return_value = MagicMock() result = _run_backfill() diff --git a/test/lambda/test_models_lambda.py b/test/lambda/test_models_lambda.py index 625cd2b7d..69d3fc8f4 100644 --- a/test/lambda/test_models_lambda.py +++ b/test/lambda/test_models_lambda.py @@ -356,11 +356,11 @@ def test_delete_model_handler( ) # Mock SSM client, VectorStoreRepository, and CollectionRepository - with patch("models.handler.delete_model_handler.ssm_client") as mock_ssm, patch( - "models.handler.delete_model_handler.VectorStoreRepository" - ) as mock_repo_class, patch( - "models.handler.delete_model_handler.CollectionRepository" - ) as mock_collection_repo_class: + with ( + patch("models.handler.delete_model_handler.ssm_client") as mock_ssm, + patch("models.handler.delete_model_handler.VectorStoreRepository") as mock_repo_class, + patch("models.handler.delete_model_handler.CollectionRepository") as mock_collection_repo_class, + ): mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "vector-store-table"}} mock_repo = mock_repo_class.return_value @@ -401,11 +401,11 @@ def test_delete_model_handler_model_in_use_by_repository( ) # Mock SSM client, VectorStoreRepository, and CollectionRepository - with patch("models.handler.delete_model_handler.ssm_client") as mock_ssm, patch( - "models.handler.delete_model_handler.VectorStoreRepository" - ) as mock_repo_class, patch( - "models.handler.delete_model_handler.CollectionRepository" - ) as mock_collection_repo_class: + with ( + patch("models.handler.delete_model_handler.ssm_client") as mock_ssm, + patch("models.handler.delete_model_handler.VectorStoreRepository") as mock_repo_class, + patch("models.handler.delete_model_handler.CollectionRepository") as mock_collection_repo_class, + ): mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "vector-store-table"}} mock_repo = mock_repo_class.return_value @@ -444,11 +444,11 @@ def test_delete_model_handler_model_in_use_by_pipeline( ) # Mock SSM client, VectorStoreRepository, and CollectionRepository - with patch("models.handler.delete_model_handler.ssm_client") as mock_ssm, patch( - "models.handler.delete_model_handler.VectorStoreRepository" - ) as mock_repo_class, patch( - "models.handler.delete_model_handler.CollectionRepository" - ) as mock_collection_repo_class: + with ( + patch("models.handler.delete_model_handler.ssm_client") as mock_ssm, + patch("models.handler.delete_model_handler.VectorStoreRepository") as mock_repo_class, + patch("models.handler.delete_model_handler.CollectionRepository") as mock_collection_repo_class, + ): mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "vector-store-table"}} mock_repo = mock_repo_class.return_value @@ -586,9 +586,10 @@ def test_update_model_handler( ) # Mock the to_lisa_model function to return a model with streaming=False - with patch("models.handler.update_model_handler.to_lisa_model") as mock_to_lisa_model, patch.object( - handler, "_stepfunctions" - ) as mock_sf: + with ( + patch("models.handler.update_model_handler.to_lisa_model") as mock_to_lisa_model, + patch.object(handler, "_stepfunctions") as mock_sf, + ): # Configure the mock to return a model with streaming=False mock_model = LISAModel( modelId="test-model", @@ -715,7 +716,6 @@ def test_update_model_validation( @pytest.mark.asyncio async def test_exception_handlers(): """Test exception handlers.""" - # Setup mock request request = MagicMock() @@ -756,14 +756,13 @@ async def test_fastapi_endpoints( client = TestClient(app) # Setup mocks for the handlers - with patch("models.lambda_functions.CreateModelHandler") as mock_create_handler, patch( - "models.lambda_functions.ListModelsHandler" - ) as mock_list_handler, patch("models.lambda_functions.GetModelHandler") as mock_get_handler, patch( - "models.lambda_functions.UpdateModelHandler" - ) as mock_update_handler, patch( - "models.lambda_functions.DeleteModelHandler" - ) as mock_delete_handler: - + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_create_handler, + patch("models.lambda_functions.ListModelsHandler") as mock_list_handler, + patch("models.lambda_functions.GetModelHandler") as mock_get_handler, + patch("models.lambda_functions.UpdateModelHandler") as mock_update_handler, + patch("models.lambda_functions.DeleteModelHandler") as mock_delete_handler, + ): # Setup handler mocks create_handler_instance = MagicMock() create_model_response = CreateModelResponse( @@ -879,7 +878,6 @@ async def test_fastapi_endpoints( @pytest.mark.asyncio async def test_get_instances(): """Test get_instances endpoint.""" - # Mock the shape_for method to return a mock with enum attribute mock_shape = MagicMock() mock_shape.enum = ["t2.micro", "t3.small", "m5.large"] @@ -926,7 +924,6 @@ def non_admin_event(): @pytest.mark.asyncio async def test_get_admin_status_and_groups(mock_auth): """Test the get_admin_status_and_groups helper function.""" - # Test with admin event admin_event = { "requestContext": { @@ -978,7 +975,6 @@ async def test_create_model_admin_required( sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, non_admin_event ): """Test that create_model endpoint requires admin access.""" - # Test non-admin cannot create mock_request = MagicMock(spec=Request) mock_request.scope = {"aws.event": non_admin_event} @@ -987,9 +983,11 @@ async def test_create_model_admin_required( modelId="test-model", modelName="test-model", modelType=ModelType.TEXTGEN, streaming=True ) - with patch("utilities.auth.is_admin") as mock_is_admin, patch( - "utilities.auth.get_groups" - ) as mock_get_groups, patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin: + with ( + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin, + ): mock_is_admin.return_value = False mock_decorator_is_admin.return_value = False mock_get_groups.return_value = [] @@ -1005,16 +1003,17 @@ async def test_update_model_admin_required( sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, non_admin_event ): """Test that update_model endpoint requires admin access.""" - # Test non-admin cannot update mock_request = MagicMock(spec=Request) mock_request.scope = {"aws.event": non_admin_event} update_request = UpdateModelRequest(streaming=False) - with patch("utilities.auth.is_admin") as mock_is_admin, patch( - "utilities.auth.get_groups" - ) as mock_get_groups, patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin: + with ( + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin, + ): mock_is_admin.return_value = False mock_decorator_is_admin.return_value = False mock_get_groups.return_value = [] @@ -1030,14 +1029,15 @@ async def test_delete_model_admin_required( sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, non_admin_event ): """Test that delete_model endpoint requires admin access.""" - # Test non-admin cannot delete mock_request = MagicMock(spec=Request) mock_request.scope = {"aws.event": non_admin_event} - with patch("utilities.auth.is_admin") as mock_is_admin, patch( - "utilities.auth.get_groups" - ) as mock_get_groups, patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin: + with ( + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.fastapi_middleware.auth_decorators.is_admin") as mock_decorator_is_admin, + ): mock_is_admin.return_value = False mock_decorator_is_admin.return_value = False mock_get_groups.return_value = [] @@ -1053,17 +1053,17 @@ async def test_create_update_delete_admin_allowed( sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, mock_auth ): """Test that admin users can successfully create, update, and delete models.""" - # Set admin access via mock_auth fixture mock_auth.set_user("admin-user", ["admin-group"], is_admin=True) mock_request = MagicMock(spec=Request) mock_request.scope = {"aws.event": admin_event} - with patch("models.lambda_functions.CreateModelHandler") as mock_create_handler, patch( - "models.lambda_functions.UpdateModelHandler" - ) as mock_update_handler, patch("models.lambda_functions.DeleteModelHandler") as mock_delete_handler: - + with ( + patch("models.lambda_functions.CreateModelHandler") as mock_create_handler, + patch("models.lambda_functions.UpdateModelHandler") as mock_update_handler, + patch("models.lambda_functions.DeleteModelHandler") as mock_delete_handler, + ): # Mock create handler create_handler_instance = MagicMock() create_model_response = CreateModelResponse( diff --git a/test/lambda/test_numeric_type_preservation.py b/test/lambda/test_numeric_type_preservation.py old mode 100644 new mode 100755 index df89691ab..82dd5c175 --- a/test/lambda/test_numeric_type_preservation.py +++ b/test/lambda/test_numeric_type_preservation.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Unit test to verify numeric type preservation through encryption/decryption. +"""Unit test to verify numeric type preservation through encryption/decryption. -This test ensures that numeric types (floats, ints) are preserved correctly -when session data goes through the encryption/decryption process. +This test ensures that numeric types (floats, ints) are preserved correctly when session data goes through the +encryption/decryption process. """ import json diff --git a/test/lambda/test_pipeline_delete_documents.py b/test/lambda/test_pipeline_delete_documents.py index fceef2e09..e0cd34b53 100644 --- a/test/lambda/test_pipeline_delete_documents.py +++ b/test/lambda/test_pipeline_delete_documents.py @@ -42,9 +42,10 @@ def test_drop_opensearch_index(setup_env): mock_service = Mock() - with patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_delete_documents.RepositoryServiceFactory" - ) as mock_factory: + with ( + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.RepositoryServiceFactory") as mock_factory, + ): mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_factory.create_service.return_value = mock_service @@ -59,9 +60,10 @@ def test_drop_opensearch_index_not_exists(setup_env): mock_service = Mock() - with patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_delete_documents.RepositoryServiceFactory" - ) as mock_factory: + with ( + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.RepositoryServiceFactory") as mock_factory, + ): mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_factory.create_service.return_value = mock_service @@ -76,9 +78,10 @@ def test_drop_pgvector_collection(setup_env): mock_service = Mock() - with patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_delete_documents.RepositoryServiceFactory" - ) as mock_factory: + with ( + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.RepositoryServiceFactory") as mock_factory, + ): mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "pgvector"} mock_factory.create_service.return_value = mock_service @@ -101,14 +104,13 @@ def test_pipeline_delete_collection_opensearch(setup_env): job_type=JobActionType.COLLECTION_DELETION, ) - with patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_delete_documents.drop_opensearch_index" - ) as mock_drop, patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_delete_documents.collection_repo" - ), patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ) as mock_job_repo: - + with ( + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.drop_opensearch_index") as mock_drop, + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.collection_repo"), + patch("repository.pipeline_delete_documents.ingestion_job_repository") as mock_job_repo, + ): mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.OPENSEARCH} from repository.pipeline_delete_documents import pipeline_delete_collection @@ -134,16 +136,14 @@ def test_pipeline_delete_collection_bedrock_kb(setup_env): job_type=JobActionType.COLLECTION_DELETION, ) - with patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_delete_documents.boto3" - ) as mock_boto3, patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_delete_documents.collection_repo" - ), patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ), patch( - "repository.pipeline_delete_documents.bulk_delete_documents_from_kb" - ) as mock_bulk_delete: - + with ( + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.boto3") as mock_boto3, + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.collection_repo"), + patch("repository.pipeline_delete_documents.ingestion_job_repository"), + patch("repository.pipeline_delete_documents.bulk_delete_documents_from_kb") as mock_bulk_delete, + ): mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.BEDROCK_KB} mock_dynamodb = Mock() @@ -179,12 +179,12 @@ def test_pipeline_delete_collection_failure(setup_env): job_type=JobActionType.COLLECTION_DELETION, ) - with patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_delete_documents.rag_document_repository" - ) as mock_doc_repo, patch("repository.pipeline_delete_documents.collection_repo") as mock_coll_repo, patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ) as mock_job_repo: - + with ( + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.collection_repo") as mock_coll_repo, + patch("repository.pipeline_delete_documents.ingestion_job_repository") as mock_job_repo, + ): mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.OPENSEARCH} mock_doc_repo.delete_all.side_effect = Exception("Delete failed") @@ -224,12 +224,12 @@ def test_pipeline_delete_document(setup_env): chunk_strategy=FixedChunkingStrategy(size=1000, overlap=100), ) - with patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_delete_documents.vs_repo" - ) as mock_vs_repo, patch("repository.pipeline_delete_documents.remove_document_from_vectorstore"), patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ) as mock_job_repo: - + with ( + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.remove_document_from_vectorstore"), + patch("repository.pipeline_delete_documents.ingestion_job_repository") as mock_job_repo, + ): mock_doc_repo.find_by_id.return_value = rag_doc mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.OPENSEARCH} @@ -254,10 +254,10 @@ def test_pipeline_delete_document_not_found(setup_env): document_id="doc1", ) - with patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ) as mock_job_repo: - + with ( + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.ingestion_job_repository") as mock_job_repo, + ): mock_doc_repo.find_by_id.return_value = None from repository.pipeline_delete_documents import pipeline_delete_document @@ -278,12 +278,12 @@ def test_handle_pipeline_delete_event(setup_env): } } - with patch("repository.pipeline_ingest_handlers.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_ingest_handlers.ingestion_job_repository" - ) as mock_job_repo, patch("repository.pipeline_ingest_handlers.ingestion_service") as mock_service, patch( - "repository.pipeline_ingest_handlers.vs_repo" - ) as mock_vs_repo: - + with ( + patch("repository.pipeline_ingest_handlers.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_ingest_handlers.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_handlers.ingestion_service") as mock_service, + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + ): from models.domain_objects import FixedChunkingStrategy, RagDocument rag_doc = RagDocument( @@ -391,12 +391,12 @@ def test_pipeline_delete_documents_batch(setup_env): for i in range(1, 4) ] - with patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_delete_documents.vs_repo" - ) as mock_vs_repo, patch("repository.pipeline_delete_documents.remove_document_from_vectorstore"), patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ) as mock_job_repo: - + with ( + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.remove_document_from_vectorstore"), + patch("repository.pipeline_delete_documents.ingestion_job_repository") as mock_job_repo, + ): mock_doc_repo.find_by_id.side_effect = rag_docs mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.OPENSEARCH} @@ -434,12 +434,12 @@ def test_pipeline_delete_documents_batch_with_failures(setup_env): chunk_strategy=FixedChunkingStrategy(size=1000, overlap=100), ) - with patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, patch( - "repository.pipeline_delete_documents.vs_repo" - ) as mock_vs_repo, patch("repository.pipeline_delete_documents.remove_document_from_vectorstore"), patch( - "repository.pipeline_delete_documents.ingestion_job_repository" - ) as mock_job_repo: - + with ( + patch("repository.pipeline_delete_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_delete_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_delete_documents.remove_document_from_vectorstore"), + patch("repository.pipeline_delete_documents.ingestion_job_repository") as mock_job_repo, + ): # First succeeds, second fails, third succeeds mock_doc_repo.find_by_id.side_effect = [rag_doc1, Exception("Delete failed"), rag_doc1] mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.OPENSEARCH} diff --git a/test/lambda/test_pipeline_ingest_documents.py b/test/lambda/test_pipeline_ingest_documents.py index f12b6b578..b917adb2f 100644 --- a/test/lambda/test_pipeline_ingest_documents.py +++ b/test/lambda/test_pipeline_ingest_documents.py @@ -128,9 +128,11 @@ def test_store_chunks_in_vectorstore(setup_env): mock_service = Mock() mock_service.get_vector_store_client.return_value = mock_vs - with patch("repository.pipeline_ingest_documents.RagEmbeddings"), patch( - "repository.pipeline_ingest_documents.VectorStoreRepository" - ) as mock_vs_repo, patch("repository.pipeline_ingest_documents.RepositoryServiceFactory") as mock_factory: + with ( + patch("repository.pipeline_ingest_documents.RagEmbeddings"), + patch("repository.pipeline_ingest_documents.VectorStoreRepository") as mock_vs_repo, + patch("repository.pipeline_ingest_documents.RepositoryServiceFactory") as mock_factory, + ): mock_vs_repo.return_value.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_factory.create_service.return_value = mock_service @@ -153,9 +155,11 @@ def test_store_chunks_in_vectorstore_failure(setup_env): mock_service = Mock() mock_service.get_vector_store_client.return_value = mock_vs - with patch("repository.pipeline_ingest_documents.RagEmbeddings"), patch( - "repository.pipeline_ingest_documents.VectorStoreRepository" - ) as mock_vs_repo, patch("repository.pipeline_ingest_documents.RepositoryServiceFactory") as mock_factory: + with ( + patch("repository.pipeline_ingest_documents.RagEmbeddings"), + patch("repository.pipeline_ingest_documents.VectorStoreRepository") as mock_vs_repo, + patch("repository.pipeline_ingest_documents.RepositoryServiceFactory") as mock_factory, + ): mock_vs_repo.return_value.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_factory.create_service.return_value = mock_service @@ -185,10 +189,11 @@ def test_pipeline_ingest_bedrock_kb(setup_env): }, } - with patch("repository.pipeline_ingest_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_documents.rag_document_repository" - ) as mock_doc_repo, patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo: - + with ( + patch("repository.pipeline_ingest_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, + ): mock_vs_repo.find_repository_by_id.return_value = bedrock_kb_repo mock_doc_repo.find_by_source.return_value = [] @@ -225,14 +230,13 @@ def test_pipeline_ingest_bedrock_kb_copy_from_lisa_bucket(setup_env): }, } - with patch("repository.pipeline_ingest_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_documents.rag_document_repository" - ) as mock_doc_repo, patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingest_documents.s3" - ) as mock_s3, patch( - "repository.pipeline_ingest_documents.bedrock_agent" - ) as mock_bedrock_agent: - + with ( + patch("repository.pipeline_ingest_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_documents.s3") as mock_s3, + patch("repository.pipeline_ingest_documents.bedrock_agent") as mock_bedrock_agent, + ): mock_vs_repo.find_repository_by_id.return_value = bedrock_kb_repo mock_doc_repo.find_by_source.return_value = [] @@ -296,18 +300,15 @@ def test_pipeline_ingest_with_previous_document(setup_env): document_id="prev-doc", ) - with patch("repository.pipeline_ingest_documents.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_documents.generate_chunks" - ) as mock_chunks, patch("repository.pipeline_ingest_documents.prepare_chunks") as mock_prepare, patch( - "repository.pipeline_ingest_documents.store_chunks_in_vectorstore" - ) as mock_store, patch( - "repository.pipeline_ingest_documents.rag_document_repository" - ) as mock_doc_repo, patch( - "repository.pipeline_ingest_documents.ingestion_job_repository" - ) as mock_job_repo, patch( - "repository.pipeline_ingest_documents.remove_document_from_vectorstore" + with ( + patch("repository.pipeline_ingest_documents.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_documents.generate_chunks") as mock_chunks, + patch("repository.pipeline_ingest_documents.prepare_chunks") as mock_prepare, + patch("repository.pipeline_ingest_documents.store_chunks_in_vectorstore") as mock_store, + patch("repository.pipeline_ingest_documents.rag_document_repository") as mock_doc_repo, + patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_documents.remove_document_from_vectorstore"), ): - mock_vs_repo.find_repository_by_id.return_value = {"type": RepositoryType.OPENSEARCH} mock_chunks.return_value = [Mock(page_content="text", metadata={})] mock_prepare.return_value = (["text"], [{"key": "value"}]) @@ -335,14 +336,12 @@ def test_handle_pipeline_ingest_event(setup_env): "requestContext": {"authorizer": {"username": "user1"}}, } - with patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_handlers.collection_service" - ) as mock_coll_service, patch( - "repository.pipeline_ingest_handlers.ingestion_job_repository" - ) as mock_job_repo, patch( - "repository.pipeline_ingest_handlers.ingestion_service" - ) as mock_service: - + with ( + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_handlers.collection_service") as mock_coll_service, + patch("repository.pipeline_ingest_handlers.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_handlers.ingestion_service") as mock_service, + ): mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1"} mock_coll_service.get_collection_metadata.return_value = {} @@ -369,14 +368,13 @@ def test_handle_pipline_ingest_schedule(setup_env): now = datetime.now(timezone.utc) recent = now - timedelta(hours=12) - with patch("repository.pipeline_ingest_handlers.s3") as mock_s3, patch( - "repository.pipeline_ingest_handlers.vs_repo" - ) as mock_vs_repo, patch("repository.pipeline_ingest_handlers.collection_service") as mock_coll_service, patch( - "repository.pipeline_ingest_handlers.ingestion_job_repository" - ) as mock_job_repo, patch( - "repository.pipeline_ingest_handlers.ingestion_service" - ) as mock_service: - + with ( + patch("repository.pipeline_ingest_handlers.s3") as mock_s3, + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_handlers.collection_service") as mock_coll_service, + patch("repository.pipeline_ingest_handlers.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_handlers.ingestion_service") as mock_service, + ): mock_paginator = Mock() mock_paginator.paginate.return_value = [ { @@ -411,10 +409,11 @@ def test_handle_pipline_ingest_schedule_no_contents(setup_env): "requestContext": {"authorizer": {"username": "user1"}}, } - with patch("repository.pipeline_ingest_handlers.s3") as mock_s3, patch( - "repository.pipeline_ingest_handlers.vs_repo" - ) as mock_vs_repo, patch("repository.pipeline_ingest_handlers.collection_service") as mock_coll_service: - + with ( + patch("repository.pipeline_ingest_handlers.s3") as mock_s3, + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_handlers.collection_service") as mock_coll_service, + ): mock_paginator = Mock() mock_paginator.paginate.return_value = [{}] # No Contents key mock_s3.get_paginator.return_value = mock_paginator @@ -446,9 +445,11 @@ def test_remove_document_from_vectorstore(setup_env): mock_service = Mock() mock_service.get_vector_store_client.return_value = mock_vs - with patch("repository.pipeline_ingest_documents.RagEmbeddings"), patch( - "repository.pipeline_ingest_documents.VectorStoreRepository" - ) as mock_vs_repo, patch("repository.pipeline_ingest_documents.RepositoryServiceFactory") as mock_factory: + with ( + patch("repository.pipeline_ingest_documents.RagEmbeddings"), + patch("repository.pipeline_ingest_documents.VectorStoreRepository") as mock_vs_repo, + patch("repository.pipeline_ingest_documents.RepositoryServiceFactory") as mock_factory, + ): mock_vs_repo.return_value.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_factory.create_service.return_value = mock_service @@ -474,10 +475,10 @@ def test_pipeline_ingest_documents_batch(setup_env): s3_paths=["s3://bucket/key1", "s3://bucket/key2", "s3://bucket/key3"], ) - with patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingest_documents.pipeline_ingest_document" - ) as mock_ingest_doc: - + with ( + patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_documents.pipeline_ingest_document") as mock_ingest_doc, + ): from repository.pipeline_ingest_documents import pipeline_ingest_documents pipeline_ingest_documents(job) @@ -501,10 +502,10 @@ def test_pipeline_ingest_documents_batch_with_failures(setup_env): s3_paths=["s3://bucket/key1", "s3://bucket/key2", "s3://bucket/key3"], ) - with patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingest_documents.pipeline_ingest_document" - ) as mock_ingest_doc: - + with ( + patch("repository.pipeline_ingest_documents.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_documents.pipeline_ingest_document") as mock_ingest_doc, + ): # First succeeds, second fails, third succeeds mock_ingest_doc.side_effect = [None, Exception("Ingest failed"), None] @@ -610,7 +611,9 @@ def test_pipeline_ingest_routes_to_single_ingestion(setup_env): def test_handle_pipeline_ingest_event_resolves_collection_by_name(setup_env): """Pipeline collectionId name string must be resolved to UUID before job creation. - The job must be created with the UUID, not the name string.""" + + The job must be created with the UUID, not the name string. + """ event = { "detail": { "bucket": "test-bucket", @@ -632,14 +635,12 @@ def test_handle_pipeline_ingest_event_resolves_collection_by_name(setup_env): "metadata": None, } - with patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, patch( - "repository.pipeline_ingest_handlers.collection_service" - ) as mock_coll_service, patch( - "repository.pipeline_ingest_handlers.ingestion_job_repository" - ) as mock_job_repo, patch( - "repository.pipeline_ingest_handlers.ingestion_service" - ) as mock_service: - + with ( + patch("repository.pipeline_ingest_handlers.vs_repo") as mock_vs_repo, + patch("repository.pipeline_ingest_handlers.collection_service") as mock_coll_service, + patch("repository.pipeline_ingest_handlers.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingest_handlers.ingestion_service") as mock_service, + ): mock_vs_repo.find_repository_by_id.return_value = {"repositoryId": "repo1", "type": "opensearch"} mock_coll_service.collection_repo.find_by_id_or_name.return_value = mock_collection diff --git a/test/lambda/test_pipeline_ingestion.py b/test/lambda/test_pipeline_ingestion.py index f70bee4a8..37f6447f9 100644 --- a/test/lambda/test_pipeline_ingestion.py +++ b/test/lambda/test_pipeline_ingestion.py @@ -52,10 +52,10 @@ def sample_ingestion_job(): def test_ingest_success(sample_ingestion_job): """Test successful ingest function.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingestion.pipeline_ingest" - ) as mock_pipeline_ingest: - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingestion.pipeline_ingest") as mock_pipeline_ingest, + ): # Setup mocks mock_job_repo.update_status.return_value = sample_ingestion_job mock_pipeline_ingest.return_value = None @@ -73,10 +73,10 @@ def test_ingest_success(sample_ingestion_job): def test_ingest_error(sample_ingestion_job): """Test ingest function with error.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingestion.pipeline_ingest" - ) as mock_pipeline_ingest: - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingestion.pipeline_ingest") as mock_pipeline_ingest, + ): # Setup mocks mock_job_repo.update_status.return_value = sample_ingestion_job mock_pipeline_ingest.side_effect = Exception("Test error") @@ -94,10 +94,10 @@ def test_ingest_error(sample_ingestion_job): def test_delete_success(sample_ingestion_job): """Test successful delete function.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingestion.pipeline_delete" - ) as mock_pipeline_delete: - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingestion.pipeline_delete") as mock_pipeline_delete, + ): # Setup mocks mock_job_repo.update_status.return_value = sample_ingestion_job mock_pipeline_delete.return_value = None @@ -115,10 +115,10 @@ def test_delete_success(sample_ingestion_job): def test_delete_error(sample_ingestion_job): """Test delete function with error.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingestion.pipeline_delete" - ) as mock_pipeline_delete: - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingestion.pipeline_delete") as mock_pipeline_delete, + ): # Setup mocks mock_job_repo.update_status.return_value = sample_ingestion_job mock_pipeline_delete.side_effect = Exception("Test error") @@ -136,10 +136,11 @@ def test_delete_error(sample_ingestion_job): def test_main_ingest_action(sample_ingestion_job): """Test main function with ingest action.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingestion.ingest" - ) as mock_ingest, patch("sys.argv", ["pipeline_ingestion.py", "ingest", "test-job-id"]): - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingestion.ingest") as mock_ingest, + patch("sys.argv", ["pipeline_ingestion.py", "ingest", "test-job-id"]), + ): # Setup mocks mock_job_repo.find_by_id.return_value = sample_ingestion_job mock_ingest.return_value = None @@ -162,10 +163,11 @@ def test_main_ingest_action(sample_ingestion_job): def test_main_delete_action(sample_ingestion_job): """Test main function with delete action.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, patch( - "repository.pipeline_ingestion.delete" - ) as mock_delete, patch("sys.argv", ["pipeline_ingestion.py", "delete", "test-job-id"]): - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository") as mock_job_repo, + patch("repository.pipeline_ingestion.delete") as mock_delete, + patch("sys.argv", ["pipeline_ingestion.py", "delete", "test-job-id"]), + ): # Setup mocks mock_job_repo.find_by_id.return_value = sample_ingestion_job mock_delete.return_value = None @@ -188,14 +190,13 @@ def test_main_delete_action(sample_ingestion_job): def test_main_invalid_action(sample_ingestion_job): """Test main function with invalid action.""" - with patch("repository.pipeline_ingestion.ingestion_job_repository"), patch( - "repository.pipeline_ingestion.ingest" - ), patch("repository.pipeline_ingestion.delete"), patch( - "sys.argv", ["pipeline_ingestion.py", "invalid", "test-job-id"] - ), patch( - "sys.exit" - ) as mock_exit: - + with ( + patch("repository.pipeline_ingestion.ingestion_job_repository"), + patch("repository.pipeline_ingestion.ingest"), + patch("repository.pipeline_ingestion.delete"), + patch("sys.argv", ["pipeline_ingestion.py", "invalid", "test-job-id"]), + patch("sys.exit") as mock_exit, + ): # Import the module # Simulate the main function logic - should not call any functions for invalid action import sys @@ -219,7 +220,6 @@ def test_main_invalid_action(sample_ingestion_job): def test_main_missing_arguments(sample_ingestion_job): """Test main function with missing arguments.""" with patch("sys.argv", ["pipeline_ingestion.py"]), patch("sys.exit") as mock_exit: - # Import the module # Simulate the main function logic - should exit when not enough arguments import sys diff --git a/test/lambda/test_projects_lambda.py b/test/lambda/test_projects_lambda.py index dbbc66bdd..ab0c22cdd 100644 --- a/test/lambda/test_projects_lambda.py +++ b/test/lambda/test_projects_lambda.py @@ -852,9 +852,10 @@ def test_delete_project_cascade_deletes_sessions(projects_table, sessions_table, {"sessionId": "sess-1", "userId": "test-user", "projectId": "proj-1"}, {"sessionId": "sess-2", "userId": "test-user", "projectId": "proj-1"}, ] - with patch("projects.lambda_functions.get_all_user_sessions", return_value=project_sessions), patch( - "projects.lambda_functions.delete_user_session" - ) as mock_delete_session: + with ( + patch("projects.lambda_functions.get_all_user_sessions", return_value=project_sessions), + patch("projects.lambda_functions.delete_user_session") as mock_delete_session, + ): delete_project(_delete_event(body={"deleteSessions": True}), lambda_context) assert mock_delete_session.call_count == 2 diff --git a/test/lambda/test_prompt_templates_lambda.py b/test/lambda/test_prompt_templates_lambda.py index 6ced981df..2692d9b25 100644 --- a/test/lambda/test_prompt_templates_lambda.py +++ b/test/lambda/test_prompt_templates_lambda.py @@ -50,8 +50,8 @@ def mock_api_wrapper(func): """Mock API wrapper that handles both success and error cases for testing. - For successful function calls, it wraps the result in an HTTP response format. - For error cases, it returns an appropriate error response with proper status code. + For successful function calls, it wraps the result in an HTTP response format. For error cases, it returns an + appropriate error response with proper status code. """ def wrapper(event, context): diff --git a/test/lambda/test_rag_admin_repository.py b/test/lambda/test_rag_admin_repository.py index 6d975ce92..b5af634c6 100644 --- a/test/lambda/test_rag_admin_repository.py +++ b/test/lambda/test_rag_admin_repository.py @@ -14,14 +14,13 @@ """Tests for RAG Admin authorization boundaries in repository lambda functions. -Uses _auth_context() to patch auth references directly on repository.lambda_functions. -This is necessary because the module uses `from utilities.auth import ...` which creates -local bindings that conftest's patches on utilities.auth do not reach. - -Note: The conftest patches decorators (admin_only, rag_admin_or_admin) as passthroughs -when test_repository_lambda.py runs first (module-level import). So these tests focus on -the inner function logic: group access filtering, effective_admin, field restrictions, -and document ownership bypass. +Uses _auth_context() to patch auth references directly on repository.lambda_functions. This is necessary because the +module uses `from utilities.auth import ...` which creates local bindings that conftest's patches on utilities.auth do +not reach. + +Note: The conftest patches decorators (admin_only, rag_admin_or_admin) as passthroughs when test_repository_lambda.py +runs first (module-level import). So these tests focus on the inner function logic: group access filtering, +effective_admin, field restrictions, and document ownership bypass. """ import json @@ -88,8 +87,8 @@ def _make_event(username="test-user", groups=None): def _auth_context(username, groups, is_admin_val=False, is_rag_admin_val=False): """Patch all auth references on repository.lambda_functions for a test. - Because repository.lambda_functions uses `from utilities.auth import ...`, - the module has local bindings that must be patched directly. + Because repository.lambda_functions uses `from utilities.auth import ...`, the module has local bindings that must + be patched directly. """ stack = ExitStack() for p in [ @@ -111,9 +110,11 @@ def test_rag_admin_can_create_collection_on_accessible_repo(ctx): event["pathParameters"] = {"repositoryId": "repo-1"} event["body"] = json.dumps({"name": "New Collection", "embeddingModel": "model-1"}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs, patch("repository.lambda_functions.collection_service") as mcs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + patch("repository.lambda_functions.collection_service") as mcs, + ): mvs.find_repository_by_id.return_value = ACCESSIBLE_REPO mock_coll = MagicMock() mock_coll.model_dump.return_value = {"collectionId": "new-coll", "name": "New Collection"} @@ -134,9 +135,10 @@ def test_rag_admin_cannot_create_collection_on_inaccessible_repo(ctx): event["pathParameters"] = {"repositoryId": "repo-2"} event["body"] = json.dumps({"name": "New Collection", "embeddingModel": "model-2"}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = INACCESSIBLE_REPO from repository.lambda_functions import create_collection @@ -151,9 +153,11 @@ def test_rag_admin_can_update_collection_on_accessible_repo(ctx): event["pathParameters"] = {"repositoryId": "repo-1", "collectionId": "coll-1"} event["body"] = json.dumps({"name": "Updated Collection"}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs, patch("repository.lambda_functions.collection_service") as mcs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + patch("repository.lambda_functions.collection_service") as mcs, + ): mvs.find_repository_by_id.return_value = ACCESSIBLE_REPO mock_coll = MagicMock() mock_coll.model_dump.return_value = {"collectionId": "coll-1", "name": "Updated Collection"} @@ -173,9 +177,11 @@ def test_rag_admin_can_delete_collection_on_accessible_repo(ctx): event["pathParameters"] = {"repositoryId": "repo-1", "collectionId": "coll-1"} event["queryStringParameters"] = {} - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs, patch("repository.lambda_functions.collection_service") as mcs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + patch("repository.lambda_functions.collection_service") as mcs, + ): mvs.find_repository_by_id.return_value = ACCESSIBLE_REPO mcs.delete_collection.return_value = {"deleted": True} @@ -197,8 +203,8 @@ def test_rag_admin_can_delete_collection_on_accessible_repo(ctx): def test_effective_admin_passed_to_collection_service(ctx, is_rag_admin_val, expected): """Verify effective_admin (is_admin OR is_rag_admin) is passed to collection_service. - Call-arg inspection is necessary here because collection_service is always mocked - at this layer — it's an external dependency boundary. + Call-arg inspection is necessary here because collection_service is always mocked at this layer — it's an external + dependency boundary. """ username = "rag-admin-user" if is_rag_admin_val else "regular-user" groups = ["rag-team", "rag-admins"] if is_rag_admin_val else ["rag-team"] @@ -206,9 +212,11 @@ def test_effective_admin_passed_to_collection_service(ctx, is_rag_admin_val, exp event["pathParameters"] = {"repositoryId": "repo-1", "collectionId": "coll-1"} event["body"] = json.dumps({"name": "Updated"}) - with _auth_context(username, groups, is_rag_admin_val=is_rag_admin_val), patch( - "repository.lambda_functions.vs_repo" - ) as mvs, patch("repository.lambda_functions.collection_service") as mcs: + with ( + _auth_context(username, groups, is_rag_admin_val=is_rag_admin_val), + patch("repository.lambda_functions.vs_repo") as mvs, + patch("repository.lambda_functions.collection_service") as mcs, + ): mvs.find_repository_by_id.return_value = ACCESSIBLE_REPO mock_coll = MagicMock() mock_coll.model_dump.return_value = {"collectionId": "coll-1"} @@ -240,9 +248,10 @@ def test_rag_admin_can_update_pipelines_on_accessible_repo(ctx): event["pathParameters"] = {"repositoryId": "repo-1"} event["body"] = json.dumps({"pipelines": new_pipelines}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = {**ACCESSIBLE_REPO, "config": ACCESSIBLE_REPO} mvs.update.return_value = {**ACCESSIBLE_REPO, "pipelines": new_pipelines} @@ -277,11 +286,12 @@ def test_rag_admin_can_add_new_pipeline_to_accessible_repo(ctx): "executionArn": "arn:execution:123", } - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs, patch("repository.lambda_functions.ssm_client") as mock_ssm, patch( - "repository.lambda_functions.step_functions_client" - ) as mock_sf: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + patch("repository.lambda_functions.ssm_client") as mock_ssm, + patch("repository.lambda_functions.step_functions_client") as mock_sf, + ): mvs.find_repository_by_id.return_value = {**ACCESSIBLE_REPO, "config": ACCESSIBLE_REPO} mvs.update.return_value = updated_config mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "arn:test-state-machine"}} @@ -303,9 +313,10 @@ def test_rag_admin_cannot_update_allowed_groups(ctx): event["pathParameters"] = {"repositoryId": "repo-1"} event["body"] = json.dumps({"allowedGroups": ["new-group"]}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = {**ACCESSIBLE_REPO, "config": ACCESSIBLE_REPO} from repository.lambda_functions import update_repository @@ -321,9 +332,10 @@ def test_rag_admin_cannot_update_mixed_fields(ctx): event["pathParameters"] = {"repositoryId": "repo-1"} event["body"] = json.dumps({"pipelines": [], "name": "sneaky-rename"}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = {**ACCESSIBLE_REPO, "config": ACCESSIBLE_REPO} from repository.lambda_functions import update_repository @@ -339,9 +351,10 @@ def test_rag_admin_cannot_update_mixed_fields(ctx): def test_rag_admin_sees_only_group_accessible_repos_in_list(ctx): event = _make_event("rag-admin-user", ["rag-team", "rag-admins"]) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.get_registered_repositories.return_value = [ACCESSIBLE_REPO, INACCESSIBLE_REPO] from repository.lambda_functions import list_all @@ -408,9 +421,10 @@ def test_rag_admin_update_bad_body_does_not_500(ctx, body_value, expected_status else: event["body"] = body_value - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = {**ACCESSIBLE_REPO, "config": ACCESSIBLE_REPO} mvs.update.return_value = ACCESSIBLE_REPO @@ -427,9 +441,10 @@ def test_rag_admin_cannot_update_repository_on_inaccessible_repo(ctx): event["pathParameters"] = {"repositoryId": "repo-2"} event["body"] = json.dumps({"pipelines": []}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = INACCESSIBLE_REPO from repository.lambda_functions import update_repository @@ -445,17 +460,17 @@ def test_rag_admin_cannot_update_repository_on_inaccessible_repo(ctx): def test_rag_admin_list_user_collections_passes_is_rag_admin(ctx): """list_user_collections passes is_rag_admin=True for RAG admin callers. - RAG admins get scoped-admin collection access (bypass collection-level - allowedGroups) within repos they have group access to. Repo-level filtering - uses is_admin (real flag), so RAG admins do NOT see all repos — only their - group-accessible ones. is_rag_admin is threaded through to collection filtering. + RAG admins get scoped-admin collection access (bypass collection-level allowedGroups) within repos they have group + access to. Repo-level filtering uses is_admin (real flag), so RAG admins do NOT see all repos — only their group- + accessible ones. is_rag_admin is threaded through to collection filtering. """ event = _make_event("rag-admin-user", ["rag-team", "rag-admins"]) event["queryStringParameters"] = {} - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.collection_service" - ) as mcs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.collection_service") as mcs, + ): mcs.list_all_user_collections.return_value = ([], None) from repository.lambda_functions import list_user_collections @@ -489,9 +504,10 @@ def test_rag_admin_can_update_bedrock_knowledge_base_config(ctx): "config": {**ACCESSIBLE_REPO, "type": "bedrock_kb"}, } - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = bedrock_repo mvs.update.return_value = bedrock_repo @@ -508,8 +524,8 @@ def test_rag_admin_can_update_bedrock_knowledge_base_config(ctx): def test_rag_admin_update_filters_serialized_output(ctx): """Defense-in-depth filter strips non-allowed fields from model_dump output. - Even if Pydantic populates default values during serialization, the second - filter (lines 1613-1615) ensures only allowed fields reach the update call. + Even if Pydantic populates default values during serialization, the second filter (lines 1613-1615) ensures only + allowed fields reach the update call. """ event = _make_event("rag-admin-user", ["rag-team", "rag-admins"]) event["pathParameters"] = {"repositoryId": "repo-1"} @@ -525,9 +541,10 @@ def test_rag_admin_update_filters_serialized_output(ctx): ] event["body"] = json.dumps({"pipelines": new_pipelines}) - with _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs: + with ( + _auth_context("rag-admin-user", ["rag-team", "rag-admins"], is_rag_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + ): mvs.find_repository_by_id.return_value = {**ACCESSIBLE_REPO, "config": ACCESSIBLE_REPO} mvs.update.return_value = {**ACCESSIBLE_REPO, "pipelines": new_pipelines} @@ -554,9 +571,11 @@ def test_admin_can_create_collection(ctx): event["pathParameters"] = {"repositoryId": "repo-1"} event["body"] = json.dumps({"name": "New Collection", "embeddingModel": "model-1"}) - with _auth_context("admin-user", ["admin"], is_admin_val=True), patch( - "repository.lambda_functions.vs_repo" - ) as mvs, patch("repository.lambda_functions.collection_service") as mcs: + with ( + _auth_context("admin-user", ["admin"], is_admin_val=True), + patch("repository.lambda_functions.vs_repo") as mvs, + patch("repository.lambda_functions.collection_service") as mcs, + ): mvs.find_repository_by_id.return_value = ACCESSIBLE_REPO mock_coll = MagicMock() mock_coll.model_dump.return_value = {"collectionId": "new-coll", "name": "New Collection"} diff --git a/test/lambda/test_rag_document_repo_lambda.py b/test/lambda/test_rag_document_repo_lambda.py index b51fd6c7e..5af5dac0a 100644 --- a/test/lambda/test_rag_document_repo_lambda.py +++ b/test/lambda/test_rag_document_repo_lambda.py @@ -80,9 +80,10 @@ def mock_vector_store_repo(): @pytest.fixture(autouse=True) def patch_dynamodb_methods(): - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client: + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_doc_table.put_item.return_value = {} @@ -96,10 +97,11 @@ def patch_dynamodb_methods(): def test_rag_document_repository_init(): """Test RagDocumentRepository initialization.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -119,10 +121,11 @@ def test_rag_document_repository_init(): def test_delete_by_id_success(sample_rag_document): """Test successful deletion by ID.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -157,10 +160,11 @@ def test_delete_by_id_success(sample_rag_document): def test_delete_by_id_no_document(sample_rag_sub_document): """Test deletion by ID when document doesn't exist.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -195,10 +199,11 @@ def test_delete_by_id_no_document(sample_rag_sub_document): def test_delete_by_id_client_error(): """Test deletion by ID with client error.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -232,10 +237,11 @@ def test_delete_by_id_client_error(): def test_save_success(sample_rag_document): """Test successful save operation.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -266,10 +272,11 @@ def test_save_success(sample_rag_document): def test_save_client_error(sample_rag_document): """Test save operation with client error.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -301,10 +308,11 @@ def test_save_client_error(sample_rag_document): def test_find_by_id_success(sample_rag_document): """Test successful find by ID.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -332,10 +340,11 @@ def test_find_by_id_success(sample_rag_document): def test_find_by_id_not_found(): """Test find by ID when document doesn't exist.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -362,10 +371,11 @@ def test_find_by_id_not_found(): def test_find_by_source_success(sample_rag_document): """Test find by source with results.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -392,10 +402,11 @@ def test_find_by_source_success(sample_rag_document): def test_find_by_source_no_results(): """Test find by source with no results.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -422,10 +433,11 @@ def test_find_by_source_no_results(): def test_find_by_source_client_error(): """Test find by source with client error.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -451,10 +463,11 @@ def test_find_by_source_client_error(): def test_list_all_with_repository_id_only(sample_rag_document): """Test list_all with repository_id only.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -482,10 +495,11 @@ def test_list_all_with_repository_id_only(sample_rag_document): def test_list_all_with_collection_id(sample_rag_document): """Test list_all with collection_id.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -513,10 +527,11 @@ def test_list_all_with_collection_id(sample_rag_document): def test_list_all_with_pagination(sample_rag_document): """Test list_all with pagination.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -548,12 +563,12 @@ def test_list_all_with_pagination(sample_rag_document): def test_list_all_with_join_docs(sample_rag_document): """Test list_all with join_docs=True.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, patch( - "repository.rag_document_repo.RagDocumentRepository._get_subdoc_ids", return_value=["subdoc1"] + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + patch("repository.rag_document_repo.RagDocumentRepository._get_subdoc_ids", return_value=["subdoc1"]), ): - mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -587,10 +602,11 @@ def test_list_all_with_join_docs(sample_rag_document): def test_find_subdocs_by_id_success(sample_rag_sub_document): """Test successful find_subdocs_by_id.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -617,10 +633,11 @@ def test_find_subdocs_by_id_success(sample_rag_sub_document): def test_find_subdocs_by_id_with_pagination(sample_rag_sub_document): """Test find_subdocs_by_id with pagination.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -649,10 +666,11 @@ def test_find_subdocs_by_id_with_pagination(sample_rag_sub_document): def test_delete_s3_object_success(): """Test successful S3 object deletion.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -676,10 +694,11 @@ def test_delete_s3_object_success(): def test_delete_s3_object_client_error(): """Test S3 object deletion with client error.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -704,10 +723,11 @@ def test_delete_s3_object_client_error(): def test_delete_s3_docs_manual_ingestion(sample_rag_document): """Test delete_s3_docs with manual ingestion.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -737,10 +757,11 @@ def test_delete_s3_docs_auto_ingestion_with_auto_remove(sample_rag_document): # Create document with auto_remove=True sample_rag_document.ingestion_type = IngestionType.AUTO - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -775,10 +796,11 @@ def test_delete_s3_docs_auto_ingestion_without_auto_remove(sample_rag_document): # Create document with auto_remove=False sample_rag_document.ingestion_type = IngestionType.AUTO - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() @@ -810,10 +832,11 @@ def test_delete_s3_docs_auto_ingestion_without_auto_remove(sample_rag_document): def test_delete_s3_docs_no_pipelines(sample_rag_document): """Test delete_s3_docs with no pipelines.""" - with patch("repository.rag_document_repo.boto3.resource") as mock_resource, patch( - "repository.rag_document_repo.boto3.client" - ) as mock_client, patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class: - + with ( + patch("repository.rag_document_repo.boto3.resource") as mock_resource, + patch("repository.rag_document_repo.boto3.client") as mock_client, + patch("repository.rag_document_repo.VectorStoreRepository") as mock_vs_repo_class, + ): mock_doc_table = MagicMock() mock_subdoc_table = MagicMock() mock_s3_client = MagicMock() diff --git a/test/lambda/test_repository_lambda.py b/test/lambda/test_repository_lambda.py index da08a7764..f9f91680c 100644 --- a/test/lambda/test_repository_lambda.py +++ b/test/lambda/test_repository_lambda.py @@ -273,9 +273,8 @@ def mock_boto3_client(*args, **kwargs): def _admin_only_patch_fixture(): """Ensure admin_only and rag_admin_or_admin patches are stopped when this module's tests complete. - The patches must be started at import time so repository.lambda_functions - imports with the mocked decorators. This fixture cleans them up to avoid - leaking into other test modules and order-dependent failures. + The patches must be started at import time so repository.lambda_functions imports with the mocked decorators. This + fixture cleans them up to avoid leaking into other test modules and order-dependent failures. """ yield _admin_only_patch.stop() @@ -381,7 +380,7 @@ def sample_repository(): def test_list_all(): - """Test list_all lambda function""" + """Test list_all lambda function.""" # Create a patched version that returns the expected repository list def mock_list_all_func(event, context): @@ -408,7 +407,7 @@ def mock_list_all_func(event, context): def test_list_status(): - """Test list_status lambda function""" + """Test list_status lambda function.""" # Create a patched version that returns the expected repository status def mock_list_status_func(event, context): @@ -434,7 +433,7 @@ def mock_list_status_func(event, context): def test_similarity_search(): - """Test similarity_search lambda function""" + """Test similarity_search lambda function.""" # Create a patched version that returns the expected search results def mock_similarity_search_func(event, context): @@ -465,7 +464,7 @@ def mock_similarity_search_func(event, context): def test_ingest_documents(): - """Test ingest_documents lambda function""" + """Test ingest_documents lambda function.""" # Create a patched version that returns the expected response def mock_ingest_documents_func(event, context): @@ -498,7 +497,7 @@ def mock_ingest_documents_func(event, context): def test_download_document(): - """Test download_document lambda function""" + """Test download_document lambda function.""" # Create a patched version that returns the expected URL def mock_download_document_func(event, context): @@ -524,7 +523,7 @@ def mock_download_document_func(event, context): def test_list_docs(): - """Test list_docs lambda function""" + """Test list_docs lambda function.""" # Create a patched version that returns the expected document list def mock_list_docs_func(event, context): @@ -553,7 +552,7 @@ def mock_list_docs_func(event, context): def test_delete(): - """Test delete lambda function""" + """Test delete lambda function.""" # Create a patched version that returns the success response def mock_delete_func(event, context): @@ -582,7 +581,7 @@ def mock_delete_func(event, context): def test_delete_documents_by_id(): - """Test delete_documents lambda function by document id""" + """Test delete_documents lambda function by document id.""" # Create a patched version that returns the expected response def mock_delete_documents_func(event, context): @@ -618,7 +617,7 @@ def mock_delete_documents_func(event, context): def test_delete_documents_by_name(): - """Test delete_documents lambda function by document name""" + """Test delete_documents lambda function by document name.""" # Create a patched version that returns the expected response def mock_delete_documents_func(event, context): @@ -648,7 +647,7 @@ def mock_delete_documents_func(event, context): def test_delete_documents_error(): - """Test delete_documents lambda function with no document parameters""" + """Test delete_documents lambda function with no document parameters.""" # Create a patched version that raises an exception def mock_delete_documents_func(event, context): @@ -677,7 +676,7 @@ def mock_delete_documents_func(event, context): def test_delete_documents_unauthorized(): - """Test delete_documents lambda function with unauthorized access""" + """Test delete_documents lambda function with unauthorized access.""" # Create a patched version that raises an exception def mock_delete_documents_func(event, context): @@ -709,7 +708,7 @@ def mock_delete_documents_func(event, context): def test_presigned_url(): - """Test presigned_url lambda function""" + """Test presigned_url lambda function.""" # Create test event event = {"requestContext": {"authorizer": {"claims": {"username": "test-user"}}}, "body": "test-key"} @@ -728,7 +727,7 @@ def test_presigned_url(): def test_create(): - """Test create lambda function""" + """Test create lambda function.""" # Create a patched version that returns the success response def mock_create_func(event, context): @@ -759,7 +758,7 @@ def mock_create_func(event, context): def test_delete_legacy(): - """Test delete lambda function with legacy repository""" + """Test delete lambda function with legacy repository.""" # Create a patched version that returns the success response def mock_delete_func(event, context): @@ -788,7 +787,7 @@ def mock_delete_func(event, context): def test_delete_missing_repository_id(): - """Test delete lambda function with missing repository ID""" + """Test delete lambda function with missing repository ID.""" # Create a patched version that raises a ValidationError def mock_delete_func(event, context): @@ -815,7 +814,7 @@ def mock_delete_func(event, context): def test_RagEmbeddings_error(): - """Test error handling in RagEmbeddings function""" + """Test error handling in RagEmbeddings function.""" # Create a patched version of the class that raises an error def mock_RagEmbeddings(model_name, api_key): @@ -829,7 +828,7 @@ def mock_RagEmbeddings(model_name, api_key): def test_similarity_search_forbidden(): - """Test similarity_search with forbidden access""" + """Test similarity_search with forbidden access.""" # Create a patched version that raises a permission error def mock_similarity_search_func(event, context): @@ -861,7 +860,7 @@ def mock_similarity_search_func(event, context): def test_remove_legacy(): - """Test _remove_legacy function""" + """Test _remove_legacy function.""" # Create a patched version of the function def mock_remove_legacy(repository_id): @@ -876,7 +875,7 @@ def mock_remove_legacy(repository_id): def test_pipeline_embeddings_embed_documents_error(): - """Test error handling in LisaOpenAIEmbeddings.embed_documents""" + """Test error handling in LisaOpenAIEmbeddings.embed_documents.""" # Mock the function to raise an exception def mock_embed_documents(docs): @@ -896,7 +895,7 @@ def mock_embed_documents(docs): def test_embeddings_embed_query_error(): - """Test error handling in OpenAIEmbeddings.embed_query""" + """Test error handling in OpenAIEmbeddings.embed_query.""" # Mock the function to raise an exception def mock_embed_query(query): @@ -919,7 +918,7 @@ def mock_embed_query(query): def test_get_repository_unauthorized(): - """Test get_repository with unauthorized access""" + """Test get_repository with unauthorized access.""" # Create a mock function that raises an exception def mock_get_repository(event, repository_id): @@ -942,7 +941,7 @@ def mock_get_repository(event, repository_id): def test_document_ownership_validation(): - """Test document ownership validation logic""" + """Test document ownership validation logic.""" from models.domain_objects import ChunkingStrategyType, FixedChunkingStrategy, RagDocument # Test case 1: User is admin @@ -1017,8 +1016,7 @@ def test_document_ownership_validation(): def test_validate_model_name(): - """Test validate_model_name function""" - + """Test validate_model_name function.""" # Test valid model name assert validate_model_name("embedding-model") is True @@ -1031,15 +1029,15 @@ def test_validate_model_name(): def test_repository_access_validation(): - """Test get_repository access validation logic""" - + """Test get_repository access validation logic.""" # Test case 1: User is admin - get_repository should return the repository event = { "requestContext": {"authorizer": {"claims": {"username": "admin-user"}, "groups": json.dumps(["admin-group"])}} } - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.is_admin", return_value=True + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.is_admin", return_value=True), ): mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["admin-group"], "status": "active"} # Admin should always have access @@ -1054,8 +1052,9 @@ def test_repository_access_validation(): "requestContext": {"authorizer": {"claims": {"username": "test-user"}, "groups": json.dumps(["test-group"])}} } - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.is_admin", return_value=False + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.is_admin", return_value=False), ): mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} # User has the right group @@ -1067,9 +1066,11 @@ def test_repository_access_validation(): "requestContext": {"authorizer": {"claims": {"username": "test-user"}, "groups": json.dumps(["wrong-group"])}} } - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.is_admin", return_value=False - ), patch("repository.lambda_functions.get_groups", return_value=["wrong-group"]): + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.is_admin", return_value=False), + patch("repository.lambda_functions.get_groups", return_value=["wrong-group"]), + ): mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} # User doesn't have the right group with pytest.raises(HTTPException) as exc_info: @@ -1081,13 +1082,14 @@ def test_repository_access_validation(): def test_RagEmbeddings_function(): - """Test the RagEmbeddings function""" + """Test the RagEmbeddings function.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings.get_management_key") as mock_key: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings.get_management_key") as mock_key, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" mock_key.return_value = "test-token" @@ -1100,13 +1102,14 @@ def test_RagEmbeddings_function(): def test_pipeline_embeddings_init(): - """Test RagEmbeddings initialization""" + """Test RagEmbeddings initialization.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_management_key") as mock_management_key, patch( - "repository.embeddings.get_rest_api_container_endpoint" - ) as mock_endpoint, patch("repository.embeddings.get_cert_path") as mock_cert: - + with ( + patch("repository.embeddings.get_management_key") as mock_management_key, + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + ): mock_management_key.return_value = "test-token" mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1121,7 +1124,7 @@ def test_pipeline_embeddings_init(): def test_pipeline_embeddings_init_error(): - """Test LisaOpenAIEmbeddings initialization error handling""" + """Test LisaOpenAIEmbeddings initialization error handling.""" from repository.embeddings import RagEmbeddings with patch("repository.embeddings.ssm_client") as mock_ssm: @@ -1132,13 +1135,14 @@ def test_pipeline_embeddings_init_error(): def test_pipeline_embeddings_embed_documents(): - """Test RagEmbeddings embed_documents method""" + """Test RagEmbeddings embed_documents method.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1161,13 +1165,14 @@ def test_pipeline_embeddings_embed_documents(): def test_pipeline_embeddings_embed_documents_no_texts(): - """Test RagEmbeddings embed_documents with no texts""" + """Test RagEmbeddings embed_documents with no texts.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings.get_management_key") as mock_key: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings.get_management_key") as mock_key, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" mock_key.return_value = "test-token" @@ -1179,13 +1184,14 @@ def test_pipeline_embeddings_embed_documents_no_texts(): def test_pipeline_embeddings_embed_documents_api_error(): - """Test RagEmbeddings embed_documents with API error""" + """Test RagEmbeddings embed_documents with API error.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1200,13 +1206,14 @@ def test_pipeline_embeddings_embed_documents_api_error(): def test_pipeline_embeddings_embed_documents_timeout(): - """Test RagEmbeddings embed_documents with timeout""" + """Test RagEmbeddings embed_documents with timeout.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1221,13 +1228,14 @@ def test_pipeline_embeddings_embed_documents_timeout(): def test_pipeline_embeddings_embed_documents_different_formats(): - """Test RagEmbeddings embed_documents with different response formats""" + """Test RagEmbeddings embed_documents with different response formats.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1251,13 +1259,14 @@ def test_pipeline_embeddings_embed_documents_different_formats(): def test_pipeline_embeddings_embed_documents_no_embeddings(): - """Test RagEmbeddings embed_documents with no embeddings in response""" + """Test RagEmbeddings embed_documents with no embeddings in response.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1276,13 +1285,14 @@ def test_pipeline_embeddings_embed_documents_no_embeddings(): def test_pipeline_embeddings_embed_documents_mismatch(): - """Test RagEmbeddings embed_documents with embedding count mismatch""" + """Test RagEmbeddings embed_documents with embedding count mismatch.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1301,13 +1311,14 @@ def test_pipeline_embeddings_embed_documents_mismatch(): def test_pipeline_embeddings_embed_query(): - """Test RagEmbeddings embed_query method""" + """Test RagEmbeddings embed_query method.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings._get_http_session") as mock_get_session: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings._get_http_session") as mock_get_session, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" @@ -1326,13 +1337,14 @@ def test_pipeline_embeddings_embed_query(): def test_pipeline_embeddings_embed_query_invalid(): - """Test RagEmbeddings embed_query with invalid input""" + """Test RagEmbeddings embed_query with invalid input.""" from repository.embeddings import RagEmbeddings - with patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, patch( - "repository.embeddings.get_cert_path" - ) as mock_cert, patch("repository.embeddings.get_management_key") as mock_key: - + with ( + patch("repository.embeddings.get_rest_api_container_endpoint") as mock_endpoint, + patch("repository.embeddings.get_cert_path") as mock_cert, + patch("repository.embeddings.get_management_key") as mock_key, + ): mock_endpoint.return_value = "https://api.example.com" mock_cert.return_value = "/path/to/cert" mock_key.return_value = "test-token" @@ -1347,14 +1359,14 @@ def test_pipeline_embeddings_embed_query_invalid(): def test_real_list_all_function(): - """Test the actual list_all function with real imports""" + """Test the actual list_all function with real imports.""" from repository.lambda_functions import list_all # Mock the vs_repo to return test data - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "utilities.auth.get_groups" - ) as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): mock_get_groups.return_value = ["test-group"] mock_vs_repo.get_registered_repositories.return_value = [ {"name": "Test Repo", "type": "opensearch", "allowedGroups": ["test-group"], "status": "active"} @@ -1376,7 +1388,7 @@ def test_real_list_all_function(): def test_real_list_status_function(): - """Test the actual list_status function with real imports""" + """Test the actual list_status function with real imports.""" from repository.lambda_functions import list_status with patch("repository.lambda_functions.vs_repo") as mock_vs_repo: @@ -1395,15 +1407,15 @@ def test_real_list_status_function(): def test_real_similarity_search_function(): - """Test the actual similarity_search function with real imports""" + """Test the actual similarity_search function with real imports.""" from repository.lambda_functions import similarity_search - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.embeddings.RagEmbeddings" - ) as mock_RagEmbeddings, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.common_functions.get_id_token" - ) as mock_get_token: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.embeddings.RagEmbeddings") as mock_RagEmbeddings, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.common_functions.get_id_token") as mock_get_token, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_get_token.return_value = "test-token" @@ -1440,7 +1452,7 @@ def test_real_similarity_search_function(): def test_real_similarity_search_missing_params(): - """Test similarity_search with missing required parameters""" + """Test similarity_search with missing required parameters.""" from repository.lambda_functions import similarity_search with patch("repository.lambda_functions.vs_repo") as mock_vs_repo: @@ -1465,17 +1477,16 @@ def test_real_similarity_search_missing_params(): def test_real_delete_documents_function(): - """Test the actual delete_documents function""" + """Test the actual delete_documents function.""" from repository.lambda_functions import delete_documents - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, patch( - "utilities.auth.is_admin" - ) as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_get_username.return_value = "test-user" @@ -1500,15 +1511,15 @@ def test_real_delete_documents_function(): def test_real_ingest_documents_function(): - """Test the actual ingest_documents function""" + """Test the actual ingest_documents function.""" from repository.lambda_functions import ingest_documents - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.ingestion_service" - ) as mock_ingestion, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.ingestion_service") as mock_ingestion, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_get_username.return_value = "test-user" @@ -1532,19 +1543,17 @@ def test_real_ingest_documents_function(): def test_real_download_document_function(): - """Test the actual download_document function""" + """Test the actual download_document function.""" from repository.lambda_functions import download_document - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("repository.lambda_functions.s3") as mock_s3, patch( - "utilities.auth.get_groups" - ) as mock_get_groups, patch( - "utilities.auth.get_username" - ) as mock_get_username, patch( - "utilities.auth.is_admin" - ) as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("repository.lambda_functions.s3") as mock_s3, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.get_username") as mock_get_username, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_get_username.return_value = "test-user" @@ -1578,13 +1587,14 @@ def test_real_download_document_function(): @mock_aws() def test_real_list_docs_function(): - """Test the actual list_docs function""" + """Test the actual list_docs function.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] @@ -1613,13 +1623,14 @@ def test_real_list_docs_function(): @mock_aws() def test_list_docs_with_pagination(): - """Test list_docs function with pagination parameters""" + """Test list_docs function with pagination parameters.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} @@ -1661,13 +1672,14 @@ def test_list_docs_with_pagination(): @mock_aws() def test_list_docs_with_previous_page(): - """Test list_docs function with previous page indicator""" + """Test list_docs function with previous page indicator.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} @@ -1698,13 +1710,14 @@ def test_list_docs_with_previous_page(): @mock_aws() def test_list_docs_with_custom_page_size(): - """Test list_docs function with custom page size""" + """Test list_docs function with custom page size.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} @@ -1732,13 +1745,14 @@ def test_list_docs_with_custom_page_size(): @mock_aws() def test_list_docs_with_edge_case_page_sizes(): - """Test list_docs function with edge case page sizes""" + """Test list_docs function with edge case page sizes.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} @@ -1768,13 +1782,14 @@ def test_list_docs_with_edge_case_page_sizes(): @mock_aws() def test_list_docs_with_encoded_pagination_keys(): - """Test list_docs function with URL-encoded pagination keys""" + """Test list_docs function with URL-encoded pagination keys.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_doc_repo, patch("utilities.auth.get_groups") as mock_get_groups: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.doc_repo") as mock_doc_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} @@ -1807,13 +1822,14 @@ def test_list_docs_with_encoded_pagination_keys(): def test_real_create_function(): - """Test the actual create function""" + """Test the actual create function.""" from repository.lambda_functions import create - with patch("repository.lambda_functions.step_functions_client") as mock_sf, patch( - "repository.lambda_functions.ssm_client" - ) as mock_ssm, patch("utilities.auth.is_admin") as mock_is_admin: - + with ( + patch("repository.lambda_functions.step_functions_client") as mock_sf, + patch("repository.lambda_functions.ssm_client") as mock_ssm, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Setup mocks mock_is_admin.return_value = True mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "test-arn"}} @@ -1833,15 +1849,15 @@ def test_real_create_function(): def test_real_delete_function(): - """Test the actual delete function""" + """Test the actual delete function.""" from repository.lambda_functions import delete - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.step_functions_client" - ) as mock_sf, patch("repository.embeddings.ssm_client") as mock_ssm, patch( - "utilities.auth.is_admin" - ) as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.step_functions_client") as mock_sf, + patch("repository.embeddings.ssm_client") as mock_ssm, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Setup mocks mock_is_admin.return_value = True mock_vs_repo.find_repository_by_id.return_value = {"stackName": "test-stack"} @@ -1862,13 +1878,14 @@ def test_real_delete_function(): def test_real_delete_function_legacy(): - """Test the actual delete function with legacy repository""" + """Test the actual delete function with legacy repository.""" from repository.lambda_functions import delete - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "utilities.auth.is_admin" - ) as mock_is_admin, patch("repository.lambda_functions._remove_legacy") as mock_remove_legacy: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("utilities.auth.is_admin") as mock_is_admin, + patch("repository.lambda_functions._remove_legacy") as mock_remove_legacy, + ): # Setup mocks mock_is_admin.return_value = True # Return a legacy repository config instead of None @@ -1890,11 +1907,10 @@ def test_real_delete_function_legacy(): def test_remove_legacy_function(): - """Test the _remove_legacy function""" + """Test the _remove_legacy function.""" from repository.lambda_functions import _remove_legacy with patch("repository.lambda_functions.ssm_client") as mock_ssm: - # Mock SSM to return valid JSON with repositories repositories = [ {"repositoryId": "test-repo", "name": "Test Repo"}, @@ -1912,13 +1928,14 @@ def test_remove_legacy_function(): def test_ensure_repository_access_edge_cases(): """Test repository access validation with edge cases (now handled in get_repository)""" - # Test with missing groups in event - get_groups returns empty list, so user has no access event = {"requestContext": {"authorizer": {"claims": {"username": "test-user"}}}} - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.is_admin", return_value=False - ), patch("repository.lambda_functions.get_groups", return_value=[]): + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.is_admin", return_value=False), + patch("repository.lambda_functions.get_groups", return_value=[]), + ): mock_vs_repo.find_repository_by_id.return_value = {"allowedGroups": ["test-group"], "status": "active"} # get_repository will raise HTTPException because user has no groups (empty list) @@ -1928,8 +1945,7 @@ def test_ensure_repository_access_edge_cases(): def test_ensure_document_ownership_edge_cases(): - """Test _ensure_document_ownership with edge cases""" - + """Test _ensure_document_ownership with edge cases.""" # Test with empty docs list event = {"requestContext": {"authorizer": {"claims": {"username": "test-user"}}}} @@ -1938,7 +1954,7 @@ def test_ensure_document_ownership_edge_cases(): def test_enrich_metadata_with_document_id(): - """Test enrich_metadata_with_document_id function""" + """Test enrich_metadata_with_document_id function.""" from repository.lambda_functions import enrich_metadata_with_document_id with patch("repository.lambda_functions.doc_repo") as mock_doc_repo: @@ -1968,7 +1984,7 @@ def test_enrich_metadata_with_document_id(): def test_enrich_metadata_with_document_id_missing_source(): - """Test enrich_metadata_with_document_id with missing source""" + """Test enrich_metadata_with_document_id with missing source.""" from repository.lambda_functions import enrich_metadata_with_document_id with patch("repository.lambda_functions.doc_repo") as mock_doc_repo: @@ -1987,7 +2003,7 @@ def test_enrich_metadata_with_document_id_missing_source(): def test_enrich_metadata_with_document_id_not_found(): - """Test enrich_metadata_with_document_id when RAG document not found""" + """Test enrich_metadata_with_document_id when RAG document not found.""" from repository.lambda_functions import enrich_metadata_with_document_id with patch("repository.lambda_functions.doc_repo") as mock_doc_repo: @@ -2005,7 +2021,7 @@ def test_enrich_metadata_with_document_id_not_found(): def test_enrich_metadata_with_document_id_exception(): - """Test enrich_metadata_with_document_id handles exceptions gracefully""" + """Test enrich_metadata_with_document_id handles exceptions gracefully.""" from repository.lambda_functions import enrich_metadata_with_document_id with patch("repository.lambda_functions.doc_repo") as mock_doc_repo: @@ -2023,17 +2039,16 @@ def test_enrich_metadata_with_document_id_exception(): def test_real_similarity_search_bedrock_kb_function(): - """Test the actual similarity_search function for Bedrock Knowledge Base repositories""" + """Test the actual similarity_search function for Bedrock Knowledge Base repositories.""" from repository.lambda_functions import similarity_search - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.bedrock_client" - ) as mock_bedrock, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service, patch( - "repository.lambda_functions.enrich_metadata_with_document_id" - ) as mock_enrich: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.bedrock_client") as mock_bedrock, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + patch("repository.lambda_functions.enrich_metadata_with_document_id") as mock_enrich, + ): mock_get_groups.return_value = ["test-group"] mock_vs_repo.find_repository_by_id.return_value = { "repositoryId": "test-repo", @@ -2094,7 +2109,7 @@ def test_real_similarity_search_bedrock_kb_function(): @mock_aws() def test_list_jobs_function(): - """Test the list_jobs function""" + """Test the list_jobs function.""" from repository.lambda_functions import list_jobs # Override global mocks for this test @@ -2102,16 +2117,14 @@ def test_list_jobs_function(): mock_common.is_admin.return_value = True try: - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.ingestion_job_repository" - ) as mock_job_repo, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.is_admin" - ) as mock_is_admin, patch( - "utilities.auth.get_username" - ) as mock_get_username, patch( - "repository.lambda_functions.get_user_context" - ) as mock_get_user_context: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.ingestion_job_repository") as mock_job_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_username") as mock_get_username, + patch("repository.lambda_functions.get_user_context") as mock_get_user_context, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_is_admin.return_value = True # Admin access required @@ -2197,7 +2210,7 @@ def test_list_jobs_function(): @mock_aws() def test_list_jobs_missing_repository_id(): - """Test list_jobs function with missing repository ID""" + """Test list_jobs function with missing repository ID.""" from repository.lambda_functions import list_jobs with patch("utilities.auth.is_admin") as mock_is_admin: @@ -2220,13 +2233,14 @@ def test_list_jobs_missing_repository_id(): @mock_aws() def test_list_jobs_unauthorized_access(): - """Test list_jobs function with unauthorized access""" + """Test list_jobs function with unauthorized access.""" from repository.lambda_functions import list_jobs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "utilities.auth.get_groups" - ) as mock_get_groups, patch("utilities.auth.is_admin") as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Setup mocks - user is not admin and doesn't have group access mock_get_groups.return_value = ["other-group"] mock_is_admin.return_value = False @@ -2254,17 +2268,16 @@ def test_list_jobs_unauthorized_access(): @mock_aws() def test_list_jobs_empty_results(): - """Test list_jobs function with no jobs found""" + """Test list_jobs function with no jobs found.""" from repository.lambda_functions import list_jobs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.ingestion_job_repository" - ) as mock_job_repo, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.is_admin" - ) as mock_is_admin, patch( - "utilities.auth.get_username" - ) as mock_get_username: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.ingestion_job_repository") as mock_job_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_username") as mock_get_username, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_is_admin.return_value = True @@ -2301,15 +2314,15 @@ def test_list_jobs_empty_results(): @mock_aws() def test_list_jobs_malformed_dynamodb_items(): - """Test list_jobs function with error in repository layer""" + """Test list_jobs function with error in repository layer.""" from repository.lambda_functions import list_jobs - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.ingestion_job_repository" - ) as mock_job_repo, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.is_admin" - ) as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.ingestion_job_repository") as mock_job_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_is_admin.return_value = True @@ -2336,7 +2349,7 @@ def test_list_jobs_malformed_dynamodb_items(): @mock_aws() def test_list_jobs_with_pagination(): - """Test list_jobs function with pagination parameters""" + """Test list_jobs function with pagination parameters.""" from repository.lambda_functions import list_jobs # Override global mocks for this test @@ -2344,16 +2357,14 @@ def test_list_jobs_with_pagination(): mock_common.is_admin.return_value = True try: - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.ingestion_job_repository" - ) as mock_job_repo, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.is_admin" - ) as mock_is_admin, patch( - "utilities.auth.get_username" - ) as mock_get_username, patch( - "repository.lambda_functions.get_user_context" - ) as mock_get_user_context: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.ingestion_job_repository") as mock_job_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_username") as mock_get_username, + patch("repository.lambda_functions.get_user_context") as mock_get_user_context, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_is_admin.return_value = True @@ -2429,7 +2440,7 @@ def test_list_jobs_with_pagination(): @mock_aws() def test_list_jobs_with_last_evaluated_key(): - """Test list_jobs function with lastEvaluatedKey parameter""" + """Test list_jobs function with lastEvaluatedKey parameter.""" from repository.lambda_functions import list_jobs # Override global mocks for this test @@ -2437,16 +2448,14 @@ def test_list_jobs_with_last_evaluated_key(): mock_common.is_admin.return_value = True try: - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.ingestion_job_repository" - ) as mock_job_repo, patch("utilities.auth.get_groups") as mock_get_groups, patch( - "utilities.auth.is_admin" - ) as mock_is_admin, patch( - "utilities.auth.get_username" - ) as mock_get_username, patch( - "repository.lambda_functions.get_user_context" - ) as mock_get_user_context: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.ingestion_job_repository") as mock_job_repo, + patch("utilities.auth.get_groups") as mock_get_groups, + patch("utilities.auth.is_admin") as mock_is_admin, + patch("utilities.auth.get_username") as mock_get_username, + patch("repository.lambda_functions.get_user_context") as mock_get_user_context, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_is_admin.return_value = True @@ -2522,24 +2531,19 @@ def test_list_jobs_with_last_evaluated_key(): @mock_aws() def test_ingest_documents_with_chunking_override(): - """Test ingest_documents with chunking strategy override""" + """Test ingest_documents with chunking strategy override.""" from models.domain_objects import CollectionStatus, FixedChunkingStrategy, RagCollectionConfig from repository.lambda_functions import ingest_documents - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service, patch( - "repository.lambda_functions.ingestion_job_repository" - ) as mock_ingestion_job_repo, patch( - "repository.lambda_functions.ingestion_service" - ) as mock_ingestion_service, patch( - "repository.lambda_functions.get_groups" - ) as mock_get_groups, patch( - "repository.lambda_functions.get_username" - ) as mock_get_username, patch( - "repository.lambda_functions.is_admin" - ) as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + patch("repository.lambda_functions.ingestion_job_repository") as mock_ingestion_job_repo, + patch("repository.lambda_functions.ingestion_service") as mock_ingestion_service, + patch("repository.lambda_functions.get_groups") as mock_get_groups, + patch("repository.lambda_functions.get_username") as mock_get_username, + patch("repository.lambda_functions.is_admin") as mock_is_admin, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_get_username.return_value = "test-user" @@ -2608,18 +2612,17 @@ def test_ingest_documents_with_chunking_override(): def test_ingest_documents_access_denied(): - """Test ingest_documents with access denied to collection""" + """Test ingest_documents with access denied to collection.""" from repository.lambda_functions import ingest_documents from utilities.validation import ValidationError - with patch("repository.lambda_functions.vs_repo") as mock_vs_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service, patch("repository.lambda_functions.get_groups") as mock_get_groups, patch( - "repository.lambda_functions.get_username" - ) as mock_get_username, patch( - "repository.lambda_functions.is_admin" - ) as mock_is_admin: - + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + patch("repository.lambda_functions.get_groups") as mock_get_groups, + patch("repository.lambda_functions.get_username") as mock_get_username, + patch("repository.lambda_functions.is_admin") as mock_is_admin, + ): # Setup mocks mock_get_groups.return_value = ["test-group"] mock_get_username.return_value = "test-user" @@ -2655,11 +2658,12 @@ def test_ingest_documents_access_denied(): def test_get_repository_admin(): - """Test get_repository with admin user""" + """Test get_repository with admin user.""" from repository.lambda_functions import get_repository - with patch("repository.lambda_functions.vs_repo") as mock_repo, patch( - "repository.lambda_functions.is_admin", return_value=True + with ( + patch("repository.lambda_functions.vs_repo") as mock_repo, + patch("repository.lambda_functions.is_admin", return_value=True), ): mock_repo.find_repository_by_id.return_value = {"allowedGroups": ["group1"]} event = {"requestContext": {"authorizer": {"groups": json.dumps(["group2"])}}} @@ -2669,12 +2673,14 @@ def test_get_repository_admin(): def test_get_repository_with_access(): - """Test get_repository with group access""" + """Test get_repository with group access.""" from repository.lambda_functions import get_repository - with patch("repository.lambda_functions.vs_repo") as mock_repo, patch( - "repository.lambda_functions.is_admin", return_value=False - ), patch("repository.lambda_functions.get_groups", return_value=["group1"]): + with ( + patch("repository.lambda_functions.vs_repo") as mock_repo, + patch("repository.lambda_functions.is_admin", return_value=False), + patch("repository.lambda_functions.get_groups", return_value=["group1"]), + ): mock_repo.find_repository_by_id.return_value = {"allowedGroups": ["group1"]} event = {"requestContext": {"authorizer": {"groups": json.dumps(["group1"])}}} @@ -2683,12 +2689,13 @@ def test_get_repository_with_access(): def test_get_repository_no_access(): - """Test get_repository without access""" + """Test get_repository without access.""" from repository.lambda_functions import get_repository from utilities.exceptions import HTTPException - with patch("repository.lambda_functions.vs_repo") as mock_repo, patch( - "repository.lambda_functions.is_admin", return_value=False + with ( + patch("repository.lambda_functions.vs_repo") as mock_repo, + patch("repository.lambda_functions.is_admin", return_value=False), ): mock_repo.find_repository_by_id.return_value = {"allowedGroups": ["group1"]} event = {"requestContext": {"authorizer": {"groups": json.dumps(["group2"])}}} @@ -2698,7 +2705,7 @@ def test_get_repository_no_access(): def test_similarity_search_with_score(): - """Test retrieve_documents with score via service layer""" + """Test retrieve_documents with score via service layer.""" from repository.services.opensearch_repository_service import OpenSearchRepositoryService repository = {"repositoryId": "test-repo", "type": "opensearch"} @@ -2720,7 +2727,7 @@ def test_similarity_search_with_score(): def test_similarity_search_without_score(): - """Test retrieve_documents without score via service layer""" + """Test retrieve_documents without score via service layer.""" from repository.services.opensearch_repository_service import OpenSearchRepositoryService repository = {"repositoryId": "test-repo", "type": "opensearch"} @@ -2742,12 +2749,13 @@ def test_similarity_search_without_score(): def test_ensure_document_ownership_admin(): - """Test _ensure_document_ownership with admin""" + """Test _ensure_document_ownership with admin.""" from models.domain_objects import FixedChunkingStrategy, RagDocument from repository.lambda_functions import _ensure_document_ownership - with patch("repository.lambda_functions.get_username", return_value="admin"), patch( - "repository.lambda_functions.is_admin", return_value=True + with ( + patch("repository.lambda_functions.get_username", return_value="admin"), + patch("repository.lambda_functions.is_admin", return_value=True), ): event = {} doc = RagDocument( @@ -2764,12 +2772,13 @@ def test_ensure_document_ownership_admin(): def test_ensure_document_ownership_owner(): - """Test _ensure_document_ownership with owner""" + """Test _ensure_document_ownership with owner.""" from models.domain_objects import FixedChunkingStrategy, RagDocument from repository.lambda_functions import _ensure_document_ownership - with patch("repository.lambda_functions.get_username", return_value="user1"), patch( - "repository.lambda_functions.is_admin", return_value=False + with ( + patch("repository.lambda_functions.get_username", return_value="user1"), + patch("repository.lambda_functions.is_admin", return_value=False), ): event = {} doc = RagDocument( @@ -2786,12 +2795,13 @@ def test_ensure_document_ownership_owner(): def test_ensure_document_ownership_not_owner(): - """Test _ensure_document_ownership without ownership""" + """Test _ensure_document_ownership without ownership.""" from models.domain_objects import FixedChunkingStrategy, RagDocument from repository.lambda_functions import _ensure_document_ownership - with patch("repository.lambda_functions.get_username", return_value="user1"), patch( - "repository.lambda_functions.is_admin", return_value=False + with ( + patch("repository.lambda_functions.get_username", return_value="user1"), + patch("repository.lambda_functions.is_admin", return_value=False), ): event = {} doc = RagDocument( @@ -2809,12 +2819,14 @@ def test_ensure_document_ownership_not_owner(): def test_list_all_with_groups(): - """Test list_all filters by groups""" + """Test list_all filters by groups.""" from repository.lambda_functions import list_all - with patch("repository.lambda_functions.vs_repo") as mock_repo, patch( - "repository.lambda_functions.get_user_context", return_value=("test-user", False, ["group1"]) - ), patch("repository.lambda_functions.is_admin", return_value=False): + with ( + patch("repository.lambda_functions.vs_repo") as mock_repo, + patch("repository.lambda_functions.get_user_context", return_value=("test-user", False, ["group1"])), + patch("repository.lambda_functions.is_admin", return_value=False), + ): mock_repo.get_registered_repositories.return_value = [ {"allowedGroups": ["group1"], "name": "repo1"}, {"allowedGroups": ["group2"], "name": "repo2"}, @@ -2829,11 +2841,12 @@ def test_list_all_with_groups(): def test_list_status_admin(): - """Test list_status requires admin""" + """Test list_status requires admin.""" from repository.lambda_functions import list_status - with patch("repository.lambda_functions.vs_repo") as mock_repo, patch( - "repository.lambda_functions.is_admin", return_value=True + with ( + patch("repository.lambda_functions.vs_repo") as mock_repo, + patch("repository.lambda_functions.is_admin", return_value=True), ): mock_repo.get_repository_status.return_value = {"repo1": "active"} event = {} @@ -2844,7 +2857,7 @@ def test_list_status_admin(): def test_get_repository_by_id(): - """Test get_repository_by_id""" + """Test get_repository_by_id.""" from repository.lambda_functions import get_repository_by_id with patch("repository.lambda_functions.get_repository") as mock_get: @@ -2857,7 +2870,7 @@ def test_get_repository_by_id(): def test_get_repository_by_id_missing(): - """Test get_repository_by_id with missing id""" + """Test get_repository_by_id with missing id.""" from repository.lambda_functions import get_repository_by_id event = {"pathParameters": {}} @@ -2868,11 +2881,12 @@ def test_get_repository_by_id_missing(): def test_presigned_url_success(): - """Test presigned_url generation""" + """Test presigned_url generation.""" from repository.lambda_functions import presigned_url - with patch("repository.lambda_functions.s3") as mock_s3, patch( - "repository.lambda_functions.get_username", return_value="user1" + with ( + patch("repository.lambda_functions.s3") as mock_s3, + patch("repository.lambda_functions.get_username", return_value="user1"), ): mock_s3.generate_presigned_post.return_value = {"url": "https://test.com", "fields": {}} event = { @@ -2886,12 +2900,13 @@ def test_presigned_url_success(): def test_get_document_success(): - """Test get_document""" + """Test get_document.""" from repository.lambda_functions import get_document - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_repo: + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.doc_repo") as mock_repo, + ): mock_get_repo.return_value = {"repositoryId": "repo1", "allowedGroups": ["users"]} mock_doc = MagicMock() mock_doc.model_dump.return_value = {"documentId": "doc1"} @@ -2908,12 +2923,14 @@ def test_get_document_success(): def test_download_document_success(): - """Test download_document""" + """Test download_document.""" from repository.lambda_functions import download_document - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.doc_repo" - ) as mock_repo, patch("repository.lambda_functions.s3") as mock_s3: + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.doc_repo") as mock_repo, + patch("repository.lambda_functions.s3") as mock_s3, + ): mock_get_repo.return_value = {"repositoryId": "repo1", "allowedGroups": ["users"]} mock_doc = MagicMock() mock_doc.source = "s3://bucket/key" @@ -2931,12 +2948,13 @@ def test_download_document_success(): def test_list_docs_success(): - """Test list_docs""" + """Test list_docs.""" from repository.lambda_functions import list_docs - with patch("repository.lambda_functions.get_repository"), patch( - "repository.lambda_functions.doc_repo" - ) as mock_repo: + with ( + patch("repository.lambda_functions.get_repository"), + patch("repository.lambda_functions.doc_repo") as mock_repo, + ): mock_doc = MagicMock() mock_doc.model_dump.return_value = {"documentId": "doc1"} mock_repo.list_all.return_value = ([mock_doc], None, 1) @@ -2949,7 +2967,7 @@ def test_list_docs_success(): def test_update_repository_success(): - """Test update_repository""" + """Test update_repository.""" from repository.lambda_functions import update_repository with patch("repository.lambda_functions.vs_repo") as mock_vs: @@ -2964,7 +2982,7 @@ def test_update_repository_success(): def test_update_repository_missing_id(): - """Test update_repository with missing id""" + """Test update_repository with missing id.""" from repository.lambda_functions import update_repository event = {"pathParameters": {}, "body": "{}"} @@ -2975,14 +2993,15 @@ def test_update_repository_missing_id(): def test_update_repository_with_pipeline_change(): - """Test update_repository triggers state machine when pipeline changes""" + """Test update_repository triggers state machine when pipeline changes.""" from repository.lambda_functions import update_repository - with patch("repository.lambda_functions.vs_repo") as mock_vs, patch( - "repository.lambda_functions.ssm_client" - ) as mock_ssm, patch("repository.lambda_functions.step_functions_client") as mock_sf, patch( - "utilities.auth.is_admin" - ) as mock_is_admin: + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs, + patch("repository.lambda_functions.ssm_client") as mock_ssm, + patch("repository.lambda_functions.step_functions_client") as mock_sf, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Mock admin access mock_is_admin.return_value = True @@ -3085,12 +3104,14 @@ def test_update_repository_with_pipeline_change(): def test_update_repository_without_pipeline_change(): - """Test update_repository does not trigger state machine when pipeline unchanged""" + """Test update_repository does not trigger state machine when pipeline unchanged.""" from repository.lambda_functions import update_repository - with patch("repository.lambda_functions.vs_repo") as mock_vs, patch( - "repository.lambda_functions.step_functions_client" - ) as mock_sf, patch("utilities.auth.is_admin") as mock_is_admin: + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs, + patch("repository.lambda_functions.step_functions_client") as mock_sf, + patch("utilities.auth.is_admin") as mock_is_admin, + ): # Mock admin access mock_is_admin.return_value = True @@ -3134,12 +3155,14 @@ def test_update_repository_without_pipeline_change(): def test_create_success(): - """Test create repository""" + """Test create repository.""" from repository.lambda_functions import create - with patch("repository.lambda_functions.ssm_client") as mock_ssm, patch( - "repository.lambda_functions.step_functions_client" - ) as mock_sf, patch("utilities.auth.is_admin") as mock_is_admin: + with ( + patch("repository.lambda_functions.ssm_client") as mock_ssm, + patch("repository.lambda_functions.step_functions_client") as mock_sf, + patch("utilities.auth.is_admin") as mock_is_admin, + ): mock_is_admin.return_value = True mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "arn:test"}} mock_sf.start_execution.return_value = {"executionArn": "arn:execution"} @@ -3155,12 +3178,14 @@ def test_create_success(): def test_delete_legacy_repository(): - """Test delete with legacy repository""" + """Test delete with legacy repository.""" from repository.lambda_functions import delete - with patch("repository.lambda_functions.vs_repo") as mock_vs, patch( - "repository.lambda_functions._remove_legacy" - ), patch("repository.lambda_functions.collection_service") as mock_coll: + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs, + patch("repository.lambda_functions._remove_legacy"), + patch("repository.lambda_functions.collection_service") as mock_coll, + ): mock_vs.find_repository_by_id.return_value = {"legacy": True, "repositoryId": "repo1"} mock_coll.list_collections.return_value = MagicMock(collections=[]) @@ -3173,14 +3198,15 @@ def test_delete_legacy_repository(): def test_delete_non_legacy_repository(): - """Test delete with non-legacy repository""" + """Test delete with non-legacy repository.""" from repository.lambda_functions import delete - with patch("repository.lambda_functions.vs_repo") as mock_vs, patch( - "repository.lambda_functions.ssm_client" - ) as mock_ssm, patch("repository.lambda_functions.step_functions_client") as mock_sf, patch( - "repository.lambda_functions.collection_service" - ) as mock_coll: + with ( + patch("repository.lambda_functions.vs_repo") as mock_vs, + patch("repository.lambda_functions.ssm_client") as mock_ssm, + patch("repository.lambda_functions.step_functions_client") as mock_sf, + patch("repository.lambda_functions.collection_service") as mock_coll, + ): mock_vs.find_repository_by_id.return_value = {"stackName": "test-stack", "repositoryId": "repo1"} mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "arn:test"}} mock_sf.start_execution.return_value = {"executionArn": "arn:execution"} @@ -3195,7 +3221,7 @@ def test_delete_non_legacy_repository(): # Additional coverage tests for repository lambda functions def test_similarity_search_helpers(): - """Test retrieve_documents via service layer""" + """Test retrieve_documents via service layer.""" import os from unittest.mock import MagicMock, patch @@ -3243,8 +3269,7 @@ def lambda_event_user_collections(): def test_list_user_collections_endpoint_success_workflow( lambda_event_user_collections, lambda_context, mock_collection_service_for_lambda ): - """ - Complete API workflow: event → handler → service → response with collections. + """Complete API workflow: event → handler → service → response with collections. Workflow: 1. API Gateway sends event with user context @@ -3288,8 +3313,7 @@ def test_list_user_collections_endpoint_success_workflow( def test_list_user_collections_endpoint_auth_workflow(lambda_context): - """ - Complete auth workflow: missing auth → 401 response. + """Complete auth workflow: missing auth → 401 response. Workflow: 1. API Gateway sends event without auth context @@ -3312,8 +3336,7 @@ def test_list_user_collections_endpoint_auth_workflow(lambda_context): def test_list_user_collections_endpoint_pagination_workflow( lambda_event_user_collections, lambda_context, mock_collection_service_for_lambda ): - """ - Complete pagination workflow: request with token → next page returned. + """Complete pagination workflow: request with token → next page returned. Workflow: 1. API Gateway sends event with pagination token @@ -3357,8 +3380,7 @@ def test_list_user_collections_endpoint_pagination_workflow( def test_list_user_collections_endpoint_filtering_workflow( lambda_event_user_collections, lambda_context, mock_collection_service_for_lambda ): - """ - Complete filtering workflow: filter param → filtered results. + """Complete filtering workflow: filter param → filtered results. Workflow: 1. API Gateway sends event with filter parameter @@ -3401,8 +3423,7 @@ def test_list_user_collections_endpoint_filtering_workflow( def test_list_user_collections_endpoint_error_handling_workflow( lambda_event_user_collections, lambda_context, mock_collection_service_for_lambda ): - """ - Complete error handling workflow: service error → 500 response with logging. + """Complete error handling workflow: service error → 500 response with logging. Workflow: 1. API Gateway sends valid event @@ -3439,9 +3460,10 @@ def test_list_bedrock_knowledge_bases_success(): } } - with patch("repository.lambda_functions.list_knowledge_bases") as mock_list_kb, patch( - "repository.lambda_functions.vs_repo" - ) as mock_vs_repo: + with ( + patch("repository.lambda_functions.list_knowledge_bases") as mock_list_kb, + patch("repository.lambda_functions.vs_repo") as mock_vs_repo, + ): mock_kb = MagicMock() mock_kb.knowledgeBaseId = "KB123" mock_kb.name = "Test KB" @@ -3477,9 +3499,10 @@ def test_list_bedrock_data_sources_success(): "pathParameters": {"kbId": "KB123"}, } - with patch("repository.lambda_functions.get_available_data_sources") as mock_get_ds, patch( - "repository.lambda_functions.validate_bedrock_kb_exists" - ) as mock_validate: + with ( + patch("repository.lambda_functions.get_available_data_sources") as mock_get_ds, + patch("repository.lambda_functions.validate_bedrock_kb_exists") as mock_validate, + ): mock_ds = MagicMock() mock_ds.dataSourceId = "DS123" mock_ds.name = "Test DS" @@ -3575,9 +3598,10 @@ def test_update_collection_success(): ), } - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service: + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + ): mock_get_repo.return_value = { "repositoryId": "test-repo", "repositoryType": "opensearch", @@ -3607,9 +3631,10 @@ def test_delete_collection_success(): "pathParameters": {"repositoryId": "test-repo", "collectionId": "test-collection"}, } - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service: + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + ): mock_get_repo.return_value = { "repositoryId": "test-repo", "repositoryType": "opensearch", @@ -3637,9 +3662,10 @@ def test_list_user_collections_success(): "queryStringParameters": {}, } - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service: + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + ): mock_get_repo.return_value = { "repositoryId": "test-repo", "repositoryType": "opensearch", @@ -3679,14 +3705,13 @@ def test_ingest_documents_bedrock_kb_s3_scan(): ), } - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.collection_service" - ) as mock_collection_service, patch("repository.lambda_functions.MetadataGenerator") as mock_metadata_gen, patch( - "repository.lambda_functions.S3MetadataManager" - ) as mock_s3_metadata, patch( - "repository.lambda_functions.ingestion_service" - ) as mock_ingestion_service, patch( - "repository.lambda_functions.ingestion_job_repository" + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + patch("repository.lambda_functions.MetadataGenerator") as mock_metadata_gen, + patch("repository.lambda_functions.S3MetadataManager") as mock_s3_metadata, + patch("repository.lambda_functions.ingestion_service") as mock_ingestion_service, + patch("repository.lambda_functions.ingestion_job_repository"), ): mock_get_repo.return_value = { "repositoryId": "test-repo", @@ -3785,9 +3810,11 @@ def test_create_repository_bedrock_kb(): ), } - with patch("repository.lambda_functions.build_pipeline_configs_from_kb_config") as mock_build_pipelines, patch( - "repository.lambda_functions.step_functions_client" - ) as mock_sfn, patch("repository.lambda_functions.ssm_client") as mock_ssm: + with ( + patch("repository.lambda_functions.build_pipeline_configs_from_kb_config") as mock_build_pipelines, + patch("repository.lambda_functions.step_functions_client") as mock_sfn, + patch("repository.lambda_functions.ssm_client") as mock_ssm, + ): mock_build_pipelines.return_value = [ { "collectionId": "DS123", @@ -3830,9 +3857,11 @@ def test_similarity_search_bedrock_kb(): }, } - with patch("repository.lambda_functions.get_repository") as mock_get_repo, patch( - "repository.lambda_functions.RepositoryServiceFactory" - ) as mock_factory, patch("repository.lambda_functions.collection_service") as mock_collection_service: + with ( + patch("repository.lambda_functions.get_repository") as mock_get_repo, + patch("repository.lambda_functions.RepositoryServiceFactory") as mock_factory, + patch("repository.lambda_functions.collection_service") as mock_collection_service, + ): mock_get_repo.return_value = { "repositoryId": "test-repo", "type": "bedrock_knowledge_base", diff --git a/test/lambda/test_scheduling_domain_objects.py b/test/lambda/test_scheduling_domain_objects.py index 29335bbd4..a034f8263 100644 --- a/test/lambda/test_scheduling_domain_objects.py +++ b/test/lambda/test_scheduling_domain_objects.py @@ -143,7 +143,6 @@ def test_each_day_schedule(self): def test_invalid_schedule_consistency(self): """Test invalid schedule consistency.""" - # Test RecurringSchedulingConfig requires recurringSchedule with pytest.raises(ValidationError, match="Field required"): RecurringSchedulingConfig(timezone="UTC") diff --git a/test/lambda/test_session_encryption.py b/test/lambda/test_session_encryption.py old mode 100644 new mode 100755 index d2d27b743..4a4926617 --- a/test/lambda/test_session_encryption.py +++ b/test/lambda/test_session_encryption.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Comprehensive tests for session encryption utilities. +"""Comprehensive tests for session encryption utilities. -This test module provides comprehensive coverage for all functions in -lambda/utilities/session_encryption.py, including error conditions, -edge cases, and exception handling paths. +This test module provides comprehensive coverage for all functions in lambda/utilities/session_encryption.py, including +error conditions, edge cases, and exception handling paths. """ import base64 diff --git a/test/lambda/test_session_lambda.py b/test/lambda/test_session_lambda.py index bbb4467b0..21317307f 100644 --- a/test/lambda/test_session_lambda.py +++ b/test/lambda/test_session_lambda.py @@ -438,7 +438,6 @@ def test_list_sessions_empty(dynamodb_table, lambda_context): def test_is_session_encryption_enabled_true(config_table, lambda_context): """Test session encryption enabled via global configuration.""" - # Add global configuration entry with encryption enabled config_table.put_item( Item={ @@ -460,7 +459,6 @@ def test_is_session_encryption_enabled_true(config_table, lambda_context): def test_is_session_encryption_enabled_false(config_table, lambda_context): """Test session encryption disabled via global configuration.""" - # Add global configuration entry with encryption disabled config_table.put_item( Item={ @@ -482,7 +480,6 @@ def test_is_session_encryption_enabled_false(config_table, lambda_context): def test_is_session_encryption_enabled_default_fallback(config_table, lambda_context): """Test session encryption defaults to disabled when configuration is missing.""" - # Don't add any configuration entry # Clear cache to ensure fresh result from session.lambda_functions import cache @@ -530,7 +527,6 @@ def test_is_session_encryption_enabled_client_error(mock_config_table, lambda_co @patch("session.lambda_functions.config_table") def test_is_session_encryption_enabled_general_exception(mock_config_table, lambda_context): """Test session encryption with general exception in configuration lookup.""" - # Mock config table to raise general exception mock_config_table.query.side_effect = Exception("General database error") @@ -1015,7 +1011,6 @@ def test_get_session_model_config_update(mock_update_config, dynamodb_table, sam # Image Attachment Tests def test_attach_image_to_session_success(lambda_context): """Test attach_image_to_session with valid image data.""" - # Create a simple base64 encoded image image_data = ( "data:image/png;base64," @@ -1060,7 +1055,6 @@ def test_attach_image_to_session_missing_message(lambda_context): def test_attach_image_to_session_s3_upload_error(lambda_context): """Test attach_image_to_session with S3 upload error.""" - image_data = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB" event = { @@ -1149,7 +1143,6 @@ def test_put_session_encryption_enabled_success( mock_migrate, mock_encryption_enabled, dynamodb_table, config_table, sample_session, lambda_context ): """Test put_session with encryption enabled and successful encryption.""" - # Mock encryption enabled mock_encryption_enabled.return_value = True @@ -1212,7 +1205,6 @@ def test_put_session_sqs_metrics_success( mock_sqs_client, mock_get_user_context, dynamodb_table, config_table, sample_session, lambda_context ): """Test put_session with successful SQS metrics publishing.""" - # Set environment variable for metrics queue os.environ["USAGE_METRICS_QUEUE_NAME"] = "test-metrics-queue" @@ -1236,7 +1228,6 @@ def test_put_session_sqs_metrics_missing_queue( mock_sqs_client, dynamodb_table, config_table, sample_session, lambda_context ): """Test put_session with missing USAGE_METRICS_QUEUE_NAME environment variable.""" - # Remove environment variable if "USAGE_METRICS_QUEUE_NAME" in os.environ: del os.environ["USAGE_METRICS_QUEUE_NAME"] @@ -1260,7 +1251,6 @@ def test_put_session_sqs_metrics_error( mock_sqs_client, mock_get_user_context, dynamodb_table, config_table, sample_session, lambda_context ): """Test put_session with SQS metrics publishing error.""" - # Set environment variable for metrics queue os.environ["USAGE_METRICS_QUEUE_NAME"] = "test-metrics-queue" @@ -1286,7 +1276,6 @@ def test_put_session_model_config_update( mock_update_config, dynamodb_table, config_table, sample_session, lambda_context ): """Test put_session with model configuration update.""" - # Mock model config update - return SessionConfigurationModel updated_config = SessionConfigurationModel( selectedModel=SelectedModel(modelId="test-model", features=[{"name": "new-feature", "overview": ""}]) @@ -1446,7 +1435,6 @@ def test_extract_video_s3_keys_string_content(): def test_generate_presigned_video_url_success(mock_s3_client): """Test _generate_presigned_video_url with success.""" mock_s3_client.generate_presigned_url.return_value = "https://presigned-video-url.com" - result = _generate_presigned_video_url("videos/test-video.mp4") assert result == "https://presigned-video-url.com" @@ -1576,7 +1564,7 @@ def test_map_session_strips_merged_context_from_string_message(): "history": [ { "type": "human", - "content": ("Context from document search:\n" "Some retrieved content\n\n" "who is dustin?"), + "content": ("Context from document search:\nSome retrieved content\n\nwho is dustin?"), } ], } diff --git a/test/lambda/test_similarity_functions.py b/test/lambda/test_similarity_functions.py index 8bd00a2c4..ee3760c71 100644 --- a/test/lambda/test_similarity_functions.py +++ b/test/lambda/test_similarity_functions.py @@ -32,7 +32,7 @@ def test_opensearch_retrieve_documents_without_score(): - """Test OpenSearch retrieve_documents without scores""" + """Test OpenSearch retrieve_documents without scores.""" repository = {"repositoryId": "test-repo", "type": "opensearch"} service = OpenSearchRepositoryService(repository) @@ -55,7 +55,7 @@ def test_opensearch_retrieve_documents_without_score(): def test_pgvector_retrieve_documents_with_score(): - """Test PGVector retrieve_documents with score normalization""" + """Test PGVector retrieve_documents with score normalization.""" repository = {"repositoryId": "test-repo", "type": "pgvector"} service = PGVectorRepositoryService(repository) @@ -77,7 +77,7 @@ def test_pgvector_retrieve_documents_with_score(): def test_opensearch_retrieve_documents_with_score(): - """Test OpenSearch retrieve_documents with score""" + """Test OpenSearch retrieve_documents with score.""" repository = {"repositoryId": "test-repo", "type": "opensearch"} service = OpenSearchRepositoryService(repository) diff --git a/test/lambda/test_update_mcp_server_state_machine.py b/test/lambda/test_update_mcp_server_state_machine.py index d358b50a0..4c62f3c82 100644 --- a/test/lambda/test_update_mcp_server_state_machine.py +++ b/test/lambda/test_update_mcp_server_state_machine.py @@ -337,10 +337,14 @@ def test_connections_table_updates(mcp_servers_table, in_service_server, lambda_ # Seed entry connections_table.put_item(Item={"id": "server-inservice", "owner": "lisa:public", "status": "active"}) - with patch("mcp_server.state_machine.update_mcp_server.mcp_servers_table", mcp_servers_table), patch( - "mcp_server.state_machine.update_mcp_server.ssm_client.get_parameter", - return_value={"Parameter": {"Value": "connections-table"}}, - ), patch("mcp_server.state_machine.update_mcp_server.ddbResource", dynamodb): + with ( + patch("mcp_server.state_machine.update_mcp_server.mcp_servers_table", mcp_servers_table), + patch( + "mcp_server.state_machine.update_mcp_server.ssm_client.get_parameter", + return_value={"Parameter": {"Value": "connections-table"}}, + ), + patch("mcp_server.state_machine.update_mcp_server.ddbResource", dynamodb), + ): # Disable updates status to inactive handle_job_intake({"server_id": "server-inservice", "update_payload": {"enabled": False}}, lambda_context) row = connections_table.get_item(Key={"id": "server-inservice", "owner": "lisa:public"}).get("Item") diff --git a/test/lambda/test_vector_store_repo.py b/test/lambda/test_vector_store_repo.py index 945e1deeb..ba9d77e54 100644 --- a/test/lambda/test_vector_store_repo.py +++ b/test/lambda/test_vector_store_repo.py @@ -37,7 +37,7 @@ def setup_env(monkeypatch): def test_vector_store_repo_find_by_id(): - """Test vector store repository find by id""" + """Test vector store repository find by id.""" with patch("boto3.resource") as mock_resource: mock_table = Mock() mock_resource.return_value.Table.return_value = mock_table @@ -66,7 +66,7 @@ def test_vector_store_repo_find_by_id(): def test_vector_store_repo_get_registered(): - """Test vector store repository get registered repositories""" + """Test vector store repository get registered repositories.""" with patch("boto3.resource") as mock_resource: mock_table = Mock() mock_resource.return_value.Table.return_value = mock_table @@ -96,7 +96,7 @@ def test_vector_store_repo_get_registered(): def test_vector_store_repo_save(): - """Test vector store repository save""" + """Test vector store repository save.""" with patch("boto3.resource") as mock_resource: mock_table = Mock() mock_resource.return_value.Table.return_value = mock_table @@ -127,7 +127,7 @@ def test_vector_store_repo_save(): def test_vector_store_repo_delete(): - """Test vector store repository delete""" + """Test vector store repository delete.""" with patch("boto3.resource") as mock_resource: mock_table = Mock() mock_resource.return_value.Table.return_value = mock_table @@ -144,7 +144,7 @@ def test_vector_store_repo_delete(): def test_vector_store_repo_get_status(): - """Test vector store repository get status""" + """Test vector store repository get status.""" with patch("boto3.resource") as mock_resource: mock_table = Mock() mock_resource.return_value.Table.return_value = mock_table diff --git a/test/mcp-workbench/conftest.py b/test/mcp-workbench/conftest.py index 26677aa02..9a2e9dccf 100644 --- a/test/mcp-workbench/conftest.py +++ b/test/mcp-workbench/conftest.py @@ -60,23 +60,22 @@ def temp_tools_dir() -> Generator[Path]: @pytest.fixture(scope="function") def sample_function_tool_content() -> str: """Sample function-based tool content.""" - return """ -from mcpworkbench.core.annotations import mcp_tool - -@mcp_tool( - name="echo_test", - description="Echo back the input text for testing", -) -def echo_message(message: str): - return {"echoed": message, "length": len(message)} - -@mcp_tool( - name="add_test", - description="Add two numbers together for testing", -) -async def add_numbers(a: float, b: float): - return {"a": a, "b": b, "sum": a + b} -""" + return """From mcpworkbench.core.annotations import mcp_tool. + + @mcp_tool( + name="echo_test", + description="Echo back the input text for testing", + ) + def echo_message(message: str): + return {"echoed": message, "length": len(message)} + + @mcp_tool( + name="add_test", + description="Add two numbers together for testing", + ) + async def add_numbers(a: float, b: float): + return {"a": a, "b": b, "sum": a + b} + """ @pytest.fixture(scope="function") diff --git a/test/mcp-workbench/test_aws_session_store.py b/test/mcp-workbench/test_aws_session_store.py index 8a9ea50aa..bf7c10ec7 100644 --- a/test/mcp-workbench/test_aws_session_store.py +++ b/test/mcp-workbench/test_aws_session_store.py @@ -83,11 +83,10 @@ def test_get_session_respects_expiration() -> None: @pytest.mark.parametrize("safety_margin_seconds", [0, 30]) def test_get_session_treats_near_expiration_as_expired(safety_margin_seconds: int) -> None: - """ - Ensure that sessions very close to expiration are treated as expired. + """Ensure that sessions very close to expiration are treated as expired. - This gives us a small safety buffer so MCP tools don't start long-running - operations with credentials that are about to expire. + This gives us a small safety buffer so MCP tools don't start long-running operations with credentials that are about + to expire. """ store = InMemoryAwsSessionStore(safety_margin_seconds=safety_margin_seconds) # Expires in 10 seconds; with a 30 second safety margin this should be expired. diff --git a/test/mcp-workbench/test_manual.py b/test/mcp-workbench/test_manual.py old mode 100644 new mode 100755 index bbf380345..bc200f659 --- a/test/mcp-workbench/test_manual.py +++ b/test/mcp-workbench/test_manual.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Test script for MCP Workbench. +"""Test script for MCP Workbench. -This script creates a temporary tools directory, populates it with example tools, -starts the MCP workbench server, and tests the API endpoints. +This script creates a temporary tools directory, populates it with example tools, starts the MCP workbench server, and +tests the API endpoints. """ import subprocess @@ -31,7 +30,6 @@ def create_test_tools(tools_dir: Path): """Create test tools in the given directory.""" - # Create a simple function-based tool function_tool = """ from mcpworkbench.core.annotations import mcp_tool diff --git a/test/python/README.md b/test/python/README.md index 23d308b8b..3393ef196 100644 --- a/test/python/README.md +++ b/test/python/README.md @@ -73,16 +73,20 @@ The test creates the following resources with predictable names for easy identif ### Resource Lifecycle -#### Without --cleanup flag: +#### Without --cleanup flag + Resources remain deployed for manual testing and must be cleaned up through: + - LISA UI (Model Management and Configuration pages) - Running the script again with `--cleanup` - Manual AWS resource deletion -#### With --cleanup flag: +#### With --cleanup flag + Integration-test-scoped resources (models and repositories created by this script) are automatically deleted at the end of the test run. User-created resources are not affected. -#### With --wait flag: +#### With --wait flag + Script monitors resource deployment status and waits up to 30 minutes for each resource to become ready. Useful for validating full deployment pipeline. ### Exit Codes @@ -111,12 +115,14 @@ Script monitors resource deployment status and waits up to 30 minutes for each r The integration setup test includes new SDK functions in the lisa-sdk package: ### Model Management + - `create_bedrock_model()` - Create Bedrock model configurations - `create_self_hosted_model()` - Deploy self-hosted models with full configuration - `delete_model()` - Remove models from LISA - `get_model()` - Retrieve model details and status ### Repository Management + - `create_repository()` - Generic repository creation - `create_pgvector_repository()` - Create PGVector repositories with RDS - `create_opensearch_repository()` - Create OpenSearch repositories with clusters diff --git a/test/python/integration-setup-test.py b/test/python/integration-setup-test.py old mode 100644 new mode 100755 index 8a88f01f1..6451ae05e --- a/test/python/integration-setup-test.py +++ b/test/python/integration-setup-test.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Integration test script that deploys resources to LISA. +"""Integration test script that deploys resources to LISA. + This script creates: - Self-hosted and Bedrock models (textgen and embedding) - PGVector, OpenSearch, and Bedrock Knowledge Base repositories @@ -115,7 +115,7 @@ def wait_for_resource_ready( except Exception as e: print(f" Check failed: {e}") if i < max_iterations - 1: - print(f" Still waiting... ({i+1}/{max_iterations})") + print(f" Still waiting... ({i + 1}/{max_iterations})") time.sleep(15) print(f"✗ Timeout waiting for {resource_type} '{resource_id}' to be ready") return False @@ -888,7 +888,10 @@ def cleanup_integ_repositories(lisa_client: LisaApi) -> None: def cleanup_resources(lisa_client: LisaApi, created_resources: dict[str, list], region: str | None = None) -> None: - """Clean up only integration test resources. Does NOT delete all models/repos.""" + """Clean up only integration test resources. + + Does NOT delete all models/repos. + """ print("\nCleaning up integration test resources...") cleanup_integ_models(lisa_client) cleanup_integ_repositories(lisa_client) diff --git a/test/python/integration_definitions.py b/test/python/integration_definitions.py old mode 100644 new mode 100755 index a4ec3e0f6..fbed8ec84 --- a/test/python/integration_definitions.py +++ b/test/python/integration_definitions.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Resource definitions and deploy lists for the LISA integration setup test. +"""Resource definitions and deploy lists for the LISA integration setup test. -To control what gets deployed, edit the deploy_* lists at the bottom of this file. -To add a new resource, add an entry to the appropriate *_DEFINITIONS dict and -optionally add its key to the corresponding deploy_* list. Feel free to use this as -an example for deploying models. +To control what gets deployed, edit the deploy_* lists at the bottom of this file. To add a new resource, add an entry +to the appropriate *_DEFINITIONS dict and optionally add its key to the corresponding deploy_* list. Feel free to use +this as an example for deploying models. """ # --------------------------------------------------------------------------- diff --git a/test/python/list-integ-models.py b/test/python/list-integ-models.py old mode 100644 new mode 100755 diff --git a/test/rest-api/README.md b/test/rest-api/README.md index 1fdfcf212..ca7721135 100644 --- a/test/rest-api/README.md +++ b/test/rest-api/README.md @@ -19,9 +19,11 @@ test/rest-api/ ## Current Test Coverage ### Utils Module (3 tests) ✅ + - **Singleton Decorator**: Single instance creation, state preservation ### Auth Module (22 tests) ✅ + - **AuthHeaders**: Enum values and methods - **Token Extraction**: Authorization and Api-Key headers - **Group Membership**: Simple and nested JWT properties @@ -29,9 +31,11 @@ test/rest-api/ - **User Context**: API users, JWT users, group membership ### Request Utils Module (3 tests) ✅ + - **Stream Exception Handling**: Normal operation, exception handling, error formatting ### Guardrails Module (18 tests) ✅ + - **Model Guardrails**: Retrieval, empty results, error handling - **Applicable Guardrails**: Public/group-specific, deletion markers - **Violation Detection**: Error message parsing @@ -40,10 +44,12 @@ test/rest-api/ - **JSON Responses**: Structure, status codes, metadata ### Metrics Module (12 tests) ✅ + - **Message Extraction**: Simple/array content, RAG context, tool calls - **Metrics Publishing**: Success, error handling, queue configuration, session IDs ### Routes Module (7 tests) ✅ + - **Health Check**: Logic validation for success, missing vars, exceptions - **Router/Middleware/Lifespan/Passthrough**: Placeholder tests (full testing requires complete app with aiobotocore, text_generation, etc.) @@ -108,30 +114,35 @@ The REST API uses relative imports (e.g., `from .utils import ...`) which makes ## Test Categories ### Authentication Tests (`test_auth.py`) + - Token validation (API tokens, management tokens, JWT) - Group membership and access control - User context extraction - Authorization for different user types ### Request Utilities Tests (`test_request_utils.py`) + - Model validation against registered models - Model and validator retrieval from cache/registry - Request preparation and validation - Stream exception handling ### Guardrails Tests (`test_guardrails.py`) + - Guardrail retrieval from DynamoDB - Determining applicable guardrails based on user groups - Violation detection and response extraction - Streaming and JSON response formatting ### Metrics Tests (`test_metrics.py`) + - Message extraction for metrics calculation - Metrics event publishing to SQS - RAG context and tool call detection - Error handling and queue configuration ### Routes Tests (`test_routes.py`) + - Health check endpoint - Router configuration with/without auth - Middleware functionality (request IDs, CORS, errors) diff --git a/test/rest-api/test_guardrails.py b/test/rest-api/test_guardrails.py index d6b3a543b..bb8a97097 100644 --- a/test/rest-api/test_guardrails.py +++ b/test/rest-api/test_guardrails.py @@ -61,7 +61,6 @@ async def test_get_guardrails_success(self, mock_env_vars): mock_dynamodb.Table.return_value = mock_table with patch.dict("os.environ", mock_env_vars), patch("boto3.resource", return_value=mock_dynamodb): - result = await get_model_guardrails("test-model") assert result == mock_guardrails @@ -77,7 +76,6 @@ async def test_get_guardrails_empty(self, mock_env_vars): mock_dynamodb.Table.return_value = mock_table with patch.dict("os.environ", mock_env_vars), patch("boto3.resource", return_value=mock_dynamodb): - result = await get_model_guardrails("test-model") assert result == [] @@ -92,7 +90,6 @@ async def test_get_guardrails_error(self, mock_env_vars): mock_dynamodb.Table.return_value = mock_table with patch.dict("os.environ", mock_env_vars), patch("boto3.resource", return_value=mock_dynamodb): - result = await get_model_guardrails("test-model") assert result == [] diff --git a/test/rest-api/test_metrics.py b/test/rest-api/test_metrics.py index b18ac8b56..7d37e91c4 100644 --- a/test/rest-api/test_metrics.py +++ b/test/rest-api/test_metrics.py @@ -166,10 +166,12 @@ def test_publish_metrics_success(self, mock_env_vars, mock_request): mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("test-user", ["users"]) - ), patch("utils.metrics.is_api_user", return_value=True): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("test-user", ["users"])), + patch("utils.metrics.is_api_user", return_value=True), + ): publish_metrics_event(mock_request, params, 200) mock_sqs.send_message.assert_called_once() @@ -193,7 +195,6 @@ def test_publish_metrics_non_200_status(self, mock_env_vars, mock_request): mock_sqs = MagicMock() with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs): - publish_metrics_event(mock_request, params, 400) mock_sqs.send_message.assert_not_called() @@ -206,7 +207,6 @@ def test_publish_metrics_no_queue_url(self, mock_env_vars, mock_request): mock_sqs = MagicMock() with patch.dict("os.environ", mock_env_vars, clear=True), patch("utils.metrics.sqs_client", mock_sqs): - publish_metrics_event(mock_request, params, 200) mock_sqs.send_message.assert_not_called() @@ -219,10 +219,11 @@ def test_publish_metrics_error_handling(self, mock_env_vars, mock_request): mock_sqs = MagicMock() mock_sqs.send_message.side_effect = Exception("SQS error") - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("test-user", []) + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("test-user", [])), ): - # Should not raise exception publish_metrics_event(mock_request, params, 200) @@ -233,10 +234,12 @@ def test_publish_metrics_session_id_format(self, mock_env_vars, mock_request): params = {"messages": []} mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("api-user", []) - ), patch("utils.metrics.is_api_user", return_value=True): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("api-user", [])), + patch("utils.metrics.is_api_user", return_value=True), + ): publish_metrics_event(mock_request, params, 200) call_args = mock_sqs.send_message.call_args @@ -258,10 +261,12 @@ def test_publish_metrics_with_complex_messages(self, mock_env_vars, mock_request mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("user", ["users"]) - ), patch("utils.metrics.is_api_user", return_value=True): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("user", ["users"])), + patch("utils.metrics.is_api_user", return_value=True), + ): publish_metrics_event(mock_request, params, 200) call_args = mock_sqs.send_message.call_args @@ -354,10 +359,12 @@ def test_jwt_user_with_tokens_publishes_token_only_event(self, mock_env_vars, mo params = {"messages": [{"role": "user", "content": "Hello"}], "model": "my-model"} mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("jwt-user", ["users"]) - ), patch("utils.metrics.is_api_user", return_value=False): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("jwt-user", ["users"])), + patch("utils.metrics.is_api_user", return_value=False), + ): publish_metrics_event(mock_request, params, 200, prompt_tokens=50, completion_tokens=20) mock_sqs.send_message.assert_called_once() @@ -378,17 +385,19 @@ def test_jwt_user_without_tokens_skips_publish(self, mock_env_vars, mock_request params = {"messages": [{"role": "user", "content": "Hello"}]} mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("jwt-user", ["users"]) - ), patch("utils.metrics.is_api_user", return_value=False): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("jwt-user", ["users"])), + patch("utils.metrics.is_api_user", return_value=False), + ): publish_metrics_event(mock_request, params, 200) # no tokens passed mock_sqs.send_message.assert_not_called() def test_tokens_extracted_from_response_body_for_non_streaming(self, mock_env_vars, mock_request): - """When response_body is provided and prompt_tokens is not passed directly, - tokens should be extracted from the response body before publishing. + """When response_body is provided and prompt_tokens is not passed directly, tokens should be extracted from the + response body before publishing. Expected: Published message contains promptTokens/completionTokens from the response body. """ @@ -400,10 +409,12 @@ def test_tokens_extracted_from_response_body_for_non_streaming(self, mock_env_va } mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("api-user", []) - ), patch("utils.metrics.is_api_user", return_value=True): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("api-user", [])), + patch("utils.metrics.is_api_user", return_value=True), + ): publish_metrics_event(mock_request, params, 200, response_body=response_body) body = json.loads(mock_sqs.send_message.call_args[1]["MessageBody"]) @@ -419,10 +430,12 @@ def test_api_user_publishes_full_event_type(self, mock_env_vars, mock_request): params = {"messages": [{"role": "user", "content": "Hello"}]} mock_sqs = MagicMock() - with patch.dict("os.environ", mock_env_vars), patch("utils.metrics.sqs_client", mock_sqs), patch( - "utils.metrics.get_user_context", return_value=("api-user", []) - ), patch("utils.metrics.is_api_user", return_value=True): - + with ( + patch.dict("os.environ", mock_env_vars), + patch("utils.metrics.sqs_client", mock_sqs), + patch("utils.metrics.get_user_context", return_value=("api-user", [])), + patch("utils.metrics.is_api_user", return_value=True), + ): publish_metrics_event(mock_request, params, 200) body = json.loads(mock_sqs.send_message.call_args[1]["MessageBody"]) diff --git a/test/rest-api/test_rate_limit_middleware.py b/test/rest-api/test_rate_limit_middleware.py index 12071adf2..8eb91d9c3 100644 --- a/test/rest-api/test_rate_limit_middleware.py +++ b/test/rest-api/test_rate_limit_middleware.py @@ -622,7 +622,7 @@ async def test_oidc_user_override(self): @pytest.mark.asyncio async def test_zero_rpm_override_does_not_crash(self): - """rpm=0 should not raise and should return a 429 after burst is consumed.""" + """Rpm=0 should not raise and should return a 429 after burst is consumed.""" mod = self._mod mod.RATE_LIMIT_OVERRIDES = {"token:zero-rpm": {"rpm": 0, "burst": 1}} diff --git a/test/sdk/README.md b/test/sdk/README.md index a2ee3ec98..b82e4c618 100644 --- a/test/sdk/README.md +++ b/test/sdk/README.md @@ -84,6 +84,7 @@ pytest test/sdk --cov=lisa-sdk/lisapy --cov-report=html The test suite provides comprehensive coverage of all SDK operations: ### ModelMixin (11 tests) + - ✅ List models - ✅ List embedding models - ✅ List instance types @@ -95,6 +96,7 @@ The test suite provides comprehensive coverage of all SDK operations: - ✅ Error handling ### RepositoryMixin (10 tests) + - ✅ List repositories - ✅ Create repository - ✅ Create PGVector repository @@ -104,6 +106,7 @@ The test suite provides comprehensive coverage of all SDK operations: - ✅ Error handling ### CollectionMixin (14 tests) + - ✅ Create collection (basic, with chunking, with metadata) - ✅ Get collection - ✅ Update collection (name, description, status) @@ -113,6 +116,7 @@ The test suite provides comprehensive coverage of all SDK operations: - ✅ Error handling ### RagMixin (13 tests) + - ✅ List documents - ✅ Get document by ID - ✅ Delete documents (by IDs, by name) @@ -123,17 +127,20 @@ The test suite provides comprehensive coverage of all SDK operations: - ✅ Error handling ### ConfigMixin (4 tests) + - ✅ Get configs (global, custom scope) - ✅ Empty configs - ✅ Error handling ### SessionMixin (5 tests) + - ✅ List sessions - ✅ Get session by user - ✅ Empty sessions - ✅ Error handling ### DocsMixin (2 tests) + - ✅ Get API documentation - ✅ Error handling @@ -198,17 +205,20 @@ def test_new_operation(self, lisa_api: LisaApi, api_url: str): ## Benefits of This Approach ### Fast Execution + - All 57 tests run in ~0.1 seconds - No network latency - No AWS resource dependencies ### Fully Isolated + - No external dependencies - No deployed LISA environment required - No AWS credentials needed - Can run in CI/CD without infrastructure ### Comprehensive Coverage + - Tests all SDK methods - Tests error handling - Tests request formatting @@ -216,6 +226,7 @@ def test_new_operation(self, lisa_api: LisaApi, api_url: str): - Tests query parameters and request bodies ### Easy to Maintain + - Clear test structure - Reusable fixtures - Simple mock responses @@ -236,6 +247,7 @@ def test_new_operation(self, lisa_api: LisaApi, api_url: str): ## Continuous Integration These tests are ideal for CI/CD pipelines because they: + - Run quickly - Require no infrastructure - Have no external dependencies @@ -247,6 +259,7 @@ These tests are ideal for CI/CD pipelines because they: ### Import Errors If you see import errors for `responses`: + ```bash pip install responses ``` @@ -254,6 +267,7 @@ pip install responses ### Fixture Not Found If you see fixture errors, ensure you're running from the LISA root directory: + ```bash cd /path/to/LISA pytest test/sdk -v @@ -262,6 +276,7 @@ pytest test/sdk -v ### Test Failures If tests fail: + 1. Check that you're using the latest SDK code 2. Verify the mock responses match the expected API format 3. Check for changes in the SDK that require test updates diff --git a/test/sdk/test_langchain.py b/test/sdk/test_langchain.py index 590823244..6c2461919 100644 --- a/test/sdk/test_langchain.py +++ b/test/sdk/test_langchain.py @@ -14,8 +14,8 @@ """Unit tests for LISA SDK langchain adapters. -Note: These are simplified tests. Full integration tests with real LisaLlm -instances are better suited for integration testing due to Pydantic validation complexity. +Note: These are simplified tests. Full integration tests with real LisaLlm instances are better suited for integration +testing due to Pydantic validation complexity. """ import sys diff --git a/test/sdk/test_main.py b/test/sdk/test_main.py index b81ad0276..16f5c3e34 100644 --- a/test/sdk/test_main.py +++ b/test/sdk/test_main.py @@ -483,7 +483,7 @@ class TestLisaLlmComplete: """Test suite for legacy text completions.""" def test_complete_success(self): - """complete() should return a CompletionResponse with parsed fields.""" + """Complete() should return a CompletionResponse with parsed fields.""" from lisapy.main import LisaLlm from lisapy.types import CompletionResponse @@ -512,7 +512,7 @@ def test_complete_success(self): mock_post.assert_called_once() def test_complete_with_kwargs(self): - """complete() should forward allowed kwargs and filter unknown ones.""" + """Complete() should forward allowed kwargs and filter unknown ones.""" from lisapy.main import LisaLlm llm = LisaLlm(url="https://api.example.com") @@ -540,7 +540,7 @@ def test_complete_with_kwargs(self): assert "unknown_param" not in payload def test_complete_error(self): - """complete() should raise on non-200 response.""" + """Complete() should raise on non-200 response.""" from lisapy.main import LisaLlm llm = LisaLlm(url="https://api.example.com") diff --git a/test/utils/integration_test_utils.py b/test/utils/integration_test_utils.py old mode 100644 new mode 100755 index 6593a9efb..38bb39744 --- a/test/utils/integration_test_utils.py +++ b/test/utils/integration_test_utils.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Common utilities for LISA integration tests. +"""Common utilities for LISA integration tests. This module provides reusable functions for: - Authentication setup (re-exported from lisapy.authentication) @@ -122,7 +121,7 @@ def wait_for_resource_ready( logger.debug(f"Check failed: {e}") if i < max_iterations - 1: - logger.debug(f"Still waiting... ({i+1}/{max_iterations})") + logger.debug(f"Still waiting... ({i + 1}/{max_iterations})") time.sleep(check_interval_seconds) logger.warning(f"Timeout waiting for {resource_type} '{resource_id}' to be ready")