diff --git a/estuary-cdk/estuary_cdk/capture/base_capture_connector.py b/estuary-cdk/estuary_cdk/capture/base_capture_connector.py index e5a267c880..a1a8817679 100644 --- a/estuary-cdk/estuary_cdk/capture/base_capture_connector.py +++ b/estuary-cdk/estuary_cdk/capture/base_capture_connector.py @@ -1,27 +1,31 @@ -from datetime import datetime, timedelta, UTC -from estuary_cdk.capture.connector_status import ConnectorStatus -from pydantic import BaseModel -from typing import Generic, Awaitable, Any, BinaryIO, Callable -from logging import Logger import abc import asyncio import json import sys +from datetime import UTC, datetime, timedelta +from typing import Any, Awaitable, BinaryIO, Callable, Generic + +from pydantic import BaseModel + +from estuary_cdk.capture.connector_status import ConnectorStatus -from . import request, response, Request, Response, Task from .. import BaseConnector, Stopped -from .common import ConnectorState, _ConnectorState from ..flow import ( ConnectorSpec, - ConnectorState as GeneralConnectorState, ConnectorStateUpdate, EndpointConfig, ResourceConfig, RotatingOAuth2Credentials, ) -from ..http import HTTPError, HTTPMixin, TokenSource +from ..flow import ( + ConnectorState as GeneralConnectorState, +) +from ..http import HTTPMixin, TokenSource from ..logger import FlowLogger from ..utils import format_error_message, sort_dict +from . import Request, Response, Task, request, response +from .common import _ConnectorState + class BaseCaptureConnector( BaseConnector[Request[EndpointConfig, ResourceConfig, _ConnectorState]], @@ -99,7 +103,9 @@ async def periodic_stop() -> None: stopping.event.set() # Rotate OAuth2 tokens for credentials with periodically expiring tokens. - if isinstance(self.token_source, TokenSource) and isinstance(self.token_source.credentials, RotatingOAuth2Credentials): + if isinstance(self.token_source, TokenSource) and isinstance( + self.token_source.credentials, RotatingOAuth2Credentials + ): await self._rotate_oauth2_tokens(log, open) # Gracefully exit after a moderate period of time. @@ -124,9 +130,7 @@ async def periodic_stop() -> None: if stopping.first_error: msg = format_error_message(stopping.first_error) - raise Stopped( - f"Task {stopping.first_error_task}: {msg}" - ) + raise Stopped(f"Task {stopping.first_error_task}: {msg}") else: raise Stopped(None) @@ -136,18 +140,16 @@ async def periodic_stop() -> None: else: raise RuntimeError("malformed request", request) - def _emit(self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState]): + def _emit( + self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState] + ): self.output.write( response.model_dump_json(by_alias=True, exclude_unset=True).encode() ) self.output.write(b"\n") self.output.flush() - def _checkpoint( - self, - state: GeneralConnectorState, - merge_patch: bool = True - ): + def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True): r = Response[Any, Any, GeneralConnectorState]( checkpoint=response.Checkpoint( state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch) @@ -168,7 +170,9 @@ async def _encrypt_config( # include=config.model_fields_set ensures only the fields that are explicitly set on the # model. This means any default values that are set are included, but any fields that are unset # & fallback to some default value are left unset. - unencrypted_config = config.model_dump(mode="json", include=config.model_fields_set) + unencrypted_config = config.model_dump( + mode="json", include=config.model_fields_set + ) body = { # Flow always sorts object properties lexicographically. This keeps a consistent property @@ -179,9 +183,11 @@ async def _encrypt_config( "schema": config.model_json_schema(), } - encrypted_config = await self.request(log, ENCRYPTION_URL, "POST", json=body, _with_token = False) + encrypted_config = await self.request( + log, ENCRYPTION_URL, "POST", json=body, _with_token=False + ) - return json.loads(encrypted_config.decode('utf-8')) + return json.loads(encrypted_config.decode("utf-8")) async def _rotate_oauth2_tokens( self, @@ -198,14 +204,20 @@ async def _rotate_oauth2_tokens( if access_token_expiration < now + timedelta(minutes=5): # Exchange for new access token & refresh token. log.info("Attempting to rotate OAuth2 tokens.") - token_exchange_response = await self.token_source.initialize_oauth2_tokens(log, self) + token_exchange_response = await self.token_source.initialize_oauth2_tokens( + log, self + ) # Replace tokens in the config. credentials = self.token_source.credentials credentials.access_token = token_exchange_response.access_token credentials.refresh_token = token_exchange_response.refresh_token - credentials.access_token_expires_at = now + timedelta(seconds=token_exchange_response.expires_in) + credentials.access_token_expires_at = now + timedelta( + seconds=token_exchange_response.expires_in + ) # Encrypt the updated config and emit an event telling the control plane to publish a new spec. encrypted_config = await self._encrypt_config(log, open.capture.config) - log.event.config_update("Rotating OAuth2 tokens in endpoint config.", encrypted_config) + log.event.config_update( + "Rotating OAuth2 tokens in endpoint config.", encrypted_config + ) diff --git a/estuary-cdk/estuary_cdk/capture/common.py b/estuary-cdk/estuary_cdk/capture/common.py index d95bb1fc86..9330e27df3 100644 --- a/estuary-cdk/estuary_cdk/capture/common.py +++ b/estuary-cdk/estuary_cdk/capture/common.py @@ -29,7 +29,6 @@ CaptureBinding, ClientCredentialsOAuth2Credentials, OAuth2TokenFlowSpec, - OAuth2RotatingTokenSpec, AuthorizationCodeFlowOAuth2Credentials, LongLivedClientCredentialsOAuth2Credentials, ResourceOwnerPasswordOAuth2Credentials, diff --git a/estuary-cdk/estuary_cdk/flow.py b/estuary-cdk/estuary_cdk/flow.py index 82931331fd..4940e43067 100644 --- a/estuary-cdk/estuary_cdk/flow.py +++ b/estuary-cdk/estuary_cdk/flow.py @@ -1,8 +1,10 @@ import abc from dataclasses import dataclass from datetime import datetime -from pydantic import BaseModel, NonNegativeInt, PositiveInt, Field, ConfigDict -from typing import Any, Literal, TypeVar, Generic, Literal +from enum import StrEnum, auto +from typing import Any, ClassVar, Generic, Literal, Self, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt from .pydantic_polyfill import GenericModel @@ -12,6 +14,16 @@ "LOCAL", # We're running directly on the host as a local process. ] + +class OAuth2ClientCredentialsPlacement(StrEnum): + """ + Placement of client id and client secret during the OAuth2 token exchange step. + """ + + HEADERS = auto() + FORM = auto() + + # Generic type of a connector's endpoint configuration. EndpointConfig = TypeVar("EndpointConfig") @@ -70,6 +82,7 @@ class Checkpoint(BaseModel): class OAuth2TokenFlowSpec(BaseModel): accessTokenResponseMap: dict[str, str] accessTokenUrlTemplate: str + additionalTokenExchangeBody: dict[str, str | int] = {} class OAuth2Spec(BaseModel): @@ -80,9 +93,9 @@ class OAuth2Spec(BaseModel): accessTokenResponseMap: dict[str, str] accessTokenUrlTemplate: str - -class OAuth2RotatingTokenSpec(OAuth2Spec): - additionalTokenExchangeBody: dict[str, str | int] | None + # additionalTokenExchangeBody pertains to internal connector token exchanges + # and should be excluded from spec responses + additionalTokenExchangeBody: dict[str, str | int] = Field(default={}, exclude=True) class ConnectorSpec(BaseModel): @@ -101,8 +114,7 @@ class ConnectorStateUpdate(GenericModel, Generic[ConnectorState]): class AccessToken(BaseModel): credentials_title: Literal["Private App Credentials"] = Field( - default="Private App Credentials", - json_schema_extra={"type": "string"} + default="Private App Credentials", json_schema_extra={"type": "string"} ) access_token: str = Field( title="Access Token", @@ -112,8 +124,7 @@ class AccessToken(BaseModel): class BasicAuth(BaseModel): credentials_title: Literal["Username & Password"] = Field( - default="Username & Password", - json_schema_extra={"type": "string"} + default="Username & Password", json_schema_extra={"type": "string"} ) username: str password: str = Field( @@ -132,10 +143,17 @@ class ValidationError(Exception): errors: list[str] -class ResourceOwnerPasswordOAuth2Credentials(abc.ABC, BaseModel): +class _BaseOAuth2CredentialsData(BaseModel): + """ + Abstract base class containing common OAuth2 credential fields. + """ + + client_credentials_placement: ClassVar[OAuth2ClientCredentialsPlacement] = ( + OAuth2ClientCredentialsPlacement.FORM + ) + credentials_title: Literal["OAuth Credentials"] = Field( - default="OAuth Credentials", - json_schema_extra={"type": "string"} + default="OAuth Credentials", json_schema_extra={"type": "string"} ) client_id: str = Field( title="Client Id", @@ -146,123 +164,75 @@ class ResourceOwnerPasswordOAuth2Credentials(abc.ABC, BaseModel): json_schema_extra={"secret": True}, ) - -class ClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): # This configuration provides a "title" annotation for the UI to display # instead of the class name. model_config = ConfigDict( title="OAuth", ) - credentials_title: Literal["OAuth Credentials"] = Field( - default="OAuth Credentials", - json_schema_extra={"type": "string"} - ) - client_id: str = Field( - title="Client Id", - json_schema_extra={"secret": True}, - ) - client_secret: str = Field( - title="Client Secret", - json_schema_extra={"secret": True}, - ) - - -class AuthorizationCodeFlowOAuth2Credentials(abc.ABC, BaseModel): - credentials_title: Literal["OAuth Credentials"] = Field( - default="OAuth Credentials", json_schema_extra={"type": "string"} - ) - client_id: str = Field( - title="Client Id", - json_schema_extra={"secret": True}, - ) - client_secret: str = Field( - title="Client Secret", - json_schema_extra={"secret": True}, - ) + @classmethod + def with_client_credentials_placement( + cls, + placement: OAuth2ClientCredentialsPlacement, + ) -> type[Self]: + """ + Returns a subclass with a custom client credentials placement. + """ - @abc.abstractmethod - def _you_must_build_oauth2_credentials_for_a_provider(self): ... + return type( + str(cls.__name__), (cls,), {"client_credentials_placement": placement} + ) # pyright: ignore[reportReturnType] - @staticmethod - def for_provider( - provider: str, - ) -> type["AuthorizationCodeFlowOAuth2Credentials"]: + @classmethod + def for_provider(cls, provider: str) -> type[Self]: """ Builds an OAuth2Credentials model for the given OAuth2 `provider`. This routine is only available in Pydantic V2 environments. """ from pydantic import ConfigDict - class _OAuth2Credentials(AuthorizationCodeFlowOAuth2Credentials): - model_config = ConfigDict( - json_schema_extra={"x-oauth2-provider": provider}, - title="OAuth", - ) + return type( # pyright: ignore[reportReturnType] + cls.__name__, + (cls,), + { + "model_config": ConfigDict( + json_schema_extra={"x-oauth2-provider": provider}, + title="OAuth", + ), + "_you_must_build_oauth2_credentials_for_a_provider": lambda _: None, # pyright: ignore[reportUnknownLambdaType] + }, + ) - def _you_must_build_oauth2_credentials_for_a_provider(self): ... - return _OAuth2Credentials +class ResourceOwnerPasswordOAuth2Credentials(_BaseOAuth2CredentialsData): + grant_type: ClassVar[str] = "password" -class LongLivedClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): - # This configuration provides a "title" annotation for the UI to display - # instead of the class name. - model_config = ConfigDict( - title="OAuth", - ) +class ClientCredentialsOAuth2Credentials(_BaseOAuth2CredentialsData): + grant_type: ClassVar[str] = "client_credentials" + + +class AuthorizationCodeFlowOAuth2Credentials( + _BaseOAuth2CredentialsData, metaclass=abc.ABCMeta +): + grant_type: ClassVar[str] = "authorization_code" - credentials_title: Literal["OAuth Credentials"] = Field( - default="OAuth Credentials", - json_schema_extra={"type": "string"} - ) - client_id: str = Field( - title="Client Id", - json_schema_extra={"secret": True}, - ) - client_secret: str = Field( - title="Client Secret", - json_schema_extra={"secret": True}, - ) - access_token: str = Field( - title="Access Token", - json_schema_extra={"secret": True} - ) @abc.abstractmethod def _you_must_build_oauth2_credentials_for_a_provider(self): ... - @staticmethod - def for_provider(provider: str) -> type["LongLivedClientCredentialsOAuth2Credentials"]: - """ - Builds an OAuth2Credentials model for the given OAuth2 `provider`. - This routine is only available in Pydantic V2 environments. - """ - from pydantic import ConfigDict - class _OAuth2Credentials(LongLivedClientCredentialsOAuth2Credentials): - model_config = ConfigDict( - json_schema_extra={"x-oauth2-provider": provider}, - title="OAuth", - ) +class LongLivedClientCredentialsOAuth2Credentials( + _BaseOAuth2CredentialsData, metaclass=abc.ABCMeta +): + access_token: str = Field(title="Access Token", json_schema_extra={"secret": True}) - def _you_must_build_oauth2_credentials_for_a_provider(self): ... + @abc.abstractmethod + def _you_must_build_oauth2_credentials_for_a_provider(self): ... - return _OAuth2Credentials +class BaseOAuth2Credentials(_BaseOAuth2CredentialsData, metaclass=abc.ABCMeta): + grant_type: ClassVar[str] = "refresh_token" -class BaseOAuth2Credentials(abc.ABC, BaseModel): - credentials_title: Literal["OAuth Credentials"] = Field( - default="OAuth Credentials", - json_schema_extra={"type": "string"} - ) - client_id: str = Field( - title="Client Id", - json_schema_extra={"secret": True}, - ) - client_secret: str = Field( - title="Client Secret", - json_schema_extra={"secret": True}, - ) refresh_token: str = Field( title="Refresh Token", json_schema_extra={"secret": True}, @@ -271,50 +241,12 @@ class BaseOAuth2Credentials(abc.ABC, BaseModel): @abc.abstractmethod def _you_must_build_oauth2_credentials_for_a_provider(self): ... - @staticmethod - def for_provider(provider: str) -> type["BaseOAuth2Credentials"]: - """ - Builds an OAuth2Credentials model for the given OAuth2 `provider`. - This routine is only available in Pydantic V2 environments. - """ - from pydantic import ConfigDict - - class _OAuth2Credentials(BaseOAuth2Credentials): - model_config = ConfigDict( - json_schema_extra={"x-oauth2-provider": provider}, - title="OAuth", - ) - - def _you_must_build_oauth2_credentials_for_a_provider(self): ... - - return _OAuth2Credentials - -class RotatingOAuth2Credentials(BaseOAuth2Credentials): - access_token: str = Field( - title="Access Token", - json_schema_extra={"secret": True} - ) +class RotatingOAuth2Credentials(BaseOAuth2Credentials, metaclass=abc.ABCMeta): + access_token: str = Field(title="Access Token", json_schema_extra={"secret": True}) access_token_expires_at: datetime = Field( title="Access token expiration time.", ) - @staticmethod - def for_provider(provider: str) -> type["RotatingOAuth2Credentials"]: - """ - Builds an OAuth2Credentials model for the given OAuth2 `provider`. - This routine is only available in Pydantic V2 environments. - """ - from pydantic import ConfigDict - - class _OAuth2Credentials(RotatingOAuth2Credentials): - model_config = ConfigDict( - json_schema_extra={"x-oauth2-provider": provider}, - title="OAuth", - ) - - def _you_must_build_oauth2_credentials_for_a_provider(self): ... - - return _OAuth2Credentials class GoogleServiceAccountSpec(BaseModel): @@ -324,11 +256,10 @@ class GoogleServiceAccountSpec(BaseModel): class GoogleServiceAccount(BaseModel): credentials_title: Literal["Google Service Account"] = Field( default="Google Service Account", - json_schema_extra={"type": "string", "order": 0} + json_schema_extra={"type": "string", "order": 0}, ) service_account: str = Field( title="Google Service Account", description="Service account JSON key", json_schema_extra={"secret": True, "multiline": True, "order": 1}, ) - diff --git a/estuary-cdk/estuary_cdk/http.py b/estuary-cdk/estuary_cdk/http.py index ed61170394..1c862c5716 100644 --- a/estuary-cdk/estuary_cdk/http.py +++ b/estuary-cdk/estuary_cdk/http.py @@ -1,39 +1,37 @@ -from dataclasses import dataclass -from logging import Logger -from pydantic import BaseModel -from typing import AsyncGenerator, Any, Awaitable, TypeVar, Callable, Protocol -from multidict import CIMultiDict import abc -import aiohttp import asyncio import base64 import json import time +from dataclasses import dataclass +from logging import Logger +from typing import Any, AsyncGenerator, Awaitable, Callable, Protocol, TypeVar - +import aiohttp from google.auth.credentials import TokenState as GoogleTokenState from google.auth.transport.requests import Request as GoogleAuthRequest from google.oauth2.service_account import Credentials as GoogleServiceAccountCredentials +from multidict import CIMultiDict +from pydantic import BaseModel from . import Mixin from .flow import ( AccessToken, - BasicAuth, - BaseOAuth2Credentials, AuthorizationCodeFlowOAuth2Credentials, + BaseOAuth2Credentials, + BasicAuth, ClientCredentialsOAuth2Credentials, - OAuth2TokenFlowSpec, + GoogleServiceAccount, + GoogleServiceAccountSpec, LongLivedClientCredentialsOAuth2Credentials, + OAuth2ClientCredentialsPlacement, + OAuth2Spec, + OAuth2TokenFlowSpec, ResourceOwnerPasswordOAuth2Credentials, RotatingOAuth2Credentials, - OAuth2Spec, - OAuth2RotatingTokenSpec, - GoogleServiceAccount, - GoogleServiceAccountSpec, ) from .utils import format_error_message - DEFAULT_AUTHORIZATION_HEADER = "Authorization" DEFAULT_AUTHORIZATION_TOKEN_TYPE = "Bearer" @@ -44,6 +42,7 @@ BodyGeneratorFunction = Callable[[], AsyncGenerator[bytes, None]] HeadersAndBodyGenerator = tuple[Headers, BodyGeneratorFunction] + class ShouldRetryProtocol(Protocol): """ ShouldRetryProtocol defines a callback function signature for custom retry logic. @@ -65,12 +64,9 @@ def custom_retry(status: int, headers: Headers, body: bytes, attempt: int) -> bo return False return status >= 500 """ + def __call__( - self, - status: int, - headers: Headers, - body: bytes, - attempt: int + self, status: int, headers: Headers, body: bytes, attempt: int ) -> bool: ... @@ -120,7 +116,15 @@ async def request( chunks: list[bytes] = [] _, body_generator = await self._request_stream( - log, url, method, params, json, form, _with_token, headers, should_retry, + log, + url, + method, + params, + json, + form, + _with_token, + headers, + should_retry, ) async for chunk in body_generator(): @@ -148,7 +152,15 @@ async def request_lines( """Request a url and return its response as streaming lines, as they arrive""" resp_headers, body = await self._request_stream( - log, url, method, params, json, form, True, headers, should_retry, + log, + url, + method, + params, + json, + form, + True, + headers, + should_retry, ) async def gen() -> AsyncGenerator[bytes, None]: @@ -178,8 +190,9 @@ async def request_stream( ) -> tuple[Headers, BodyGeneratorFunction]: """Request a url and and return the raw response as a stream of bytes""" - return await self._request_stream(log, url, method, params, json, form, _with_token, headers, should_retry) - + return await self._request_stream( + log, url, method, params, json, form, _with_token, headers, should_retry + ) @abc.abstractmethod async def _request_stream( @@ -209,7 +222,7 @@ class AccessTokenResponse(BaseModel): refresh_token: str = "" scope: str = "" - oauth_spec: OAuth2Spec | OAuth2TokenFlowSpec | OAuth2RotatingTokenSpec | None + oauth_spec: OAuth2Spec | OAuth2TokenFlowSpec | None credentials: ( BaseOAuth2Credentials | RotatingOAuth2Credentials @@ -228,7 +241,9 @@ class AccessTokenResponse(BaseModel): _fetched_at: int = 0 async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str]: - if isinstance(self.credentials, ( + if isinstance( + self.credentials, + ( AccessToken, LongLivedClientCredentialsOAuth2Credentials, # RotatingOAuth2Credentials are refreshed _only_ at connector startup. @@ -237,7 +252,7 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str # to keep valid tokens in the endpoint config, so we never attempt to # exchange tokens in `fetch_token` for `RotatingOAuth2Credentials`. RotatingOAuth2Credentials, - ) + ), ): return (self.authorization_token_type, self.credentials.access_token) elif isinstance(self.credentials, BasicAuth): @@ -250,9 +265,11 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str elif isinstance(self.credentials, GoogleServiceAccount): assert isinstance(self.google_spec, GoogleServiceAccountSpec) if self._access_token is None: - self._access_token = GoogleServiceAccountCredentials.from_service_account_info( - json.loads(self.credentials.service_account), - scopes=self.google_spec.scopes, + self._access_token = ( + GoogleServiceAccountCredentials.from_service_account_info( + json.loads(self.credentials.service_account), + scopes=self.google_spec.scopes, + ) ) assert isinstance(self._access_token, GoogleServiceAccountCredentials) @@ -263,7 +280,9 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str case GoogleTokenState.STALE | GoogleTokenState.INVALID: self._access_token.refresh(GoogleAuthRequest()) case _: - raise RuntimeError(f"Unknown GoogleTokenState: {self._access_token.token_state}") + raise RuntimeError( + f"Unknown GoogleTokenState: {self._access_token.token_state}" + ) return (self.authorization_token_type, self._access_token.token) @@ -306,7 +325,9 @@ async def initialize_oauth2_tokens( self._fetched_at = int(time.time()) response = await self._fetch_oauth2_token( - log, session, self.credentials, + log, + session, + self.credentials, ) self._access_token = response return response @@ -315,63 +336,46 @@ async def _fetch_oauth2_token( self, log: Logger, session: HTTPSession, - credentials: BaseOAuth2Credentials - | ResourceOwnerPasswordOAuth2Credentials - | ClientCredentialsOAuth2Credentials - | AuthorizationCodeFlowOAuth2Credentials - | RotatingOAuth2Credentials, + credentials: ( + BaseOAuth2Credentials + | ResourceOwnerPasswordOAuth2Credentials + | ClientCredentialsOAuth2Credentials + | AuthorizationCodeFlowOAuth2Credentials + | RotatingOAuth2Credentials + ), ) -> AccessTokenResponse: assert self.oauth_spec - headers = {} - form = {} - - match credentials: - case RotatingOAuth2Credentials(): - assert isinstance(self.oauth_spec, OAuth2RotatingTokenSpec) - form: dict[str, str | int] = { - "grant_type": "refresh_token", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - "refresh_token": credentials.refresh_token, - } - - # Some providers require additional parameters within the form body, like - # an `expires_in` to configure how long the access token remains valid. - if self.oauth_spec.additionalTokenExchangeBody: - form.update(self.oauth_spec.additionalTokenExchangeBody) - - case BaseOAuth2Credentials(): - form = { - "grant_type": "refresh_token", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - "refresh_token": credentials.refresh_token, - } - case ClientCredentialsOAuth2Credentials(): - form = { - "grant_type": "client_credentials", - } - headers = { - "Authorization": "Basic " + headers: dict[str, str | int] = {} + form: dict[str, str | int] = {"grant_type": credentials.grant_type} + + match credentials.client_credentials_placement: + case OAuth2ClientCredentialsPlacement.HEADERS: + headers["Authorization"] = ( + "Basic " + base64.b64encode( f"{credentials.client_id}:{credentials.client_secret}".encode() ).decode() - } - case AuthorizationCodeFlowOAuth2Credentials(): - form = { - "grant_type": "authorization_code", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - } - case ResourceOwnerPasswordOAuth2Credentials(): - form = { - "grant_type": "password", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - } + ) + case OAuth2ClientCredentialsPlacement.FORM: + form.update( + { + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } + ) case _: - raise TypeError(f"Unsupported credentials type: {type(credentials)}.") + raise RuntimeError( + f"Unknown OAuth client credentials placement: {credentials.client_credentials_placement}" + ) + + # Some providers require additional parameters within the form body, like + # an `expires_in` to configure how long the access token remains valid. + form.update(self.oauth_spec.additionalTokenExchangeBody) + + if isinstance(credentials, BaseOAuth2Credentials): + assert isinstance(self.oauth_spec, OAuth2Spec) + form["refresh_token"] = credentials.refresh_token response = await session.request( log, @@ -404,7 +408,7 @@ class RateLimiter: """ delay: float = 1.0 - MAX_DELAY: float = 300.0 # 5 minutes + MAX_DELAY: float = 300.0 # 5 minutes gain: float = 0.01 failed: int = 0 @@ -491,19 +495,24 @@ async def _retry_on_connection_error( try: return await operation() except ( - asyncio.TimeoutError, # Connection timeouts - aiohttp.ClientConnectorError, # DNS, SSL handshake, connection refused errors - aiohttp.ClientConnectorDNSError, # DNS resolution failures - aiohttp.ConnectionTimeoutError, # aiohttp connection timeouts (sock_connect, connect) - ConnectionResetError, # TCP connection reset - aiohttp.ClientOSError, # OS errors (like BrokenPipeError) during request sending + asyncio.TimeoutError, # Connection timeouts + aiohttp.ClientConnectorError, # DNS, SSL handshake, connection refused errors + aiohttp.ClientConnectorDNSError, # DNS resolution failures + aiohttp.ConnectionTimeoutError, # aiohttp connection timeouts (sock_connect, connect) + ConnectionResetError, # TCP connection reset + aiohttp.ClientOSError, # OS errors (like BrokenPipeError) during request sending aiohttp.ClientConnectionResetError, # Connection reset errors - aiohttp.ServerDisconnectedError, # Server disconnections + aiohttp.ServerDisconnectedError, # Server disconnections ) as e: if attempt <= max_attempts: log.warning( f"Connection error occurred while establishing connection (will retry)", - {"url": url, "method": method, "attempt": attempt, "error": format_error_message(e)} + { + "url": url, + "method": method, + "attempt": attempt, + "error": format_error_message(e), + }, ) attempt += 1 else: @@ -531,10 +540,19 @@ async def _request_stream( attempt += 1 resp = await self._retry_on_connection_error( - log, url, method, + log, + url, + method, lambda: self._establish_connection_and_get_response( - log, url, method, params, json, form, _with_token, headers, - ) + log, + url, + method, + params, + json, + form, + _with_token, + headers, + ), ) should_release_response = True @@ -555,9 +573,8 @@ async def _request_stream( elif resp.status >= 500 and resp.status < 600: body = await resp.read() - if ( - should_retry is None - or should_retry(resp.status, resp.headers, body, attempt) + if should_retry is None or should_retry( + resp.status, resp.headers, body, attempt ): log.warning( "server internal error (will retry)", @@ -591,5 +608,3 @@ async def body_generator() -> AsyncGenerator[bytes, None]: finally: if should_release_response: await resp.release() - - diff --git a/source-facebook-marketing-native/tests/snapshots/snapshots__spec__capture.stdout.json b/source-facebook-marketing-native/tests/snapshots/snapshots__spec__capture.stdout.json index 2f815d75f5..6be137d6b5 100644 --- a/source-facebook-marketing-native/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-facebook-marketing-native/tests/snapshots/snapshots__spec__capture.stdout.json @@ -108,7 +108,7 @@ "title": "InsightsConfig", "type": "object" }, - "_OAuth2Credentials": { + "LongLivedClientCredentialsOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -170,13 +170,13 @@ "credentials": { "discriminator": { "mapping": { - "OAuth Credentials": "#/$defs/_OAuth2Credentials" + "OAuth Credentials": "#/$defs/LongLivedClientCredentialsOAuth2Credentials" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/LongLivedClientCredentialsOAuth2Credentials" } ], "title": "Authentication" diff --git a/source-genesys/source_genesys/models.py b/source-genesys/source_genesys/models.py index f4e8e38327..bb53ee340f 100644 --- a/source-genesys/source_genesys/models.py +++ b/source-genesys/source_genesys/models.py @@ -1,17 +1,20 @@ from datetime import datetime, timezone, timedelta from enum import StrEnum from pydantic import AwareDatetime, BaseModel, Field -from typing import ClassVar, Literal, Optional +from typing import TYPE_CHECKING, ClassVar, Literal, Optional from estuary_cdk.capture.common import ( BaseDocument, ConnectorState as GenericConnectorState, - ClientCredentialsOAuth2Credentials, - OAuth2TokenFlowSpec, ResourceConfig, ResourceState, ) +from estuary_cdk.flow import ( + ClientCredentialsOAuth2Credentials, + OAuth2ClientCredentialsPlacement, + OAuth2TokenFlowSpec, +) OAUTH2_SPEC = OAuth2TokenFlowSpec( @@ -23,6 +26,14 @@ } ) +if TYPE_CHECKING: + OAuth2Credentials = ClientCredentialsOAuth2Credentials +else: + OAuth2Credentials = ( + ClientCredentialsOAuth2Credentials.with_client_credentials_placement( + OAuth2ClientCredentialsPlacement.HEADERS + ) + ) def default_start_date(): dt = datetime.now(timezone.utc) - timedelta(days=30) @@ -54,7 +65,7 @@ class EndpointConfig(BaseModel): ] = Field( title="Genesys Cloud Domain" ) - credentials: ClientCredentialsOAuth2Credentials = Field( + credentials: OAuth2Credentials = Field( title="Authentication", discriminator="credentials_title" ) diff --git a/source-google-analytics-data-api-native/tests/snapshots/snapshots__spec__capture.stdout.json b/source-google-analytics-data-api-native/tests/snapshots/snapshots__spec__capture.stdout.json index 7b7b85c45d..9addfc1bf0 100644 --- a/source-google-analytics-data-api-native/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-google-analytics-data-api-native/tests/snapshots/snapshots__spec__capture.stdout.json @@ -16,7 +16,7 @@ "title": "Advanced", "type": "object" }, - "_OAuth2Credentials": { + "BaseOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -71,13 +71,13 @@ "credentials": { "discriminator": { "mapping": { - "OAuth Credentials": "#/$defs/_OAuth2Credentials" + "OAuth Credentials": "#/$defs/BaseOAuth2Credentials" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/BaseOAuth2Credentials" } ], "title": "Authentication" diff --git a/source-google-sheets-native/tests/snapshots/snapshots__spec__stdout.json b/source-google-sheets-native/tests/snapshots/snapshots__spec__stdout.json index 33466c8a2f..141d86285c 100644 --- a/source-google-sheets-native/tests/snapshots/snapshots__spec__stdout.json +++ b/source-google-sheets-native/tests/snapshots/snapshots__spec__stdout.json @@ -23,31 +23,7 @@ "title": "AccessToken", "type": "object" }, - "GoogleServiceAccount": { - "properties": { - "credentials_title": { - "const": "Google Service Account", - "default": "Google Service Account", - "order": 0, - "title": "Credentials Title", - "type": "string" - }, - "service_account": { - "description": "Service account JSON key", - "multiline": true, - "order": 1, - "secret": true, - "title": "Google Service Account", - "type": "string" - } - }, - "required": [ - "service_account" - ], - "title": "GoogleServiceAccount", - "type": "object" - }, - "_OAuth2Credentials": { + "BaseOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -79,6 +55,30 @@ "title": "OAuth", "type": "object", "x-oauth2-provider": "google" + }, + "GoogleServiceAccount": { + "properties": { + "credentials_title": { + "const": "Google Service Account", + "default": "Google Service Account", + "order": 0, + "title": "Credentials Title", + "type": "string" + }, + "service_account": { + "description": "Service account JSON key", + "multiline": true, + "order": 1, + "secret": true, + "title": "Google Service Account", + "type": "string" + } + }, + "required": [ + "service_account" + ], + "title": "GoogleServiceAccount", + "type": "object" } }, "properties": { @@ -86,14 +86,14 @@ "discriminator": { "mapping": { "Google Service Account": "#/$defs/GoogleServiceAccount", - "OAuth Credentials": "#/$defs/_OAuth2Credentials", + "OAuth Credentials": "#/$defs/BaseOAuth2Credentials", "Private App Credentials": "#/$defs/AccessToken" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/BaseOAuth2Credentials" }, { "$ref": "#/$defs/GoogleServiceAccount" diff --git a/source-hubspot-native/tests/snapshots/snapshots__spec__stdout.json b/source-hubspot-native/tests/snapshots/snapshots__spec__stdout.json index b692fb03be..94928ddcf7 100644 --- a/source-hubspot-native/tests/snapshots/snapshots__spec__stdout.json +++ b/source-hubspot-native/tests/snapshots/snapshots__spec__stdout.json @@ -23,7 +23,7 @@ "title": "AccessToken", "type": "object" }, - "_OAuth2Credentials": { + "BaseOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -61,14 +61,14 @@ "credentials": { "discriminator": { "mapping": { - "OAuth Credentials": "#/$defs/_OAuth2Credentials", + "OAuth Credentials": "#/$defs/BaseOAuth2Credentials", "Private App Credentials": "#/$defs/AccessToken" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/BaseOAuth2Credentials" } ], "title": "Authentication" diff --git a/source-intercom-native/tests/snapshots/snapshots__spec__capture.stdout.json b/source-intercom-native/tests/snapshots/snapshots__spec__capture.stdout.json index fec37b6e8e..9eb12587f5 100644 --- a/source-intercom-native/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-intercom-native/tests/snapshots/snapshots__spec__capture.stdout.json @@ -50,7 +50,7 @@ "title": "Advanced", "type": "object" }, - "_OAuth2Credentials": { + "LongLivedClientCredentialsOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -88,14 +88,14 @@ "credentials": { "discriminator": { "mapping": { - "OAuth Credentials": "#/$defs/_OAuth2Credentials", + "OAuth Credentials": "#/$defs/LongLivedClientCredentialsOAuth2Credentials", "Private App Credentials": "#/$defs/AccessToken" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/LongLivedClientCredentialsOAuth2Credentials" }, { "$ref": "#/$defs/AccessToken" diff --git a/source-looker/tests/snapshots/snapshots__spec__capture.stdout.json b/source-looker/tests/snapshots/snapshots__spec__capture.stdout.json index 3b08b2f881..f1bcce68a1 100644 --- a/source-looker/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-looker/tests/snapshots/snapshots__spec__capture.stdout.json @@ -26,7 +26,7 @@ "client_id", "client_secret" ], - "title": "OAuth2", + "title": "OAuth", "type": "object" } }, diff --git a/source-navan/source_navan/models.py b/source-navan/source_navan/models.py index 314125ee02..a7d7635fad 100644 --- a/source-navan/source_navan/models.py +++ b/source-navan/source_navan/models.py @@ -1,4 +1,5 @@ from typing import ( + TYPE_CHECKING, Generic, TypeVar, ) @@ -12,6 +13,7 @@ ) from estuary_cdk.flow import ( ClientCredentialsOAuth2Credentials, + OAuth2ClientCredentialsPlacement, OAuth2TokenFlowSpec, ) from pydantic import ( @@ -28,9 +30,17 @@ accessTokenResponseMap={"access_token": "/access_token"}, ) +if TYPE_CHECKING: + OAuth2Credentials = ClientCredentialsOAuth2Credentials +else: + OAuth2Credentials = ( + ClientCredentialsOAuth2Credentials.with_client_credentials_placement( + OAuth2ClientCredentialsPlacement.HEADERS + ) + ) class EndpointConfig(BaseModel): - credentials: ClientCredentialsOAuth2Credentials = Field( + credentials: OAuth2Credentials = Field( title="Authentication", description="See https://app.navan.com/app/helpcenter/articles/travel/admin/other-integrations/booking-data-integration", ) diff --git a/source-outreach/source_outreach/models.py b/source-outreach/source_outreach/models.py index d0fcf8ccfc..08b4003419 100644 --- a/source-outreach/source_outreach/models.py +++ b/source-outreach/source_outreach/models.py @@ -7,7 +7,7 @@ from estuary_cdk.capture.common import ( BaseDocument, RotatingOAuth2Credentials, - OAuth2RotatingTokenSpec, + OAuth2Spec, ResourceConfig, ResourceState, ) @@ -91,7 +91,7 @@ def urlencode_field(field: str): } -OAUTH2_SPEC = OAuth2RotatingTokenSpec( +OAUTH2_SPEC = OAuth2Spec( provider="outreach", accessTokenBody=json.dumps(accessTokenBody), authUrlTemplate=( @@ -111,7 +111,6 @@ def urlencode_field(field: str): accessTokenHeaders={ "Content-Type": "application/json", }, - additionalTokenExchangeBody=None, ) diff --git a/source-outreach/tests/snapshots/snapshots__spec__capture.stdout.json b/source-outreach/tests/snapshots/snapshots__spec__capture.stdout.json index 2ae6427e32..ea736e00c4 100644 --- a/source-outreach/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-outreach/tests/snapshots/snapshots__spec__capture.stdout.json @@ -3,7 +3,7 @@ "protocol": 3032023, "configSchema": { "$defs": { - "_OAuth2Credentials": { + "RotatingOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -59,13 +59,13 @@ "credentials": { "discriminator": { "mapping": { - "OAuth Credentials": "#/$defs/_OAuth2Credentials" + "OAuth Credentials": "#/$defs/RotatingOAuth2Credentials" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/RotatingOAuth2Credentials" } ], "title": "Authentication" diff --git a/source-shopify-native/tests/snapshots/snapshots__spec__capture.stdout.json b/source-shopify-native/tests/snapshots/snapshots__spec__capture.stdout.json index 3ca9068451..841c32389a 100644 --- a/source-shopify-native/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-shopify-native/tests/snapshots/snapshots__spec__capture.stdout.json @@ -36,7 +36,7 @@ "title": "Advanced", "type": "object" }, - "_OAuth2Credentials": { + "LongLivedClientCredentialsOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -79,7 +79,7 @@ "credentials": { "discriminator": { "mapping": { - "OAuth Credentials": "#/$defs/_OAuth2Credentials", + "OAuth Credentials": "#/$defs/LongLivedClientCredentialsOAuth2Credentials", "Private App Credentials": "#/$defs/AccessToken" }, "propertyName": "credentials_title" @@ -89,7 +89,7 @@ "$ref": "#/$defs/AccessToken" }, { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/LongLivedClientCredentialsOAuth2Credentials" } ], "title": "Authentication" diff --git a/source-zendesk-support-native/source_zendesk_support_native/models.py b/source-zendesk-support-native/source_zendesk_support_native/models.py index 7fa0293986..311afea812 100644 --- a/source-zendesk-support-native/source_zendesk_support_native/models.py +++ b/source-zendesk-support-native/source_zendesk_support_native/models.py @@ -8,7 +8,7 @@ BaseDocument, LongLivedClientCredentialsOAuth2Credentials, RotatingOAuth2Credentials, - OAuth2RotatingTokenSpec, + OAuth2Spec, ResourceConfig, ResourceState, ) @@ -47,7 +47,7 @@ def urlencode_field(field: str): "expires_in": MAX_VALID_ACCESS_TOKEN_DURATION, } -OAUTH2_SPEC = OAuth2RotatingTokenSpec( +OAUTH2_SPEC = OAuth2Spec( provider="zendesk", accessTokenBody=json.dumps(accessTokenBody), authUrlTemplate=( diff --git a/source-zendesk-support-native/tests/snapshots/snapshots__capture__capture.stdout.json b/source-zendesk-support-native/tests/snapshots/snapshots__capture__capture.stdout.json index a3b28346e4..2470550394 100644 --- a/source-zendesk-support-native/tests/snapshots/snapshots__capture__capture.stdout.json +++ b/source-zendesk-support-native/tests/snapshots/snapshots__capture__capture.stdout.json @@ -844,6 +844,10 @@ { "id": 39405384709780, "value": null + }, + { + "id": 43010621832468, + "value": null } ], "custom_status_id": 28835714727188, @@ -861,6 +865,10 @@ { "id": 39405384709780, "value": null + }, + { + "id": 43010621832468, + "value": null } ], "follower_ids": [], @@ -885,6 +893,7 @@ "status": "open", "subject": "SAMPLE TICKET: Meet the ticket", "submitter_id": 28835702984212, + "support_type": "agent", "tags": [ "sample", "support", @@ -929,6 +938,10 @@ { "id": 39405384709780, "value": null + }, + { + "id": 43010621832468, + "value": null } ], "custom_status_id": 28835714727188, @@ -945,6 +958,10 @@ { "id": 39405384709780, "value": null + }, + { + "id": 43010621832468, + "value": null } ], "follower_ids": [], @@ -969,6 +986,7 @@ "status": "open", "subject": "SAMPLE TICKET: Meet the ticket", "submitter_id": 28835702984212, + "support_type": "agent", "tags": [ "sample", "support", diff --git a/source-zendesk-support-native/tests/snapshots/snapshots__spec__capture.stdout.json b/source-zendesk-support-native/tests/snapshots/snapshots__spec__capture.stdout.json index 8e676de72b..3b35e09bd6 100644 --- a/source-zendesk-support-native/tests/snapshots/snapshots__spec__capture.stdout.json +++ b/source-zendesk-support-native/tests/snapshots/snapshots__spec__capture.stdout.json @@ -75,7 +75,7 @@ "type": "object", "x-oauth2-provider": "zendesk" }, - "_OAuth2Credentials": { + "RotatingOAuth2Credentials": { "properties": { "credentials_title": { "const": "OAuth Credentials", @@ -139,13 +139,13 @@ "mapping": { "Deprecated OAuth Credentials": "#/$defs/DeprecatedOAuthCredentials", "Email & API Token": "#/$defs/ApiToken", - "OAuth Credentials": "#/$defs/_OAuth2Credentials" + "OAuth Credentials": "#/$defs/RotatingOAuth2Credentials" }, "propertyName": "credentials_title" }, "oneOf": [ { - "$ref": "#/$defs/_OAuth2Credentials" + "$ref": "#/$defs/RotatingOAuth2Credentials" }, { "$ref": "#/$defs/ApiToken"