Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions parsons/aws/s3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
import os
import re

import boto3
from botocore.client import ClientError

from parsons.utilities import files
from parsons.utilities import check_env, files

logger = logging.getLogger(__name__)

Expand All @@ -31,7 +30,7 @@ def __init__(
# whenever the aws_access_key_id and aws_secret_access_key are passed.

if aws_session_token is None and use_env_token:
aws_session_token = os.getenv("AWS_SESSION_TOKEN")
aws_session_token = check_env.check("AWS_SESSION_TOKEN", None, optional=True)

self.session = boto3.Session(
aws_access_key_id=aws_access_key_id,
Expand Down
3 changes: 1 addition & 2 deletions parsons/azure/azure_blob_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
from typing import Literal
from urllib.parse import urlparse
Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(
account_domain="blob.core.windows.net",
account_url=None,
):
self.account_url = os.getenv("AZURE_ACCOUNT_URL", account_url)
self.account_url = check_env.check("AZURE_ACCOUNT_URL", account_url, optional=True)
self.credential = check_env.check("AZURE_CREDENTIAL", credential)
if not self.account_url:
self.account_name = check_env.check("AZURE_ACCOUNT_NAME", account_name)
Expand Down
4 changes: 2 additions & 2 deletions parsons/box/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import logging
import os
import tempfile
from pathlib import Path
from typing import Literal
Expand All @@ -31,6 +30,7 @@
from box_sdk_gen.networking.auth import Authentication

from parsons.etl.table import Table
from parsons.utilities import check_env
from parsons.utilities.files import create_temp_file, create_temp_file_for_path

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,7 +65,7 @@ class Box:

def __init__(self, auth: Authentication | None = None) -> None:
if auth is None:
access_token = os.environ["BOX_ACCESS_TOKEN"]
access_token = check_env.check("BOX_ACCESS_TOKEN", None)
oauth = BoxDeveloperTokenAuth(token=access_token)
self.client = BoxClient(auth=oauth)
else:
Expand Down
12 changes: 8 additions & 4 deletions parsons/databases/mysql/mysql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
import pickle
from contextlib import contextmanager
from pathlib import Path
Expand Down Expand Up @@ -37,18 +36,23 @@ class MySQL(DatabaseConnector, MySQLCreateTable, Alchemy):
db: str
Required if env variable ``MYSQL_DB`` not populated
port: int
Can be set by env variable ``MYSQL_PORT`` or argument.
If omitted or ``None``, uses ``MYSQL_PORT`` when set, otherwise 3306. If passed
(including ``3306``), the argument takes precedence over ``MYSQL_PORT``.

"""

def __init__(self, host=None, username=None, password=None, db=None, port=3306):
def __init__(self, host=None, username=None, password=None, db=None, port=None):
super().__init__()

self.username = check_env.check("MYSQL_USERNAME", username)
self.password = check_env.check("MYSQL_PASSWORD", password)
self.host = check_env.check("MYSQL_HOST", host)
self.db = check_env.check("MYSQL_DB", db)
self.port = port or os.environ.get("MYSQL_PORT")
if port is not None:
self.port = port
else:
env_port = check_env.check("MYSQL_PORT", None, optional=True)
self.port = int(env_port) if env_port is not None else 3306

@contextmanager
def connection(self):
Expand Down
21 changes: 13 additions & 8 deletions parsons/databases/postgres/postgres.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
from typing import Literal

Expand All @@ -8,6 +7,7 @@
from parsons.databases.postgres.postgres_core import PostgresCore
from parsons.databases.table import BaseTable
from parsons.etl.table import Table
from parsons.utilities import check_env

logger = logging.getLogger(__name__)

Expand All @@ -27,20 +27,25 @@ class Postgres(PostgresCore, Alchemy, DatabaseConnector):
db: str
Required if env variable ``PGDATABASE`` not populated
port: int
Required if env variable ``PGPORT`` not populated.
If omitted or ``None``, uses ``PGPORT`` when set, otherwise 5432. If passed
(including ``5432``), the argument takes precedence over ``PGPORT``.
timeout: int
Seconds to timeout if connection not established.

"""

def __init__(self, username=None, password=None, host=None, db=None, port=5432, timeout=10):
def __init__(self, username=None, password=None, host=None, db=None, port=None, timeout=10):
super().__init__()

self.username = username or os.environ.get("PGUSER")
self.password = password or os.environ.get("PGPASSWORD")
self.host = host or os.environ.get("PGHOST")
self.db = db or os.environ.get("PGDATABASE")
self.port = port or os.environ.get("PGPORT")
self.username = check_env.check("PGUSER", username, optional=True)
self.password = check_env.check("PGPASSWORD", password, optional=True)
self.host = check_env.check("PGHOST", host, optional=True)
self.db = check_env.check("PGDATABASE", db, optional=True)
if port is not None:
self.port = port
else:
env_port = check_env.check("PGPORT", None, optional=True)
self.port = int(env_port) if env_port is not None else 5432

# Check if there is a pgpass file. Psycopg2 will search for this file first when
# creating a connection.
Expand Down
19 changes: 7 additions & 12 deletions parsons/databases/redshift/redshift.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import json
import logging
import os
import pickle
import random
from contextlib import contextmanager
Expand All @@ -20,7 +19,7 @@
from parsons.databases.redshift.rs_table_utilities import RedshiftTableUtilities
from parsons.databases.table import BaseTable
from parsons.etl.table import Table
from parsons.utilities import files, sql_helpers
from parsons.utilities import check_env, files, sql_helpers

# Max number of rows that we query at a time, so we can avoid loading huge
# data sets into memory.
Expand Down Expand Up @@ -92,19 +91,15 @@ def __init__(
):
super().__init__()

try:
self.username = username or os.environ["REDSHIFT_USERNAME"]
self.password = password or os.environ["REDSHIFT_PASSWORD"]
self.host = host or os.environ["REDSHIFT_HOST"]
self.db = db or os.environ["REDSHIFT_DB"]
self.port = port or os.environ["REDSHIFT_PORT"]
except KeyError as error:
logger.error("Connection info missing. Most include as kwarg or env variable.")
raise error
self.username = check_env.check("REDSHIFT_USERNAME", username)
self.password = check_env.check("REDSHIFT_PASSWORD", password)
self.host = check_env.check("REDSHIFT_HOST", host)
self.db = check_env.check("REDSHIFT_DB", db)
self.port = check_env.check("REDSHIFT_PORT", port)

self.timeout = timeout
self.dialect = "redshift"
self.s3_temp_bucket = s3_temp_bucket or os.environ.get("S3_TEMP_BUCKET")
self.s3_temp_bucket = check_env.check("S3_TEMP_BUCKET", s3_temp_bucket, optional=True)
# Set prefix for temp S3 bucket paths that include subfolders
self.s3_temp_bucket_prefix = None
if self.s3_temp_bucket and "/" in self.s3_temp_bucket:
Expand Down
15 changes: 5 additions & 10 deletions parsons/facebook_ads/facebook_ads.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from joblib import Parallel, delayed

from parsons.etl.table import Table
from parsons.utilities import check_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,16 +80,10 @@ class FacebookAds:
}

def __init__(self, app_id=None, app_secret=None, access_token=None, ad_account_id=None):
try:
self.app_id = app_id or os.environ["FB_APP_ID"]
self.app_secret = app_secret or os.environ["FB_APP_SECRET"]
self.access_token = access_token or os.environ["FB_ACCESS_TOKEN"]
self.ad_account_id = ad_account_id or os.environ["FB_AD_ACCOUNT_ID"]
except KeyError as error:
logger.error(
"FB Marketing API credentials missing. Must be specified as env vars or kwargs"
)
raise error
self.app_id = check_env.check("FB_APP_ID", app_id)
self.app_secret = check_env.check("FB_APP_SECRET", app_secret)
self.access_token = check_env.check("FB_ACCESS_TOKEN", access_token)
self.ad_account_id = check_env.check("FB_AD_ACCOUNT_ID", ad_account_id)

FacebookAdsApi.init(self.app_id, self.app_secret, self.access_token)
self.ad_account = AdAccount(f"act_{self.ad_account_id}")
Expand Down
4 changes: 2 additions & 2 deletions parsons/mobilize_america/ma.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import collections.abc
import logging
import os
import re

import petl
from requests import request as _request

from parsons.etl.table import Table
from parsons.utilities import check_env
from parsons.utilities.datetime import date_to_timestamp

logger = logging.getLogger(__name__)
Expand All @@ -28,7 +28,7 @@ class MobilizeAmerica:

def __init__(self, api_key=None):
self.uri = MA_URI
self.api_key = api_key or os.environ.get("MOBILIZE_AMERICA_API_KEY")
self.api_key = check_env.check("MOBILIZE_AMERICA_API_KEY", api_key, optional=True)

if not self.api_key:
logger.info(
Expand Down
14 changes: 1 addition & 13 deletions parsons/notifications/slack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import warnings
from pathlib import Path

Expand All @@ -12,18 +11,7 @@

class Slack:
def __init__(self, api_key=None):
if api_key is None:
try:
self.api_key = os.environ["SLACK_API_TOKEN"]

except KeyError as e:
raise KeyError(
"Missing api_key. It must be passed as an "
"argument or stored as environmental variable"
) from e

else:
self.api_key = api_key
self.api_key = check("SLACK_API_TOKEN", api_key)

# Create client with built-in rate limit handler
rate_limit_handler = RateLimitErrorRetryHandler(max_retry_count=1)
Expand Down
13 changes: 6 additions & 7 deletions parsons/salesforce/salesforce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import os

from simple_salesforce import Salesforce as _Salesforce

Expand Down Expand Up @@ -54,12 +53,12 @@ def __init__(
domain=None,
authentication_method=None,
):
if authentication_method:
self.authentication_method = authentication_method
elif env_authentication_method := os.environ.get("SALESFORCE_AUTHENTICATION_METHOD"):
self.authentication_method = env_authentication_method
else:
self.authentication_method = "password"
self.authentication_method = (
check_env.check(
"SALESFORCE_AUTHENTICATION_METHOD", authentication_method, optional=True
)
or "password"
)

if self.authentication_method == "password":
self.username = check_env.check("SALESFORCE_USERNAME", username)
Expand Down
46 changes: 46 additions & 0 deletions test/test_mysql/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
import os
import unittest
import unittest.mock as mock

import pytest

from parsons import MySQL, Table
from parsons.databases.mysql.create_table import MySQLCreateTable
from test.conftest import assert_matching_tables

_MYSQL_CONN_KWARGS = {
"host": "test-host",
"username": "test-user",
"password": "test-pass",
"db": "test-db",
}


class TestMySQLPortPrecedence(unittest.TestCase):
"""
Port resolution for MySQL:

- ``MYSQL_PORT`` is used when ``port`` is omitted (or ``port=None``).
- An explicit ``port=`` argument wins over ``MYSQL_PORT`` (including ``port=3306``).
"""

def _env_without_mysql_port(self):
return {key: value for key, value in os.environ.items() if key != "MYSQL_PORT"}

def test_mysql_port_env_overrides_default_port(self):
with mock.patch.dict(os.environ, {"MYSQL_PORT": "3307"}, clear=False):
mysql = MySQL(**_MYSQL_CONN_KWARGS)
assert mysql.port == 3307

def test_default_port_when_mysql_port_unset(self):
with mock.patch.dict(os.environ, self._env_without_mysql_port(), clear=True):
mysql = MySQL(**_MYSQL_CONN_KWARGS)
assert mysql.port == 3306

def test_explicit_port_kwarg_when_mysql_port_unset(self):
with mock.patch.dict(os.environ, self._env_without_mysql_port(), clear=True):
mysql = MySQL(**_MYSQL_CONN_KWARGS, port=8888)
assert mysql.port == 8888

def test_explicit_port_kwarg_overrides_mysql_port_env(self):
with mock.patch.dict(os.environ, {"MYSQL_PORT": "3307"}, clear=False):
mysql = MySQL(**_MYSQL_CONN_KWARGS, port=8888)
assert mysql.port == 8888

def test_explicit_default_port_kwarg_ignores_mysql_port_env(self):
with mock.patch.dict(os.environ, {"MYSQL_PORT": "3307"}, clear=False):
mysql = MySQL(**_MYSQL_CONN_KWARGS, port=3306)
assert mysql.port == 3306


# These tests interact directly with the MySQL database. To run, set env variable "LIVE_TEST=True"
@pytest.mark.live
Expand Down
Loading
Loading