|
| 1 | +from typing import Callable, Dict |
| 2 | +import pytest |
| 3 | +import os |
| 4 | +from _pytest.fixtures import FixtureRequest |
| 5 | + |
| 6 | +S3_AUTO_CREATE_BUCKET = os.getenv("LLS_FILES_S3_AUTO_CREATE_BUCKET", "true") |
| 7 | + |
| 8 | + |
| 9 | +@pytest.fixture(scope="class") |
| 10 | +def files_provider_config_factory( |
| 11 | + request: FixtureRequest, |
| 12 | +) -> Callable[[str], list[Dict[str, str]]]: |
| 13 | + """ |
| 14 | + Factory fixture for configuring external files providers and returning their configuration. |
| 15 | +
|
| 16 | + This fixture returns a factory function that can configure additional files storage providers |
| 17 | + (such as S3/minio) and return the necessary environment variables |
| 18 | + for configuring the LlamaStack server to use these providers. |
| 19 | +
|
| 20 | + Args: |
| 21 | + request: Pytest fixture request object for accessing other fixtures |
| 22 | +
|
| 23 | + Returns: |
| 24 | + Callable[[str], list[Dict[str, str]]]: Factory function that takes a provider name |
| 25 | + and returns a list of environment variable dictionaries |
| 26 | +
|
| 27 | + Supported Providers: |
| 28 | + - "local": defaults to using just local filesystem |
| 29 | + - "s3": a remote S3/Minio storage provider |
| 30 | +
|
| 31 | + Environment Variables by Provider: |
| 32 | + - "s3": |
| 33 | + * ENABLE_S3: Enables S3/Minio storage provider |
| 34 | + * CI_S3_BUCKET_NAME: Name of the S3/Minio bucket |
| 35 | + * CI_S3_BUCKET_REGION: Region of the S3/Minio bucket |
| 36 | + * CI_S3_BUCKET_ENDPOINT: Endpoint URL of the S3/Minio bucket |
| 37 | + * AWS_ACCESS_KEY_ID: Access key ID for the S3/Minio bucket |
| 38 | + * AWS_SECRET_ACCESS_KEY: Secret access key for the S3/Minio bucket |
| 39 | + * S3_AUTO_CREATE_BUCKET: Whether to automatically create the S3/Minio bucket if it doesn't exist |
| 40 | +
|
| 41 | + Example: |
| 42 | + def test_with_s3(files_provider_config_factory): |
| 43 | + env_vars = files_provider_config_factory("s3") |
| 44 | + # env_vars contains S3_BUCKET_NAME, S3_BUCKET_ENDPOINT_URL, etc. |
| 45 | + """ |
| 46 | + |
| 47 | + def _factory(provider_name: str) -> list[Dict[str, str]]: |
| 48 | + env_vars: list[dict[str, str]] = [] |
| 49 | + |
| 50 | + if provider_name == "local" or provider_name is None: |
| 51 | + # Default case - no additional environment variables needed |
| 52 | + pass |
| 53 | + elif provider_name == "s3": |
| 54 | + env_vars.append({"name": "ENABLE_S3", "value": "s3"}) |
| 55 | + env_vars.append({"name": "S3_BUCKET_NAME", "value": request.getfixturevalue(argname="ci_s3_bucket_name")}) |
| 56 | + env_vars.append({ |
| 57 | + "name": "AWS_DEFAULT_REGION", |
| 58 | + "value": request.getfixturevalue(argname="ci_s3_bucket_region"), |
| 59 | + }) |
| 60 | + env_vars.append({ |
| 61 | + "name": "S3_ENDPOINT_URL", |
| 62 | + "value": request.getfixturevalue(argname="ci_s3_bucket_endpoint"), |
| 63 | + }) |
| 64 | + env_vars.append({ |
| 65 | + "name": "AWS_ACCESS_KEY_ID", |
| 66 | + "value": request.getfixturevalue(argname="aws_access_key_id"), |
| 67 | + }) |
| 68 | + env_vars.append({ |
| 69 | + "name": "AWS_SECRET_ACCESS_KEY", |
| 70 | + "value": request.getfixturevalue(argname="aws_secret_access_key"), |
| 71 | + }) |
| 72 | + env_vars.append({"name": "S3_AUTO_CREATE_BUCKET", "value": S3_AUTO_CREATE_BUCKET}) |
| 73 | + |
| 74 | + return env_vars |
| 75 | + |
| 76 | + return _factory |
0 commit comments