Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http.http_client import MessageRepresentationAirbyteTracedErrors
from airbyte_cdk.sources.streams.http.requests_native_auth import MultipleTokenAuthenticator
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from source_github.utils import MultipleTokenAuthenticatorWithRateLimiter
Expand Down Expand Up @@ -184,7 +185,7 @@ def user_friendly_error_message(self, message: str) -> str:
# 404 Client Error: Not Found for url: https://api.github.com/orgs/airbytehqBLA/repos?per_page=100
org_name = message.split("https://api.github.com/orgs/")[1].split("/")[0]
user_message = f'Organization name: "{org_name}" is unknown, "repository" config option should be updated. Please validate your repository config.'
elif "401 Client Error: Unauthorized for url" in message:
elif "401 Client Error: Unauthorized for url" in message or "401. Error: Unauthorized" in message:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this condition be more general, for example, like 'Error: Unauthorized' in message and '401' in message. So we don't need to update this check every time we get new format of the error message?

# 401 Client Error: Unauthorized for url: https://api.github.com/orgs/datarootsio/repos?per_page=100&sort=updated&direction=desc
user_message = (
"Github credentials have expired or changed, please review your credentials and re-authenticate or renew your access token."
Expand All @@ -203,6 +204,9 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) ->
)
return True, None

except MessageRepresentationAirbyteTracedErrors as e:
user_message = self.user_friendly_error_message(e.message)
return False, user_message or e.message
except Exception as e:
message = repr(e)
user_message = self.user_friendly_error_message(message)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
import time
from dataclasses import dataclass
from datetime import timedelta
Expand All @@ -10,10 +10,12 @@

import requests

from airbyte_cdk.models import SyncMode
from airbyte_cdk.models import FailureType, SyncMode
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http import HttpClient
from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import AbstractHeaderAuthenticator
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse


Expand Down Expand Up @@ -59,14 +61,30 @@ class MultipleTokenAuthenticatorWithRateLimiter(AbstractHeaderAuthenticator):
DURATION = timedelta(seconds=3600) # Duration at which the current rate limit window resets

def __init__(self, tokens: List[str], auth_method: str = "token", auth_header: str = "Authorization"):
self._logger = logging.getLogger("airbyte")
self._auth_method = auth_method
self._auth_header = auth_header
self._tokens = {t: Token() for t in tokens}
# It would've been nice to instantiate a single client on this authenticator. However, we are checking
# the limits of each token which is associated with a TokenAuthenticator. And each HttpClient can only
# correspond to one authenticator.
self._token_to_http_client: Mapping[str, HttpClient] = self._initialize_http_clients(tokens)
self.check_all_tokens()
self._tokens_iter = cycle(self._tokens)
self._active_token = next(self._tokens_iter)
self._max_time = 60 * 10 # 10 minutes as default

def _initialize_http_clients(self, tokens: List[str]) -> Mapping[str, HttpClient]:
return {
token: HttpClient(
name="token_validator",
logger=self._logger,
authenticator=TokenAuthenticator(token, auth_method=self._auth_method),
use_cache=False, # We don't want to reuse cached valued because rate limit values change frequently
)
for token in tokens
}

@property
def auth_header(self) -> str:
return self._auth_header
Expand Down Expand Up @@ -114,14 +132,27 @@ def max_time(self, value: int) -> None:

def _check_token_limits(self, token: str):
"""check that token is not limited"""
headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
rate_limit_info = (
requests.get(
"https://api.github.com/rate_limit", headers=headers, auth=TokenAuthenticator(token, auth_method=self._auth_method)
)
.json()
.get("resources")

http_client = self._token_to_http_client.get(token)
if not http_client:
raise ValueError("No HttpClient was initialized for this token. This is unexpected. Please contact Airbyte support.")

_, response = http_client.send_request(
http_method="GET",
url="https://api.github.com/rate_limit",
headers={"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"},
request_kwargs={},
)

response_body = response.json()
if "resources" not in response_body:
raise AirbyteTracedException(
failure_type=FailureType.config_error,
internal_message=f"Token rate limit info response did not contain expected key: resources",
message="Unable to validate token. Please double check that specified authentication tokens are correct",
)

rate_limit_info = response_body.get("resources")
token_info = self._tokens[token]
remaining_info_core = rate_limit_info.get("core")
token_info.count_rest, token_info.reset_at_rest = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import pytest
import responses
from freezegun import freeze_time
from requests import JSONDecodeError
from source_github import SourceGithub
from source_github.streams import Organizations
from source_github.utils import MultipleTokenAuthenticatorWithRateLimiter, read_full_refresh

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.streams.http.http_client import MessageRepresentationAirbyteTracedErrors
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_helpers import ab_datetime_now

Expand Down Expand Up @@ -148,3 +150,21 @@ def request_callback_orgs(request, context):
list(read_full_refresh(stream))
sleep_mock.assert_called_once_with(ACCEPTED_WAITING_TIME_IN_SECONDS)
assert [(x.count_rest, x.count_graphql) for x in authenticator._tokens.values()] == [(500, 500), (500, 500), (498, 500)]


def test_invalid_credentials_error_message(requests_mock):
"""
Test that validates that invalid or expired credentials are gracefully caught and surfaced back in a way
that the connector can display actionable messages back to users
"""

requests_mock.get(
"https://api.github.com/rate_limit",
status_code=401,
json={"message": "Bad credentials", "documentation_url": "https://docs.github.com/rest", "status": "401"},
)

with pytest.raises(AirbyteTracedException) as e:
MultipleTokenAuthenticatorWithRateLimiter(tokens=["token1", "token2", "token3"])

assert "HTTP Status Code: 401" in e.value.message
Loading