diff --git a/packages/atproto_client/client/client.py b/packages/atproto_client/client/client.py index 95048079..b1738730 100644 --- a/packages/atproto_client/client/client.py +++ b/packages/atproto_client/client/client.py @@ -6,6 +6,7 @@ from atproto_client import models from atproto_client.client.methods_mixin import SessionMethodsMixin, TimeMethodsMixin from atproto_client.client.methods_mixin.headers import HeadersConfigurationMethodsMixin +from atproto_client.client.methods_mixin.oauth import OauthSessionMethodsMixin from atproto_client.client.methods_mixin.session import SessionDispatchMixin from atproto_client.client.raw import ClientRaw from atproto_client.client.session import Session, SessionEvent, SessionResponse @@ -18,7 +19,13 @@ from atproto_client.request import Response -class Client(SessionDispatchMixin, SessionMethodsMixin, TimeMethodsMixin, HeadersConfigurationMethodsMixin, ClientRaw): +class Client(OauthSessionMethodsMixin, + SessionDispatchMixin, + SessionMethodsMixin, + TimeMethodsMixin, + HeadersConfigurationMethodsMixin, + ClientRaw + ): """High-level client for XRPC of ATProto.""" def __init__(self, base_url: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any) -> None: diff --git a/packages/atproto_client/client/methods_mixin/dpop.py b/packages/atproto_client/client/methods_mixin/dpop.py new file mode 100644 index 00000000..0fc1e472 --- /dev/null +++ b/packages/atproto_client/client/methods_mixin/dpop.py @@ -0,0 +1,229 @@ +"""DPoP (Demonstrating Proof-of-Possession) implementation.""" + +import contextlib +import hashlib +import json +import secrets +import time +import typing as t +from base64 import urlsafe_b64encode +from urllib.parse import urlparse + +import httpx +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature + +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey + + +class DPoPManager: + """Manages DPoP proof generation for OAuth.""" + + @staticmethod + def generate_keypair() -> 'EllipticCurvePrivateKey': + """Generate ES256 keypair for DPoP. + + Returns: + EC private key (P-256 curve). + """ + return ec.generate_private_key(ec.SECP256R1()) + + @staticmethod + def _key_to_jwk(private_key: 'EllipticCurvePrivateKey', include_private: bool = False) -> t.Dict[str, t.Any]: + """Convert EC private key to JWK format. + + Args: + private_key: The EC private key. + include_private: Whether to include private key components. + + Returns: + JWK dictionary. + """ + public_key = private_key.public_key() + public_numbers = public_key.public_numbers() + + # Convert to bytes and base64url encode + def int_to_base64url(n: int, length: int) -> str: + byte_len = (length + 7) // 8 + return urlsafe_b64encode(n.to_bytes(byte_len, 'big')).decode('utf-8').rstrip('=') + + jwk = { + 'kty': 'EC', + 'crv': 'P-256', + 'x': int_to_base64url(public_numbers.x, 256), + 'y': int_to_base64url(public_numbers.y, 256), + } + + if include_private: + private_numbers = private_key.private_numbers() + jwk['d'] = int_to_base64url(private_numbers.private_value, 256) + + return jwk + + @staticmethod + def _sign_jwt( + header: t.Dict[str, t.Any], payload: t.Dict[str, t.Any], private_key: 'EllipticCurvePrivateKey' + ) -> str: + """Sign a JWT using ES256. + + Args: + header: JWT header. + payload: JWT payload. + private_key: EC private key for signing. + + Returns: + Complete JWT string. + """ + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + + # Encode header and payload + header_b64 = urlsafe_b64encode(json.dumps(header, separators=(',', ':')).encode()).decode().rstrip('=') + payload_b64 = urlsafe_b64encode(json.dumps(payload, separators=(',', ':')).encode()).decode().rstrip('=') + + # Create signing input + signing_input = f'{header_b64}.{payload_b64}'.encode() + + # Sign (returns DER-encoded signature) + der_signature = private_key.sign(signing_input, ec.ECDSA(hashes.SHA256())) + + # Convert DER signature to IEEE P1363 format (raw r|s concatenated) + # ES256 uses P-256 curve, so r and s are each 32 bytes + r, s = decode_dss_signature(der_signature) + + # Convert r and s to 32-byte big-endian sequences + r_bytes = r.to_bytes(32, 'big') + s_bytes = s.to_bytes(32, 'big') + + # Concatenate and encode + raw_signature = r_bytes + s_bytes + signature_b64 = urlsafe_b64encode(raw_signature).decode().rstrip('=') + + return f'{header_b64}.{payload_b64}.{signature_b64}' + + @classmethod + def create_proof( + cls, + method: str, + url: str, + private_key: 'EllipticCurvePrivateKey', + nonce: t.Optional[str] = None, + access_token: t.Optional[str] = None, + ) -> str: + """Generate DPoP proof JWT. + + Args: + method: HTTP method (e.g., 'GET', 'POST'). + url: Full URL of the request. + private_key: EC private key for signing. + nonce: Optional server-provided nonce. + access_token: Optional access token (for 'ath' claim). + + Returns: + DPoP proof JWT string. + """ + # Get public key JWK + public_jwk = cls._key_to_jwk(private_key, include_private=False) + + # Create header + header = { + 'typ': 'dpop+jwt', + 'alg': 'ES256', + 'jwk': public_jwk, + } + + # Strip query and fragment from URL per RFC 9449 + parsed_url = urlparse(url) + htu = f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}' + + # Create payload + now = int(time.time()) + payload = { + 'jti': secrets.token_urlsafe(16), + 'htm': method.upper(), + 'htu': htu, + 'iat': now, + 'exp': now + 60, # Valid for 60 seconds + } + + # Add optional claims + if nonce: + payload['nonce'] = nonce + + if access_token: + # Hash access token for 'ath' claim (same as PKCE S256) + ath_hash = hashlib.sha256(access_token.encode('utf-8')).digest() + payload['ath'] = urlsafe_b64encode(ath_hash).decode('utf-8').rstrip('=') + + return cls._sign_jwt(header, payload, private_key) + + @staticmethod + def extract_nonce_from_response(response: t.Union[httpx.Response, t.Any]) -> t.Optional[str]: + """Extract DPoP nonce from HTTP response. + + Checks both the 'DPoP-Nonce' header and error responses. + + Args: + response: HTTP response object (httpx.Response or atproto Response). + + Returns: + DPoP nonce string if present, None otherwise. + """ + # Handle both httpx.Response and wrapped Response objects + headers = response.headers + # Try both cases for header name (httpx is case-insensitive, dict is not) + nonce = headers.get('DPoP-Nonce') or headers.get('dpop-nonce') + if nonce: + return nonce + + # Check for error response with use_dpop_nonce + if response.status_code in (400, 401): + with contextlib.suppress(json.JSONDecodeError, AttributeError, TypeError): + # Handle both httpx.Response (.json()) and wrapped Response (.content) + if hasattr(response, 'json'): + error_body = response.json() + else: + error_body = response.content + if hasattr(error_body, 'error'): + # XrpcError object + if error_body.error == 'use_dpop_nonce': + return headers.get('DPoP-Nonce') or headers.get('dpop-nonce') + elif isinstance(error_body, dict) and error_body.get('error') == 'use_dpop_nonce': + return headers.get('DPoP-Nonce') or headers.get('dpop-nonce') + + return None + + @staticmethod + def is_dpop_nonce_error(response: t.Union[httpx.Response, t.Any]) -> bool: + """Check if response indicates DPoP nonce error. + + Args: + response: HTTP response object (httpx.Response or atproto Response). + + Returns: + True if response indicates need for new DPoP nonce. + """ + if response.status_code not in (400, 401): + return False + + headers = response.headers + # Check WWW-Authenticate header (try both cases) + www_auth = headers.get('WWW-Authenticate', '') or headers.get('www-authenticate', '') + if www_auth and 'use_dpop_nonce' in www_auth.lower(): + return True + + # Check error response + with contextlib.suppress(json.JSONDecodeError, AttributeError, TypeError): + # Handle both httpx.Response (.json()) and wrapped Response (.content) + if hasattr(response, 'json'): + error_body = response.json() + else: + error_body = response.content + if hasattr(error_body, 'error'): + # XrpcError object + return error_body.error == 'use_dpop_nonce' + if isinstance(error_body, dict) and error_body.get('error') == 'use_dpop_nonce': + return True + + return False diff --git a/packages/atproto_client/client/methods_mixin/oauth.py b/packages/atproto_client/client/methods_mixin/oauth.py new file mode 100644 index 00000000..b70d5cc9 --- /dev/null +++ b/packages/atproto_client/client/methods_mixin/oauth.py @@ -0,0 +1,579 @@ +import contextlib +import logging +import secrets +import time +import typing as t +from urllib.parse import urlencode + +import httpx +from atproto_identity.resolver import IdResolver + +from atproto_client.client.base import InvokeType + +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey + + from atproto_client.request import Response + +from atproto_client.client.methods_mixin.dpop import DPoPManager +from atproto_client.client.methods_mixin.pkce import PKCEManager +from atproto_client.client.oauth_session import ( + AuthServerMetadata, + OAuthSession, + OAuthState, + TokenResponse, + discover_authserver_from_pds, + fetch_authserver_metadata, + is_safe_url, +) +from atproto_client.client.session import SessionEvent +from atproto_client.exceptions import OAuthStateError, OAuthTokenError + +logger = logging.getLogger(__name__) + + + + +class OauthSessionMethodsMixin: + def __init__(self, *args: t.Any, + client_id: t.Optional[str] = None, + redirect_uri: t.Optional[str] = None, + scope: t.Optional[str] = None, + # state_store: StateStore, + # session_store: SessionStore, + client_secret_key: t.Optional['EllipticCurvePrivateKey'] = None, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + + # OAuth configuration (optional - only needed for OAuth flows) + self._oauth_client_id = client_id + self._oauth_redirect_uri = redirect_uri + self._oauth_scope = scope + self._oauth_client_secret_key = client_secret_key + + # Lazy-initialize OAuth components only when OAuth is configured + self._oauth_initialized = False + self._id_resolver: t.Optional[IdResolver] = None + self._dpop: t.Optional[DPoPManager] = None + self._pkce: t.Optional[PKCEManager] = None + + self._oauth_session: t.Optional[OAuthSession] = None + + def _ensure_oauth_initialized(self) -> None: + """Initialize OAuth components on first use.""" + if self._oauth_initialized: + return + + if not self._oauth_client_id or not self._oauth_redirect_uri or not self._oauth_scope: + raise ValueError('OAuth not configured. Provide client_id, redirect_uri, and scope.') + + self._id_resolver = IdResolver() + self._dpop = DPoPManager() + self._pkce = PKCEManager() + self._oauth_initialized = True + + def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response': + """Override _invoke to handle OAuth sessions with DPoP.""" + from atproto_client.client.base import _handle_kwargs + + # Process kwargs the same way as base client (removes input/output_encoding, etc.) + _handle_kwargs(kwargs) + + # Non-OAuth requests use normal flow + if not self._oauth_session or not self._is_oauth_session(): + if invoke_type is InvokeType.QUERY: + return self.request.get(**kwargs) + return self.request.post(**kwargs) + + # OAuth requests - add DPoP headers with nonce retry + self._ensure_oauth_initialized() + url = kwargs.get('url', '') + headers = kwargs.pop('headers', {}) + + # Use PDS nonce for PDS requests (separate from auth server nonce) + from atproto_client.exceptions import UnauthorizedError + + current_nonce = self._oauth_session.dpop_pds_nonce or '' + for attempt in range(2): + logger.info(f"DPoP request attempt {attempt}, nonce: {current_nonce[:20] if current_nonce else 'None'}...") + dpop_proof = self._dpop.create_proof( + method='GET' if invoke_type is InvokeType.QUERY else 'POST', + url=url, + private_key=self._oauth_session.dpop_private_key, + nonce=current_nonce if current_nonce else None, + access_token=self._oauth_session.access_jwt, + ) + + headers['Authorization'] = f'DPoP {self._oauth_session.access_jwt}' + headers['DPoP'] = dpop_proof + + try: + if invoke_type is InvokeType.QUERY: + response = self.request.get(headers=headers, **kwargs) + else: + response = self.request.post(headers=headers, **kwargs) + except UnauthorizedError as e: + # Check if it's a DPoP nonce error that we can retry + response = e.response + logger.info(f'DPoP caught UnauthorizedError, status: {response.status_code}') + + is_nonce_error = self._dpop.is_dpop_nonce_error(response) + logger.info(f'is_nonce_error: {is_nonce_error}') + + if is_nonce_error: + new_nonce = self._dpop.extract_nonce_from_response(response) + logger.info(f"Extracted nonce: {new_nonce[:20] if new_nonce else 'None'}...") + if new_nonce and attempt == 0: + current_nonce = new_nonce + continue + # Not a nonce error or can't retry - re-raise + raise + + is_nonce_error = self._dpop.is_dpop_nonce_error(response) + if is_nonce_error: + new_nonce = self._dpop.extract_nonce_from_response(response) + if new_nonce and attempt == 0: + current_nonce = new_nonce + continue + + # Update stored PDS nonce + self._oauth_session.dpop_pds_nonce = ( + self._dpop.extract_nonce_from_response(response) or current_nonce + ) + return response + + return response + + def _make_token_request( + self, + token_url: str, + params: t.Dict[str, str], + dpop_key: 'EllipticCurvePrivateKey', + dpop_nonce: str, + ) -> t.Tuple[str, httpx.Response]: + """Make token request with DPoP and client assertion. + + Handles DPoP nonce rotation automatically. + + Returns: + Tuple of (updated_dpop_nonce, response). + """ + self._ensure_oauth_initialized() + if not is_safe_url(token_url): + raise ValueError(f'Unsafe token URL: {token_url}') + + # Add client authentication + if self._oauth_client_secret_key: + # Confidential client - use client assertion + client_assertion = self._create_client_assertion(token_url) + params['client_id'] = self._oauth_client_id + params['client_assertion_type'] = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer' + params['client_assertion'] = client_assertion + else: + # Public client + params['client_id'] = self._oauth_client_id + + # Try request with DPoP nonce retry + current_nonce = dpop_nonce + for attempt in range(2): + # Create DPoP proof + dpop_proof = self._dpop.create_proof( + method='POST', + url=token_url, + private_key=dpop_key, + nonce=current_nonce if current_nonce else None, + ) + + # Make request + with httpx.Client() as client: + response = client.post( + token_url, + data=params, + headers={'DPoP': dpop_proof}, + ) + + # Check for DPoP nonce error + if self._dpop.is_dpop_nonce_error(response): + new_nonce = self._dpop.extract_nonce_from_response(response) + if new_nonce and attempt == 0: + current_nonce = new_nonce + continue # Retry with new nonce + + # Extract final nonce + final_nonce = self._dpop.extract_nonce_from_response(response) or current_nonce + + return final_nonce, response + + return current_nonce, response + + def _create_client_assertion(self, audience: str) -> str: + """Create client assertion JWT for confidential client.""" + if not self._oauth_client_secret_key: + raise ValueError('Client secret key required for client assertion') + + header = { + 'alg': 'ES256', + 'typ': 'JWT', + } + + now = int(time.time()) + payload = { + 'iss': self._oauth_client_id, + 'sub': self._oauth_client_id, + 'aud': audience, + 'jti': secrets.token_urlsafe(16), + 'iat': now, + 'exp': now + 60, # Valid for 60 seconds + } + + return self._dpop._sign_jwt(header, payload, self._oauth_client_secret_key) + + def start_authorization( + self, + handle_or_did: str, + ) -> t.Tuple[str, str, OAuthState]: + """Start OAuth authorization flow. + + Args: + handle_or_did: User handle (e.g., 'user.bsky.social') or DID. + + Returns: + Tuple of (authorization_url, state) for redirecting user. + + Raises: + ValueError: If handle/DID resolution fails or URL validation fails. + OAuthError: If authorization server discovery or PAR fails. + """ + self._ensure_oauth_initialized() + # 1. Resolve identity + if handle_or_did.startswith('did:'): + # Input is a DID + did = handle_or_did + else: + # Input is a handle - resolve to DID first + resolved_did = self._id_resolver.handle.resolve(handle_or_did) + if not resolved_did: + raise ValueError(f'Failed to resolve handle: {handle_or_did}') + did = resolved_did + + # 2. Resolve DID to get ATProto data (includes PDS, handle, etc.) + atproto_data = self._id_resolver.did.resolve_atproto_data(did) + + handle = atproto_data.handle or handle_or_did + pds_url = atproto_data.pds + + if not pds_url: + raise ValueError(f'No PDS endpoint found in DID document for {did}') + + # 3. Discover authorization server + authserver_url = discover_authserver_from_pds(pds_url) + authserver_url = authserver_url.rstrip('/') + + # 4. Fetch authorization server metadata + authserver_meta = fetch_authserver_metadata(authserver_url) + + # 5. Generate PKCE verifier and challenge + pkce_verifier, pkce_challenge = self._pkce.generate_pair() + + # 6. Generate DPoP keypair + dpop_key = self._dpop.generate_keypair() + + # 7. Generate state token + state_token = secrets.token_urlsafe(32) + + # 8. Send PAR (Pushed Authorization Request) + request_uri, dpop_nonce = self._send_par_request( + authserver_meta=authserver_meta, + login_hint=handle_or_did, + pkce_challenge=pkce_challenge, + dpop_key=dpop_key, + state=state_token, + ) + + # 9. Store state + oauth_state = OAuthState( + state=state_token, + pkce_verifier=pkce_verifier, + redirect_uri=self._oauth_redirect_uri, + scope=self._oauth_scope, + authserver_iss=authserver_meta.issuer, + dpop_private_key=dpop_key, + dpop_authserver_nonce=dpop_nonce, + did=did, + handle=handle, + pds_url=pds_url, + ) + + # 10. Build authorization URL + auth_params = { + 'client_id': self._oauth_client_id, + 'request_uri': request_uri, + } + auth_url = f'{authserver_meta.authorization_endpoint}?{urlencode(auth_params)}' + + if not is_safe_url(auth_url): + raise ValueError(f'Generated unsafe authorization URL: {auth_url}') + + return auth_url, state_token, oauth_state + + def _exchange_code_for_tokens( + self, + code: str, + oauth_state: OAuthState, + ) -> t.Tuple[TokenResponse, str]: + """Exchange authorization code for tokens. + + Returns: + Tuple of (token_response, dpop_nonce). + """ + # Fetch metadata again (could have changed) + authserver_meta = fetch_authserver_metadata(oauth_state.authserver_iss) + + params = { + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': oauth_state.pkce_verifier, + 'redirect_uri': self._oauth_redirect_uri, + } + + dpop_nonce, response = self._make_token_request( + token_url=authserver_meta.token_endpoint, + params=params, + dpop_key=oauth_state.dpop_private_key, + dpop_nonce=oauth_state.dpop_authserver_nonce, + ) + + if response.status_code not in (200, 201): + raise OAuthTokenError(f'Token exchange failed: {response.status_code} {response.text}') + + token_data = response.json() + token_response = TokenResponse( + access_token=token_data['access_token'], + token_type=token_data['token_type'], + scope=token_data['scope'], + sub=token_data['sub'], + refresh_token=token_data.get('refresh_token'), + expires_in=token_data.get('expires_in'), + ) + + return token_response, dpop_nonce + + + def handle_callback( + self, + code: str, + iss: str, + oauth_state: OAuthState + ) -> OAuthSession: + """Handle OAuth callback and complete authorization. + + Args: + code: Authorization code from callback. + iss: Issuer parameter from callback. + oauth_state: OAuth state object from start_authorization. + + Returns: + OAuth session with tokens. + + Raises: + OAuthStateError: If state validation fails. + OAuthTokenError: If token exchange fails. + """ + # 1. Retrieve and verify state + if not oauth_state: + raise OAuthStateError('Invalid or expired state parameter') + + if oauth_state.authserver_iss != iss: + raise OAuthStateError(f'Issuer mismatch: expected {oauth_state.authserver_iss}, got {iss}') + + # 2. Exchange code for tokens + token_response, dpop_nonce = self._exchange_code_for_tokens( + code=code, + oauth_state=oauth_state, + ) + + # 3. Verify token response + if token_response.sub != oauth_state.did: + raise OAuthTokenError(f'DID mismatch in token: expected {oauth_state.did}, got {token_response.sub}') + + if token_response.scope != self._oauth_scope: + raise OAuthTokenError(f'Scope mismatch: expected {self._oauth_scope}, got {token_response.scope}') + + # 4. Create and store session + session = OAuthSession( + did=oauth_state.did or token_response.sub, + handle=oauth_state.handle or '', + pds_url=oauth_state.pds_url or '', + authserver_iss=oauth_state.authserver_iss, + access_jwt=token_response.access_token, + refresh_jwt=token_response.refresh_token or '', + dpop_private_key=oauth_state.dpop_private_key, + dpop_authserver_nonce=dpop_nonce, + scope=token_response.scope, + ) + + self._set_oauth_session(SessionEvent.CREATE, session) + + return session + + def _send_par_request(self, + authserver_meta: AuthServerMetadata, + login_hint: str, + pkce_challenge: str, + dpop_key: 'EllipticCurvePrivateKey', + state: str, + ) -> t.Tuple[str, str]: + """Send Pushed Authorization Request. + + Returns: + Tuple of (request_uri, dpop_nonce). + """ + par_url = authserver_meta.pushed_authorization_request_endpoint + + params = { + 'response_type': 'code', + 'code_challenge': pkce_challenge, + 'code_challenge_method': 'S256', + 'state': state, + 'redirect_uri': self._oauth_redirect_uri, + 'scope': self._oauth_scope, + 'login_hint': login_hint, + } + + # Make PAR request with DPoP + dpop_nonce, response = self._make_token_request( + token_url=par_url, + params=params, + dpop_key=dpop_key, + dpop_nonce='', # Initial request has no nonce + ) + + if response.status_code not in (200, 201): + raise OAuthTokenError(f'PAR request failed: {response.status_code} {response.text}') + + par_response = response.json() + return par_response['request_uri'], dpop_nonce + + + def refresh_session(self, session: OAuthSession) -> OAuthSession: + """Refresh OAuth session tokens. + + Args: + session: Current OAuth session. + + Returns: + Updated OAuth session with new tokens. + + Raises: + OAuthTokenError: If token refresh fails. + """ + # Fetch current auth server metadata + authserver_meta = fetch_authserver_metadata(session.authserver_iss) + + # Prepare refresh token request + params = { + 'grant_type': 'refresh_token', + 'refresh_token': session.refresh_jwt, + } + + # Make token request with DPoP + dpop_nonce, response = self._make_token_request( + token_url=authserver_meta.token_endpoint, + params=params, + dpop_key=session.dpop_private_key, + dpop_nonce=session.dpop_authserver_nonce, + ) + + if response.status_code not in (200, 201): + raise OAuthTokenError(f'Token refresh failed: {response.status_code} {response.text}') + + token_data = response.json() + token_response = TokenResponse( + access_token=token_data['access_token'], + token_type=token_data['token_type'], + scope=token_data['scope'], + sub=token_data['sub'], + refresh_token=token_data.get('refresh_token', session.refresh_jwt), + expires_in=token_data.get('expires_in'), + ) + + # Update session + session.access_jwt = token_response.access_token + session.refresh_jwt = token_response.refresh_token + session.dpop_authserver_nonce = dpop_nonce + + return session + + def revoke_session(self, session: OAuthSession) -> None: + """Revoke OAuth session tokens. + + Args: + session: OAuth session to revoke. + """ + authserver_meta = fetch_authserver_metadata(session.authserver_iss) + + if not authserver_meta.revocation_endpoint: + # Revocation not supported + return + + # Revoke both access and refresh tokens + for token_type in ['access_token', 'refresh_token']: + token = session.access_jwt if token_type == 'access_token' else session.refresh_jwt # noqa: S105 - token type identifier, not a password + if not token: + continue + + params = { + 'token': token, + 'token_type_hint': token_type, + } + + # Best-effort revocation; failures are intentionally silent + with contextlib.suppress(OAuthTokenError, ValueError): + self._make_token_request( + token_url=authserver_meta.revocation_endpoint, + params=params, + dpop_key=session.dpop_private_key, + dpop_nonce=session.dpop_authserver_nonce, + ) + + + def import_oauth_session(self, session: OAuthSession) -> None: + self._set_oauth_session(SessionEvent.IMPORT, session) + self._ensure_oauth_initialized() + + def _set_oauth_session(self, event: SessionEvent, session: OAuthSession) -> None: + # Update base URL to PDS so all requests go through PDS (which proxies to AppView) + self.update_base_url(session.pds_url) + + if not self._oauth_session: + self._oauth_session = OAuthSession( + access_jwt=session.access_jwt, + refresh_jwt=session.refresh_jwt, + did=session.did, + authserver_iss=session.authserver_iss, + handle=session.handle, + pds_url=session.pds_url, + dpop_private_key=session.dpop_private_key, + dpop_authserver_nonce=session.dpop_authserver_nonce, + scope=session.scope, + dpop_pds_nonce=session.dpop_pds_nonce, + expires_at=session.expires_at, + created_at=session.created_at + ) + + else: + logger.info('_set_oauth_session: session {}'.format(self._oauth_session)) + self._oauth_session.access_jwt = session.access_jwt + self._oauth_session.refresh_jwt = session.refresh_jwt + self._oauth_session.authserver_iss = session.authserver_iss + self._oauth_session.did = session.did + self._oauth_session.handle = session.handle + self._oauth_session.pds_url = session.pds_url + self._oauth_session.dpop_private_key = session.dpop_private_key + self._oauth_session.dpop_authserver_nonce = session.dpop_authserver_nonce + self._oauth_session.scope = session.scope + self._oauth_session.dpop_pds_nonce = session.dpop_pds_nonce + self._oauth_session.expires_at = session.expires_at + self._oauth_session.created_at = session.created_at + + def _is_oauth_session(self) -> bool: + return self._oauth_session is not None diff --git a/packages/atproto_client/client/methods_mixin/pkce.py b/packages/atproto_client/client/methods_mixin/pkce.py new file mode 100644 index 00000000..642e483c --- /dev/null +++ b/packages/atproto_client/client/methods_mixin/pkce.py @@ -0,0 +1,57 @@ +"""PKCE (Proof Key for Code Exchange) implementation.""" + +import base64 +import hashlib +import secrets +import typing as t + + +class PKCEManager: + """Manages PKCE code verifier and challenge generation.""" + + @staticmethod + def generate_verifier(length: int = 128) -> str: + """Generate a PKCE code verifier. + + Args: + length: Length of the verifier (43-128 characters). + + Returns: + Base64url-encoded verifier string. + + Raises: + ValueError: If length is not between 43 and 128. + """ + if not 43 <= length <= 128: + raise ValueError('PKCE verifier length must be between 43 and 128') + + # Generate random bytes and encode as base64url + verifier_bytes = secrets.token_bytes(length) + return base64.urlsafe_b64encode(verifier_bytes).decode('utf-8').rstrip('=')[:length] + + @staticmethod + def generate_challenge(verifier: str) -> str: + """Generate S256 PKCE code challenge from verifier. + + Args: + verifier: The code verifier string. + + Returns: + Base64url-encoded SHA256 hash of the verifier. + """ + digest = hashlib.sha256(verifier.encode('utf-8')).digest() + return base64.urlsafe_b64encode(digest).decode('utf-8').rstrip('=') + + @classmethod + def generate_pair(cls, length: int = 128) -> t.Tuple[str, str]: + """Generate both verifier and challenge. + + Args: + length: Length of the verifier. + + Returns: + Tuple of (verifier, challenge). + """ + verifier = cls.generate_verifier(length) + challenge = cls.generate_challenge(verifier) + return verifier, challenge diff --git a/packages/atproto_client/client/oauth_session.py b/packages/atproto_client/client/oauth_session.py new file mode 100644 index 00000000..dc3c9c08 --- /dev/null +++ b/packages/atproto_client/client/oauth_session.py @@ -0,0 +1,327 @@ +import typing as t +from dataclasses import dataclass, field +from datetime import datetime, timezone +from urllib.parse import urlparse + +import httpx + +from atproto_client.exceptions import AuthServerMetadata, UnsupportedAuthServerError + +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey + +# Hardened HTTP client configuration +DEFAULT_TIMEOUT = 5.0 +MAX_REDIRECTS = 3 +ALLOWED_SCHEMES = {'https', 'http'} # http only for localhost +BLOCKED_HOSTS = { + '0.0.0.0', # noqa: S104 - blocking connections to this IP, not binding to this address + '127.0.0.1', + 'localhost', + '::1', + '169.254.169.254', # AWS metadata + 'metadata.google.internal', # GCP metadata +} + +def _is_private_ip(hostname: str) -> bool: + """Check if hostname is a private IP address.""" + if hostname.startswith('10.'): + return True + if hostname.startswith('172.'): + try: + second_octet = int(hostname.split('.')[1]) + if 16 <= second_octet <= 31: + return True + except (IndexError, ValueError): + pass + return hostname.startswith('192.168.') + + +def is_safe_url(url: str, allow_localhost: bool = False) -> bool: + """Validate URL for security (SSRF protection). + + Args: + url: URL to validate. + allow_localhost: Whether to allow localhost URLs. + + Returns: + True if URL is safe to use. + """ + try: + parsed = urlparse(url) + except ValueError: + return False + + if parsed.scheme not in ALLOWED_SCHEMES: + return False + + # For http, only allow localhost if explicitly permitted + if parsed.scheme == 'http' and (not allow_localhost or parsed.hostname not in ('localhost', '127.0.0.1', '::1')): + return False + + if parsed.hostname in BLOCKED_HOSTS and not allow_localhost: + return False + + return not (parsed.hostname and _is_private_ip(parsed.hostname)) + +def validate_authserver_metadata(metadata: t.Dict[str, t.Any], fetch_url: str) -> None: + """Validate authorization server metadata against ATProto requirements. + + Args: + metadata: Metadata dictionary from server. + fetch_url: URL where metadata was fetched from. + + Raises: + ValueError: If metadata doesn't meet requirements. + """ + issuer_url = urlparse(metadata['issuer']) + fetch_parsed = urlparse(fetch_url) + + # Issuer must match fetch URL host + if issuer_url.hostname != fetch_parsed.hostname: + raise ValueError(f'Issuer hostname mismatch: {issuer_url.hostname} != {fetch_parsed.hostname}') + + # Issuer must be HTTPS with no path/params/fragment + if issuer_url.scheme != 'https': + raise ValueError(f'Issuer must be HTTPS: {issuer_url.scheme}') + if issuer_url.port is not None: + raise ValueError(f'Issuer must not have explicit port: {issuer_url.port}') + if issuer_url.path not in ('', '/'): + raise ValueError(f'Issuer must not have path: {issuer_url.path}') + if issuer_url.params or issuer_url.fragment: + raise ValueError('Issuer must not have params or fragment') + + # Check required grant types and methods + required_checks = [ + ('code' in metadata.get('response_types_supported', []), 'response_types_supported must include "code"'), + ( + 'authorization_code' in metadata.get('grant_types_supported', []), + 'grant_types_supported must include "authorization_code"', + ), + ( + 'refresh_token' in metadata.get('grant_types_supported', []), + 'grant_types_supported must include "refresh_token"', + ), + ( + 'S256' in metadata.get('code_challenge_methods_supported', []), + 'code_challenge_methods_supported must include "S256"', + ), + ( + 'none' in metadata.get('token_endpoint_auth_methods_supported', []), + 'token_endpoint_auth_methods_supported must include "none"', + ), + ( + 'private_key_jwt' in metadata.get('token_endpoint_auth_methods_supported', []), + 'token_endpoint_auth_methods_supported must include "private_key_jwt"', + ), + ( + 'ES256' in metadata.get('token_endpoint_auth_signing_alg_values_supported', []), + 'token_endpoint_auth_signing_alg_values_supported must include "ES256"', + ), + ('atproto' in metadata.get('scopes_supported', []), 'scopes_supported must include "atproto"'), + ( + metadata.get('authorization_response_iss_parameter_supported') is True, + 'authorization_response_iss_parameter_supported must be true', + ), + ( + metadata.get('pushed_authorization_request_endpoint') is not None, + 'pushed_authorization_request_endpoint is required', + ), + ( + metadata.get('require_pushed_authorization_requests') is True, + 'require_pushed_authorization_requests must be true', + ), + ( + 'ES256' in metadata.get('dpop_signing_alg_values_supported', []), + 'dpop_signing_alg_values_supported must include "ES256"', + ), + ( + metadata.get('client_id_metadata_document_supported') is True, + 'client_id_metadata_document_supported must be true', + ), + ] + + for check, error_msg in required_checks: + if not check: + raise ValueError(error_msg) +@dataclass +class TokenResponse: + """Token response from authorization server.""" + + access_token: str + token_type: str + scope: str + sub: str # DID + refresh_token: t.Optional[str] = None + expires_in: t.Optional[int] = None + +def discover_authserver_from_pds(pds_url: str, timeout: float = 5.0) -> str: + """Discover authorization server URL from PDS. + + Args: + pds_url: PDS endpoint URL. + timeout: Request timeout in seconds. + + Returns: + Authorization server URL. + + Raises: + ValueError: If PDS URL is unsafe or response is invalid. + httpx.HTTPError: If request fails. + """ + if not is_safe_url(pds_url): + raise ValueError(f'Unsafe PDS URL: {pds_url}') + + with httpx.Client(timeout=timeout) as client: + response = client.get(f'{pds_url}/.well-known/oauth-protected-resource') + response.raise_for_status() + + if response.status_code != 200: + raise ValueError(f'PDS returned non-200 status: {response.status_code}') + + data = response.json() + if not isinstance(data, dict) or 'authorization_servers' not in data: + raise ValueError('Invalid oauth-protected-resource response') + + auth_servers = data['authorization_servers'] + if not auth_servers or not isinstance(auth_servers, list): + raise ValueError('No authorization servers found') + + return auth_servers[0] + +def fetch_authserver_metadata(authserver_url: str, timeout: float = 5.0) -> AuthServerMetadata: + """Fetch and validate authorization server metadata. + + Args: + authserver_url: Authorization server URL. + timeout: Request timeout in seconds. + + Returns: + Validated metadata object. + + Raises: + ValueError: If URL is unsafe. + UnsupportedAuthServerError: If metadata doesn't meet requirements. + httpx.HTTPError: If request fails. + """ + if not is_safe_url(authserver_url): + raise ValueError(f'Unsafe authorization server URL: {authserver_url}') + + fetch_url = f'{authserver_url}/.well-known/oauth-authorization-server' + + with httpx.Client(timeout=timeout) as client: + response = client.get(fetch_url) + response.raise_for_status() + + metadata_dict = response.json() + + # Validate against ATProto requirements + try: + validate_authserver_metadata(metadata_dict, fetch_url) + except ValueError as e: + raise UnsupportedAuthServerError(str(e)) from e + + # Parse into model + return AuthServerMetadata( + issuer=metadata_dict['issuer'], + authorization_endpoint=metadata_dict['authorization_endpoint'], + token_endpoint=metadata_dict['token_endpoint'], + pushed_authorization_request_endpoint=metadata_dict['pushed_authorization_request_endpoint'], + response_types_supported=metadata_dict['response_types_supported'], + grant_types_supported=metadata_dict['grant_types_supported'], + code_challenge_methods_supported=metadata_dict['code_challenge_methods_supported'], + token_endpoint_auth_methods_supported=metadata_dict['token_endpoint_auth_methods_supported'], + token_endpoint_auth_signing_alg_values_supported=metadata_dict[ + 'token_endpoint_auth_signing_alg_values_supported' + ], + scopes_supported=metadata_dict['scopes_supported'], + dpop_signing_alg_values_supported=metadata_dict['dpop_signing_alg_values_supported'], + authorization_response_iss_parameter_supported=metadata_dict[ + 'authorization_response_iss_parameter_supported' + ], + require_pushed_authorization_requests=metadata_dict['require_pushed_authorization_requests'], + client_id_metadata_document_supported=metadata_dict['client_id_metadata_document_supported'], + revocation_endpoint=metadata_dict.get('revocation_endpoint'), + jwks_uri=metadata_dict.get('jwks_uri'), + require_request_uri_registration=metadata_dict.get('require_request_uri_registration'), + ) + + +@dataclass +class AuthServerMetadata: + """Authorization Server metadata from discovery.""" + + issuer: str + authorization_endpoint: str + token_endpoint: str + pushed_authorization_request_endpoint: str + response_types_supported: t.List[str] + grant_types_supported: t.List[str] + code_challenge_methods_supported: t.List[str] + token_endpoint_auth_methods_supported: t.List[str] + token_endpoint_auth_signing_alg_values_supported: t.List[str] + scopes_supported: t.List[str] + dpop_signing_alg_values_supported: t.List[str] + authorization_response_iss_parameter_supported: bool + require_pushed_authorization_requests: bool + client_id_metadata_document_supported: bool + revocation_endpoint: t.Optional[str] = None + jwks_uri: t.Optional[str] = None + require_request_uri_registration: t.Optional[bool] = None + + +@dataclass +class OAuthSession: + """OAuth session with tokens and metadata.""" + + did: str + handle: str + pds_url: str + authserver_iss: str + access_jwt: str + refresh_jwt: str + dpop_private_key: 'EllipticCurvePrivateKey' + dpop_authserver_nonce: str + scope: str + dpop_pds_nonce: t.Optional[str] = None + expires_at: t.Optional[datetime] = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + +@dataclass +class OAuthState: + """OAuth state for CSRF protection during authorization flow.""" + + state: str + pkce_verifier: str + redirect_uri: str + scope: str + authserver_iss: str + dpop_private_key: 'EllipticCurvePrivateKey' + dpop_authserver_nonce: str + did: t.Optional[str] = None + handle: t.Optional[str] = None + pds_url: t.Optional[str] = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +@dataclass +class AuthServerMetadata: + """Authorization Server metadata from discovery.""" + + issuer: str + authorization_endpoint: str + token_endpoint: str + pushed_authorization_request_endpoint: str + response_types_supported: t.List[str] + grant_types_supported: t.List[str] + code_challenge_methods_supported: t.List[str] + token_endpoint_auth_methods_supported: t.List[str] + token_endpoint_auth_signing_alg_values_supported: t.List[str] + scopes_supported: t.List[str] + dpop_signing_alg_values_supported: t.List[str] + authorization_response_iss_parameter_supported: bool + require_pushed_authorization_requests: bool + client_id_metadata_document_supported: bool + revocation_endpoint: t.Optional[str] = None + jwks_uri: t.Optional[str] = None + require_request_uri_registration: t.Optional[bool] = None diff --git a/packages/atproto_client/exceptions.py b/packages/atproto_client/exceptions.py index 5e270943..6943c472 100644 --- a/packages/atproto_client/exceptions.py +++ b/packages/atproto_client/exceptions.py @@ -38,3 +38,18 @@ class BadRequestError(RequestErrorBase): ... class LoginRequiredError(AtProtocolError): def __init__(self, message: t.Optional[str] = _DEFAULT_LOGING_REQUIRED_ERROR_MESSAGE) -> None: super().__init__(message) + +"""OAuth-specific exceptions.""" +class OAuthError(AtProtocolError): + """Base exception for OAuth errors.""" + +class OAuthStateError(OAuthError): ... + + +class OAuthTokenError(OAuthError): ... + + +class UnsupportedAuthServerError(OAuthError): ... + + +class AuthServerMetadata(OAuthError): ... diff --git a/tests/test_atproto_client/client/test_oauth.py b/tests/test_atproto_client/client/test_oauth.py new file mode 100644 index 00000000..79be8d2c --- /dev/null +++ b/tests/test_atproto_client/client/test_oauth.py @@ -0,0 +1,124 @@ +"""Tests for OAuth mixin and related components.""" + +import json +import typing as t +from base64 import urlsafe_b64decode +from unittest.mock import MagicMock + +import pytest +from atproto_client.client.client import Client +from atproto_client.client.methods_mixin.dpop import DPoPManager +from atproto_client.client.methods_mixin.pkce import PKCEManager + + +class TestPKCEManager: + """Tests for PKCEManager.""" + + def test_generate_pair_produces_valid_challenge(self) -> None: + """PKCE pair should produce matching verifier and challenge.""" + verifier, challenge = PKCEManager.generate_pair() + + assert len(verifier) == 128 + assert PKCEManager.generate_challenge(verifier) == challenge + + +class TestDPoPManager: + """Tests for DPoPManager.""" + + def test_generate_keypair(self) -> None: + """Should generate ES256 keypair on P-256 curve.""" + from cryptography.hazmat.primitives.asymmetric import ec + + key = DPoPManager.generate_keypair() + assert isinstance(key, ec.EllipticCurvePrivateKey) + assert isinstance(key.curve, ec.SECP256R1) + + def test_create_proof_returns_valid_jwt(self) -> None: + """Proof should be valid JWT with required DPoP claims.""" + key = DPoPManager.generate_keypair() + proof = DPoPManager.create_proof( + method='POST', + url='https://example.com/token?query=ignored', + private_key=key, + nonce='test-nonce', + ) + + # Decode and verify structure + parts = proof.split('.') + assert len(parts) == 3 + + header = json.loads(urlsafe_b64decode(parts[0] + '==')) + payload = json.loads(urlsafe_b64decode(parts[1] + '==')) + + assert header['typ'] == 'dpop+jwt' + assert header['alg'] == 'ES256' + assert 'jwk' in header + + assert payload['htm'] == 'POST' + assert payload['htu'] == 'https://example.com/token' # query stripped + assert payload['nonce'] == 'test-nonce' + + @pytest.mark.parametrize( + 'status_code,headers,json_body,expected', + [ + # WWW-Authenticate header detection + (401, {'WWW-Authenticate': 'DPoP error="use_dpop_nonce"'}, None, True), + # JSON body detection + (400, {}, {'error': 'use_dpop_nonce'}, True), + # Not a nonce error - different error + (401, {'WWW-Authenticate': 'Bearer error="invalid_token"'}, None, False), + # Not a nonce error - success response + (200, {}, None, False), + ], + ) + def test_is_dpop_nonce_error( + self, status_code: int, headers: dict, json_body: t.Optional[dict], expected: bool + ) -> None: + """Should detect DPoP nonce errors from various response formats.""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.headers = headers + mock_response.json.return_value = json_body + + assert DPoPManager.is_dpop_nonce_error(mock_response) is expected + + @pytest.mark.parametrize( + 'headers,expected_nonce', + [ + ({'DPoP-Nonce': 'server-nonce-123'}, 'server-nonce-123'), + ({'dpop-nonce': 'lowercase-nonce'}, 'lowercase-nonce'), + ({}, None), + ], + ) + def test_extract_nonce_from_response(self, headers: dict, expected_nonce: t.Optional[str]) -> None: + """Should extract DPoP nonce from response headers.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = headers + + assert DPoPManager.extract_nonce_from_response(mock_response) == expected_nonce + + +class TestOAuthMixin: + """Tests for OauthSessionMethodsMixin integration with Client.""" + + def test_client_works_without_oauth_params(self) -> None: + """Client should instantiate normally without OAuth configuration.""" + client = Client() + assert client._oauth_client_id is None + assert client._oauth_initialized is False + + def test_client_accepts_oauth_params(self) -> None: + """Client should accept and store OAuth configuration.""" + client = Client( + client_id='https://example.com/client', + redirect_uri='https://example.com/callback', + scope='atproto', + ) + assert client._oauth_client_id == 'https://example.com/client' + + def test_ensure_oauth_initialized_requires_config(self) -> None: + """_ensure_oauth_initialized should raise without OAuth config.""" + client = Client() + with pytest.raises(ValueError, match='OAuth not configured'): + client._ensure_oauth_initialized()