Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions estuary-cdk/estuary_cdk/capture/base_capture_connector.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
from datetime import datetime, timedelta, UTC
from estuary_cdk.capture.connector_status import ConnectorStatus
from pydantic import BaseModel
from typing import Generic, Awaitable, Any, BinaryIO, Callable
from logging import Logger
import abc
import asyncio
import json
import sys
from datetime import UTC, datetime, timedelta
from typing import Any, Awaitable, BinaryIO, Callable, Generic

from pydantic import BaseModel

from estuary_cdk.capture.connector_status import ConnectorStatus

from . import request, response, Request, Response, Task
from .. import BaseConnector, Stopped
from .common import ConnectorState, _ConnectorState
from ..flow import (
ConnectorSpec,
ConnectorState as GeneralConnectorState,
ConnectorStateUpdate,
EndpointConfig,
ResourceConfig,
RotatingOAuth2Credentials,
)
from ..http import HTTPError, HTTPMixin, TokenSource
from ..flow import (
ConnectorState as GeneralConnectorState,
)
from ..http import HTTPMixin, TokenSource
from ..logger import FlowLogger
from ..utils import format_error_message, sort_dict
from . import Request, Response, Task, request, response
from .common import _ConnectorState


class BaseCaptureConnector(
BaseConnector[Request[EndpointConfig, ResourceConfig, _ConnectorState]],
Expand Down Expand Up @@ -99,7 +103,9 @@ async def periodic_stop() -> None:
stopping.event.set()

# Rotate OAuth2 tokens for credentials with periodically expiring tokens.
if isinstance(self.token_source, TokenSource) and isinstance(self.token_source.credentials, RotatingOAuth2Credentials):
if isinstance(self.token_source, TokenSource) and isinstance(
self.token_source.credentials, RotatingOAuth2Credentials
):
await self._rotate_oauth2_tokens(log, open)

# Gracefully exit after a moderate period of time.
Expand All @@ -124,9 +130,7 @@ async def periodic_stop() -> None:
if stopping.first_error:
msg = format_error_message(stopping.first_error)

raise Stopped(
f"Task {stopping.first_error_task}: {msg}"
)
raise Stopped(f"Task {stopping.first_error_task}: {msg}")
else:
raise Stopped(None)

Expand All @@ -136,18 +140,16 @@ async def periodic_stop() -> None:
else:
raise RuntimeError("malformed request", request)

def _emit(self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState]):
def _emit(
self, response: Response[EndpointConfig, ResourceConfig, GeneralConnectorState]
):
self.output.write(
response.model_dump_json(by_alias=True, exclude_unset=True).encode()
)
self.output.write(b"\n")
self.output.flush()

def _checkpoint(
self,
state: GeneralConnectorState,
merge_patch: bool = True
):
def _checkpoint(self, state: GeneralConnectorState, merge_patch: bool = True):
r = Response[Any, Any, GeneralConnectorState](
checkpoint=response.Checkpoint(
state=ConnectorStateUpdate(updated=state, mergePatch=merge_patch)
Expand All @@ -168,7 +170,9 @@ async def _encrypt_config(
# include=config.model_fields_set ensures only the fields that are explicitly set on the
# model. This means any default values that are set are included, but any fields that are unset
# & fallback to some default value are left unset.
unencrypted_config = config.model_dump(mode="json", include=config.model_fields_set)
unencrypted_config = config.model_dump(
mode="json", include=config.model_fields_set
)

body = {
# Flow always sorts object properties lexicographically. This keeps a consistent property
Expand All @@ -179,9 +183,11 @@ async def _encrypt_config(
"schema": config.model_json_schema(),
}

encrypted_config = await self.request(log, ENCRYPTION_URL, "POST", json=body, _with_token = False)
encrypted_config = await self.request(
log, ENCRYPTION_URL, "POST", json=body, _with_token=False
)

return json.loads(encrypted_config.decode('utf-8'))
return json.loads(encrypted_config.decode("utf-8"))

async def _rotate_oauth2_tokens(
self,
Expand All @@ -198,14 +204,20 @@ async def _rotate_oauth2_tokens(
if access_token_expiration < now + timedelta(minutes=5):
# Exchange for new access token & refresh token.
log.info("Attempting to rotate OAuth2 tokens.")
token_exchange_response = await self.token_source.initialize_oauth2_tokens(log, self)
token_exchange_response = await self.token_source.initialize_oauth2_tokens(
log, self
)

# Replace tokens in the config.
credentials = self.token_source.credentials
credentials.access_token = token_exchange_response.access_token
credentials.refresh_token = token_exchange_response.refresh_token
credentials.access_token_expires_at = now + timedelta(seconds=token_exchange_response.expires_in)
credentials.access_token_expires_at = now + timedelta(
seconds=token_exchange_response.expires_in
)

# Encrypt the updated config and emit an event telling the control plane to publish a new spec.
encrypted_config = await self._encrypt_config(log, open.capture.config)
log.event.config_update("Rotating OAuth2 tokens in endpoint config.", encrypted_config)
log.event.config_update(
"Rotating OAuth2 tokens in endpoint config.", encrypted_config
)
1 change: 0 additions & 1 deletion estuary-cdk/estuary_cdk/capture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
CaptureBinding,
ClientCredentialsOAuth2Credentials,
OAuth2TokenFlowSpec,
OAuth2RotatingTokenSpec,
AuthorizationCodeFlowOAuth2Credentials,
LongLivedClientCredentialsOAuth2Credentials,
ResourceOwnerPasswordOAuth2Credentials,
Expand Down
Loading
Loading