diff --git a/litellm/proxy/common_utils/load_config_utils.py b/litellm/proxy/common_utils/load_config_utils.py index 38e7b3f33b26..604b7b17e727 100644 --- a/litellm/proxy/common_utils/load_config_utils.py +++ b/litellm/proxy/common_utils/load_config_utils.py @@ -1,9 +1,11 @@ +from typing import Callable + import yaml from litellm._logging import verbose_proxy_logger -def get_file_contents_from_s3(bucket_name, object_key): +async def get_file_contents_from_s3(bucket_name, object_key): try: # v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc import tempfile @@ -40,6 +42,13 @@ def get_file_contents_from_s3(bucket_name, object_key): with open(temp_file_path, "r") as yaml_file: config = yaml.safe_load(yaml_file) + # include file config + config = await process_includes_from_bucket( + config=config, + get_file_method=get_file_contents_from_s3, + bucket_name=bucket_name, + ) + return config except ImportError as e: # this is most likely if a user is not using the litellm docker container @@ -64,6 +73,12 @@ async def get_config_file_contents_from_gcs(bucket_name, object_key): file_contents = file_contents.decode("utf-8") # convert to yaml config = yaml.safe_load(file_contents) + # include file config + config = await process_includes_from_bucket( + config=config, + get_file_method=get_config_file_contents_from_gcs, + bucket_name=bucket_name, + ) return config except Exception as e: @@ -71,6 +86,44 @@ async def get_config_file_contents_from_gcs(bucket_name, object_key): return None +async def process_includes_from_bucket( + config: dict, get_file_method: Callable, bucket_name: str +) -> dict: + """ + Process includes by appending their contents to the main config + + Handles nested config.yamls with `include` section + + Example config: This will get the contents from files in `include` and append it + ```yaml + include: + - /path/to/key/model_config.yaml + + litellm_settings: + callbacks: ["prometheus"] + ``` + """ + if "include" not in config: + return config + + if not isinstance(config["include"], list): + raise ValueError("'include' must be a list of file paths") + + # Load and append all included files + for include_file in config["include"]: + included_config = await get_file_method(bucket_name, include_file) + # Simply update/extend the main config with included config + for key, value in included_config.items(): + if isinstance(value, list) and key in config: + config[key].extend(value) + else: + config[key] = value + + # Remove the include directive + del config["include"] + return config + + # # Example usage # bucket_name = 'litellm-proxy' # object_key = 'litellm_proxy_config.yaml' diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 18e1d8d98a75..0d3a7551c81d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1489,7 +1489,7 @@ async def get_config(self, config_file_path: Optional[str] = None) -> dict: bucket_name=bucket_name, object_key=object_key ) else: - config = get_file_contents_from_s3( + config = await get_file_contents_from_s3( bucket_name=bucket_name, object_key=object_key ) diff --git a/poetry.lock b/poetry.lock index a2abc9283185..80df16e51f98 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -395,14 +395,14 @@ uvloop = ["uvloop (>=0.15.2)"] name = "boto3" version = "1.34.34" description = "The AWS SDK for Python" -optional = true +optional = false python-versions = ">= 3.8" -groups = ["main"] -markers = "extra == \"proxy\"" +groups = ["main", "dev"] files = [ {file = "boto3-1.34.34-py3-none-any.whl", hash = "sha256:33a8b6d9136fa7427160edb92d2e50f2035f04e9d63a2d1027349053e12626aa"}, {file = "boto3-1.34.34.tar.gz", hash = "sha256:b2f321e20966f021ec800b7f2c01287a3dd04fc5965acdfbaa9c505a24ca45d1"}, ] +markers = {main = "extra == \"proxy\""} [package.dependencies] botocore = ">=1.34.34,<1.35.0" @@ -416,14 +416,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.34.162" description = "Low-level, data-driven core of boto 3." -optional = true +optional = false python-versions = ">=3.8" -groups = ["main"] -markers = "extra == \"proxy\"" +groups = ["main", "dev"] files = [ {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, ] +markers = {main = "extra == \"proxy\""} [package.dependencies] jmespath = ">=0.7.1,<2.0.0" @@ -1584,7 +1584,7 @@ version = "3.1.6" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" -groups = ["main", "proxy-dev"] +groups = ["main", "dev", "proxy-dev"] files = [ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, @@ -1686,14 +1686,14 @@ files = [ name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" -optional = true +optional = false python-versions = ">=3.7" -groups = ["main"] -markers = "extra == \"proxy\"" +groups = ["main", "dev"] files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +markers = {main = "extra == \"proxy\""} [[package]] name = "jsonschema" @@ -1791,7 +1791,7 @@ version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" -groups = ["main", "proxy-dev"] +groups = ["main", "dev", "proxy-dev"] files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, @@ -1947,6 +1947,53 @@ numpy = [ [package.extras] dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] +[[package]] +name = "moto" +version = "5.0.24" +description = "" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "moto-5.0.24-py3-none-any.whl", hash = "sha256:4d826f1574849f18ddd2fcbf614d97f82c8fddfb9d95fac1078da01a39b57c10"}, + {file = "moto-5.0.24.tar.gz", hash = "sha256:dba6426bd770fbb9d892633fbd35253cbc181eeaa0eba97d6f058720a8fe9b42"}, +] + +[package.dependencies] +boto3 = ">=1.9.201" +botocore = ">=1.14.0,<1.35.45 || >1.35.45,<1.35.46 || >1.35.46" +cryptography = ">=35.0.0" +Jinja2 = ">=2.10.1" +python-dateutil = ">=2.1,<3.0.0" +requests = ">=2.5" +responses = ">=0.15.0" +werkzeug = ">=0.5,<2.2.0 || >2.2.0,<2.2.1 || >2.2.1" +xmltodict = "*" + +[package.extras] +all = ["PyYAML (>=5.1)", "antlr4-python3-runtime", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "graphql-core", "joserfc (>=0.9.0)", "jsondiff (>=1.1.2)", "jsonpath-ng", "jsonschema", "multipart", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.6)", "pyparsing (>=3.0.7)", "setuptools"] +apigateway = ["PyYAML (>=5.1)", "joserfc (>=0.9.0)", "openapi-spec-validator (>=0.5.0)"] +apigatewayv2 = ["PyYAML (>=5.1)", "openapi-spec-validator (>=0.5.0)"] +appsync = ["graphql-core"] +awslambda = ["docker (>=3.0.0)"] +batch = ["docker (>=3.0.0)"] +cloudformation = ["PyYAML (>=5.1)", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "graphql-core", "joserfc (>=0.9.0)", "jsondiff (>=1.1.2)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.6)", "pyparsing (>=3.0.7)", "setuptools"] +cognitoidp = ["joserfc (>=0.9.0)"] +dynamodb = ["docker (>=3.0.0)", "py-partiql-parser (==0.5.6)"] +dynamodbstreams = ["docker (>=3.0.0)", "py-partiql-parser (==0.5.6)"] +events = ["jsonpath-ng"] +glue = ["pyparsing (>=3.0.7)"] +iotdata = ["jsondiff (>=1.1.2)"] +proxy = ["PyYAML (>=5.1)", "antlr4-python3-runtime", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=2.5.1)", "graphql-core", "joserfc (>=0.9.0)", "jsondiff (>=1.1.2)", "jsonpath-ng", "multipart", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.6)", "pyparsing (>=3.0.7)", "setuptools"] +quicksight = ["jsonschema"] +resourcegroupstaggingapi = ["PyYAML (>=5.1)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "graphql-core", "joserfc (>=0.9.0)", "jsondiff (>=1.1.2)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.6)", "pyparsing (>=3.0.7)"] +s3 = ["PyYAML (>=5.1)", "py-partiql-parser (==0.5.6)"] +s3crc32c = ["PyYAML (>=5.1)", "crc32c", "py-partiql-parser (==0.5.6)"] +server = ["PyYAML (>=5.1)", "antlr4-python3-runtime", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "flask (!=2.2.0,!=2.2.1)", "flask-cors", "graphql-core", "joserfc (>=0.9.0)", "jsondiff (>=1.1.2)", "jsonpath-ng", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.6)", "pyparsing (>=3.0.7)", "setuptools"] +ssm = ["PyYAML (>=5.1)"] +stepfunctions = ["antlr4-python3-runtime", "jsonpath-ng"] +xray = ["aws-xray-sdk (>=0.93,!=0.96)", "setuptools"] + [[package]] name = "msal" version = "1.32.3" @@ -3204,14 +3251,14 @@ dev = ["pre-commit", "pytest-asyncio", "tox"] name = "python-dateutil" version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -groups = ["main"] -markers = "extra == \"proxy\"" +groups = ["main", "dev"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] +markers = {main = "extra == \"proxy\""} [package.dependencies] six = ">=1.5" @@ -3789,14 +3836,14 @@ files = [ name = "s3transfer" version = "0.10.4" description = "An Amazon S3 Transfer Manager" -optional = true +optional = false python-versions = ">=3.8" -groups = ["main"] -markers = "extra == \"proxy\"" +groups = ["main", "dev"] files = [ {file = "s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e"}, {file = "s3transfer-0.10.4.tar.gz", hash = "sha256:29edc09801743c21eb5ecbc617a152df41d3c287f67b615f73e5f750583666a7"}, ] +markers = {main = "extra == \"proxy\""} [package.dependencies] botocore = ">=1.33.2,<2.0a.0" @@ -3808,14 +3855,14 @@ crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] name = "six" version = "1.17.0" description = "Python 2 and 3 compatibility utilities" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -groups = ["main"] -markers = "extra == \"extra-proxy\" or extra == \"proxy\"" +groups = ["main", "dev"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, ] +markers = {main = "extra == \"extra-proxy\" or extra == \"proxy\""} [[package]] name = "sniffio" @@ -4468,6 +4515,44 @@ files = [ {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, ] +[[package]] +name = "werkzeug" +version = "3.0.6" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.10\"" +files = [ + {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, + {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + +[[package]] +name = "werkzeug" +version = "3.1.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"}, + {file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "wrapt" version = "1.17.2" @@ -4572,6 +4657,18 @@ files = [ [package.dependencies] h11 = ">=0.9.0,<1" +[[package]] +name = "xmltodict" +version = "0.14.2" +description = "Makes working with XML feel like you are working with JSON" +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "xmltodict-0.14.2-py2.py3-none-any.whl", hash = "sha256:20cc7d723ed729276e808f26fb6b3599f786cbc37e06c65e192ba77c40f20aac"}, + {file = "xmltodict-0.14.2.tar.gz", hash = "sha256:201e7c28bb210e374999d1dde6382923ab0ed1a8a5faeece48ab525b7810a553"}, +] + [[package]] name = "yarl" version = "1.15.2" @@ -4712,4 +4809,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", [metadata] lock-version = "2.1" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "f68f232a0ce2bd3b82b536f10d2d7fd8a34dc43a4a4d9c7f8030affa69d90d21" +content-hash = "d69a492b85a85ebf4d3d03475191b34c7b8e16df5d50d16fea6b9b343bfda5b5" diff --git a/pyproject.toml b/pyproject.toml index bf4cfc45ad52..d13155986bf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,7 @@ types-PyYAML = "*" opentelemetry-api = "1.25.0" opentelemetry-sdk = "1.25.0" opentelemetry-exporter-otlp = "1.25.0" +moto = "5.0.24" [tool.poetry.group.proxy-dev.dependencies] prisma = "0.11.0" diff --git a/tests/proxy_unit_tests/test_proxy_config_unit_test.py b/tests/proxy_unit_tests/test_proxy_config_unit_test.py index a1586ab6bde0..565f98307f5f 100644 --- a/tests/proxy_unit_tests/test_proxy_config_unit_test.py +++ b/tests/proxy_unit_tests/test_proxy_config_unit_test.py @@ -5,6 +5,8 @@ import pytest from dotenv import load_dotenv +from moto import mock_aws +import boto3 import litellm.proxy import litellm.proxy.proxy_server @@ -264,3 +266,68 @@ def test_add_callbacks_invalid_input(): # Cleanup litellm.success_callback = [] litellm.failure_callback = [] + + +@pytest.mark.asyncio +async def test_reading_configs_with_includes_from_s3(): + """ + Test that the config is read correctly from the S3 and read includes + """ + BUCKET_NAME = "config-bucket" + config_files = ("config_with_multiple_includes.yaml", "models_file_1.yaml", "models_file_2.yaml") + os.environ["LITELLM_CONFIG_BUCKET_NAME"] = BUCKET_NAME + os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"] = config_files[0] + os.environ["LITELLM_CONFIG_BUCKET_TYPE"] = "S3" + with mock_aws(): + # setup s3 bucket and put config files + current_path = os.path.dirname(os.path.abspath(__file__)) + s3_client = boto3.client("s3") + s3_client.create_bucket(Bucket=BUCKET_NAME) + for file_name in config_files: + config_path = os.path.join( + current_path, "example_config_yaml", file_name, + ) + s3_client.put_object( + Bucket="config-bucket", Key=file_name, Body=open(config_path, 'rb') + ) + proxy_config_instance = ProxyConfig() + config = await proxy_config_instance.get_config() + + assert config == { + 'model_list': [ + {'model_name': 'included-model-1', 'litellm_params': {'model': 'gpt-4'}}, + {'model_name': 'included-model-2', 'litellm_params': {'model': 'gpt-3.5-turbo'}} + ], + 'litellm_settings': {'callbacks': ['prometheus']}, + } + + # unset the env variable to avoid side-effects on other tests + del os.environ["LITELLM_CONFIG_BUCKET_NAME"] + del os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"] + del os.environ["LITELLM_CONFIG_BUCKET_TYPE"] + + +@pytest.mark.asyncio +async def test_reading_configs_from_s3_file_not_found(): + """ + Test that the config is not present S3 + """ + BUCKET_NAME = "config-bucket" + config_files = ("config_with_multiple_includes.yaml", "models_file_1.yaml", "models_file_2.yaml") + os.environ["LITELLM_CONFIG_BUCKET_NAME"] = BUCKET_NAME + os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"] = config_files[0] + os.environ["LITELLM_CONFIG_BUCKET_TYPE"] = "S3" + with mock_aws(): + # setup s3 bucket but do not put file + s3_client = boto3.client("s3") + s3_client.create_bucket(Bucket=BUCKET_NAME) + proxy_config_instance = ProxyConfig() + with pytest.raises(Exception) as ex: + await proxy_config_instance.get_config() + + assert str(ex.value) == "Unable to load config from given source." + + # unset the env variable to avoid side-effects on other tests + del os.environ["LITELLM_CONFIG_BUCKET_NAME"] + del os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"] + del os.environ["LITELLM_CONFIG_BUCKET_TYPE"] \ No newline at end of file