diff --git a/metrics_utility/automation_controller_billing/collector.py b/metrics_utility/automation_controller_billing/collector.py index 9129a38ad..5d9b6bd46 100644 --- a/metrics_utility/automation_controller_billing/collector.py +++ b/metrics_utility/automation_controller_billing/collector.py @@ -3,12 +3,12 @@ from django.conf import settings from django.core.serializers.json import DjangoJSONEncoder -from django.db import connection import metrics_utility.base as base from metrics_utility.automation_controller_billing.helpers import get_last_entries_from_db from metrics_utility.automation_controller_billing.package.factory import Factory as PackageFactory +from metrics_utility.db import get_connection from metrics_utility.library.lock import lock from metrics_utility.logger import logger @@ -43,7 +43,7 @@ def gather(self, dest=None, subset=None, since=None, until=None, billing_provide if suffix: key = f'gather_automation_controller_billing_{suffix}_lock' - with lock(key, wait=False, db=connection) as acquired: + with lock(key, wait=False, db=get_connection()) as acquired: if not acquired: logger.log(self.log_level, 'Not gathering Automation Controller billing data, another task holds lock') return None @@ -79,7 +79,7 @@ def _gather_config(self): @staticmethod def db_connection(): - return connection + return get_connection() @classmethod def registered_collectors(cls, module=None): @@ -104,7 +104,7 @@ def _gather_finalize(self): if self.ship and not disabled: # We need to wait on analytics lock, to update the last collected timestamp settings # so we don't clash with analytics job collection. - with lock('gather_analytics_lock', wait=True, db=connection): + with lock('gather_analytics_lock', wait=True, db=get_connection()): # We need to load fresh settings again as we're obtaning the lock, since # Analytics job could have changed this on the background and we'd be resetting # the Analytics values here. diff --git a/metrics_utility/automation_controller_billing/collectors.py b/metrics_utility/automation_controller_billing/collectors.py index c24808b22..cff31c4ec 100644 --- a/metrics_utility/automation_controller_billing/collectors.py +++ b/metrics_utility/automation_controller_billing/collectors.py @@ -9,7 +9,6 @@ import distro -from django.db import connection from django.db.utils import ProgrammingError from django.utils.timezone import now, timedelta from django.utils.translation import gettext_lazy as _ @@ -21,6 +20,7 @@ ) from metrics_utility.base import register from metrics_utility.base.utils import get_max_gather_period_days, get_optional_collectors +from metrics_utility.db import get_connection from metrics_utility.exceptions import MetricsException, MissingRequiredEnvVar from metrics_utility.library import CsvFileSplitter from metrics_utility.library.collectors.util import date_where @@ -161,7 +161,7 @@ def _copy_table(table, query, path, prepend_query=None): file_path = os.path.join(path, table + '_table.csv') file = CsvFileSplitter(filespec=file_path) - with connection.cursor() as cursor: + with get_connection().cursor() as cursor: if prepend_query: cursor.execute(prepend_query) @@ -877,8 +877,7 @@ def main_jobevent_service_table(since, full_path, until, **kwargs): """ jobs = [] - # do raw sql for django.db connection - with connection.cursor() as cursor: + with get_connection().cursor() as cursor: cursor.execute(jobs_query, {'since': since, 'until': until}) jobs = cursor.fetchall() diff --git a/metrics_utility/automation_controller_billing/extract/extractor_controller_db.py b/metrics_utility/automation_controller_billing/extract/extractor_controller_db.py index edaade953..a63799aa4 100644 --- a/metrics_utility/automation_controller_billing/extract/extractor_controller_db.py +++ b/metrics_utility/automation_controller_billing/extract/extractor_controller_db.py @@ -2,7 +2,7 @@ import pandas as pd -from django.db import connection +from metrics_utility.db import get_connection class ExtractorControllerDB: @@ -12,7 +12,7 @@ def __init__(self, extra_params): self.extra_params = extra_params def iter_batches(self): - with connection.cursor() as cursor: + with get_connection().cursor() as cursor: cursor.execute(self.pg_functions()) since = self.extra_params['opt_since'] diff --git a/metrics_utility/automation_controller_billing/helpers.py b/metrics_utility/automation_controller_billing/helpers.py index 66f7fbbbe..d58087fe2 100644 --- a/metrics_utility/automation_controller_billing/helpers.py +++ b/metrics_utility/automation_controller_billing/helpers.py @@ -5,9 +5,9 @@ import pandas as pd -from django.db import connection from django.utils.dateparse import parse_datetime +from metrics_utility.db import get_connection from metrics_utility.logger import logger @@ -19,7 +19,7 @@ def get_last_entries_from_db() -> Dict: Optional[str]: JSON string from database, or None if not found or error occurs """ try: - with connection.cursor() as cursor: + with get_connection().cursor() as cursor: cursor.execute(""" SELECT value FROM conf_setting @@ -41,7 +41,7 @@ def get_config_and_settings_from_db() -> Tuple[Dict[str, Any], Dict[str, Any]]: license_info = {} settings_info = {} try: - with connection.cursor() as cursor: + with get_connection().cursor() as cursor: cursor.execute(""" SELECT key, value FROM conf_setting @@ -77,7 +77,7 @@ def _fetch_one(db, sql): def get_controller_version_from_db() -> str: """Get AWX/Controller version from the main_instance DB table.""" return _fetch_one( - connection, + get_connection(), """ SELECT version FROM main_instance diff --git a/metrics_utility/base/collector.py b/metrics_utility/base/collector.py index aacc5ef84..b02187005 100644 --- a/metrics_utility/base/collector.py +++ b/metrics_utility/base/collector.py @@ -7,9 +7,9 @@ from abc import abstractmethod -from django.db import connection from django.utils.timezone import now, timedelta +from metrics_utility.db import get_connection from metrics_utility.library.lock import lock from metrics_utility.logger import logger @@ -87,17 +87,6 @@ def config_present(self): return self.collections.get('config') is not None - @staticmethod - @abstractmethod - def db_connection(): - """ - DB connection for advisory lock. Can be - - django.db.connection or - - sqlalchemy.engine.base.Engine.raw_connection() - - etc. - """ - pass - def gather(self, dest=None, subset=None, since=None, until=None): """Entry point for gathering @@ -107,7 +96,7 @@ def gather(self, dest=None, subset=None, since=None, until=None): :param until: (datetime) - high threshold of data changes (defaults to now) :return: None or list of paths to tarballs (.tar.gz) """ - with lock('gather_analytics_lock', wait=False, db=connection) as acquired: + with lock('gather_analytics_lock', wait=False, db=get_connection()) as acquired: if not acquired: logger.log(self.log_level, 'Not gathering analytics, another task holds lock') return None diff --git a/metrics_utility/db.py b/metrics_utility/db.py new file mode 100644 index 000000000..7af8833e3 --- /dev/null +++ b/metrics_utility/db.py @@ -0,0 +1,36 @@ +import json + + +_configured_connection = None + + +def configure_db(db_json=None): + """Configure the database connection. Call this early in command handle().""" + global _configured_connection + + if db_json is None: + # Use default Django connection + from django.db import connection as django_connection + + _configured_connection = django_connection + return + + # Parse JSON with defaults + config = json.loads(db_json) + config.setdefault('ENGINE', 'django.db.backends.postgresql') + config.setdefault('HOST', 'localhost') + config.setdefault('PORT', '5432') + + # Create custom connection + from django.db.utils import ConnectionHandler + + _configured_connection = ConnectionHandler({'default': config})['default'] + + +def get_connection(): + """Get the configured connection, or default Django connection if not configured.""" + if _configured_connection is None: + from django.db import connection as django_connection + + return django_connection + return _configured_connection diff --git a/metrics_utility/management/commands/build_report.py b/metrics_utility/management/commands/build_report.py index 969bb2cf3..8b32721c9 100644 --- a/metrics_utility/management/commands/build_report.py +++ b/metrics_utility/management/commands/build_report.py @@ -10,6 +10,7 @@ from metrics_utility.automation_controller_billing.extract.factory import Factory as ExtractorFactory from metrics_utility.automation_controller_billing.report.factory import Factory as ReportFactory from metrics_utility.automation_controller_billing.report_saver.factory import Factory as ReportSaverFactory +from metrics_utility.db import configure_db from metrics_utility.exceptions import BadRequiredEnvVar, BadShipTarget, MissingRequiredEnvVar from metrics_utility.logger import debug, logger from metrics_utility.management.validation import ( @@ -59,6 +60,7 @@ class Command(BaseCommand): ), 'force': ('With this option, the existing reports will be overwritten if running this command again.'), 'verbose': ('Print debug information to console.'), + 'db': ('Custom database configuration as JSON string. Example: \'{"NAME":"awx","USER":"user","PASSWORD":"pass"}\''), } def create_parser(self, prog_name, subcommand, **kwargs): @@ -117,8 +119,12 @@ def add_arguments(self, parser): parser.add_argument('--ephemeral', dest='ephemeral', action='store', help=self.help_texts.get('ephemeral')) parser.add_argument('--force', dest='force', action='store_true', help=self.help_texts.get('force')) parser.add_argument('--verbose', dest='verbose', action='store_true', help=self.help_texts.get('verbose')) + parser.add_argument('--db', dest='db', action='store', help=self.help_texts.get('db')) def handle(self, *args, **options): + # Configure database connection first + configure_db(options.get('db')) + if options.get('verbose'): debug() diff --git a/metrics_utility/management/commands/gather_automation_controller_billing_data.py b/metrics_utility/management/commands/gather_automation_controller_billing_data.py index 0cf6632ba..f9bac4e57 100644 --- a/metrics_utility/management/commands/gather_automation_controller_billing_data.py +++ b/metrics_utility/management/commands/gather_automation_controller_billing_data.py @@ -5,6 +5,7 @@ from django.core.management.base import BaseCommand from metrics_utility.automation_controller_billing.collector import Collector +from metrics_utility.db import configure_db from metrics_utility.exceptions import ( BadShipTarget, NoAnalyticsCollected, @@ -34,6 +35,7 @@ class Command(BaseCommand): 'dry-run': ('Gather billing metrics without shipping.'), 'ship': ('Enable shipping of billing metrics to the console.redhat.com'), 'verbose': ('Print debug information to console.'), + 'db': ('Custom database configuration as JSON string. Example: \'{"NAME":"awx","USER":"user","PASSWORD":"pass"}\''), } def create_parser(self, prog_name, subcommand, **kwargs): @@ -88,8 +90,12 @@ def add_arguments(self, parser): parser.add_argument('--since', dest='since', action='store', help=self.help_texts.get('since')) parser.add_argument('--until', dest='until', action='store', help=self.help_texts.get('until')) parser.add_argument('--verbose', dest='verbose', action='store_true', help=self.help_texts.get('verbose')) + parser.add_argument('--db', dest='db', action='store', help=self.help_texts.get('db')) def handle(self, *args, **options): + # Configure database connection first + configure_db(options.get('db')) + if options.get('verbose'): debug() handle_env_validation('gather')