Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion invenio_config/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2018 CERN.
# Copyright (C) 2024 KTH Royal Institute of Technology.
# Copyright (C) 2024-2025 KTH Royal Institute of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -44,3 +44,112 @@ def init_app(self, app):

# Set value
app.config[varname] = value


def _get_env_var(prefix: str, keys: list) -> dict:
"""Retrieve environment variables with a given prefix."""
return {k: os.environ.get(f"{prefix}_{k.upper()}") for k in keys}


def _check_prefix(prefix: str) -> str:
"""Check if prefix ends with an underscore and remove it."""
if not prefix:
return "INVENIO"
if prefix.endswith("_"):
prefix = prefix[:-1]
return prefix


def build_db_uri(prefix="INVENIO"):
"""
Build database URI from environment variables or use default.

Priority order:
1. INVENIO_SQLALCHEMY_DATABASE_URI
2. SQLALCHEMY_DATABASE_URI
3. INVENIO_DB_* specific environment variables
4. Default URI

Note: For option 3, to assert that the INVENIO_DB_* settings take effect,
you need to set SQLALCHEMY_DATABASE_URI="" in your environment.
"""
prefix = _check_prefix(prefix)
default_uri = "postgresql+psycopg2://invenio-app-rdm:invenio-app-rdm@localhost/invenio-app-rdm"

for key in [f"{prefix}_SQLALCHEMY_DATABASE_URI", "SQLALCHEMY_DATABASE_URI"]:
if uri := os.environ.get(key):
return uri

db_params = _get_env_var(
f"{prefix}_DB", ["user", "password", "host", "port", "name", "protocol"]
)
if all(db_params.values()):
uri = f"{db_params['protocol']}://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['name']}"
return uri

return default_uri


def build_broker_url(prefix="INVENIO"):
"""
Build broker URL from environment variables or use default.

Priority order:
1. INVENIO_BROKER_URL
2. BROKER_URL
3. INVENIO_AMQP_BROKER_* specific environment variables
4. Default URL
Note: see: https://docs.celeryq.dev/en/stable/userguide/configuration.html#new-lowercase-settings
"""
prefix = _check_prefix(prefix)
default_url = "amqp://guest:guest@localhost:5672/"

for key in [f"{prefix}_BROKER_URL", "BROKER_URL"]:
if broker_url := os.environ.get(key):
return broker_url

broker_params = _get_env_var(
f"{prefix}_AMQP_BROKER", ["user", "password", "host", "port", "protocol"]
)
if all(broker_params.values()):
broker_env_var = f"{prefix}_AMQP_BROKER_VHOST"
vhost = f"{os.environ.get(broker_env_var).removeprefix('/')}"
broker_url = f"{broker_params['protocol']}://{broker_params['user']}:{broker_params['password']}@{broker_params['host']}:{broker_params['port']}/{vhost}"
return broker_url
return default_url


def build_redis_url(db=None, prefix="INVENIO"):
"""
Build Redis URL from environment variables or use default.

Priority order:
1. INVENIO_CACHE_REDIS_URL
2. CACHE_REDIS_URL
3. INVENIO_KV_CACHE_* specific environment variables
4. Default URL
"""
prefix = _check_prefix(prefix)
db = db if db is not None else 0
default_url = f"redis://localhost:6379/{db}"

for key in [f"{prefix}_CACHE_REDIS_URL", "CACHE_REDIS_URL"]:
if cache_url := os.environ.get(key):
if cache_url.startswith(("redis://", "rediss://", "unix://")):
return cache_url

redis_params = _get_env_var(
f"{prefix}_KV_CACHE", ["host", "port", "password", "protocol"]
)

if redis_params["host"] and redis_params["port"]:
protocol = redis_params.get("protocol", "redis")
password = (
f":{redis_params['password']}@" if redis_params.get("password") else ""
)
cache_url = (
f"{protocol}://{password}{redis_params['host']}:{redis_params['port']}/{db}"
)
return cache_url

return default_url
188 changes: 188 additions & 0 deletions tests/test_invenio_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings
from os.path import join

import pytest
from flask import Flask
from mock import patch
from pkg_resources import EntryPoint
Expand All @@ -28,6 +29,7 @@
create_config_loader,
)
from invenio_config.default import ALLOWED_HTML_ATTRS, ALLOWED_HTML_TAGS
from invenio_config.env import build_broker_url, build_db_uri, build_redis_url


class ConfigEP(EntryPoint):
Expand Down Expand Up @@ -230,3 +232,189 @@ class Config(object):
assert app.config["ENV"] == "env"
finally:
shutil.rmtree(tmppath)


def set_env_vars(monkeypatch, env_vars):
"""Helper function to set environment variables."""
for key in env_vars:
monkeypatch.delenv(key, raising=False)
for key, value in env_vars.items():
monkeypatch.setenv(key, value)


@pytest.mark.parametrize(
"env_vars, expected_uri",
[
(
{
"INVENIO_DB_USER": "testuser",
"INVENIO_DB_PASSWORD": "testpassword",
"INVENIO_DB_HOST": "testhost",
"INVENIO_DB_PORT": "5432",
"INVENIO_DB_NAME": "testdb",
"INVENIO_DB_PROTOCOL": "postgresql+psycopg2",
},
"postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb",
),
(
{
"INVENIO_SQLALCHEMY_DATABASE_URI": "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb"
},
"postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb",
),
(
{
"SQLALCHEMY_DATABASE_URI": "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb"
},
"postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb",
),
(
{},
"postgresql+psycopg2://invenio-app-rdm:invenio-app-rdm@localhost/invenio-app-rdm",
),
],
)
def test_build_db_uri(monkeypatch, env_vars, expected_uri):
"""Test building database URI."""
set_env_vars(monkeypatch, env_vars)
assert build_db_uri() == expected_uri


@pytest.mark.parametrize(
"env_vars, expected_url",
[
(
{
"INVENIO_AMQP_BROKER_USER": "testuser",
"INVENIO_AMQP_BROKER_PASSWORD": "testpassword",
"INVENIO_AMQP_BROKER_HOST": "testhost",
"INVENIO_AMQP_BROKER_PORT": "5672",
"INVENIO_AMQP_BROKER_PROTOCOL": "amqp",
"INVENIO_AMQP_BROKER_VHOST": "/testvhost",
},
"amqp://testuser:testpassword@testhost:5672/testvhost",
),
(
{
"INVENIO_AMQP_BROKER_USER": "testuser",
"INVENIO_AMQP_BROKER_PASSWORD": "testpassword",
"INVENIO_AMQP_BROKER_HOST": "testhost",
"INVENIO_AMQP_BROKER_PORT": "5672",
"INVENIO_AMQP_BROKER_PROTOCOL": "amqp",
"INVENIO_AMQP_BROKER_VHOST": "testvhost",
},
"amqp://testuser:testpassword@testhost:5672/testvhost",
),
(
{
"INVENIO_AMQP_BROKER_USER": "testuser",
"INVENIO_AMQP_BROKER_PASSWORD": "testpassword",
"INVENIO_AMQP_BROKER_HOST": "testhost",
"INVENIO_AMQP_BROKER_PORT": "5672",
"INVENIO_AMQP_BROKER_PROTOCOL": "amqp",
"INVENIO_AMQP_BROKER_VHOST": "",
},
"amqp://testuser:testpassword@testhost:5672/",
),
(
{
"INVENIO_AMQP_BROKER_USER": "testuser",
"INVENIO_AMQP_BROKER_PASSWORD": "testpassword",
"INVENIO_AMQP_BROKER_HOST": "testhost",
"INVENIO_AMQP_BROKER_PORT": "5672",
"INVENIO_AMQP_BROKER_PROTOCOL": "amqp",
},
"amqp://testuser:testpassword@testhost:5672/",
),
],
)
def test_build_broker_url_with_vhost(monkeypatch, env_vars, expected_url):
"""Test building broker URL with vhost."""
set_env_vars(monkeypatch, env_vars)
assert build_broker_url() == expected_url


@pytest.mark.parametrize(
"env_vars, expected_url",
[
(
{
"INVENIO_AMQP_BROKER_USER": "testuser",
"INVENIO_AMQP_BROKER_PASSWORD": "testpassword",
"INVENIO_AMQP_BROKER_HOST": "testhost",
"INVENIO_AMQP_BROKER_PORT": "5672",
"INVENIO_AMQP_BROKER_PROTOCOL": "amqp",
"INVENIO_AMQP_BROKER_VHOST": "testvhost",
},
"amqp://testuser:testpassword@testhost:5672/testvhost",
),
(
{
"INVENIO_AMQP_BROKER_USER": "testuser",
"INVENIO_AMQP_BROKER_PASSWORD": "testpassword",
"INVENIO_AMQP_BROKER_HOST": "testhost",
"INVENIO_AMQP_BROKER_PORT": "5672",
"INVENIO_AMQP_BROKER_PROTOCOL": "amqp",
"INVENIO_AMQP_BROKER_VHOST": "",
},
"amqp://testuser:testpassword@testhost:5672/",
),
(
{"INVENIO_BROKER_URL": "amqp://guest:guest@localhost:5672/"},
"amqp://guest:guest@localhost:5672/",
),
(
{},
"amqp://guest:guest@localhost:5672/",
),
],
)
def test_build_broker_url_with_vhost(monkeypatch, env_vars, expected_url):
"""Test building broker URL with vhost."""
set_env_vars(monkeypatch, env_vars)
assert build_broker_url() == expected_url


@pytest.mark.parametrize(
"env_vars, db, expected_url",
[
(
{
"INVENIO_KV_CACHE_HOST": "testhost",
"INVENIO_KV_CACHE_PORT": "6379",
"INVENIO_KV_CACHE_PASSWORD": "testpassword",
"INVENIO_KV_CACHE_PROTOCOL": "redis",
},
2,
"redis://:testpassword@testhost:6379/2",
),
(
{
"INVENIO_KV_CACHE_HOST": "testhost",
"INVENIO_KV_CACHE_PORT": "6379",
"INVENIO_KV_CACHE_PROTOCOL": "redis",
},
1,
"redis://testhost:6379/1",
),
(
{"BROKER_URL": "redis://localhost:6379/0"},
None,
"redis://localhost:6379/0",
),
(
{"INVENIO_KV_CACHE_URL": "redis://localhost:6379/3"},
3,
"redis://localhost:6379/3",
),
(
{},
4,
"redis://localhost:6379/4",
),
],
)
def test_build_redis_url(monkeypatch, env_vars, db, expected_url):
"""Test building Redis URL."""
set_env_vars(monkeypatch, env_vars)
assert build_redis_url(db=db) == expected_url
Loading