diff --git a/backend/alembic/versions/add_codex_oauth_to_llm_models.py b/backend/alembic/versions/add_codex_oauth_to_llm_models.py new file mode 100644 index 000000000..dc90ba8d0 --- /dev/null +++ b/backend/alembic/versions/add_codex_oauth_to_llm_models.py @@ -0,0 +1,57 @@ +"""Add Codex OAuth columns to llm_models and make api_key_encrypted nullable. + +Revision ID: add_codex_oauth_to_llm_models +Revises: increase_api_key_length +Create Date: 2026-04-22 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "add_codex_oauth_to_llm_models" +down_revision: Union[str, None] = "increase_api_key_length" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # OAuth-mode models don't have a static api_key, so relax the NOT NULL + op.execute("ALTER TABLE llm_models ALTER COLUMN api_key_encrypted DROP NOT NULL") + + op.add_column( + "llm_models", + sa.Column( + "auth_type", + sa.String(20), + nullable=False, + server_default="static", + ), + ) + op.add_column( + "llm_models", + sa.Column("oauth_access_token_encrypted", sa.String(4096), nullable=True), + ) + op.add_column( + "llm_models", + sa.Column("oauth_refresh_token_encrypted", sa.String(1024), nullable=True), + ) + op.add_column( + "llm_models", + sa.Column("oauth_expires_at", sa.DateTime(timezone=True), nullable=True), + ) + op.add_column( + "llm_models", + sa.Column("oauth_account_id", sa.String(255), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("llm_models", "oauth_account_id") + op.drop_column("llm_models", "oauth_expires_at") + op.drop_column("llm_models", "oauth_refresh_token_encrypted") + op.drop_column("llm_models", "oauth_access_token_encrypted") + op.drop_column("llm_models", "auth_type") + # Revert to NOT NULL; assumes no rows with null api_key_encrypted remain. + op.execute("ALTER TABLE llm_models ALTER COLUMN api_key_encrypted SET NOT NULL") diff --git a/backend/app/api/codex_oauth.py b/backend/app/api/codex_oauth.py new file mode 100644 index 000000000..3286028bf --- /dev/null +++ b/backend/app/api/codex_oauth.py @@ -0,0 +1,410 @@ +"""Codex OAuth provisioning API. + +Flow (browser): + 1) POST /llm-models/codex-oauth/start — backend mints PKCE verifier/state, + tries to bind a local loopback listener on 127.0.0.1:1455, returns the + authorize URL (+ whether the loopback is available). + 2) User's browser navigates to the authorize URL, logs in with ChatGPT, and + is redirected back to http://localhost:1455/auth/callback?code=...&state=... + 3) Frontend polls GET /llm-models/codex-oauth/poll?state=X until a code + appears (loopback mode), OR the user manually pastes the redirect URL. + 4) POST /llm-models/codex-oauth/complete — backend exchanges code for tokens + and creates an LLMModel row with auth_type='codex_oauth'. + +Alternative (Mode B / no loopback): + POST /llm-models/codex-oauth/paste-creds — user pastes access/refresh/expiry + (e.g. from a local `codex login` run) and Clawith starts managing refresh. + +The loopback listener uses stdlib http.server in a background thread. Only one +binding per process; if the port is already held (by another Clawith worker or +an unrelated process), `/start` returns loopback_ready=false and the frontend +must fall back to Mode B or to manual URL paste. +""" + +from __future__ import annotations + +import threading +import uuid +from datetime import datetime, timedelta, timezone +from uuid import UUID +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any +from urllib.parse import parse_qs, urlparse + +from fastapi import APIRouter, Depends, HTTPException, status +from loguru import logger +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.core.security import encrypt_data, get_current_admin +from app.database import get_db +from app.models.llm import LLMModel +from app.services.llm.codex_oauth import ( + CODEX_OAUTH_MODELS, + LOOPBACK_HOST, + LOOPBACK_PORT, + REDIRECT_URI, + build_authorize_url, + decode_account_id, + exchange_code, + generate_pkce, + generate_state, +) + +router = APIRouter(prefix="/llm-models/codex-oauth", tags=["codex-oauth"]) + +# ─── In-memory OAuth session cache (per backend process) ────────────────────── +_SESSION_TTL = timedelta(minutes=10) +_sessions_lock = threading.Lock() +_sessions: dict[str, dict[str, Any]] = {} +# Keyed by state. Each entry: {"verifier": str, "expires_at": datetime, "code": str | None, "error": str | None} + +_listener_lock = threading.Lock() +_listener: HTTPServer | None = None +_listener_thread: threading.Thread | None = None + + +def _put_session(state: str, verifier: str) -> None: + _gc_sessions() + with _sessions_lock: + _sessions[state] = { + "verifier": verifier, + "expires_at": datetime.now(tz=timezone.utc) + _SESSION_TTL, + "code": None, + "error": None, + } + + +def _get_session(state: str) -> dict[str, Any] | None: + with _sessions_lock: + entry = _sessions.get(state) + if entry is None: + return None + if entry["expires_at"] < datetime.now(tz=timezone.utc): + _sessions.pop(state, None) + return None + return dict(entry) + + +def _record_callback(state: str, code: str | None, error: str | None) -> None: + with _sessions_lock: + entry = _sessions.get(state) + if entry is None: + return + if code: + entry["code"] = code + if error: + entry["error"] = error + + +def _consume_session(state: str) -> dict[str, Any] | None: + """Read-and-remove a session on successful code exchange.""" + with _sessions_lock: + return _sessions.pop(state, None) + + +def _gc_sessions() -> None: + now = datetime.now(tz=timezone.utc) + with _sessions_lock: + dead = [s for s, v in _sessions.items() if v["expires_at"] < now] + for s in dead: + _sessions.pop(s, None) + + +# ─── Loopback listener ──────────────────────────────────────────────────────── +class _CallbackHandler(BaseHTTPRequestHandler): + def do_GET(self) -> None: # noqa: N802 — required method name + parsed = urlparse(self.path) + if parsed.path != "/auth/callback": + self.send_response(404) + self.end_headers() + return + + params = parse_qs(parsed.query) + state = (params.get("state") or [""])[0] + code = (params.get("code") or [None])[0] + error = (params.get("error") or [None])[0] + + if not state: + self.send_response(400) + self.end_headers() + self.wfile.write(b"Missing state") + return + + _record_callback(state, code, error) + + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write( + b"
" + b"You can close this tab and return to Clawith.
" + b"" + ) + + def log_message(self, fmt: str, *args: Any) -> None: # noqa: A003 + # Suppress default stderr spam from BaseHTTPRequestHandler + return + + +def _ensure_listener() -> bool: + """Start the loopback listener if not already running in this process. + + Returns True if it's running (freshly or pre-existing) and False if the + port couldn't be bound (typically because another process holds it). + """ + global _listener, _listener_thread + with _listener_lock: + if _listener is not None: + return True + try: + _listener = HTTPServer((LOOPBACK_HOST, LOOPBACK_PORT), _CallbackHandler) + except OSError as e: + logger.warning(f"[codex_oauth] Could not bind {LOOPBACK_HOST}:{LOOPBACK_PORT}: {e}") + _listener = None + return False + _listener_thread = threading.Thread( + target=_listener.serve_forever, + name="codex-oauth-loopback", + daemon=True, + ) + _listener_thread.start() + logger.info(f"[codex_oauth] Loopback listener bound to {LOOPBACK_HOST}:{LOOPBACK_PORT}") + return True + + +# ─── Schemas ────────────────────────────────────────────────────────────────── +class StartRequest(BaseModel): + pass + + +class StartResponse(BaseModel): + authorize_url: str + state: str + redirect_uri: str + loopback_ready: bool + manual_paste_hint: str = Field( + default=( + "If the loopback isn't available (port 1455 busy, or Clawith is on a " + "remote host), complete the login in your browser, then paste the full " + "redirect URL (or just the code) into /complete." + ) + ) + + +class PollResponse(BaseModel): + code: str | None = None + error: str | None = None + expired: bool = False + + +class CompleteRequest(BaseModel): + state: str + code: str + label: str + model: str + enabled: bool = True + + +class PasteCredsRequest(BaseModel): + access_token: str + refresh_token: str + expires_in_seconds: int = 3600 + account_id: str | None = None + label: str + model: str + enabled: bool = True + + +class ModelCreatedResponse(BaseModel): + id: uuid.UUID + label: str + provider: str + model: str + oauth_account_id: str | None = None + + +# ─── Helpers ────────────────────────────────────────────────────────────────── +def _reject_non_codex_model(model: str) -> None: + if model not in CODEX_OAUTH_MODELS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Model '{model}' is not a Codex OAuth-supported model. " + f"Allowed: {', '.join(CODEX_OAUTH_MODELS)}" + ), + ) + + +def _resolve_tenant_id(tenant_id_override: str | None, current_user: Any) -> UUID | None: + """Resolve the target tenant for a new LLM model row. + + Mirrors the override semantics of `add_llm_model` in enterprise.py so a + platform admin managing another tenant can create OAuth models on that + tenant's behalf. Falls back to the caller's own tenant. + """ + raw = tenant_id_override or ( + str(current_user.tenant_id) if getattr(current_user, "tenant_id", None) else None + ) + if not raw: + return None + try: + return UUID(raw) + except (TypeError, ValueError) as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid tenant_id: {raw!r}", + ) from e + + +async def _insert_oauth_model( + db: AsyncSession, + *, + current_user: Any, + tenant_id_override: str | None, + label: str, + model: str, + access_token: str, + refresh_token_value: str, + expires_at: datetime, + account_id: str | None, +) -> LLMModel: + settings = get_settings() + row = LLMModel( + tenant_id=_resolve_tenant_id(tenant_id_override, current_user), + provider="codex-oauth", + model=model, + label=label, + auth_type="codex_oauth", + api_key_encrypted=None, + oauth_access_token_encrypted=encrypt_data(access_token, settings.SECRET_KEY), + oauth_refresh_token_encrypted=encrypt_data(refresh_token_value, settings.SECRET_KEY), + oauth_expires_at=expires_at, + oauth_account_id=account_id, + base_url="https://chatgpt.com/backend-api", + enabled=True, + ) + db.add(row) + await db.flush() + await db.commit() + await db.refresh(row) + return row + + +# ─── Endpoints ──────────────────────────────────────────────────────────────── +@router.post("/start", response_model=StartResponse) +async def start_oauth( + _req: StartRequest | None = None, + current_user: Any = Depends(get_current_admin), +) -> StartResponse: + """Kick off a Codex OAuth flow: mint PKCE+state, try to bind loopback, return authorize URL.""" + pkce = generate_pkce() + state = generate_state() + _put_session(state, pkce.verifier) + loopback_ready = _ensure_listener() + authorize_url = build_authorize_url(pkce.challenge, state) + return StartResponse( + authorize_url=authorize_url, + state=state, + redirect_uri=REDIRECT_URI, + loopback_ready=loopback_ready, + ) + + +@router.get("/poll", response_model=PollResponse) +async def poll_oauth( + state: str, + current_user: Any = Depends(get_current_admin), +) -> PollResponse: + """Check whether the loopback listener has received the auth code yet.""" + entry = _get_session(state) + if entry is None: + return PollResponse(expired=True) + return PollResponse(code=entry.get("code"), error=entry.get("error")) + + +@router.post("/complete", response_model=ModelCreatedResponse) +async def complete_oauth( + req: CompleteRequest, + tenant_id: str | None = None, + db: AsyncSession = Depends(get_db), + current_user: Any = Depends(get_current_admin), +) -> ModelCreatedResponse: + """Exchange the authorization code for tokens and persist the model. + + `tenant_id` (query) is optional and matches the semantics of + `POST /enterprise/llm-models`: a platform admin can provision the model + on behalf of another tenant. Without it, the caller's own tenant is used. + """ + _reject_non_codex_model(req.model) + + entry = _consume_session(req.state) + if entry is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="OAuth session not found or expired. Restart the flow via /start.", + ) + + try: + bundle = await exchange_code(req.code, entry["verifier"]) + except Exception as e: + logger.exception("[codex_oauth] exchange_code failed") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Token exchange failed: {e}", + ) from e + + row = await _insert_oauth_model( + db, + current_user=current_user, + tenant_id_override=tenant_id, + label=req.label, + model=req.model, + access_token=bundle.access_token, + refresh_token_value=bundle.refresh_token, + expires_at=bundle.expires_at, + account_id=bundle.account_id, + ) + return ModelCreatedResponse( + id=row.id, + label=row.label, + provider=row.provider, + model=row.model, + oauth_account_id=row.oauth_account_id, + ) + + +@router.post("/paste-creds", response_model=ModelCreatedResponse) +async def paste_credentials( + req: PasteCredsRequest, + tenant_id: str | None = None, + db: AsyncSession = Depends(get_db), + current_user: Any = Depends(get_current_admin), +) -> ModelCreatedResponse: + """Mode B fallback: import tokens directly (e.g. from ~/.codex/auth.json). + + `tenant_id` (query) — same override semantics as `/complete`. + """ + _reject_non_codex_model(req.model) + expires_at = datetime.now(tz=timezone.utc) + timedelta(seconds=max(0, req.expires_in_seconds)) + account_id = req.account_id or decode_account_id(req.access_token) + row = await _insert_oauth_model( + db, + current_user=current_user, + tenant_id_override=tenant_id, + label=req.label, + model=req.model, + access_token=req.access_token, + refresh_token_value=req.refresh_token, + expires_at=expires_at, + account_id=account_id, + ) + return ModelCreatedResponse( + id=row.id, + label=row.label, + provider=row.provider, + model=row.model, + oauth_account_id=row.oauth_account_id, + ) diff --git a/backend/app/api/enterprise.py b/backend/app/api/enterprise.py index efa7fd2d3..77d8901f0 100644 --- a/backend/app/api/enterprise.py +++ b/backend/app/api/enterprise.py @@ -151,6 +151,27 @@ async def list_llm_models( return models +_OAUTH_ONLY_PROVIDERS = {"codex-oauth"} + + +def _reject_oauth_provider_via_static_form(provider: str | None) -> None: + """Block creating/updating OAuth-backed providers through the static API-key form. + + These providers require OAuth tokens (not an API key) and must be provisioned + through their dedicated flow so both provider and auth_type stay consistent. + """ + if provider and provider.strip().lower() in _OAUTH_ONLY_PROVIDERS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Provider '{provider}' cannot be created or updated via the " + "static form. Use POST /api/llm-models/codex-oauth/start for the " + "full OAuth flow, or POST /api/llm-models/codex-oauth/paste-creds " + "to import existing tokens." + ), + ) + + @router.post("/llm-models", response_model=LLMModelOut, status_code=status.HTTP_201_CREATED) async def add_llm_model( data: LLMModelCreate, @@ -159,6 +180,7 @@ async def add_llm_model( db: AsyncSession = Depends(get_db), ): """Add a new LLM model to the tenant's pool (admin).""" + _reject_oauth_provider_via_static_form(data.provider) tid = tenant_id or (str(current_user.tenant_id) if current_user.tenant_id else None) model = LLMModel( provider=data.provider, @@ -235,6 +257,32 @@ async def update_llm_model( if not model: raise HTTPException(status_code=404, detail="Model not found") + # Guard against mis-routing between static and OAuth provisioning paths. + # Edit forms re-submit the current provider on every save, so we must + # distinguish "this row is already OAuth" (allow provider='codex-oauth' as + # an unchanged passthrough) from "this row is static but caller wants to + # switch to an OAuth-only slug" (reject — use the dedicated endpoints). + if model.auth_type == "codex_oauth": + if data.provider and data.provider != model.provider: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot change provider on a Codex OAuth model — re-connect instead.", + ) + if hasattr(data, "base_url") and data.base_url is not None and data.base_url != model.base_url: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot change base_url on a Codex OAuth model.", + ) + if data.api_key and data.api_key.strip() and not data.api_key.startswith("****"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Codex OAuth models don't use an API key; rotate via re-connect.", + ) + else: + # Static row; block any attempt to convert it into an OAuth-only provider + # through this form. + _reject_oauth_provider_via_static_form(data.provider) + try: if data.provider: model.provider = data.provider diff --git a/backend/app/main.py b/backend/app/main.py index 896976db5..d66816081 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -303,6 +303,7 @@ def _bg_task_error(t): from app.api.pages import router as pages_router, public_router as pages_public_router from app.api.agent_credentials import router as credentials_router from app.api.agentbay_control import router as agentbay_control_router +from app.api.codex_oauth import router as codex_oauth_router app.include_router(auth_router, prefix=settings.API_PREFIX) app.include_router(agents_router, prefix=settings.API_PREFIX) @@ -344,6 +345,7 @@ def _bg_task_error(t): app.include_router(pages_public_router) # Public endpoint for /p/{short_id}, no API prefix app.include_router(credentials_router, prefix=settings.API_PREFIX) app.include_router(agentbay_control_router, prefix=settings.API_PREFIX) +app.include_router(codex_oauth_router, prefix=settings.API_PREFIX) @app.get("/api/health", response_model=HealthResponse, tags=["health"]) diff --git a/backend/app/models/llm.py b/backend/app/models/llm.py index cbe5fddd2..ba51687f1 100644 --- a/backend/app/models/llm.py +++ b/backend/app/models/llm.py @@ -19,7 +19,13 @@ class LLMModel(Base): tenant_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True) provider: Mapped[str] = mapped_column(String(50), nullable=False) # anthropic, openai, deepseek, etc. model: Mapped[str] = mapped_column(String(100), nullable=False) # claude-opus-4-6, gpt-4o, etc. - api_key_encrypted: Mapped[str] = mapped_column(String(1024), nullable=False) + # 'static' (default) uses api_key_encrypted; 'codex_oauth' uses oauth_* columns below. + auth_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default="static") + api_key_encrypted: Mapped[str | None] = mapped_column(String(1024), nullable=True) + oauth_access_token_encrypted: Mapped[str | None] = mapped_column(String(4096), nullable=True) + oauth_refresh_token_encrypted: Mapped[str | None] = mapped_column(String(1024), nullable=True) + oauth_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + oauth_account_id: Mapped[str | None] = mapped_column(String(255), nullable=True) base_url: Mapped[str | None] = mapped_column(String(500)) label: Mapped[str] = mapped_column(String(200), nullable=False) # Display name max_tokens_per_day: Mapped[int | None] = mapped_column(Integer) diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index 2fafa229a..40594afce 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -5089,6 +5089,7 @@ async def _send_message_to_agent(from_agent_id: uuid.UUID, args: dict) -> str: create_llm_client, LLMMessage, get_model_api_key, + get_llm_client_for_model, LLMError, ) from app.services.agent_tools import get_agent_tools_for_llm, execute_tool @@ -5109,11 +5110,8 @@ async def _send_message_to_agent(from_agent_id: uuid.UUID, args: dict) -> str: from app.services.token_tracker import record_token_usage, extract_usage_tokens, estimate_tokens_from_chars - llm_client = create_llm_client( - provider=target_model.provider, - api_key=get_model_api_key(target_model), - model=target_model.model, - base_url=base_url, + llm_client = get_llm_client_for_model( + target_model, timeout=float(getattr(target_model, 'request_timeout', None) or 120.0), ) _A2A_RETRYABLE_MARKERS = ( diff --git a/backend/app/services/heartbeat.py b/backend/app/services/heartbeat.py index dcaffec20..6b4bdb672 100644 --- a/backend/app/services/heartbeat.py +++ b/backend/app/services/heartbeat.py @@ -169,6 +169,10 @@ async def _execute_heartbeat(agent_id: uuid.UUID): model_temperature = model.temperature model_max_output_tokens = getattr(model, 'max_output_tokens', None) model_request_timeout = getattr(model, 'request_timeout', None) + model_auth_type = getattr(model, 'auth_type', 'static') + # Keep a reference to the detached model row for OAuth-backed calls (CodexOAuthClient + # opens its own session via session_factory on demand, so the detached instance is fine). + cached_model = model # Read HEARTBEAT.md if it exists, otherwise use default from pathlib import Path @@ -255,17 +259,30 @@ async def _execute_heartbeat(agent_id: uuid.UUID): full_instruction = heartbeat_instruction + recent_context + inbox_context # Call LLM with tools using unified client - from app.services.llm import create_llm_client, get_max_tokens, LLMMessage, LLMError, get_model_api_key + from app.services.llm import ( + create_llm_client, + get_llm_client_for_model, + get_max_tokens, + LLMMessage, + LLMError, + get_model_api_key, + ) from app.services.agent_tools import execute_tool, get_agent_tools_for_llm try: - client = create_llm_client( - provider=model_provider, - api_key=model_api_key, - model=model_model, - base_url=model_base_url, - timeout=float(model_request_timeout or 120.0), - ) + if model_auth_type == "codex_oauth": + client = get_llm_client_for_model( + cached_model, + timeout=float(model_request_timeout or 120.0), + ) + else: + client = create_llm_client( + provider=model_provider, + api_key=model_api_key, + model=model_model, + base_url=model_base_url, + timeout=float(model_request_timeout or 120.0), + ) except Exception as e: logger.error(f"Failed to create LLM client: {e}") return diff --git a/backend/app/services/llm/__init__.py b/backend/app/services/llm/__init__.py index 7bc80cab1..02af43bc9 100644 --- a/backend/app/services/llm/__init__.py +++ b/backend/app/services/llm/__init__.py @@ -31,7 +31,14 @@ ) from .client import LLMClient, LLMResponse, LLMError, LLMMessage from .failover import classify_error, FailoverErrorType -from .utils import create_llm_client, get_max_tokens, get_model_api_key, get_provider_base_url, get_provider_manifest +from .utils import ( + create_llm_client, + get_llm_client_for_model, + get_max_tokens, + get_model_api_key, + get_provider_base_url, + get_provider_manifest, +) __all__ = [ # Core caller functions @@ -51,6 +58,7 @@ "LLMMessage", # Utilities "create_llm_client", + "get_llm_client_for_model", "get_max_tokens", "get_model_api_key", "get_provider_base_url", diff --git a/backend/app/services/llm/caller.py b/backend/app/services/llm/caller.py index d3a3b66d1..39a204a30 100644 --- a/backend/app/services/llm/caller.py +++ b/backend/app/services/llm/caller.py @@ -27,7 +27,7 @@ from .client import LLMError from .failover import classify_error, FailoverErrorType -from .utils import LLMMessage, create_llm_client, get_max_tokens, get_model_api_key +from .utils import LLMMessage, create_llm_client, get_llm_client_for_model, get_max_tokens, get_model_api_key if TYPE_CHECKING: from app.models.agent import Agent @@ -329,13 +329,7 @@ async def call_llm( # Create the unified LLM client try: - client = create_llm_client( - provider=model.provider, - api_key=get_model_api_key(model), - model=model.model, - base_url=model.base_url, - timeout=_get_model_timeout(model), - ) + client = get_llm_client_for_model(model, timeout=_get_model_timeout(model)) except Exception as e: return f"[Error] Failed to create LLM client: {e}" @@ -692,13 +686,7 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]: _accumulated_tokens = 0 tool_executed = False try: - client = create_llm_client( - provider=model.provider, - api_key=get_model_api_key(model), - model=model.model, - base_url=model.base_url, - timeout=_get_model_timeout(model), - ) + client = get_llm_client_for_model(model, timeout=_get_model_timeout(model)) max_tokens = get_max_tokens( model.provider, model.model, diff --git a/backend/app/services/llm/client.py b/backend/app/services/llm/client.py index 9bd68788b..0a94b26c4 100644 --- a/backend/app/services/llm/client.py +++ b/backend/app/services/llm/client.py @@ -888,6 +888,127 @@ async def close(self) -> None: await self._client.aclose() +# ============================================================================ +# Codex OAuth Client (ChatGPT Plus/Pro Subscription) +# ============================================================================ + +class CodexOAuthClient(OpenAIResponsesClient): + """OAuth-authenticated client for Codex via chatgpt.com/backend-api. + + Uses the ChatGPT subscription OAuth flow (see app.services.llm.codex_oauth) + instead of a static API key. Token lifecycle (expiry check + refresh) is + handled against the llm_models row identified by `model_id`. + """ + + DEFAULT_BASE_URL = "https://chatgpt.com/backend-api" + + def __init__( + self, + model_id: Any, + session_factory: Any, + model: str, + timeout: float = 120.0, + supports_tool_choice: bool = True, + ): + super().__init__( + api_key="", + base_url=self.DEFAULT_BASE_URL, + model=model, + timeout=timeout, + supports_tool_choice=supports_tool_choice, + ) + self.model_id = model_id + self.session_factory = session_factory + self.account_id: str | None = None + + def _get_headers(self) -> dict[str, str]: + from .codex_oauth import ORIGINATOR, OPENAI_BETA + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + "OpenAI-Beta": OPENAI_BETA, + "originator": ORIGINATOR, + } + if self.account_id: + headers["chatgpt-account-id"] = self.account_id + return headers + + async def _ensure_fresh_token(self) -> None: + """Load access token from DB; refresh if near expiry, persisting the new tokens.""" + from sqlalchemy import select + from app.config import get_settings + from app.core.security import decrypt_data, encrypt_data + from app.models.llm import LLMModel + from .codex_oauth import is_near_expiry, refresh_token + + settings = get_settings() + async with self.session_factory() as session: + stmt = select(LLMModel).where(LLMModel.id == self.model_id).with_for_update() + row = (await session.execute(stmt)).scalar_one_or_none() + if row is None: + raise LLMError(f"Codex OAuth model {self.model_id} not found") + if row.auth_type != "codex_oauth" or not row.oauth_refresh_token_encrypted: + raise LLMError( + f"Model {self.model_id} is not configured for Codex OAuth" + ) + + expires_at = row.oauth_expires_at + needs_refresh = expires_at is None or is_near_expiry(expires_at) + + if needs_refresh: + refresh_plain = decrypt_data(row.oauth_refresh_token_encrypted, settings.SECRET_KEY) + bundle = await refresh_token(refresh_plain) + row.oauth_access_token_encrypted = encrypt_data(bundle.access_token, settings.SECRET_KEY) + row.oauth_refresh_token_encrypted = encrypt_data(bundle.refresh_token, settings.SECRET_KEY) + row.oauth_expires_at = bundle.expires_at + if bundle.account_id: + row.oauth_account_id = bundle.account_id + await session.commit() + self.api_key = bundle.access_token + self.account_id = bundle.account_id or row.oauth_account_id + else: + self.api_key = decrypt_data(row.oauth_access_token_encrypted or "", settings.SECRET_KEY) + self.account_id = row.oauth_account_id + + async def complete( + self, + messages: list[LLMMessage], + tools: list[dict] | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + **kwargs: Any, + ) -> LLMResponse: + await self._ensure_fresh_token() + return await super().complete( + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) + + async def stream( + self, + messages: list[LLMMessage], + tools: list[dict] | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + on_chunk: ChunkCallback | None = None, + on_thinking: ThinkingCallback | None = None, + **kwargs: Any, + ) -> LLMResponse: + await self._ensure_fresh_token() + return await super().stream( + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + on_chunk=on_chunk, + on_thinking=on_thinking, + **kwargs, + ) + + # ============================================================================ # Gemini Native Client # ============================================================================ @@ -1709,7 +1830,7 @@ class ProviderSpec: provider: str display_name: str - protocol: Literal["openai_compatible", "anthropic", "openai_responses", "gemini"] + protocol: Literal["openai_compatible", "anthropic", "openai_responses", "gemini", "codex_oauth"] default_base_url: str | None supports_tool_choice: bool = True default_max_tokens: int = 4096 @@ -1845,6 +1966,19 @@ class ProviderSpec: default_base_url=None, default_max_tokens=4096, ), + "codex-oauth": ProviderSpec( + provider="codex-oauth", + display_name="OpenAI Codex (ChatGPT Subscription)", + protocol="codex_oauth", + default_base_url="https://chatgpt.com/backend-api", + supports_tool_choice=True, + default_max_tokens=16384, + model_max_tokens={ + "gpt-5.1-codex": 16384, + "gpt-5.2-codex": 16384, + "gpt-5.1-codex-max": 32768, + }, + ), } @@ -1860,9 +1994,18 @@ def get_provider_spec(provider: str) -> ProviderSpec | None: def get_provider_manifest() -> list[dict[str, Any]]: - """List supported providers and capabilities for UI/config discovery.""" + """List supported providers and capabilities for UI/config discovery. + + Providers whose protocol requires out-of-band provisioning (e.g. an OAuth + flow) are intentionally hidden here so the generic "Add Model" form can't + create inconsistent rows. They are reachable through their dedicated + endpoints instead. + """ + hidden_protocols = {"codex_oauth"} out: list[dict[str, Any]] = [] for spec in PROVIDER_REGISTRY.values(): + if spec.protocol in hidden_protocols: + continue out.append({ "provider": spec.provider, "display_name": spec.display_name, @@ -1881,6 +2024,8 @@ def get_provider_manifest() -> list[dict[str, Any]]: spec.provider: ( AnthropicClient if spec.protocol == "anthropic" + else CodexOAuthClient + if spec.protocol == "codex_oauth" else OpenAIResponsesClient if spec.protocol == "openai_responses" else GeminiClient @@ -1959,6 +2104,9 @@ def create_llm_client( model: str, base_url: str | None = None, timeout: float = 120.0, + *, + model_id: Any = None, + session_factory: Any = None, ) -> LLMClient: """Create an LLM client for the given provider. @@ -1968,12 +2116,14 @@ def create_llm_client( model: Model name base_url: Optional custom base URL timeout: Request timeout in seconds + model_id: Required for codex_oauth protocol; identifies the llm_models row + session_factory: Required for codex_oauth protocol; async session factory for token refresh Returns: An instance of the appropriate LLMClient subclass Raises: - ValueError: If provider is not supported + ValueError: If provider is not supported or required OAuth context is missing. """ normalized_provider = normalize_provider(provider) spec = get_provider_spec(normalized_provider) @@ -1989,6 +2139,19 @@ def create_llm_client( model=model, timeout=timeout, ) + elif spec and spec.protocol == "codex_oauth": + if model_id is None or session_factory is None: + raise ValueError( + "codex_oauth provider requires model_id and session_factory; " + "use get_llm_client_for_model(model) instead." + ) + return CodexOAuthClient( + model_id=model_id, + session_factory=session_factory, + model=model, + timeout=timeout, + supports_tool_choice=spec.supports_tool_choice, + ) elif spec and spec.protocol == "openai_responses": return OpenAIResponsesClient( api_key=api_key, diff --git a/backend/app/services/llm/codex_oauth.py b/backend/app/services/llm/codex_oauth.py new file mode 100644 index 000000000..9ca033cb1 --- /dev/null +++ b/backend/app/services/llm/codex_oauth.py @@ -0,0 +1,216 @@ +"""OAuth 2.1 + PKCE client for OpenAI Codex (ChatGPT Plus/Pro subscription). + +Lets Clawith act as a third-party OAuth client against OpenAI's authorization +server so users can authenticate with their ChatGPT subscription instead of an +API key. The resulting access token is used to call the Codex inference endpoint +at https://chatgpt.com/backend-api/responses. + +All constants below are intentionally hardcoded — they mirror the values used by +the official Codex CLI and by community integrations (OpenClaw, Hermes, +numman-ali/opencode-openai-codex-auth). The client_id is a public, shared OSS +identifier, not a secret. + +This module contains only pure primitives (no DB, no HTTP server, no FastAPI +dependency) so it can be unit-tested in isolation. +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import os +import secrets +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from urllib.parse import urlencode + +import httpx + +# ─── OAuth constants ────────────────────────────────────────────────────────── +CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" +TOKEN_URL = "https://auth.openai.com/oauth/token" +# REDIRECT_URI must match what OpenAI registered for this public client_id. Do +# not change — the value is fixed on OpenAI's side. The loopback listener just +# needs to service this URL from the user's browser perspective. +REDIRECT_URI = "http://localhost:1455/auth/callback" +# Default loopback listener bind. In a Docker deployment, set +# CODEX_OAUTH_LOOPBACK_HOST=0.0.0.0 so the mapped host port can reach the +# container; the listener surface only handles /auth/callback with state, so +# exposing it inside a private network is safe. +LOOPBACK_HOST = os.environ.get("CODEX_OAUTH_LOOPBACK_HOST", "127.0.0.1") +LOOPBACK_PORT = int(os.environ.get("CODEX_OAUTH_LOOPBACK_PORT", "1455")) +SCOPE = "openid profile email offline_access" + +# ─── Inference constants ────────────────────────────────────────────────────── +CODEX_BASE_URL = "https://chatgpt.com/backend-api" +ORIGINATOR = "codex_cli_rs" +OPENAI_BETA = "responses=experimental" +# JWT claim that holds the ChatGPT account record; used for chatgpt-account-id header +JWT_AUTH_CLAIM = "https://api.openai.com/auth" + +# ─── Models the Codex OAuth endpoint accepts ────────────────────────────────── +CODEX_OAUTH_MODELS = ( + "gpt-5.1", + "gpt-5.1-codex", + "gpt-5.1-codex-mini", + "gpt-5.1-codex-max", + "gpt-5.2", + "gpt-5.2-codex", + "codex-mini-latest", +) + +# Refresh a bit before actual expiry to avoid racing against wall clock skew +_REFRESH_LEEWAY_SECONDS = 60 + + +@dataclass(frozen=True) +class PKCEPair: + verifier: str + challenge: str + + +@dataclass(frozen=True) +class TokenBundle: + access_token: str + refresh_token: str + expires_at: datetime # timezone-aware UTC + account_id: str | None = None + + +def generate_pkce() -> PKCEPair: + """Generate a PKCE verifier/challenge pair per RFC 7636 (S256).""" + # 64 URL-safe bytes → ~86 chars, within RFC 7636 limit (43–128) + verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(verifier.encode("ascii")).digest() + challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return PKCEPair(verifier=verifier, challenge=challenge) + + +def generate_state() -> str: + """CSRF state token for the OAuth flow.""" + return secrets.token_hex(16) + + +def build_authorize_url(challenge: str, state: str, redirect_uri: str = REDIRECT_URI) -> str: + """Build the OpenAI authorize URL for the PKCE flow. + + Includes `codex_cli_simplified_flow` / `id_token_add_organizations` / + `originator` params that the Codex CLI sends; without these the authorize + page behaves differently. + """ + params = { + "response_type": "code", + "client_id": CLIENT_ID, + "redirect_uri": redirect_uri, + "scope": SCOPE, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "originator": ORIGINATOR, + } + return f"{AUTHORIZE_URL}?{urlencode(params)}" + + +def _parse_token_response(payload: dict) -> TokenBundle: + access = payload.get("access_token") + refresh = payload.get("refresh_token") + expires_in = payload.get("expires_in") + if not (isinstance(access, str) and access and isinstance(refresh, str) and refresh): + raise ValueError(f"token response missing access_token/refresh_token: {payload!r}") + if not isinstance(expires_in, int): + raise ValueError(f"token response missing or non-integer expires_in: {payload!r}") + return TokenBundle( + access_token=access, + refresh_token=refresh, + expires_at=datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in), + account_id=decode_account_id(access), + ) + + +async def exchange_code( + code: str, + verifier: str, + redirect_uri: str = REDIRECT_URI, + *, + client: httpx.AsyncClient | None = None, +) -> TokenBundle: + """Exchange an authorization code (+ PKCE verifier) for access+refresh tokens.""" + data = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": redirect_uri, + } + return await _post_token(data, client) + + +async def refresh_token( + refresh_token_value: str, + *, + client: httpx.AsyncClient | None = None, +) -> TokenBundle: + """Trade a refresh_token for a new access+refresh pair.""" + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token_value, + "client_id": CLIENT_ID, + } + return await _post_token(data, client) + + +async def _post_token(data: dict, client: httpx.AsyncClient | None) -> TokenBundle: + owns_client = client is None + http = client or httpx.AsyncClient(timeout=30.0) + try: + resp = await http.post( + TOKEN_URL, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if resp.status_code != 200: + raise ValueError( + f"token endpoint returned {resp.status_code}: {resp.text[:500]}" + ) + return _parse_token_response(resp.json()) + finally: + if owns_client: + await http.aclose() + + +def decode_jwt_payload(token: str) -> dict | None: + """Base64-decode the payload segment of a JWT (no signature verification).""" + parts = token.split(".") + if len(parts) != 3: + return None + payload = parts[1] + # base64url — pad to multiple of 4 + padding = "=" * (-len(payload) % 4) + try: + raw = base64.urlsafe_b64decode(payload + padding) + return json.loads(raw.decode("utf-8")) + except Exception: + return None + + +def decode_account_id(access_token_jwt: str) -> str | None: + """Extract the ChatGPT account id from the access token JWT's auth claim.""" + payload = decode_jwt_payload(access_token_jwt) + if not isinstance(payload, dict): + return None + claim = payload.get(JWT_AUTH_CLAIM) + if not isinstance(claim, dict): + return None + account_id = claim.get("chatgpt_account_id") or claim.get("account_id") + return account_id if isinstance(account_id, str) else None + + +def is_near_expiry(expires_at: datetime, leeway_seconds: int = _REFRESH_LEEWAY_SECONDS) -> bool: + """True if the token expires within the leeway window (or is already expired).""" + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return datetime.now(tz=timezone.utc) + timedelta(seconds=leeway_seconds) >= expires_at diff --git a/backend/app/services/llm/utils.py b/backend/app/services/llm/utils.py index fbb3637b2..c36ea2b8e 100644 --- a/backend/app/services/llm/utils.py +++ b/backend/app/services/llm/utils.py @@ -10,11 +10,13 @@ from app.core.security import decrypt_data from app.config import get_settings +from app.database import async_session from app.models.llm import LLMModel # Re-export all client classes and functions from client.py from .client import ( AnthropicClient, + CodexOAuthClient, GeminiClient, LLMClient, LLMError, @@ -47,7 +49,10 @@ def get_model_api_key(model: LLMModel) -> str: - """Decrypt the model's API key, with backward compatibility for plaintext keys.""" + """Decrypt the model's API key, with backward compatibility for plaintext keys. + + Returns an empty string for OAuth-backed models (they have no static key). + """ raw = model.api_key_encrypted or "" if not raw: return "" @@ -58,6 +63,41 @@ def get_model_api_key(model: LLMModel) -> str: return raw +def get_llm_client_for_model( + model: LLMModel, + *, + timeout: float | None = None, + session_factory=None, +) -> LLMClient: + """Create the correct LLMClient for a model, dispatching on auth_type. + + Static-key providers continue to pull the decrypted `api_key_encrypted`. + OAuth-backed providers (auth_type='codex_oauth') bypass the static key and + delegate token lifecycle to CodexOAuthClient, which reads/refreshes tokens + from the llm_models row via the provided async session factory. + """ + effective_timeout = float(timeout if timeout is not None else (getattr(model, "request_timeout", None) or 120.0)) + + if getattr(model, "auth_type", "static") == "codex_oauth": + return create_llm_client( + provider=model.provider or "codex-oauth", + api_key="", + model=model.model, + base_url=model.base_url, + timeout=effective_timeout, + model_id=model.id, + session_factory=session_factory or async_session, + ) + + return create_llm_client( + provider=model.provider, + api_key=get_model_api_key(model), + model=model.model, + base_url=model.base_url, + timeout=effective_timeout, + ) + + def get_tool_params(provider: str) -> dict: """Return provider-specific tool calling parameters. @@ -82,12 +122,14 @@ def get_tool_params(provider: str) -> dict: "get_provider_base_url", "get_max_tokens", "get_model_api_key", + "get_llm_client_for_model", # New client classes "LLMClient", "OpenAICompatibleClient", "OpenAIResponsesClient", "GeminiClient", "AnthropicClient", + "CodexOAuthClient", "LLMMessage", "LLMResponse", "LLMStreamChunk", diff --git a/backend/app/services/supervision_reminder.py b/backend/app/services/supervision_reminder.py index 1c318cdaa..16883609c 100644 --- a/backend/app/services/supervision_reminder.py +++ b/backend/app/services/supervision_reminder.py @@ -108,6 +108,7 @@ async def _get_agent_reply(target_agent, message: str, db) -> str | None: from app.services.llm import ( get_provider_base_url, create_llm_client, + get_llm_client_for_model, LLMMessage, get_model_api_key, ) @@ -137,11 +138,8 @@ async def _get_agent_reply(target_agent, message: str, db) -> str | None: LLMMessage(role="user", content=message), ] - client = create_llm_client( - provider=model.provider, - api_key=get_model_api_key(model), - model=model.model, - base_url=base_url, + client = get_llm_client_for_model( + model, timeout=float(getattr(model, 'request_timeout', None) or 60.0), ) try: diff --git a/backend/tests/test_codex_oauth.py b/backend/tests/test_codex_oauth.py new file mode 100644 index 000000000..325c6ee97 --- /dev/null +++ b/backend/tests/test_codex_oauth.py @@ -0,0 +1,217 @@ +"""Unit tests for app.services.llm.codex_oauth. + +Covers the pure primitives: PKCE generation, authorize URL shape, token +exchange/refresh over a mocked httpx client, and JWT payload decoding. No DB or +FastAPI fixtures — that lives in the integration test layer. +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +from datetime import datetime, timezone +from urllib.parse import parse_qs, urlparse + +import pytest + +from app.services.llm import codex_oauth as co + + +# ── PKCE ──────────────────────────────────────────────────────────────── + +def test_generate_pkce_pair_is_consistent(): + pair = co.generate_pkce() + # RFC 7636: challenge = base64url(sha256(verifier)), no padding + digest = hashlib.sha256(pair.verifier.encode("ascii")).digest() + expected = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + assert pair.challenge == expected + # Verifier length must be in 43..128 per RFC + assert 43 <= len(pair.verifier) <= 128 + # Characters must be URL-safe + assert all(c.isalnum() or c in "-_" for c in pair.verifier) + + +def test_generate_pkce_produces_unique_pairs(): + seen = {co.generate_pkce().verifier for _ in range(20)} + assert len(seen) == 20 + + +def test_generate_state_is_hex(): + s = co.generate_state() + assert len(s) == 32 + int(s, 16) # must parse as hex + + +# ── Authorize URL ─────────────────────────────────────────────────────── + +def test_build_authorize_url_has_required_params(): + url = co.build_authorize_url(challenge="chal", state="st") + parsed = urlparse(url) + assert parsed.scheme == "https" + assert parsed.netloc == "auth.openai.com" + assert parsed.path == "/oauth/authorize" + q = parse_qs(parsed.query) + assert q["response_type"] == ["code"] + assert q["client_id"] == [co.CLIENT_ID] + assert q["redirect_uri"] == [co.REDIRECT_URI] + assert q["scope"] == [co.SCOPE] + assert q["code_challenge"] == ["chal"] + assert q["code_challenge_method"] == ["S256"] + assert q["state"] == ["st"] + # Codex-specific opt-ins the CLI sends — behavior of the login page changes + # without them, so lock them in. + assert q["codex_cli_simplified_flow"] == ["true"] + assert q["id_token_add_organizations"] == ["true"] + assert q["originator"] == [co.ORIGINATOR] + + +# ── Token parsing ─────────────────────────────────────────────────────── + +def _fake_jwt(payload: dict) -> str: + """Build a JWT-shaped string (header.payload.signature, no real sig).""" + header = base64.urlsafe_b64encode(b'{"alg":"none"}').rstrip(b"=").decode("ascii") + body_raw = json.dumps(payload).encode("utf-8") + body = base64.urlsafe_b64encode(body_raw).rstrip(b"=").decode("ascii") + return f"{header}.{body}.sig" + + +def test_decode_jwt_payload_roundtrip(): + payload = {"sub": "user-123", co.JWT_AUTH_CLAIM: {"chatgpt_account_id": "acc-xyz"}} + jwt = _fake_jwt(payload) + decoded = co.decode_jwt_payload(jwt) + assert decoded == payload + + +def test_decode_account_id_prefers_chatgpt_account_id(): + jwt = _fake_jwt({co.JWT_AUTH_CLAIM: {"chatgpt_account_id": "acc-xyz"}}) + assert co.decode_account_id(jwt) == "acc-xyz" + + +def test_decode_account_id_falls_back_to_account_id(): + jwt = _fake_jwt({co.JWT_AUTH_CLAIM: {"account_id": "acc-fallback"}}) + assert co.decode_account_id(jwt) == "acc-fallback" + + +def test_decode_account_id_returns_none_for_invalid_jwt(): + assert co.decode_account_id("not-a-jwt") is None + assert co.decode_account_id("one.two") is None + + +# ── Exchange / refresh ────────────────────────────────────────────────── + +class _FakeResponse: + def __init__(self, status_code: int, body: dict | str): + self.status_code = status_code + self._body = body + + @property + def text(self) -> str: + return self._body if isinstance(self._body, str) else json.dumps(self._body) + + def json(self) -> dict: + if isinstance(self._body, dict): + return self._body + return json.loads(self._body) + + +class _FakeHttpClient: + def __init__(self, response: _FakeResponse): + self._response = response + self.last_call: dict | None = None + + async def post(self, url, data=None, headers=None): + self.last_call = {"url": url, "data": dict(data or {}), "headers": dict(headers or {})} + return self._response + + async def aclose(self): + return None + + +@pytest.mark.asyncio +async def test_exchange_code_builds_correct_request_and_parses_tokens(): + jwt = _fake_jwt({co.JWT_AUTH_CLAIM: {"chatgpt_account_id": "acc-42"}}) + resp = _FakeResponse( + 200, + { + "access_token": jwt, + "refresh_token": "rt-123", + "expires_in": 3600, + }, + ) + client = _FakeHttpClient(resp) + + before = datetime.now(tz=timezone.utc) + bundle = await co.exchange_code(code="ac-code", verifier="ver", client=client) + after = datetime.now(tz=timezone.utc) + + assert bundle.access_token == jwt + assert bundle.refresh_token == "rt-123" + assert bundle.account_id == "acc-42" + # Expiry should be approximately now + 3600s + assert before.timestamp() + 3500 <= bundle.expires_at.timestamp() <= after.timestamp() + 3700 + + # Request shape + assert client.last_call["url"] == co.TOKEN_URL + assert client.last_call["headers"]["Content-Type"] == "application/x-www-form-urlencoded" + data = client.last_call["data"] + assert data["grant_type"] == "authorization_code" + assert data["client_id"] == co.CLIENT_ID + assert data["code"] == "ac-code" + assert data["code_verifier"] == "ver" + assert data["redirect_uri"] == co.REDIRECT_URI + + +@pytest.mark.asyncio +async def test_refresh_token_sends_refresh_grant(): + resp = _FakeResponse( + 200, + { + "access_token": _fake_jwt({}), + "refresh_token": "rt-new", + "expires_in": 1800, + }, + ) + client = _FakeHttpClient(resp) + bundle = await co.refresh_token("rt-old", client=client) + data = client.last_call["data"] + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == "rt-old" + assert data["client_id"] == co.CLIENT_ID + assert bundle.refresh_token == "rt-new" + + +@pytest.mark.asyncio +async def test_exchange_code_raises_on_http_error(): + resp = _FakeResponse(400, {"error": "invalid_grant"}) + client = _FakeHttpClient(resp) + with pytest.raises(ValueError, match="token endpoint returned 400"): + await co.exchange_code(code="x", verifier="y", client=client) + + +@pytest.mark.asyncio +async def test_exchange_code_raises_on_missing_fields(): + resp = _FakeResponse(200, {"access_token": "", "refresh_token": "", "expires_in": 0}) + client = _FakeHttpClient(resp) + with pytest.raises(ValueError): + await co.exchange_code(code="x", verifier="y", client=client) + + +# ── Expiry helper ─────────────────────────────────────────────────────── + +def test_is_near_expiry_triggers_within_leeway(): + from datetime import timedelta + + soon = datetime.now(tz=timezone.utc) + timedelta(seconds=30) + assert co.is_near_expiry(soon) is True + + later = datetime.now(tz=timezone.utc) + timedelta(seconds=3600) + assert co.is_near_expiry(later) is False + + +def test_is_near_expiry_accepts_naive_datetime(): + from datetime import timedelta + + naive_future = (datetime.now(tz=timezone.utc) + timedelta(seconds=3600)).replace(tzinfo=None) + # Should treat naive as UTC and not blow up + assert co.is_near_expiry(naive_future) is False diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 000000000..321fbcdbe --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,91 @@ +services: + postgres: + image: postgres:15-alpine + restart: unless-stopped + networks: + - clawith_codex_net + environment: + POSTGRES_USER: clawith + POSTGRES_PASSWORD: clawith + POSTGRES_DB: clawith + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U clawith" ] + interval: 5s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + restart: unless-stopped + networks: + - clawith_codex_net + volumes: + - redisdata:/data + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] + interval: 5s + timeout: 5s + retries: 5 + + backend: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + SECRET_KEY: codex-test-secret-key-do-not-use-in-prod + JWT_SECRET_KEY: codex-test-jwt-secret-do-not-use-in-prod + CORS_ORIGINS: '["*"]' + DOCKER_NETWORK: clawith_codex_net + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-http://localhost:3009} + CODEX_OAUTH_LOOPBACK_HOST: 0.0.0.0 + ports: + - "1455:1455" + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + networks: + - clawith_codex_net + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + frontend: + build: ./frontend + restart: unless-stopped + ports: + - "3009:3000" + environment: + VITE_API_URL: http://localhost:8000 + volumes: + - ./frontend/src:/app/src + - ./frontend/public:/app/public + networks: + - clawith_codex_net + depends_on: + - backend + +volumes: + pgdata: + redisdata: + +networks: + clawith_codex_net: + name: clawith_codex_network diff --git a/docker-compose.yml b/docker-compose.yml index c57a296ec..3e31b152c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -55,6 +55,13 @@ services: PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} # Password reset token lifetime in minutes (default: 30) PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + # Bind the Codex OAuth loopback listener on 0.0.0.0 inside the container + # so that the mapped host port (1455) can reach it from the browser. + CODEX_OAUTH_LOOPBACK_HOST: 0.0.0.0 + ports: + # Loopback port for Codex OAuth flow — OpenAI redirects the user's browser + # to http://localhost:1455/auth/callback, which must reach this container. + - "1455:1455" volumes: - ./backend:/app - ./backend/agent_data:/data/agents diff --git a/frontend/src/components/ConnectChatGPTModal.tsx b/frontend/src/components/ConnectChatGPTModal.tsx new file mode 100644 index 000000000..56cced935 --- /dev/null +++ b/frontend/src/components/ConnectChatGPTModal.tsx @@ -0,0 +1,467 @@ +import { useEffect, useRef, useState } from 'react'; +import type { CSSProperties, ReactNode } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + CODEX_OAUTH_MODELS, + codexOauthApi, + type CodexOauthStartResponse, +} from '../services/api'; + +interface ConnectChatGPTModalProps { + open: boolean; + onClose: () => void; + onCreated: (modelId: string) => void; + /** Tenant to provision the model under. Required for platform-admin sessions + * managing a non-default tenant; when unset the backend falls back to the + * caller's own tenant. */ + tenantId?: string | null; +} + +type FlowTab = 'oauth' | 'paste'; +type OauthStep = 'idle' | 'authorizing' | 'got-code' | 'submitting' | 'done'; + +const DEFAULT_LABEL = 'Codex (ChatGPT subscription)'; +const POLL_INTERVAL_MS = 1500; +const POLL_MAX_DURATION_MS = 5 * 60_000; + +export default function ConnectChatGPTModal({ open, onClose, onCreated, tenantId }: ConnectChatGPTModalProps) { + const { t } = useTranslation(); + const [tab, setTab] = useState+ {t('enterprise.llm.codex.subtitle')} +
+ +{t('enterprise.llm.codex.oauth.idleHint')}
+ ++ {oauthSession.loopback_ready + ? t('enterprise.llm.codex.oauth.waitingLoopback') + : t('enterprise.llm.codex.oauth.waitingManual')} +
+ ++ {t('enterprise.llm.codex.oauth.codeReceived')} +
+ {labelInput} + {modelSelect} + ++ {t('enterprise.llm.codex.oauth.done')} +
+ )} + {oauthError && ( +{oauthError}
+ )} +{t('enterprise.llm.codex.paste.hint')}
+ + + + + {labelInput} + {modelSelect} + + {pasteError &&{pasteError}
} +