|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from dataclasses import dataclass, field, replace |
| 3 | +from typing import Optional, Mapping, Any, Dict |
| 4 | +from psycopg import conninfo |
| 5 | + |
| 6 | +from django.conf import settings |
| 7 | + |
| 8 | + |
| 9 | +POSTGRES_ENGINE = "django.db.backends.postgresql" |
| 10 | + |
| 11 | + |
| 12 | +# Refer Django docs: https://docs.djangoproject.com/en/4.2/ref/settings/#databases |
| 13 | +# Inspiration: https://github.com/jazzband/dj-database-url/blob/master/dj_database_url/__init__.py |
| 14 | +# We do not use dj-database-url directly due to a bug with it's parsing logic when unix sockets |
| 15 | +# are involved. |
| 16 | +# Following class only contains a subset of the available django configurations. |
| 17 | +@dataclass(frozen=True) |
| 18 | +class DBConfig(ABC): |
| 19 | + dbname: str |
| 20 | + engine: str |
| 21 | + host: Optional[str] = None |
| 22 | + port: Optional[int] = None |
| 23 | + role: Optional[str] = None |
| 24 | + password: Optional[str] = field(default=None, repr=False) |
| 25 | + options: Mapping[str, Any] = field(default_factory=dict) |
| 26 | + atomic_requests: Optional[bool] = None |
| 27 | + autocommit: Optional[bool] = None |
| 28 | + conn_max_age: Optional[int] = 0 |
| 29 | + conn_health_checks: Optional[bool] = False |
| 30 | + disable_server_side_cursors: Optional[bool] = False |
| 31 | + time_zone: Optional[str] = None |
| 32 | + |
| 33 | + @classmethod |
| 34 | + @abstractmethod |
| 35 | + def from_connection_string(cls, url: str) -> "DBConfig": |
| 36 | + """ |
| 37 | + Parse a database URL into a DBConfig instance. |
| 38 | + Must be implemented by subclasses. |
| 39 | + """ |
| 40 | + raise NotImplementedError |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def from_django_dict(cls, cfg: Mapping[str, Any]) -> "DBConfig": |
| 44 | + """ |
| 45 | + Create DBConfig instance from a Django DATABASES dict |
| 46 | + """ |
| 47 | + missing = {"ENGINE", "NAME"} - set(cfg.keys()) |
| 48 | + if missing: |
| 49 | + raise ValueError(f"DBConfig.from_dict missing required fields: {missing}") |
| 50 | + |
| 51 | + return cls( |
| 52 | + engine=cfg["ENGINE"], |
| 53 | + dbname=cfg["NAME"], |
| 54 | + host=cfg.get("HOST"), |
| 55 | + port=parse_port(cfg.get("PORT")), |
| 56 | + role=cfg.get("USER"), |
| 57 | + password=cfg.get("PASSWORD"), |
| 58 | + options=cfg.get("OPTIONS", {}).copy(), |
| 59 | + atomic_requests=cfg.get("ATOMIC_REQUESTS"), |
| 60 | + autocommit=cfg.get("AUTOCOMMIT"), |
| 61 | + conn_max_age=cfg.get("CONN_MAX_AGE"), |
| 62 | + conn_health_checks=cfg.get("CONN_HEALTH_CHECKS"), |
| 63 | + disable_server_side_cursors=cfg.get("DISABLE_SERVER_SIDE_CURSORS"), |
| 64 | + time_zone=cfg.get("TIME_ZONE"), |
| 65 | + ) |
| 66 | + |
| 67 | + def to_django_dict(self) -> Dict[str, Any]: |
| 68 | + result: Dict[str, Any] = { |
| 69 | + "ENGINE": self.engine, |
| 70 | + "NAME": self.dbname, |
| 71 | + } |
| 72 | + if self.host is not None: |
| 73 | + result["HOST"] = self.host |
| 74 | + if self.port is not None: |
| 75 | + result["PORT"] = str(self.port) |
| 76 | + if self.role is not None: |
| 77 | + result["USER"] = self.role |
| 78 | + if self.password is not None: |
| 79 | + result["PASSWORD"] = self.password |
| 80 | + if self.atomic_requests is not None: |
| 81 | + result["ATOMIC_REQUESTS"] = self.atomic_requests |
| 82 | + if self.autocommit is not None: |
| 83 | + result["AUTOCOMMIT"] = self.autocommit |
| 84 | + if self.conn_max_age is not None: |
| 85 | + result["CONN_MAX_AGE"] = self.conn_max_age |
| 86 | + if self.conn_health_checks is not None: |
| 87 | + result["CONN_HEALTH_CHECKS"] = self.conn_health_checks |
| 88 | + if self.disable_server_side_cursors is not None: |
| 89 | + result["DISABLE_SERVER_SIDE_CURSORS"] = self.disable_server_side_cursors |
| 90 | + if self.time_zone is not None: |
| 91 | + result["TIME_ZONE"] = self.time_zone |
| 92 | + if self.options: |
| 93 | + result["OPTIONS"] = dict(self.options) |
| 94 | + return result |
| 95 | + |
| 96 | + |
| 97 | +# Reference: https://docs.djangoproject.com/en/4.2/ref/databases/#postgresql-notes |
| 98 | +# Options are merged from values passed to the class, only sslmode is explicitly handled here. |
| 99 | +@dataclass(frozen=True) |
| 100 | +class PostgresConfig(DBConfig): |
| 101 | + engine: str = POSTGRES_ENGINE |
| 102 | + sslmode: Optional[str] = None |
| 103 | + |
| 104 | + # Inject sslmode into OPTIONS |
| 105 | + # https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION |
| 106 | + # TODO: Avoid doing this in the frozen class, find a better way. |
| 107 | + def __post_init__(self) -> None: |
| 108 | + base_opts = dict(self.options) |
| 109 | + if self.sslmode is not None: |
| 110 | + base_opts["sslmode"] = self.sslmode |
| 111 | + object.__setattr__(self, "options", base_opts) |
| 112 | + |
| 113 | + @classmethod |
| 114 | + def from_connection_string(cls, url: str) -> "PostgresConfig": |
| 115 | + params = conninfo.conninfo_to_dict(url) |
| 116 | + dbname = params.get("dbname") |
| 117 | + if not dbname: |
| 118 | + raise ValueError("PostgresConfig.from_connection_string: missing database name in URL") |
| 119 | + |
| 120 | + return cls( |
| 121 | + dbname=dbname, |
| 122 | + host=params.get("host"), |
| 123 | + port=parse_port(params.get("port")), |
| 124 | + role=params.get("user"), |
| 125 | + password=params.get("password"), |
| 126 | + sslmode=params.get("sslmode"), |
| 127 | + ) |
| 128 | + |
| 129 | + @classmethod |
| 130 | + def from_django_dict(cls, cfg: Mapping[str, Any]) -> "PostgresConfig": |
| 131 | + raw_opts = cfg.get("OPTIONS", {}).copy() |
| 132 | + sslmode = raw_opts.pop("sslmode", None) |
| 133 | + base_cfg = dict(cfg, OPTIONS=raw_opts) |
| 134 | + base = super().from_django_dict(base_cfg) |
| 135 | + return replace(base, sslmode=sslmode) |
| 136 | + |
| 137 | + |
| 138 | +def get_internal_database_config(): |
| 139 | + conn_info = settings.DATABASES.get("default") |
| 140 | + if not conn_info: |
| 141 | + raise KeyError("settings.DATABASES['default'] is not defined") |
| 142 | + engine = conn_info.get("ENGINE") |
| 143 | + if engine == POSTGRES_ENGINE: |
| 144 | + return PostgresConfig.from_django_dict(conn_info) |
| 145 | + raise NotImplementedError(f"Database engine '{engine}' is not supported") |
| 146 | + |
| 147 | + |
| 148 | +def parse_port(raw_port): |
| 149 | + port = None |
| 150 | + if raw_port not in (None, ""): |
| 151 | + try: |
| 152 | + port = int(raw_port) |
| 153 | + except (TypeError, ValueError): |
| 154 | + raise ValueError(f"Invalid PORT value: {raw_port!r}") |
| 155 | + return port |
0 commit comments