diff --git a/parsons/aws/s3.py b/parsons/aws/s3.py index 2afd7475be6..52c54680eca 100644 --- a/parsons/aws/s3.py +++ b/parsons/aws/s3.py @@ -1,11 +1,10 @@ import logging -import os import re import boto3 from botocore.client import ClientError -from parsons.utilities import files +from parsons.utilities import check_env, files logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ def __init__( # whenever the aws_access_key_id and aws_secret_access_key are passed. if aws_session_token is None and use_env_token: - aws_session_token = os.getenv("AWS_SESSION_TOKEN") + aws_session_token = check_env.check("AWS_SESSION_TOKEN", None, optional=True) self.session = boto3.Session( aws_access_key_id=aws_access_key_id, diff --git a/parsons/azure/azure_blob_storage.py b/parsons/azure/azure_blob_storage.py index cb7d11289b6..83f2b42ca2a 100644 --- a/parsons/azure/azure_blob_storage.py +++ b/parsons/azure/azure_blob_storage.py @@ -1,5 +1,4 @@ import logging -import os from pathlib import Path from typing import Literal from urllib.parse import urlparse @@ -50,7 +49,7 @@ def __init__( account_domain="blob.core.windows.net", account_url=None, ): - self.account_url = os.getenv("AZURE_ACCOUNT_URL", account_url) + self.account_url = check_env.check("AZURE_ACCOUNT_URL", account_url, optional=True) self.credential = check_env.check("AZURE_CREDENTIAL", credential) if not self.account_url: self.account_name = check_env.check("AZURE_ACCOUNT_NAME", account_name) diff --git a/parsons/box/box.py b/parsons/box/box.py index e4f44a37e47..2cd1e5ab8ea 100644 --- a/parsons/box/box.py +++ b/parsons/box/box.py @@ -17,7 +17,6 @@ """ import logging -import os import tempfile from pathlib import Path from typing import Literal @@ -31,6 +30,7 @@ from box_sdk_gen.networking.auth import Authentication from parsons.etl.table import Table +from parsons.utilities import check_env from parsons.utilities.files import create_temp_file, create_temp_file_for_path logger = logging.getLogger(__name__) @@ -65,7 +65,7 @@ class Box: def __init__(self, auth: Authentication | None = None) -> None: if auth is None: - access_token = os.environ["BOX_ACCESS_TOKEN"] + access_token = check_env.check("BOX_ACCESS_TOKEN", None) oauth = BoxDeveloperTokenAuth(token=access_token) self.client = BoxClient(auth=oauth) else: diff --git a/parsons/databases/mysql/mysql.py b/parsons/databases/mysql/mysql.py index a53ce6e912b..eb07e81cd59 100644 --- a/parsons/databases/mysql/mysql.py +++ b/parsons/databases/mysql/mysql.py @@ -1,5 +1,4 @@ import logging -import os import pickle from contextlib import contextmanager from pathlib import Path @@ -37,18 +36,23 @@ class MySQL(DatabaseConnector, MySQLCreateTable, Alchemy): db: str Required if env variable ``MYSQL_DB`` not populated port: int - Can be set by env variable ``MYSQL_PORT`` or argument. + If omitted or ``None``, uses ``MYSQL_PORT`` when set, otherwise 3306. If passed + (including ``3306``), the argument takes precedence over ``MYSQL_PORT``. """ - def __init__(self, host=None, username=None, password=None, db=None, port=3306): + def __init__(self, host=None, username=None, password=None, db=None, port=None): super().__init__() self.username = check_env.check("MYSQL_USERNAME", username) self.password = check_env.check("MYSQL_PASSWORD", password) self.host = check_env.check("MYSQL_HOST", host) self.db = check_env.check("MYSQL_DB", db) - self.port = port or os.environ.get("MYSQL_PORT") + if port is not None: + self.port = port + else: + env_port = check_env.check("MYSQL_PORT", None, optional=True) + self.port = int(env_port) if env_port is not None else 3306 @contextmanager def connection(self): diff --git a/parsons/databases/postgres/postgres.py b/parsons/databases/postgres/postgres.py index 9df9908ef8b..dcee24fe66d 100644 --- a/parsons/databases/postgres/postgres.py +++ b/parsons/databases/postgres/postgres.py @@ -1,5 +1,4 @@ import logging -import os from pathlib import Path from typing import Literal @@ -8,6 +7,7 @@ from parsons.databases.postgres.postgres_core import PostgresCore from parsons.databases.table import BaseTable from parsons.etl.table import Table +from parsons.utilities import check_env logger = logging.getLogger(__name__) @@ -27,20 +27,25 @@ class Postgres(PostgresCore, Alchemy, DatabaseConnector): db: str Required if env variable ``PGDATABASE`` not populated port: int - Required if env variable ``PGPORT`` not populated. + If omitted or ``None``, uses ``PGPORT`` when set, otherwise 5432. If passed + (including ``5432``), the argument takes precedence over ``PGPORT``. timeout: int Seconds to timeout if connection not established. """ - def __init__(self, username=None, password=None, host=None, db=None, port=5432, timeout=10): + def __init__(self, username=None, password=None, host=None, db=None, port=None, timeout=10): super().__init__() - self.username = username or os.environ.get("PGUSER") - self.password = password or os.environ.get("PGPASSWORD") - self.host = host or os.environ.get("PGHOST") - self.db = db or os.environ.get("PGDATABASE") - self.port = port or os.environ.get("PGPORT") + self.username = check_env.check("PGUSER", username, optional=True) + self.password = check_env.check("PGPASSWORD", password, optional=True) + self.host = check_env.check("PGHOST", host, optional=True) + self.db = check_env.check("PGDATABASE", db, optional=True) + if port is not None: + self.port = port + else: + env_port = check_env.check("PGPORT", None, optional=True) + self.port = int(env_port) if env_port is not None else 5432 # Check if there is a pgpass file. Psycopg2 will search for this file first when # creating a connection. diff --git a/parsons/databases/redshift/redshift.py b/parsons/databases/redshift/redshift.py index 2a295d1a19c..90ab277b65a 100644 --- a/parsons/databases/redshift/redshift.py +++ b/parsons/databases/redshift/redshift.py @@ -1,7 +1,6 @@ import datetime import json import logging -import os import pickle import random from contextlib import contextmanager @@ -20,7 +19,7 @@ from parsons.databases.redshift.rs_table_utilities import RedshiftTableUtilities from parsons.databases.table import BaseTable from parsons.etl.table import Table -from parsons.utilities import files, sql_helpers +from parsons.utilities import check_env, files, sql_helpers # Max number of rows that we query at a time, so we can avoid loading huge # data sets into memory. @@ -92,19 +91,15 @@ def __init__( ): super().__init__() - try: - self.username = username or os.environ["REDSHIFT_USERNAME"] - self.password = password or os.environ["REDSHIFT_PASSWORD"] - self.host = host or os.environ["REDSHIFT_HOST"] - self.db = db or os.environ["REDSHIFT_DB"] - self.port = port or os.environ["REDSHIFT_PORT"] - except KeyError as error: - logger.error("Connection info missing. Most include as kwarg or env variable.") - raise error + self.username = check_env.check("REDSHIFT_USERNAME", username) + self.password = check_env.check("REDSHIFT_PASSWORD", password) + self.host = check_env.check("REDSHIFT_HOST", host) + self.db = check_env.check("REDSHIFT_DB", db) + self.port = check_env.check("REDSHIFT_PORT", port) self.timeout = timeout self.dialect = "redshift" - self.s3_temp_bucket = s3_temp_bucket or os.environ.get("S3_TEMP_BUCKET") + self.s3_temp_bucket = check_env.check("S3_TEMP_BUCKET", s3_temp_bucket, optional=True) # Set prefix for temp S3 bucket paths that include subfolders self.s3_temp_bucket_prefix = None if self.s3_temp_bucket and "/" in self.s3_temp_bucket: diff --git a/parsons/facebook_ads/facebook_ads.py b/parsons/facebook_ads/facebook_ads.py index d4fc31ca5ce..407bb4bcd2a 100644 --- a/parsons/facebook_ads/facebook_ads.py +++ b/parsons/facebook_ads/facebook_ads.py @@ -10,6 +10,7 @@ from joblib import Parallel, delayed from parsons.etl.table import Table +from parsons.utilities import check_env logger = logging.getLogger(__name__) @@ -79,16 +80,10 @@ class FacebookAds: } def __init__(self, app_id=None, app_secret=None, access_token=None, ad_account_id=None): - try: - self.app_id = app_id or os.environ["FB_APP_ID"] - self.app_secret = app_secret or os.environ["FB_APP_SECRET"] - self.access_token = access_token or os.environ["FB_ACCESS_TOKEN"] - self.ad_account_id = ad_account_id or os.environ["FB_AD_ACCOUNT_ID"] - except KeyError as error: - logger.error( - "FB Marketing API credentials missing. Must be specified as env vars or kwargs" - ) - raise error + self.app_id = check_env.check("FB_APP_ID", app_id) + self.app_secret = check_env.check("FB_APP_SECRET", app_secret) + self.access_token = check_env.check("FB_ACCESS_TOKEN", access_token) + self.ad_account_id = check_env.check("FB_AD_ACCOUNT_ID", ad_account_id) FacebookAdsApi.init(self.app_id, self.app_secret, self.access_token) self.ad_account = AdAccount(f"act_{self.ad_account_id}") diff --git a/parsons/mobilize_america/ma.py b/parsons/mobilize_america/ma.py index 7280c1c368f..e6d3bdc827d 100644 --- a/parsons/mobilize_america/ma.py +++ b/parsons/mobilize_america/ma.py @@ -1,12 +1,12 @@ import collections.abc import logging -import os import re import petl from requests import request as _request from parsons.etl.table import Table +from parsons.utilities import check_env from parsons.utilities.datetime import date_to_timestamp logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class MobilizeAmerica: def __init__(self, api_key=None): self.uri = MA_URI - self.api_key = api_key or os.environ.get("MOBILIZE_AMERICA_API_KEY") + self.api_key = check_env.check("MOBILIZE_AMERICA_API_KEY", api_key, optional=True) if not self.api_key: logger.info( diff --git a/parsons/notifications/slack.py b/parsons/notifications/slack.py index fb25ca15cda..dd9690fa2fe 100644 --- a/parsons/notifications/slack.py +++ b/parsons/notifications/slack.py @@ -1,4 +1,3 @@ -import os import warnings from pathlib import Path @@ -12,18 +11,7 @@ class Slack: def __init__(self, api_key=None): - if api_key is None: - try: - self.api_key = os.environ["SLACK_API_TOKEN"] - - except KeyError as e: - raise KeyError( - "Missing api_key. It must be passed as an " - "argument or stored as environmental variable" - ) from e - - else: - self.api_key = api_key + self.api_key = check("SLACK_API_TOKEN", api_key) # Create client with built-in rate limit handler rate_limit_handler = RateLimitErrorRetryHandler(max_retry_count=1) diff --git a/parsons/salesforce/salesforce.py b/parsons/salesforce/salesforce.py index 584d7e012bd..13ec67e048b 100644 --- a/parsons/salesforce/salesforce.py +++ b/parsons/salesforce/salesforce.py @@ -1,6 +1,5 @@ import json import logging -import os from simple_salesforce import Salesforce as _Salesforce @@ -54,12 +53,12 @@ def __init__( domain=None, authentication_method=None, ): - if authentication_method: - self.authentication_method = authentication_method - elif env_authentication_method := os.environ.get("SALESFORCE_AUTHENTICATION_METHOD"): - self.authentication_method = env_authentication_method - else: - self.authentication_method = "password" + self.authentication_method = ( + check_env.check( + "SALESFORCE_AUTHENTICATION_METHOD", authentication_method, optional=True + ) + or "password" + ) if self.authentication_method == "password": self.username = check_env.check("SALESFORCE_USERNAME", username) diff --git a/test/test_mysql/test_mysql.py b/test/test_mysql/test_mysql.py index 3ae533a51e0..50089a09de3 100644 --- a/test/test_mysql/test_mysql.py +++ b/test/test_mysql/test_mysql.py @@ -1,4 +1,6 @@ +import os import unittest +import unittest.mock as mock import pytest @@ -6,6 +8,50 @@ from parsons.databases.mysql.create_table import MySQLCreateTable from test.conftest import assert_matching_tables +_MYSQL_CONN_KWARGS = { + "host": "test-host", + "username": "test-user", + "password": "test-pass", + "db": "test-db", +} + + +class TestMySQLPortPrecedence(unittest.TestCase): + """ + Port resolution for MySQL: + + - ``MYSQL_PORT`` is used when ``port`` is omitted (or ``port=None``). + - An explicit ``port=`` argument wins over ``MYSQL_PORT`` (including ``port=3306``). + """ + + def _env_without_mysql_port(self): + return {key: value for key, value in os.environ.items() if key != "MYSQL_PORT"} + + def test_mysql_port_env_overrides_default_port(self): + with mock.patch.dict(os.environ, {"MYSQL_PORT": "3307"}, clear=False): + mysql = MySQL(**_MYSQL_CONN_KWARGS) + assert mysql.port == 3307 + + def test_default_port_when_mysql_port_unset(self): + with mock.patch.dict(os.environ, self._env_without_mysql_port(), clear=True): + mysql = MySQL(**_MYSQL_CONN_KWARGS) + assert mysql.port == 3306 + + def test_explicit_port_kwarg_when_mysql_port_unset(self): + with mock.patch.dict(os.environ, self._env_without_mysql_port(), clear=True): + mysql = MySQL(**_MYSQL_CONN_KWARGS, port=8888) + assert mysql.port == 8888 + + def test_explicit_port_kwarg_overrides_mysql_port_env(self): + with mock.patch.dict(os.environ, {"MYSQL_PORT": "3307"}, clear=False): + mysql = MySQL(**_MYSQL_CONN_KWARGS, port=8888) + assert mysql.port == 8888 + + def test_explicit_default_port_kwarg_ignores_mysql_port_env(self): + with mock.patch.dict(os.environ, {"MYSQL_PORT": "3307"}, clear=False): + mysql = MySQL(**_MYSQL_CONN_KWARGS, port=3306) + assert mysql.port == 3306 + # These tests interact directly with the MySQL database. To run, set env variable "LIVE_TEST=True" @pytest.mark.live diff --git a/test/test_postgres/test_postgres.py b/test/test_postgres/test_postgres.py index 5f7dfd246d8..7883a849753 100644 --- a/test/test_postgres/test_postgres.py +++ b/test/test_postgres/test_postgres.py @@ -1,5 +1,6 @@ import os import unittest +import unittest.mock as mock import pytest @@ -146,6 +147,51 @@ def test_create_statement(self): self.pg.create_statement(empty_table, "tmc.test") +_POSTGRES_CONN_KWARGS = { + "username": "test-user", + "password": "test-pass", + "host": "test-host", + "db": "test-db", +} + + +class TestPostgresPortPrecedence(unittest.TestCase): + """ + Port resolution for Postgres: + + - ``PGPORT`` is used when ``port`` is omitted (or ``port=None``). + - An explicit ``port=`` argument wins over ``PGPORT`` (including ``port=5432``). + """ + + def _env_without_pgport(self): + return {key: value for key, value in os.environ.items() if key != "PGPORT"} + + def test_pgport_env_overrides_default_port(self): + with mock.patch.dict(os.environ, {"PGPORT": "5433"}, clear=False): + pg = Postgres(**_POSTGRES_CONN_KWARGS) + assert pg.port == 5433 + + def test_default_port_when_pgport_unset(self): + with mock.patch.dict(os.environ, self._env_without_pgport(), clear=True): + pg = Postgres(**_POSTGRES_CONN_KWARGS) + assert pg.port == 5432 + + def test_explicit_port_kwarg_when_pgport_unset(self): + with mock.patch.dict(os.environ, self._env_without_pgport(), clear=True): + pg = Postgres(**_POSTGRES_CONN_KWARGS, port=9999) + assert pg.port == 9999 + + def test_explicit_port_kwarg_overrides_pgport_env(self): + with mock.patch.dict(os.environ, {"PGPORT": "5433"}, clear=False): + pg = Postgres(**_POSTGRES_CONN_KWARGS, port=9999) + assert pg.port == 9999 + + def test_explicit_default_port_kwarg_ignores_pgport_env(self): + with mock.patch.dict(os.environ, {"PGPORT": "5433"}, clear=False): + pg = Postgres(**_POSTGRES_CONN_KWARGS, port=5432) + assert pg.port == 5432 + + # These tests interact directly with the Postgres database