Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions parsons/aws/lambda_distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from parsons.aws.s3 import S3
from parsons.etl.table import Table
from parsons.utilities.check_env import check
from parsons.utilities import check_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -203,7 +203,7 @@ def distribute_task(
"""
if storage not in ("s3", "local"):
raise DistributeTaskException("storage argument must be s3 or local")
bucket = check("S3_TEMP_BUCKET", bucket)
bucket = check_env.check("S3_TEMP_BUCKET", bucket)
csvdata = StringIO()
outcsv = csv.writer(csvdata)
outcsv.writerows(table.table.data())
Expand Down
10 changes: 5 additions & 5 deletions parsons/braintree/braintree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import braintree

from parsons.etl.table import Table
from parsons.utilities.check_env import check as check_env
from parsons.utilities import check_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,10 +184,10 @@ def __init__(
timeout=None,
production=True,
):
merchant_id = check_env("BRAINTREE_MERCHANT_ID", merchant_id)
public_key = check_env("BRAINTREE_PUBLIC_KEY", public_key)
private_key = check_env("BRAINTREE_PRIVATE_KEY", private_key)
timeout = check_env("BRAINTREE_TIMEOUT", timeout, optional=True) or 200
merchant_id = check_env.check("BRAINTREE_MERCHANT_ID", merchant_id)
public_key = check_env.check("BRAINTREE_PUBLIC_KEY", public_key)
private_key = check_env.check("BRAINTREE_PRIVATE_KEY", private_key)
timeout = check_env.check("BRAINTREE_TIMEOUT", timeout, optional=True) or 200

self.gateway = braintree.BraintreeGateway(
braintree.Configuration(
Expand Down
4 changes: 2 additions & 2 deletions parsons/notifications/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from slack_sdk.http_retry.builtin_handlers import RateLimitErrorRetryHandler

from parsons.etl.table import Table
from parsons.utilities.check_env import check
from parsons.utilities import check_env


class Slack:
Expand Down Expand Up @@ -122,7 +122,7 @@ def message(cls, channel, text, webhook=None, parent_message_id=None):
The `ts` value of the parent message. If used, this will thread the message.

"""
webhook = check("SLACK_API_WEBHOOK", webhook, optional=True)
webhook = check_env.check("SLACK_API_WEBHOOK", webhook, optional=True)
payload = {"channel": channel, "text": text}
if parent_message_id:
payload["thread_ts"] = parent_message_id
Expand Down
18 changes: 9 additions & 9 deletions parsons/redash/redash.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests

from parsons.etl.table import Table
from parsons.utilities.check_env import check
from parsons.utilities import check_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,10 +45,10 @@ def __init__(
timeout=0, # never timeout
verify=True,
):
self.base_url = check("REDASH_BASE_URL", base_url)
self.user_api_key = check("REDASH_USER_API_KEY", user_api_key, optional=True)
self.pause = int(check("REDASH_PAUSE_TIME", str(pause_time), optional=True))
self.timeout = int(check("REDASH_TIMEOUT", str(timeout), optional=True))
self.base_url = check_env.check("REDASH_BASE_URL", base_url)
self.user_api_key = check_env.check("REDASH_USER_API_KEY", user_api_key, optional=True)
self.pause = int(check_env.check("REDASH_PAUSE_TIME", pause_time))
self.timeout = int(check_env.check("REDASH_TIMEOUT", timeout))

self.verify = verify # for https requests
self.session = requests.Session()
Expand Down Expand Up @@ -164,8 +164,8 @@ def get_fresh_query_results(self, query_id=None, params=None):
Table Class

"""
query_id = check("REDASH_QUERY_ID", query_id, optional=True)
params_from_env = check("REDASH_QUERY_PARAMS", "", optional=True)
query_id = check_env.check("REDASH_QUERY_ID", query_id, optional=True)
params_from_env = check_env.check("REDASH_QUERY_PARAMS", "", optional=True)
redash_params = (
{f"p_{k}": str(v).replace("'", "''") for k, v in params.items()} if params else {}
)
Expand Down Expand Up @@ -210,8 +210,8 @@ def get_cached_query_results(self, query_id=None, query_api_key=None):
Table Class

"""
query_id = check("REDASH_QUERY_ID", query_id)
query_api_key = check("REDASH_QUERY_API_KEY", query_api_key, optional=True)
query_id = check_env.check("REDASH_QUERY_ID", query_id)
query_api_key = check_env.check("REDASH_QUERY_API_KEY", query_api_key, optional=True)
params = {}
if not self.user_api_key and query_api_key:
params["api_key"] = query_api_key
Expand Down
114 changes: 101 additions & 13 deletions parsons/utilities/check_env.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,106 @@
import os
import warnings
from typing import Literal, TypeVar, overload

T = TypeVar("T")

def check(env: str, field: str | None, optional: bool | None = False) -> str | None:

@overload
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these @overload decorators doing?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overload is a type hinting feature.
It tells static type checkers that if you provide a certain type as an input you get a certain type as an output.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Why are there three of them? It feels more verbose than is ideal (but maybe if we can use a comment to signal to the reader that these are type hints with no impact on the code itself, the verbosity is less of an issue)

Copy link
Copy Markdown
Collaborator Author

@bmos bmos May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just how the syntax works, unfortunately.
That's why overload is only used situationally. I think it's worth it here since check is used in a ton of connectors so having accurate typing is extra useful.

Copy link
Copy Markdown
Collaborator Author

@bmos bmos May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to chat through how this works sometime, I just keep being busy on Thursdays 😅
It's a very helpful feature.

Copy link
Copy Markdown
Collaborator

@shaunagm shaunagm May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've got some other PRs I could review too so yeah let's schedule a call.

def check(
env: str,
value: T,
opt: bool | None = ...,
*,
optional: bool = ...,
field: T | None = ...,
) -> T: ...


@overload
def check(
env: str,
value: None = None,
opt: bool | None = ...,
*,
optional: Literal[True],
field: None = None,
) -> str | None: ...


@overload
def check(
env: str,
value: None = None,
opt: bool | None = ...,
*,
optional: Literal[False] = False,
field: None = None,
) -> str: ...


def check(
env: str,
value: T | None = None,
opt: bool | None = None,
*,
optional: bool = False,
field: T | None = None,
) -> T | str | None:
"""
Check if an environment variable has been set. If it has not been set
and the passed field or arguments have not been passed, then raise an
error.
Check if an environment variable has been set or value has been provided.

Args:
env:
Name of environment variable to check.
value:
If provided, ignore environment variable and return this.
opt:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the opt parameter coming from? I get that it's being listed as deprecated, but I'm not seeing where it's used at all.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional (the former 3rd positional argument) is now a keyword-only argument. So I added "opt" as the 3rd positional argument with the deprecation warning. That way it's not a breaking change if people try to provide the optional parameter as a positional argument.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. What's your reasoning on moving it from a 3rd positional argument to a keyword-only argument?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deprecated; use `optional` instead.
If ``True``, return ``None`` if no value is found instead of raising ``KeyError``.

Keyword Args:
optional:
If ``True``, return ``None`` if no value is found instead of raising ``KeyError``.
field:
Deprecated; use `value` instead.
If provided, ignore environment variable and return this.

Returns:
The value of the requested environment variable (str) or the provided value (T).
If called with ``optional=True``, will return ``None`` if no value is found or provided.

Raises:
KeyError: If no value is found/provided and `optional` is ``False``.

"""
if field:
return field
try:
return os.environ[env]
except KeyError as e:
if not optional:
raise KeyError(
f"No {env} found. Store as environment variable or pass as an argument."
) from e
# Handle deprecated arguments
if opt is not None:
warnings.warn(
"The 'opt' positional argument is deprecated. "
"Use the 'optional' keyword argument. "
"Overriding 'optional' with value of 'opt' for backwards-compatibility.",
DeprecationWarning,
stacklevel=2,
)
optional = opt

if field is not None:
warnings.warn(
"The 'field' keyword argument is deprecated. "
"Use the 'value' positional or keyword argument. "
"Overriding 'value' with value of 'field' for backwards-compatibility.",
DeprecationWarning,
stacklevel=2,
)
value = field

if value is not None:
return value

if (environment_variable := os.environ.get(env)) is not None:
return environment_variable

if optional:
return None

raise KeyError(f"No '{env}' found. Store as environment variable or pass as an argument.")
56 changes: 34 additions & 22 deletions test/test_utilities/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import os
import unittest
from pathlib import Path
from unittest import mock

Expand Down Expand Up @@ -142,25 +141,38 @@ def test_redact_credentials():
assert sql_helpers.redact_credentials(test_str) == test_result


class TestCheckEnv(unittest.TestCase):
def test_environment_field(self):
"""Test check field"""
result = check_env.check("PARAM", "param")
assert result == "param"

@mock.patch.dict(os.environ, {"PARAM": "env_param"})
def test_environment_env(self):
"""Test check env"""
result = check_env.check("PARAM", None)
assert result == "env_param"

@mock.patch.dict(os.environ, {"PARAM": "env_param"})
def test_environment_field_env(self):
"""Test check field with env and field"""
result = check_env.check("PARAM", "param")
assert result == "param"

def test_envrionment_error(self):
"""Test check env raises error"""
with pytest.raises(KeyError):
class TestCheckEnv:
@pytest.mark.parametrize(
("environment_value", "input_value", "expected_result"),
[
({}, "param", "param"),
({"PARAM": "env_param"}, None, "env_param"),
({"PARAM": "env_param"}, "param", "param"),
],
ids=["test_environment_field", "test_environment_env", "test_environment_field_env"],
)
def test_check_env_success(
self, environment_value: dict, input_value: str | None, expected_result: str
):
"""Tests successful retrieval of parameters from field or environment."""
with mock.patch.dict(os.environ, environment_value):
result = check_env.check("PARAM", input_value)
assert result == expected_result

def test_check_env_returns_none(self):
"""Tests returning None when environment and value are empty and optional value is passed."""
with mock.patch.dict(os.environ, {}, clear=True):
result = check_env.check("PARAM", None, optional=True)
assert result is None

def test_environment_error(self):
"""Test check env raises error when both are missing."""
# We ensure the environment is empty for this key to trigger the KeyError
with (
mock.patch.dict(os.environ, {}, clear=True),
pytest.raises(
KeyError,
match="No 'PARAM' found. Store as environment variable or pass as an argument.",
),
):
check_env.check("PARAM", None)
Loading