diff --git a/src/backend/base/langflow/plugin_routes.py b/src/backend/base/langflow/plugin_routes.py index 81101d7287c7..9028ab1a9ba2 100644 --- a/src/backend/base/langflow/plugin_routes.py +++ b/src/backend/base/langflow/plugin_routes.py @@ -96,6 +96,10 @@ def add_api_route(self, path: str, endpoint, **kwargs): self._check_and_reserve(path, set(methods)) return self._app.add_api_route(path, endpoint, **kwargs) + def add_middleware(self, middleware_class, **kwargs): + """Allow plugins to register ASGI middleware on the host app.""" + self._app.add_middleware(middleware_class, **kwargs) + def load_plugin_routes(app: FastAPI) -> None: """Discover and register additional routers from enterprise plugins. diff --git a/src/backend/saas/alembic.ini b/src/backend/saas/alembic.ini new file mode 100644 index 000000000000..01b78c4a2b84 --- /dev/null +++ b/src/backend/saas/alembic.ini @@ -0,0 +1,41 @@ +[alembic] +# Path to migration scripts, relative to this file. +script_location = langflow_saas/migrations + +# sqlalchemy.url is intentionally left blank — the env.py reads +# SAAS_DATABASE_URL (falling back to LANGFLOW_DATABASE_URL) at runtime. +sqlalchemy.url = + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/backend/saas/langflow_saas/__init__.py b/src/backend/saas/langflow_saas/__init__.py new file mode 100644 index 000000000000..3748ade5b78b --- /dev/null +++ b/src/backend/saas/langflow_saas/__init__.py @@ -0,0 +1 @@ +"""langflow-saas: pluggable SaaS / multi-tenancy layer for Langflow.""" diff --git a/src/backend/saas/langflow_saas/api/__init__.py b/src/backend/saas/langflow_saas/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/backend/saas/langflow_saas/api/billing.py b/src/backend/saas/langflow_saas/api/billing.py new file mode 100644 index 000000000000..f6d830574d50 --- /dev/null +++ b/src/backend/saas/langflow_saas/api/billing.py @@ -0,0 +1,270 @@ +"""Billing, plans, usage, and Stripe webhook endpoints. + +Routes: + GET /api/saas/v1/plans — list active plans (public) + GET /api/saas/v1/orgs/{org_id}/billing — get subscription details + POST /api/saas/v1/orgs/{org_id}/billing/checkout — create Stripe checkout session + GET /api/saas/v1/orgs/{org_id}/usage — get usage summary + POST /api/saas/v1/billing/webhook — Stripe webhook (no auth, HMAC-verified) + GET /api/saas/v1/audit — audit log (admin+) +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, HTTPException, Request, status +from sqlmodel import select + +from langflow_saas.dependencies import CurrentOrgContext, RequireAdmin, assert_org_match +from langflow_saas.models import ( + AuditLog, + Organization, + Plan, + PlanRead, + Subscription, + SubscriptionRead, + UsageMetric, + UsageRecord, + UsageSummary, +) +from langflow_saas.services import get_billing_service +from langflow_saas.settings import get_saas_settings + +router = APIRouter(tags=["Billing & Plans"]) + + +# --------------------------------------------------------------------------- +# Plans (public, no auth) +# --------------------------------------------------------------------------- + + +@router.get("/plans", response_model=list[PlanRead]) +async def list_plans(): + """Return all active plans. Safe to call without authentication.""" + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec(select(Plan).where(Plan.is_active == True)) # noqa: E712 + return [PlanRead.model_validate(p) for p in result.all()] + + +# --------------------------------------------------------------------------- +# Subscription info +# --------------------------------------------------------------------------- + + +@router.get("/orgs/{org_id}/billing", response_model=SubscriptionRead | None) +async def get_subscription(org_id: UUID, ctx: CurrentOrgContext): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + sub_result = await db.exec(select(Subscription).where(Subscription.org_id == org_id)) + sub = sub_result.first() + if not sub: + return None + + plan_result = await db.exec(select(Plan).where(Plan.id == sub.plan_id)) + plan = plan_result.first() + if not plan: + return None + + return SubscriptionRead( + id=sub.id, + org_id=sub.org_id, + status=sub.status, + plan=PlanRead.model_validate(plan), + current_period_end=sub.current_period_end, + cancel_at_period_end=sub.cancel_at_period_end, + trial_end=sub.trial_end, + ) + + +class CheckoutRequest(PlanRead): + stripe_price_id: str + billing_cycle: str = "monthly" # "monthly" | "yearly" + + +@router.post("/orgs/{org_id}/billing/checkout") +async def create_checkout(org_id: UUID, request: Request, ctx: RequireAdmin): + """Create a Stripe Checkout Session and return the redirect URL.""" + assert_org_match(org_id, ctx) + settings = get_saas_settings() + if not settings.billing_enabled: + raise HTTPException(501, "Billing is not enabled on this instance.") + + body = await request.json() + price_id: str = body.get("stripe_price_id", "") + if not price_id: + raise HTTPException(400, "stripe_price_id is required.") + + from langflow.services.deps import session_scope + + async with session_scope() as db: + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + if not org: + raise HTTPException(404, "Organization not found.") + + # Fetch owner email from Langflow user table. + from langflow.services.database.models.user.model import User + + user_result = await db.exec(select(User).where(User.id == org.owner_id)) + owner = user_result.first() + owner_email = getattr(owner, "email", "") or f"{org.slug}@noemail.local" + + url = await get_billing_service().create_checkout_session( + org_id=org_id, + org_name=org.name, + owner_email=owner_email, + stripe_price_id=price_id, + success_url=f"{settings.app_base_url}/settings/billing?success=1", + cancel_url=f"{settings.app_base_url}/settings/billing?canceled=1", + ) + return {"checkout_url": url} + + +# --------------------------------------------------------------------------- +# Usage summary +# --------------------------------------------------------------------------- + + +@router.get("/orgs/{org_id}/usage", response_model=UsageSummary) +async def get_usage(org_id: UUID, ctx: CurrentOrgContext): + assert_org_match(org_id, ctx) + settings = get_saas_settings() + from langflow.services.deps import session_scope + from sqlalchemy import func + + today_start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) + + async with session_scope() as db: + # Get plan limits. + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + plan: Plan | None = None + if org and org.plan_id: + plan_result = await db.exec(select(Plan).where(Plan.id == org.plan_id)) + plan = plan_result.first() + + max_flows = plan.max_flows if plan else settings.default_max_flows + max_exec = plan.max_executions_per_day if plan else settings.default_max_executions_per_day + max_storage = plan.max_storage_mb if plan else settings.default_max_storage_mb + + # Count executions today. + exec_result = await db.exec( + select(func.sum(UsageRecord.value)).where( + UsageRecord.org_id == org_id, + UsageRecord.metric == UsageMetric.FLOW_EXECUTION, + UsageRecord.recorded_at >= today_start, + ) + ) + execs_today = int(exec_result.first() or 0) + + # Count API calls today. + api_result = await db.exec( + select(func.sum(UsageRecord.value)).where( + UsageRecord.org_id == org_id, + UsageRecord.metric == UsageMetric.API_CALL, + UsageRecord.recorded_at >= today_start, + ) + ) + api_calls_today = int(api_result.first() or 0) + + # Storage (sum of all storage_bytes records for this org). + storage_result = await db.exec( + select(func.sum(UsageRecord.value)).where( + UsageRecord.org_id == org_id, UsageRecord.metric == UsageMetric.STORAGE_BYTES + ) + ) + storage_bytes = int(storage_result.first() or 0) + + # Count flows from Langflow's flows table for org members. + # We aggregate flows belonging to all members of the org. + from langflow.services.database.models.flow.model import Flow + + from langflow_saas.models import UserOrganization + + member_result = await db.exec(select(UserOrganization.user_id).where(UserOrganization.org_id == org_id)) + member_ids = [r for r in member_result.all()] + flow_count = 0 + if member_ids: + flow_count_result = await db.exec( + select(func.count(Flow.id)).where(Flow.user_id.in_(member_ids)) # type: ignore[attr-defined] + ) + flow_count = int(flow_count_result.first() or 0) + + return UsageSummary( + org_id=org_id, + executions_today=execs_today, + executions_limit=max_exec, + flows_count=flow_count, + flows_limit=max_flows, + storage_mb=round(storage_bytes / (1024 * 1024), 2), + storage_limit_mb=max_storage, + api_calls_today=api_calls_today, + plan_slug=plan.slug if plan else "free", + ) + + +# --------------------------------------------------------------------------- +# Stripe Webhook (no auth — Stripe HMAC-verified) +# --------------------------------------------------------------------------- + + +@router.post("/billing/webhook", status_code=status.HTTP_200_OK) +async def stripe_webhook(request: Request): + settings = get_saas_settings() + if not settings.billing_enabled: + raise HTTPException(501, "Billing not enabled.") + + payload = await request.body() + sig_header = request.headers.get("stripe-signature", "") + + try: + result = await get_billing_service().handle_webhook(payload=payload, sig_header=sig_header) + except Exception as exc: + raise HTTPException(400, f"Webhook processing failed: {exc}") from exc + + return result + + +# --------------------------------------------------------------------------- +# Audit Log +# --------------------------------------------------------------------------- + + +@router.get("/audit") +async def get_audit_log( + ctx: RequireAdmin, + limit: int = 100, + offset: int = 0, +): + """Paginated audit log for the current organization.""" + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec( + select(AuditLog) + .where(AuditLog.org_id == ctx.org_id) + .order_by(AuditLog.created_at.desc()) # type: ignore[union-attr] + .offset(offset) + .limit(min(limit, 500)) + ) + entries = result.all() + + return [ + { + "id": str(e.id), + "action": e.action, + "user_id": str(e.user_id) if e.user_id else None, + "resource_type": e.resource_type, + "resource_id": e.resource_id, + "metadata": e.log_metadata, + "ip_address": e.ip_address, + "created_at": e.created_at.isoformat(), + } + for e in entries + ] diff --git a/src/backend/saas/langflow_saas/api/flows.py b/src/backend/saas/langflow_saas/api/flows.py new file mode 100644 index 000000000000..7604eca3d00d --- /dev/null +++ b/src/backend/saas/langflow_saas/api/flows.py @@ -0,0 +1,139 @@ +"""Org-scoped flow management endpoints. + +Routes: + GET /api/saas/v1/orgs/{org_id}/flows — list all flows owned by the org + POST /api/saas/v1/orgs/{org_id}/flows/{flow_id}/assign — assign an existing flow to the org + DELETE /api/saas/v1/orgs/{org_id}/flows/{flow_id}/assign — unassign (remove org ownership) + +These endpoints operate on Langflow's native ``flow`` rows via a shadow table +(``saas_flow_org``) — Langflow's own schema is never modified. + +Newly created flows are auto-assigned by FlowOwnershipMiddleware so in most +cases callers never need the assign/unassign endpoints directly. +""" + +from __future__ import annotations + +from datetime import datetime +from uuid import UUID + +from fastapi import APIRouter, HTTPException, status +from sqlmodel import SQLModel, select + +from langflow_saas.dependencies import CurrentOrgContext, RequireAdmin, assert_org_match +from langflow_saas.models import FlowOrg + +router = APIRouter(tags=["Org Flows"]) + + +# --------------------------------------------------------------------------- +# Response schema (mirrors Langflow's FlowBase fields we care about) +# --------------------------------------------------------------------------- + + +class OrgFlowRead(SQLModel): + id: UUID + name: str + description: str | None = None + user_id: UUID | None = None + updated_at: datetime | None = None + assigned_at: datetime + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@router.get("/orgs/{org_id}/flows", response_model=list[OrgFlowRead]) +async def list_org_flows(org_id: UUID, ctx: CurrentOrgContext): + """Return all flows assigned to this org, enriched with Langflow metadata.""" + assert_org_match(org_id, ctx) + + from langflow.services.database.models.flow.model import Flow + from langflow.services.deps import session_scope + + async with session_scope() as db: + # Get all flow_org mappings for this org. + fo_result = await db.exec(select(FlowOrg).where(FlowOrg.org_id == org_id)) + flow_orgs = fo_result.all() + + if not flow_orgs: + return [] + + flow_id_to_assigned = {fo.flow_id: fo.assigned_at for fo in flow_orgs} + flow_ids = list(flow_id_to_assigned.keys()) + + # Fetch the actual Langflow flow rows. + flows_result = await db.exec(select(Flow).where(Flow.id.in_(flow_ids))) # type: ignore[attr-defined] + flows = flows_result.all() + + return [ + OrgFlowRead( + id=UUID(str(f.id)), + name=f.name, + description=getattr(f, "description", None), + user_id=UUID(str(f.user_id)) if f.user_id else None, + updated_at=getattr(f, "updated_at", None), + assigned_at=flow_id_to_assigned[UUID(str(f.id))], + ) + for f in flows + ] + + +@router.post( + "/orgs/{org_id}/flows/{flow_id}/assign", + status_code=status.HTTP_201_CREATED, +) +async def assign_flow(org_id: UUID, flow_id: UUID, ctx: RequireAdmin): + """Manually assign an existing flow to this org. + + Useful for flows created before the plugin was installed, or flows created + by users who weren't in an org at the time. + """ + assert_org_match(org_id, ctx) + + from langflow.services.database.models.flow.model import Flow + from langflow.services.deps import session_scope + + async with session_scope() as db: + # Verify the flow exists in Langflow. + flow_result = await db.exec(select(Flow).where(Flow.id == flow_id)) # type: ignore[arg-type] + if not flow_result.first(): + raise HTTPException(404, "Flow not found.") + + # Reject if already assigned to a *different* org. + existing = await db.exec(select(FlowOrg).where(FlowOrg.flow_id == flow_id)) + existing_fo = existing.first() + if existing_fo: + if existing_fo.org_id == org_id: + return {"ok": True, "already_assigned": True} + raise HTTPException( + 409, + f"Flow is already assigned to org {existing_fo.org_id}. Unassign it first.", + ) + + db.add(FlowOrg(flow_id=flow_id, org_id=org_id, assigned_by=ctx.user_id)) + await db.commit() + + return {"ok": True, "org_id": str(org_id), "flow_id": str(flow_id)} + + +@router.delete( + "/orgs/{org_id}/flows/{flow_id}/assign", + status_code=status.HTTP_204_NO_CONTENT, +) +async def unassign_flow(org_id: UUID, flow_id: UUID, ctx: RequireAdmin): + """Remove org ownership of a flow (the flow itself is NOT deleted).""" + assert_org_match(org_id, ctx) + + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec(select(FlowOrg).where(FlowOrg.flow_id == flow_id, FlowOrg.org_id == org_id)) + fo = result.first() + if not fo: + raise HTTPException(404, "Flow is not assigned to this org.") + + await db.delete(fo) + await db.commit() diff --git a/src/backend/saas/langflow_saas/api/members.py b/src/backend/saas/langflow_saas/api/members.py new file mode 100644 index 000000000000..0c73ddc585e1 --- /dev/null +++ b/src/backend/saas/langflow_saas/api/members.py @@ -0,0 +1,387 @@ +"""Membership and invitation management endpoints. + +Routes: + GET /api/saas/v1/orgs/{org_id}/members — list members + PATCH /api/saas/v1/orgs/{org_id}/members/{user_id} — change role (admin+) + DELETE /api/saas/v1/orgs/{org_id}/members/{user_id} — remove member (admin+) + POST /api/saas/v1/orgs/{org_id}/invitations — invite by email (admin+) + GET /api/saas/v1/orgs/{org_id}/invitations — list pending invitations + DELETE /api/saas/v1/orgs/{org_id}/invitations/{invite_id} — revoke invitation + GET /api/saas/v1/invitations/{token} — get invitation info (public) + POST /api/saas/v1/invitations/{token}/accept — accept invitation (authenticated) +""" + +from __future__ import annotations + +import hashlib +import hmac +from datetime import datetime, timedelta, timezone +from uuid import UUID + +from fastapi import APIRouter, HTTPException, Request, status +from sqlmodel import select + +from langflow_saas.dependencies import CurrentOrgContext, RequireAdmin, assert_org_match +from langflow_saas.models import ( + Invitation, + InvitationCreate, + InvitationRead, + InvitationStatus, + MemberRead, + OrgRole, + UserOrganization, +) +from langflow_saas.services import get_audit_service, get_email_service +from langflow_saas.settings import get_saas_settings + +router = APIRouter(tags=["Members & Invitations"]) + + +def _make_token(invitation_id: UUID, secret: str) -> str: + """Produce a URL-safe HMAC token from the invitation ID.""" + sig = hmac.new(secret.encode(), str(invitation_id).encode(), hashlib.sha256).hexdigest() + return f"{invitation_id.hex}_{sig}" + + +def _verify_token(token: str, secret: str) -> UUID | None: + """Return the invitation UUID if the token is valid, else None.""" + try: + id_part, sig_part = token.split("_", 1) + inv_id = UUID(id_part) + expected = hmac.new(secret.encode(), str(inv_id).encode(), hashlib.sha256).hexdigest() + if hmac.compare_digest(expected, sig_part): + return inv_id + except Exception: # noqa: BLE001 + pass + return None + + +# --------------------------------------------------------------------------- +# Members +# --------------------------------------------------------------------------- + + +@router.get("/orgs/{org_id}/members", response_model=list[MemberRead]) +async def list_members(org_id: UUID, ctx: CurrentOrgContext): + assert_org_match(org_id, ctx) + from langflow.services.database.models.user.model import User + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec( + select(UserOrganization, User) + .join(User, User.id == UserOrganization.user_id) # type: ignore[arg-type] + .where(UserOrganization.org_id == org_id) + ) + rows = result.all() + return [ + MemberRead( + user_id=UUID(str(uo.user_id)), + username=user.username, + role=uo.role, + joined_at=uo.created_at, + ) + for uo, user in rows + ] + + +@router.patch("/orgs/{org_id}/members/{target_user_id}", status_code=status.HTTP_200_OK) +async def update_member_role( + org_id: UUID, + target_user_id: UUID, + role: OrgRole, + ctx: RequireAdmin, + request: Request, +): + assert_org_match(org_id, ctx) + + if role == OrgRole.OWNER: + raise HTTPException(400, "Transfer ownership via the dedicated transfer-ownership endpoint.") + + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec( + select(UserOrganization).where( + UserOrganization.org_id == org_id, UserOrganization.user_id == target_user_id + ) + ) + membership = result.first() + if not membership: + raise HTTPException(404, "Member not found in this organization.") + if membership.role == OrgRole.OWNER and ctx.role != OrgRole.OWNER: + raise HTTPException(403, "Only the owner can change the owner's role.") + + membership.role = role + db.add(membership) + await db.commit() + + await get_audit_service().log( + action="member.role_changed", + org_id=org_id, + user_id=ctx.user_id, + resource_type="user", + resource_id=str(target_user_id), + log_metadata={"new_role": role.value}, + ip_address=request.client.host if request.client else None, + ) + return {"ok": True} + + +@router.delete("/orgs/{org_id}/members/{target_user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def remove_member(org_id: UUID, target_user_id: UUID, ctx: RequireAdmin, request: Request): + assert_org_match(org_id, ctx) + + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec( + select(UserOrganization).where( + UserOrganization.org_id == org_id, UserOrganization.user_id == target_user_id + ) + ) + membership = result.first() + if not membership: + raise HTTPException(404, "Member not found.") + if membership.role == OrgRole.OWNER: + raise HTTPException(400, "Cannot remove the owner. Transfer ownership first.") + + await db.delete(membership) + await db.commit() + + await get_audit_service().log( + action="member.removed", + org_id=org_id, + user_id=ctx.user_id, + resource_type="user", + resource_id=str(target_user_id), + ip_address=request.client.host if request.client else None, + ) + + +# --------------------------------------------------------------------------- +# Invitations +# --------------------------------------------------------------------------- + + +@router.post( + "/orgs/{org_id}/invitations", + response_model=InvitationRead, + status_code=status.HTTP_201_CREATED, +) +async def invite_member(org_id: UUID, body: InvitationCreate, ctx: RequireAdmin, request: Request): + assert_org_match(org_id, ctx) + settings = get_saas_settings() + from langflow.services.deps import session_scope + + from langflow_saas.models import Organization + + async with session_scope() as db: + # Check member cap. + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + + count_result = await db.exec(select(UserOrganization).where(UserOrganization.org_id == org_id)) + member_count = len(count_result.all()) + if org and member_count >= settings.default_max_members: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Member limit ({settings.default_max_members}) reached. " + "Upgrade your plan to invite more members.", + ) + + # Revoke any open invitation for same email+org. + existing_inv = await db.exec( + select(Invitation).where( + Invitation.org_id == org_id, + Invitation.email == body.email, + Invitation.status == InvitationStatus.PENDING, + ) + ) + for old in existing_inv.all(): + old.status = InvitationStatus.REVOKED + db.add(old) + + expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.invitation_expire_hours) + invitation = Invitation( + org_id=org_id, + email=body.email, + role=body.role, + invited_by=ctx.user_id, + expires_at=expires_at, + token_hash="", # filled below after we have the id + ) + db.add(invitation) + await db.flush() + + token = _make_token(invitation.id, settings.invitation_secret.get_secret_value()) + invitation.token_hash = hashlib.sha256(token.encode()).hexdigest() + db.add(invitation) + await db.commit() + await db.refresh(invitation) + + accept_url = f"{settings.app_base_url}/invitations/{token}/accept" + org_name = org.name if org else str(org_id) + await get_email_service().send_invitation( + to_email=body.email, + org_name=org_name, + inviter_name=ctx.username, + role=body.role.value, + accept_url=accept_url, + expire_hours=settings.invitation_expire_hours, + ) + + await get_audit_service().log( + action="member.invited", + org_id=org_id, + user_id=ctx.user_id, + resource_type="invitation", + resource_id=str(invitation.id), + log_metadata={"email": body.email, "role": body.role.value}, + ip_address=request.client.host if request.client else None, + ) + + return InvitationRead( + id=invitation.id, + email=invitation.email, + role=invitation.role, + status=invitation.status, + expires_at=invitation.expires_at, + created_at=invitation.created_at, + ) + + +@router.get("/orgs/{org_id}/invitations", response_model=list[InvitationRead]) +async def list_invitations(org_id: UUID, ctx: RequireAdmin): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec( + select(Invitation).where(Invitation.org_id == org_id, Invitation.status == InvitationStatus.PENDING) + ) + return [ + InvitationRead( + id=inv.id, + email=inv.email, + role=inv.role, + status=inv.status, + expires_at=inv.expires_at, + created_at=inv.created_at, + ) + for inv in result.all() + ] + + +@router.delete("/orgs/{org_id}/invitations/{invite_id}", status_code=status.HTTP_204_NO_CONTENT) +async def revoke_invitation(org_id: UUID, invite_id: UUID, ctx: RequireAdmin): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec(select(Invitation).where(Invitation.id == invite_id, Invitation.org_id == org_id)) + inv = result.first() + if not inv: + raise HTTPException(404, "Invitation not found.") + inv.status = InvitationStatus.REVOKED + db.add(inv) + await db.commit() + + +# --------------------------------------------------------------------------- +# Public invitation acceptance (no org context — user may not be in the org yet) +# --------------------------------------------------------------------------- + + +@router.get("/invitations/{token}") +async def get_invitation_info(token: str): + """Return public invitation details (org name, role, expiry) — no auth required.""" + settings = get_saas_settings() + inv_id = _verify_token(token, settings.invitation_secret.get_secret_value()) + if not inv_id: + raise HTTPException(400, "Invalid invitation token.") + + from langflow.services.deps import session_scope + + from langflow_saas.models import Organization + + async with session_scope() as db: + result = await db.exec(select(Invitation).where(Invitation.id == inv_id)) + inv = result.first() + if not inv: + raise HTTPException(404, "Invitation not found.") + + if inv.status != InvitationStatus.PENDING: + raise HTTPException(410, f"Invitation is {inv.status.value}.") + if inv.expires_at < datetime.now(timezone.utc): + raise HTTPException(410, "Invitation has expired.") + + org_result = await db.exec(select(Organization).where(Organization.id == inv.org_id)) + org = org_result.first() + org_name = org.name if org else str(inv.org_id) + + return { + "id": str(inv.id), + "org_name": org_name, + "email": inv.email, + "role": inv.role.value, + "expires_at": inv.expires_at.isoformat(), + } + + +@router.post("/invitations/{token}/accept", status_code=status.HTTP_200_OK) +async def accept_invitation(token: str, ctx: CurrentOrgContext, request: Request): + """Accept a pending invitation. Caller must be authenticated as the invited email + OR an admin can accept on behalf. For simplicity we trust the authenticated user. + """ + settings = get_saas_settings() + inv_id = _verify_token(token, settings.invitation_secret.get_secret_value()) + if not inv_id: + raise HTTPException(400, "Invalid invitation token.") + + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec(select(Invitation).where(Invitation.id == inv_id)) + inv = result.first() + if not inv: + raise HTTPException(404, "Invitation not found.") + if inv.status != InvitationStatus.PENDING: + raise HTTPException(410, f"Invitation is {inv.status.value}.") + if inv.expires_at < datetime.now(timezone.utc): + raise HTTPException(410, "Invitation has expired.") + + # Check already a member. + existing_m = await db.exec( + select(UserOrganization).where( + UserOrganization.user_id == ctx.user_id, + UserOrganization.org_id == inv.org_id, + ) + ) + if existing_m.first(): + raise HTTPException(409, "You are already a member of this organization.") + + # Create membership. + membership = UserOrganization( + user_id=ctx.user_id, + org_id=inv.org_id, + role=inv.role, + invitation_id=inv.id, + ) + db.add(membership) + + inv.status = InvitationStatus.ACCEPTED + inv.accepted_at = datetime.now(timezone.utc) + inv.accepted_by = ctx.user_id + db.add(inv) + await db.commit() + + await get_audit_service().log( + action="member.joined", + org_id=inv.org_id, + user_id=ctx.user_id, + resource_type="invitation", + resource_id=str(inv.id), + ip_address=request.client.host if request.client else None, + ) + return {"ok": True, "org_id": str(inv.org_id)} diff --git a/src/backend/saas/langflow_saas/api/orgs.py b/src/backend/saas/langflow_saas/api/orgs.py new file mode 100644 index 000000000000..cd1979715025 --- /dev/null +++ b/src/backend/saas/langflow_saas/api/orgs.py @@ -0,0 +1,208 @@ +"""Organization CRUD endpoints. + +Routes: + POST /api/saas/v1/orgs — create org + GET /api/saas/v1/orgs — list caller's orgs + GET /api/saas/v1/orgs/{org_id} — get org details + PATCH /api/saas/v1/orgs/{org_id} — update (admin+) + DELETE /api/saas/v1/orgs/{org_id} — delete (owner only) +""" + +from __future__ import annotations + +import re +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, HTTPException, Request, status +from sqlmodel import select + +from langflow_saas.dependencies import CurrentOrgContext, RequireAdmin, RequireOwner, assert_org_match +from langflow_saas.models import ( + Organization, + OrganizationCreate, + OrganizationRead, + OrganizationUpdate, + OrgRole, + Plan, + PlanRead, + UserOrganization, +) +from langflow_saas.services import get_audit_service + +router = APIRouter(prefix="/orgs", tags=["Organizations"]) + +_SLUG_RE = re.compile(r"^[a-z0-9][a-z0-9\-]{1,61}[a-z0-9]$") + + +def _slugify(name: str) -> str: + slug = re.sub(r"[^a-z0-9]+", "-", name.lower()).strip("-") + return slug[:63] + + +def _validate_slug(slug: str) -> None: + if not _SLUG_RE.match(slug): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Slug must be 3–63 lowercase alphanumeric characters or hyphens, cannot start or end with a hyphen.", + ) + + +async def _org_to_read(org: Organization, role: OrgRole, db) -> OrganizationRead: + plan: Plan | None = None + if org.plan_id: + result = await db.exec(select(Plan).where(Plan.id == org.plan_id)) + plan = result.first() + + return OrganizationRead( + id=org.id, + name=org.name, + slug=org.slug, + owner_id=org.owner_id, + is_personal=org.is_personal, + is_active=org.is_active, + created_at=org.created_at, + role=role, + plan=PlanRead.model_validate(plan) if plan else None, + ) + + +@router.post("", response_model=OrganizationRead, status_code=status.HTTP_201_CREATED) +async def create_org(body: OrganizationCreate, ctx: CurrentOrgContext, request: Request): + """Create a new organization. The caller becomes its owner.""" + from langflow.services.deps import session_scope + + slug = body.slug or _slugify(body.name) + _validate_slug(slug) + + async with session_scope() as db: + # Check slug uniqueness. + existing = await db.exec(select(Organization).where(Organization.slug == slug)) + if existing.first(): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Slug '{slug}' is already taken.", + ) + + org = Organization(name=body.name, slug=slug, owner_id=ctx.user_id) + db.add(org) + await db.flush() # populate org.id before creating membership + + membership = UserOrganization(user_id=ctx.user_id, org_id=org.id, role=OrgRole.OWNER) + db.add(membership) + await db.commit() + await db.refresh(org) + + await get_audit_service().log( + action="org.created", + org_id=org.id, + user_id=ctx.user_id, + resource_type="organization", + resource_id=str(org.id), + ip_address=request.client.host if request.client else None, + ) + + async with session_scope() as db: + return await _org_to_read(org, OrgRole.OWNER, db) + + +@router.get("", response_model=list[OrganizationRead]) +async def list_orgs(ctx: CurrentOrgContext): + """List all organizations the caller belongs to.""" + from langflow.services.deps import session_scope + + async with session_scope() as db: + memberships_result = await db.exec(select(UserOrganization).where(UserOrganization.user_id == ctx.user_id)) + memberships = memberships_result.all() + + result = [] + for m in memberships: + org_result = await db.exec( + select(Organization).where(Organization.id == m.org_id, Organization.is_active == True) # noqa: E712 + ) + org = org_result.first() + if org: + result.append(await _org_to_read(org, m.role, db)) + return result + + +@router.get("/{org_id}", response_model=OrganizationRead) +async def get_org(org_id: UUID, ctx: CurrentOrgContext): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found.") + return await _org_to_read(org, ctx.role, db) + + +@router.patch("/{org_id}", response_model=OrganizationRead) +async def update_org(org_id: UUID, body: OrganizationUpdate, ctx: RequireAdmin, request: Request): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found.") + if org.is_personal: + raise HTTPException(status_code=400, detail="Personal organizations cannot be renamed.") + + if body.name: + org.name = body.name + if body.slug: + _validate_slug(body.slug) + existing = await db.exec( + select(Organization).where(Organization.slug == body.slug, Organization.id != org_id) + ) + if existing.first(): + raise HTTPException(409, detail=f"Slug '{body.slug}' is already taken.") + org.slug = body.slug + + org.updated_at = datetime.now(timezone.utc) + db.add(org) + await db.commit() + await db.refresh(org) + + await get_audit_service().log( + action="org.updated", + org_id=org_id, + user_id=ctx.user_id, + resource_type="organization", + resource_id=str(org_id), + log_metadata=body.model_dump(exclude_none=True), + ip_address=request.client.host if request.client else None, + ) + + async with session_scope() as db: + return await _org_to_read(org, ctx.role, db) + + +@router.delete("/{org_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_org(org_id: UUID, ctx: RequireOwner, request: Request): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found.") + if org.is_personal: + raise HTTPException(status_code=400, detail="Personal organizations cannot be deleted.") + + await db.delete(org) + await db.commit() + + await get_audit_service().log( + action="org.deleted", + org_id=org_id, + user_id=ctx.user_id, + resource_type="organization", + resource_id=str(org_id), + ip_address=request.client.host if request.client else None, + ) diff --git a/src/backend/saas/langflow_saas/api/router.py b/src/backend/saas/langflow_saas/api/router.py new file mode 100644 index 000000000000..7a692f54f7a1 --- /dev/null +++ b/src/backend/saas/langflow_saas/api/router.py @@ -0,0 +1,18 @@ +"""Root SaaS API router. + +All SaaS endpoints are mounted under /api/saas/v1/. +This prefix keeps them completely separate from Langflow's /api/v1/ and +/api/v2/ routes so there is zero risk of collision. +""" + +from fastapi import APIRouter + +from langflow_saas.api import billing, flows, members, orgs, teams + +router = APIRouter(prefix="/api/saas/v1") + +router.include_router(orgs.router) +router.include_router(members.router) +router.include_router(teams.router) +router.include_router(billing.router) +router.include_router(flows.router) diff --git a/src/backend/saas/langflow_saas/api/teams.py b/src/backend/saas/langflow_saas/api/teams.py new file mode 100644 index 000000000000..234f76c3697a --- /dev/null +++ b/src/backend/saas/langflow_saas/api/teams.py @@ -0,0 +1,129 @@ +"""Team management endpoints. + +Routes: + GET /api/saas/v1/orgs/{org_id}/teams — list teams + POST /api/saas/v1/orgs/{org_id}/teams — create team (admin+) + DELETE /api/saas/v1/orgs/{org_id}/teams/{team_id} — delete team (admin+) + POST /api/saas/v1/orgs/{org_id}/teams/{team_id}/members — add member (admin+) + DELETE /api/saas/v1/orgs/{org_id}/teams/{team_id}/members/{user_id} — remove member +""" + +from __future__ import annotations + +from uuid import UUID + +from fastapi import APIRouter, HTTPException, Request, status +from sqlmodel import select + +from langflow_saas.dependencies import CurrentOrgContext, RequireAdmin, assert_org_match +from langflow_saas.models import Team, TeamCreate, TeamMember, TeamRead, UserOrganization +from langflow_saas.services import get_audit_service + +router = APIRouter(tags=["Teams"]) + + +@router.get("/orgs/{org_id}/teams", response_model=list[TeamRead]) +async def list_teams(org_id: UUID, ctx: CurrentOrgContext): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec(select(Team).where(Team.org_id == org_id)) + return [TeamRead(id=t.id, org_id=t.org_id, name=t.name, description=t.description) for t in result.all()] + + +@router.post("/orgs/{org_id}/teams", response_model=TeamRead, status_code=status.HTTP_201_CREATED) +async def create_team(org_id: UUID, body: TeamCreate, ctx: RequireAdmin, request: Request): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + # Name uniqueness within org. + existing = await db.exec(select(Team).where(Team.org_id == org_id, Team.name == body.name)) + if existing.first(): + raise HTTPException(409, f"A team named '{body.name}' already exists in this org.") + + team = Team(org_id=org_id, name=body.name, description=body.description) + db.add(team) + await db.commit() + await db.refresh(team) + + await get_audit_service().log( + action="team.created", + org_id=org_id, + user_id=ctx.user_id, + resource_type="team", + resource_id=str(team.id), + ip_address=request.client.host if request.client else None, + ) + return TeamRead(id=team.id, org_id=team.org_id, name=team.name, description=team.description) + + +@router.delete("/orgs/{org_id}/teams/{team_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_team(org_id: UUID, team_id: UUID, ctx: RequireAdmin, request: Request): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec(select(Team).where(Team.id == team_id, Team.org_id == org_id)) + team = result.first() + if not team: + raise HTTPException(404, "Team not found.") + await db.delete(team) + await db.commit() + + await get_audit_service().log( + action="team.deleted", + org_id=org_id, + user_id=ctx.user_id, + resource_type="team", + resource_id=str(team_id), + ip_address=request.client.host if request.client else None, + ) + + +@router.post("/orgs/{org_id}/teams/{team_id}/members", status_code=status.HTTP_201_CREATED) +async def add_team_member(org_id: UUID, team_id: UUID, target_user_id: UUID, ctx: RequireAdmin): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + # Verify user is an org member. + mem_result = await db.exec( + select(UserOrganization).where( + UserOrganization.org_id == org_id, UserOrganization.user_id == target_user_id + ) + ) + if not mem_result.first(): + raise HTTPException(400, "User is not a member of this organization.") + + # Check already on team. + existing = await db.exec( + select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == target_user_id) + ) + if existing.first(): + raise HTTPException(409, "User is already on this team.") + + db.add(TeamMember(team_id=team_id, user_id=target_user_id)) + await db.commit() + + return {"ok": True} + + +@router.delete( + "/orgs/{org_id}/teams/{team_id}/members/{target_user_id}", + status_code=status.HTTP_204_NO_CONTENT, +) +async def remove_team_member(org_id: UUID, team_id: UUID, target_user_id: UUID, ctx: RequireAdmin): + assert_org_match(org_id, ctx) + from langflow.services.deps import session_scope + + async with session_scope() as db: + result = await db.exec( + select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == target_user_id) + ) + member = result.first() + if not member: + raise HTTPException(404, "Team member not found.") + await db.delete(member) + await db.commit() diff --git a/src/backend/saas/langflow_saas/dependencies.py b/src/backend/saas/langflow_saas/dependencies.py new file mode 100644 index 000000000000..2cd125c1ce79 --- /dev/null +++ b/src/backend/saas/langflow_saas/dependencies.py @@ -0,0 +1,84 @@ +"""FastAPI dependency callables for the SaaS plugin. + +All SaaS route handlers use these deps instead of Langflow's +``get_current_active_user`` so that org context and RBAC checks are +applied consistently. + +Dependency graph: + CurrentOrgContext — resolves OrgContextData from request.state + └─ RequireOrgRole(...) — asserts a minimum role level + └─ RequireOrgAdmin — shortcut for admin+ + └─ RequireOrgOwner — shortcut for owner only +""" + +from __future__ import annotations + +from typing import Annotated +from uuid import UUID + +from fastapi import Depends, HTTPException, Request, status + +from langflow_saas.middleware import OrgContextData +from langflow_saas.models import OrgRole + +# Role ordering for gte comparison. +_ROLE_RANK: dict[OrgRole, int] = { + OrgRole.VIEWER: 0, + OrgRole.MEMBER: 1, + OrgRole.ADMIN: 2, + OrgRole.OWNER: 3, +} + + +async def get_org_context(request: Request) -> OrgContextData: + """Retrieve the tenant context set by TenantContextMiddleware. + + Raises 401 if the context is absent (unauthenticated request) and 403 + if the user is authenticated but has no org membership (should not + happen in normal flows after personal-org auto-creation is enabled). + """ + ctx: OrgContextData | None = getattr(request.state, "saas_context", None) + if ctx is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required. Provide a Bearer token, cookie, or x-api-key.", + ) + return ctx + + +CurrentOrgContext = Annotated[OrgContextData, Depends(get_org_context)] + + +def require_role(minimum_role: OrgRole): + """Dependency factory: asserts caller has at least ``minimum_role``.""" + + async def _check(ctx: CurrentOrgContext) -> OrgContextData: + if _ROLE_RANK[ctx.role] < _ROLE_RANK[minimum_role]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"This action requires {minimum_role.value} role or higher. " + f"Your role in this org is {ctx.role.value}.", + ) + return ctx + + return _check + + +RequireMember = Annotated[OrgContextData, Depends(require_role(OrgRole.MEMBER))] +RequireAdmin = Annotated[OrgContextData, Depends(require_role(OrgRole.ADMIN))] +RequireOwner = Annotated[OrgContextData, Depends(require_role(OrgRole.OWNER))] + + +# --------------------------------------------------------------------------- +# Shorthand for reading a UUID path param and validating it matches the +# authenticated org (prevents IDOR on org-scoped resources). +# --------------------------------------------------------------------------- + + +def assert_org_match(path_org_id: UUID, ctx: OrgContextData) -> None: + """Raise 403 if the path org_id doesn't match the authenticated context.""" + if path_org_id != ctx.org_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this organization.", + ) diff --git a/src/backend/saas/langflow_saas/middleware.py b/src/backend/saas/langflow_saas/middleware.py new file mode 100644 index 000000000000..74477af27b29 --- /dev/null +++ b/src/backend/saas/langflow_saas/middleware.py @@ -0,0 +1,703 @@ +"""ASGI middleware for the SaaS plugin. + +Three middleware classes, applied outermost-first via plugin.py: + + RateLimitMiddleware — Redis sliding-window rate limiting per user/IP. + TenantContextMiddleware — Resolves org membership from JWT + X-Org-ID header, + stores OrgContextData in request.state.saas_context. + QuotaEnforcementMiddleware — Enforces per-org daily execution quotas on run endpoints. + +Upgrade safety: these classes import from Langflow only via clearly bounded +helpers (token decoding, DB session). If those helpers move in a future +Langflow version, update _extract_user_id() and _open_db() here — nothing else. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from uuid import UUID + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.responses import JSONResponse + +from langflow_saas.models import OrgRole, UsageMetric +from langflow_saas.settings import get_saas_settings + +if TYPE_CHECKING: + from langflow_saas.models import Organization, Plan + +logger = logging.getLogger("langflow_saas.middleware") + + +# --------------------------------------------------------------------------- +# Shared context object stored on request.state +# --------------------------------------------------------------------------- + + +@dataclass +class OrgContextData: + """Resolved tenant context stored on every authenticated API request.""" + + user_id: UUID + username: str + org_id: UUID + org_slug: str + role: OrgRole + plan_slug: str + rpm_limit: int + + +# --------------------------------------------------------------------------- +# Helpers: integrate with Langflow internals in ONE place +# --------------------------------------------------------------------------- + + +def _extract_token(request: Request) -> str | None: + """Pull JWT / API key from request using the same precedence as Langflow.""" + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] + token = request.cookies.get("access_token_lf") + if token: + return token + return request.headers.get("x-api-key") or request.query_params.get("x-api-key") + + +async def _resolve_user_id_from_token(token: str) -> UUID | None: + """Decode JWT and return user_id without hitting the DB. + + Langflow integration point: imports get_settings_service and PyJWT directly. + If Langflow changes its JWT structure, update only this function. + """ + try: + import jwt as pyjwt + from langflow.services.deps import get_settings_service + + settings_service = get_settings_service() + algo = settings_service.auth_settings.ALGORITHM.value + if settings_service.auth_settings.ALGORITHM.is_asymmetric(): + key = settings_service.auth_settings.PUBLIC_KEY + else: + key = settings_service.auth_settings.SECRET_KEY.get_secret_value() + + payload = pyjwt.decode(token, key, algorithms=[algo], options={"verify_exp": True}) + sub = payload.get("sub") + if not sub: + return None + return UUID(sub) + except Exception: # noqa: BLE001 + return None + + +async def _resolve_user_id_from_api_key(api_key: str) -> tuple[UUID, str] | None: + """Look up (user_id, username) from an API key via Langflow's DB. + + Opens its own short-lived DB session so it's safe to call from middleware. + """ + try: + from langflow.services.database.models.api_key.model import ApiKey + from langflow.services.database.models.user.model import User + from langflow.services.deps import session_scope + from sqlmodel import select + + async with session_scope() as db: + result = await db.exec( + select(ApiKey, User) + .join(User, User.id == ApiKey.user_id) # type: ignore[arg-type] + .where(ApiKey.api_key == api_key, ApiKey.is_active == True) # noqa: E712 + ) + row = result.first() + if row: + api_key_obj, user = row + return UUID(str(user.id)), user.username + except Exception: # noqa: BLE001 + pass + return None + + +async def _get_username(user_id: UUID) -> str: + """Fetch username for a user_id from Langflow's user table.""" + try: + from langflow.services.database.models.user.model import User + from langflow.services.deps import session_scope + from sqlmodel import select + + async with session_scope() as db: + result = await db.exec(select(User).where(User.id == user_id)) # type: ignore[arg-type] + user = result.first() + return user.username if user else str(user_id) + except Exception: # noqa: BLE001 + return str(user_id) + + +async def _bootstrap_personal_org(db, user_id: UUID, username: str): + """Create a personal org + OWNER membership for a first-time user. + + Called lazily from TenantContextMiddleware when a valid user has no memberships. + Returns the newly created UserOrganization, or None on failure. + """ + import re + from datetime import datetime, timezone + + from sqlmodel import select + + from langflow_saas.models import Organization, OrgRole, Plan, Subscription, SubscriptionStatus, UserOrganization + + try: + # Pick up the Free plan if available; leave plan_id=None otherwise. + plan_result = await db.exec(select(Plan).where(Plan.slug == "free", Plan.is_active == True)) # noqa: E712 + free_plan = plan_result.first() + + # Build a collision-safe slug from username. + base_slug = re.sub(r"[^a-z0-9]+", "-", username.lower()).strip("-")[:50] + slug_candidate = base_slug + attempt = 0 + while True: + collision = await db.exec(select(Organization).where(Organization.slug == slug_candidate)) + if not collision.first(): + break + attempt += 1 + slug_candidate = f"{base_slug}-{attempt}" + + now = datetime.now(timezone.utc) + org = Organization( + name=f"{username}'s workspace", + slug=slug_candidate, + owner_id=user_id, + plan_id=free_plan.id if free_plan else None, + is_personal=True, + created_at=now, + updated_at=now, + ) + db.add(org) + await db.flush() # populate org.id before referencing it + + membership = UserOrganization( + user_id=user_id, + org_id=org.id, + role=OrgRole.OWNER, + created_at=now, + ) + db.add(membership) + + # Auto-provision a Free subscription so quota checks have a real plan row. + if free_plan: + subscription = Subscription( + org_id=org.id, + plan_id=free_plan.id, + status=SubscriptionStatus.ACTIVE, + created_at=now, + updated_at=now, + ) + db.add(subscription) + + await db.commit() + await db.refresh(membership) + + logger.info("langflow-saas: created personal org %s for user %s", org.slug, username) + return membership + except Exception: + logger.exception("langflow-saas: failed to bootstrap personal org for user %s", username) + await db.rollback() + return None + + +# --------------------------------------------------------------------------- +# Redis helpers (graceful no-op when Redis is unavailable) +# --------------------------------------------------------------------------- + + +def _get_redis(): + """Return a Redis client or None if Redis is not reachable.""" + try: + import redis.asyncio as aioredis + + settings = get_saas_settings() + return aioredis.from_url(settings.redis_url, decode_responses=True, socket_connect_timeout=1) + except Exception: # noqa: BLE001 + return None + + +# --------------------------------------------------------------------------- +# 1. Rate Limit Middleware +# --------------------------------------------------------------------------- + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Sliding-window rate limiter backed by Redis. + + Key structure: ``saas:rl:{user_id_or_ip}:{unix_minute}`` + + Degrades gracefully (allows requests) when Redis is unavailable so that + a Redis outage never takes down the API. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + settings = get_saas_settings() + if not settings.rate_limit_enabled: + return await call_next(request) + + if not any(request.url.path.startswith(p) for p in settings.rate_limit_paths): + return await call_next(request) + + # Determine a stable identity key (user_id preferred, IP fallback). + key_id: str | None = None + rpm_limit = settings.rate_limit_default_rpm + + # Cheaply check for existing resolved context (set by TenantContextMiddleware + # when it runs *before* this one — ordering is set in plugin.py). + ctx: OrgContextData | None = getattr(request.state, "saas_context", None) + if ctx: + key_id = str(ctx.user_id) + rpm_limit = ctx.rpm_limit + else: + token = _extract_token(request) + if token: + uid = await _resolve_user_id_from_token(token) + if uid: + key_id = str(uid) + + if not key_id: + forwarded_for = request.headers.get("X-Forwarded-For") + key_id = forwarded_for or request.client.host if request.client else "unknown" + + redis = _get_redis() + if redis is None: + return await call_next(request) + + try: + minute_bucket = int(time.time()) // 60 + redis_key = f"saas:rl:{key_id}:{minute_bucket}" + burst_limit = rpm_limit * settings.rate_limit_burst_multiplier + + async with redis: + current = await redis.incr(redis_key) + if current == 1: + await redis.expire(redis_key, 120) # TTL: 2 minutes + + reset_ts = (minute_bucket + 1) * 60 + remaining = max(0, burst_limit - current) + + if current > burst_limit: + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded. Please slow down."}, + headers={ + "X-RateLimit-Limit": str(burst_limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(reset_ts), + "Retry-After": "60", + }, + ) + + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(burst_limit) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Reset"] = str(reset_ts) + return response + except Exception: # noqa: BLE001 + # Redis errors must never block the request. + return await call_next(request) + + +# --------------------------------------------------------------------------- +# 2. Tenant Context Middleware +# --------------------------------------------------------------------------- + + +class TenantContextMiddleware(BaseHTTPMiddleware): + """Resolve the authenticated user's active organization and store it in + ``request.state.saas_context`` as an ``OrgContextData`` instance. + + Logic: + 1. Skip non-API paths (static assets, health checks). + 2. Extract JWT or API key from the request. + 3. Decode user_id from JWT (no DB hit) or look up API key (one DB query). + 4. Determine active org: use ``X-Org-ID`` header if present, else the + user's single org (auto-resolve), else skip if user has multiple orgs + and ``require_org_header`` is True. + 5. Load the org's plan details and store the resolved context. + + On any failure the middleware allows the request through so Langflow's + own auth layer returns the proper 401/403. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + # Only process API paths. + if not request.url.path.startswith("/api/"): + return await call_next(request) + + try: + await self._set_context(request) + except Exception: # noqa: BLE001 + pass # Let Langflow's auth handle it. + + return await call_next(request) + + async def _set_context(self, request: Request) -> None: + settings = get_saas_settings() + + token = _extract_token(request) + if not token: + return + + # Resolve user identity. + user_id: UUID | None = None + username: str = "" + + # Try JWT first (no DB hit). + user_id = await _resolve_user_id_from_token(token) + if user_id: + username = await _get_username(user_id) + else: + # Might be an API key. + result = await _resolve_user_id_from_api_key(token) + if result: + user_id, username = result + + if not user_id: + return + + from langflow.services.deps import session_scope + from sqlmodel import select + + from langflow_saas.models import Organization, Plan, UserOrganization + + async with session_scope() as db: + # Find the org context. + org_id_header = request.headers.get("X-Org-ID") + membership: UserOrganization | None = None + + if org_id_header: + try: + requested_org_id = UUID(org_id_header) + except ValueError: + return + result = await db.exec( + select(UserOrganization).where( + UserOrganization.user_id == user_id, + UserOrganization.org_id == requested_org_id, + ) + ) + membership = result.first() + else: + # Auto-resolve: get all memberships, pick the personal org or the + # only org if the user belongs to exactly one. + result = await db.exec(select(UserOrganization).where(UserOrganization.user_id == user_id)) + memberships = result.all() + if not memberships: + if settings.auto_create_personal_org: + bootstrapped = await _bootstrap_personal_org(db, user_id, username) + if bootstrapped: + memberships = [bootstrapped] + else: + return + else: + return + if len(memberships) == 1: + membership = memberships[0] + else: + # Multiple orgs: require explicit header if configured. + if settings.require_org_header: + return + # Otherwise pick personal org as default. + personal_orgs = [] + for m in memberships: + org_result = await db.exec( + select(Organization).where(Organization.id == m.org_id, Organization.is_personal == True) # noqa: E712 + ) + if org_result.first(): + personal_orgs.append(m) + membership = personal_orgs[0] if personal_orgs else memberships[0] + + if not membership: + return + + # Load org + plan. + org_result = await db.exec( + select(Organization).where(Organization.id == membership.org_id, Organization.is_active == True) # noqa: E712 + ) + org: Organization | None = org_result.first() + if not org: + return + + plan_slug = "free" + rpm_limit = settings.default_max_executions_per_day // (60 * 24) # crude default + rpm_limit = max(rpm_limit, settings.rate_limit_default_rpm) + + if org.plan_id: + plan_result = await db.exec(select(Plan).where(Plan.id == org.plan_id)) + plan: Plan | None = plan_result.first() + if plan: + plan_slug = plan.slug + rpm_limit = plan.rpm_limit + + request.state.saas_context = OrgContextData( + user_id=user_id, + username=username, + org_id=membership.org_id, + org_slug=org.slug, + role=membership.role, + plan_slug=plan_slug, + rpm_limit=rpm_limit, + ) + + +# --------------------------------------------------------------------------- +# 3. Quota Enforcement Middleware +# --------------------------------------------------------------------------- + +# Paths that count as a "flow execution" for metering. +_EXECUTION_PATHS = ("/api/v1/run/", "/api/v2/flows/") + + +class QuotaEnforcementMiddleware(BaseHTTPMiddleware): + """Block flow executions when the org has exhausted its daily quota. + + Only queries the DB for POST requests on execution paths. All other + requests pass through with no overhead. + + After a successful execution (2xx response), a UsageRecord is inserted + asynchronously so the quota counter is updated for the next request. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + settings = get_saas_settings() + + if not settings.billing_enabled: + return await call_next(request) + + is_execution = request.method == "POST" and any(request.url.path.startswith(p) for p in _EXECUTION_PATHS) + if not is_execution: + return await call_next(request) + + ctx: OrgContextData | None = getattr(request.state, "saas_context", None) + if not ctx: + return await call_next(request) + + # Check quota before executing. + quota_ok, limit, used = await self._check_execution_quota(ctx.org_id) + if not quota_ok: + return JSONResponse( + status_code=429, + content={ + "detail": f"Daily execution quota exceeded ({used}/{limit}). " + "Upgrade your plan or wait until midnight UTC." + }, + headers={"X-Quota-Limit": str(limit), "X-Quota-Used": str(used)}, + ) + + response = await call_next(request) + + # Record usage after a successful execution. + if 200 <= response.status_code < 300: + await self._record_execution(ctx) + + return response + + async def _check_execution_quota(self, org_id: UUID) -> tuple[bool, int, int]: + """Return (quota_ok, limit, used_today).""" + try: + from langflow.services.deps import session_scope + from sqlalchemy import func + from sqlmodel import select + + from langflow_saas.models import Organization, Plan, UsageRecord + + today_start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) + + async with session_scope() as db: + # Get limit from plan. + org_result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = org_result.first() + limit = get_saas_settings().default_max_executions_per_day + + if org and org.plan_id: + plan_result = await db.exec(select(Plan).where(Plan.id == org.plan_id)) + plan = plan_result.first() + if plan and plan.max_executions_per_day != -1: + limit = plan.max_executions_per_day + + # Count today's executions. + count_result = await db.exec( + select(func.sum(UsageRecord.value)).where( + UsageRecord.org_id == org_id, + UsageRecord.metric == UsageMetric.FLOW_EXECUTION, + UsageRecord.recorded_at >= today_start, + ) + ) + used = int(count_result.first() or 0) + + if limit == -1: + return True, -1, used + return used < limit, limit, used + except Exception: # noqa: BLE001 + return True, -1, 0 # Fail open on DB errors. + + async def _record_execution(self, ctx: OrgContextData) -> None: + try: + from langflow.services.deps import session_scope + + from langflow_saas.models import UsageRecord + + record = UsageRecord( + org_id=ctx.org_id, + user_id=ctx.user_id, + metric=UsageMetric.FLOW_EXECUTION, + value=1, + ) + async with session_scope() as db: + db.add(record) + await db.commit() + except Exception: # noqa: BLE001 + logger.warning("Failed to record execution usage for org %s", ctx.org_id) + + +# --------------------------------------------------------------------------- +# 4. Flow Ownership Middleware +# --------------------------------------------------------------------------- + +# Paths that create a new flow in Langflow's native API. +_FLOW_CREATE_PATHS = ("/api/v1/flows", "/api/v1/flows/") + + +class FlowOwnershipMiddleware(BaseHTTPMiddleware): + """Auto-assign newly created Langflow flows to the creator's current org. + + Intercepts successful POST /api/v1/flows responses, extracts the new + flow's UUID from the JSON body, and inserts a ``saas_flow_org`` row so + the org-scoped flows API can surface it. + + The response body is buffered only for flow-creation requests — all other + requests pass through with zero overhead. Failures never surface to the + caller (the flow still gets created, it just won't be org-scoped yet). + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + is_flow_create = request.method == "POST" and request.url.path.rstrip("/") == "/api/v1/flows" + + if not is_flow_create: + return await call_next(request) + + ctx: OrgContextData | None = getattr(request.state, "saas_context", None) + if not ctx: + return await call_next(request) + + response = await call_next(request) + + if response.status_code not in (200, 201): + return response + + # Buffer the response body so we can extract the flow id. + body = b"" + async for chunk in response.body_iterator: + body += chunk + + try: + import json as _json + + data = _json.loads(body) + flow_id_str = data.get("id") + if flow_id_str: + await self._assign_flow(UUID(flow_id_str), ctx.org_id, ctx.user_id) + except Exception: # noqa: BLE001 + pass # Never break the create response. + + # Re-emit the buffered body as a new response. + from starlette.responses import Response as _Response + + return _Response( + content=body, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + async def _assign_flow(self, flow_id: UUID, org_id: UUID, user_id: UUID) -> None: + try: + from langflow.services.deps import session_scope + from sqlmodel import select + + from langflow_saas.models import FlowOrg + + async with session_scope() as db: + existing = await db.exec(select(FlowOrg).where(FlowOrg.flow_id == flow_id)) + if existing.first(): + return + db.add(FlowOrg(flow_id=flow_id, org_id=org_id, assigned_by=user_id)) + await db.commit() + logger.debug("langflow-saas: assigned flow %s to org %s", flow_id, org_id) + except Exception: # noqa: BLE001 + logger.warning("langflow-saas: failed to assign flow %s to org %s", flow_id, org_id) + + +# --------------------------------------------------------------------------- +# 5. User Registration Middleware +# --------------------------------------------------------------------------- + + +class UserRegistrationMiddleware(BaseHTTPMiddleware): + """Provision a personal org for every newly registered Langflow user. + + Intercepts successful POST /api/v1/users/ responses (Langflow's public + signup endpoint), extracts the new user's id + username from the JSON + body, and calls ``_bootstrap_personal_org()`` immediately — so the user + has a valid org context before their very first API request. + + The response body is buffered only for this specific endpoint. Failures + are silent: the user account is always created regardless. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + is_registration = request.method == "POST" and request.url.path.rstrip("/") == "/api/v1/users" + + if not is_registration: + return await call_next(request) + + response = await call_next(request) + + # Only provision on successful creation (201). + if response.status_code != 201: + return response + + body = b"" + async for chunk in response.body_iterator: + body += chunk + + try: + import json as _json + + data = _json.loads(body) + user_id_str = data.get("id") + username = data.get("username", "") + if user_id_str and username: + await self._provision(UUID(user_id_str), username) + except Exception: # noqa: BLE001 + pass + + from starlette.responses import Response as _Response + + return _Response( + content=body, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + async def _provision(self, user_id: UUID, username: str) -> None: + settings = get_saas_settings() + if not settings.auto_create_personal_org: + return + try: + from langflow.services.deps import session_scope + + async with session_scope() as db: + await _bootstrap_personal_org(db, user_id, username) + logger.info("langflow-saas: provisioned org for new user %s on registration", username) + except Exception: # noqa: BLE001 + logger.warning("langflow-saas: failed to provision org for new user %s", username) diff --git a/src/backend/saas/langflow_saas/migrations/env.py b/src/backend/saas/langflow_saas/migrations/env.py new file mode 100644 index 000000000000..4a58eba16a49 --- /dev/null +++ b/src/backend/saas/langflow_saas/migrations/env.py @@ -0,0 +1,101 @@ +"""Alembic env.py for the langflow-saas package. + +Key design decisions: + - Connects to the SAME database as Langflow (reads LANGFLOW_DATABASE_URL or + SAAS_DATABASE_URL override) so saas_* tables live alongside Langflow's tables. + - Uses ``include_object`` to manage ONLY tables whose names start with + ``saas_``. Langflow's own tables are never touched by these migrations. + - Shares Langflow's ``alembic_version`` table via Alembic branch labels so + there is no separate version table (avoids the Langflow migration drift check + flagging an unknown table). +""" + +import os +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool + +# --------------------------------------------------------------------------- +# Alembic Config object — gives access to values in alembic.ini. +# --------------------------------------------------------------------------- +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# --------------------------------------------------------------------------- +# Resolve database URL. +# SAAS_DATABASE_URL overrides LANGFLOW_DATABASE_URL for scenarios where the +# SaaS tables live in a separate database (advanced multi-DB setups). +# --------------------------------------------------------------------------- +db_url = os.getenv("SAAS_DATABASE_URL") or os.getenv("LANGFLOW_DATABASE_URL") or "sqlite:///./langflow.db" +config.set_main_option("sqlalchemy.url", db_url) + +# --------------------------------------------------------------------------- +# Import ALL SaaS models so SQLAlchemy registers them in its metadata before +# Alembic performs autogenerate comparison. +# --------------------------------------------------------------------------- +from langflow_saas.models import ( # noqa: E402 — must be after config setup + saas_metadata, +) + +target_metadata = saas_metadata + + +def include_object(obj, name, type_, reflected, compare_to): + """Only manage objects belonging to this plugin. + + Tables must start with ``saas_`` to be managed here; everything else + (Langflow's tables, Alembic's own version table) is left untouched. + """ + if type_ == "table": + return str(name).startswith("saas_") + # Always include indices, constraints, etc. that belong to managed tables. + if hasattr(obj, "table"): + return str(obj.table.name).startswith("saas_") + return True + + +def run_migrations_offline() -> None: + """Run migrations without a live DB connection (generates SQL script).""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + include_object=include_object, + # Use shared alembic_version table via branch labels (no separate table needed). + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations with a live DB connection.""" + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + include_object=include_object, + version_table="saas_alembic_version", + # compare_type=True makes Alembic detect column type changes. + compare_type=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/backend/saas/langflow_saas/migrations/script.py.mako b/src/backend/saas/langflow_saas/migrations/script.py.mako new file mode 100644 index 000000000000..1ba49a84b3ed --- /dev/null +++ b/src/backend/saas/langflow_saas/migrations/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/backend/saas/langflow_saas/migrations/versions/001_saas_foundation.py b/src/backend/saas/langflow_saas/migrations/versions/001_saas_foundation.py new file mode 100644 index 000000000000..5d86728d38e1 --- /dev/null +++ b/src/backend/saas/langflow_saas/migrations/versions/001_saas_foundation.py @@ -0,0 +1,327 @@ +"""SaaS foundation tables. + +Creates all saas_* tables in one migration. These tables are additive — +Langflow's own schema is never touched. + +Revision ID: 001saas +Revises: (none — initial SaaS migration) +Create Date: 2025-01-01 00:00:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import inspect as _sa_inspect + +revision: str = "001saas" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = ("saas",) +depends_on: str | Sequence[str] | None = None + +# Use a DB-agnostic UUID type: native UUID on PostgreSQL, CHAR(36) elsewhere. +_uuid = sa.Uuid() + + +def _has_table(name: str) -> bool: + """Return True if the table already exists in the target database.""" + return _sa_inspect(op.get_bind()).has_table(name) + + +def _seed_plans() -> None: + """Upsert the three built-in plans. Safe to call multiple times.""" + op.execute( + """ + INSERT INTO saas_plan + (id, name, slug, is_active, max_flows, max_executions_per_day, + max_members, max_storage_mb, max_api_keys, rpm_limit, + price_monthly_cents, price_yearly_cents, created_at, updated_at) + VALUES + ('00000000-0000-0000-0000-000000000001', 'Free', 'free', TRUE, + 50, 1000, 5, 500, 5, 60, 0, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP), + ('00000000-0000-0000-0000-000000000002', 'Pro', 'pro', TRUE, + 500, 10000, 25, 5000, 20, 300, 2900, 29000, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP), + ('00000000-0000-0000-0000-000000000003', 'Enterprise', 'enterprise', TRUE, + -1, -1, -1, -1, -1, 1000, 0, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (slug) DO NOTHING + """ + ) + + +def upgrade() -> None: + # ------------------------------------------------------------------ + # saas_plan + # ------------------------------------------------------------------ + if _has_table("saas_plan"): + # Tables were already created (e.g., plugin reinstalled on existing DB). + # Skip DDL but still re-seed plans in case the rows are missing. + _seed_plans() + return + + op.create_table( + "saas_plan", + sa.Column("id", _uuid, primary_key=True), + sa.Column("name", sa.String(), nullable=False, unique=True), + sa.Column("slug", sa.String(), nullable=False, unique=True), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column("max_flows", sa.Integer(), nullable=False, server_default="50"), + sa.Column("max_executions_per_day", sa.Integer(), nullable=False, server_default="1000"), + sa.Column("max_members", sa.Integer(), nullable=False, server_default="5"), + sa.Column("max_storage_mb", sa.Integer(), nullable=False, server_default="500"), + sa.Column("max_api_keys", sa.Integer(), nullable=False, server_default="5"), + sa.Column("rpm_limit", sa.Integer(), nullable=False, server_default="60"), + sa.Column("price_monthly_cents", sa.Integer(), nullable=False, server_default="0"), + sa.Column("price_yearly_cents", sa.Integer(), nullable=False, server_default="0"), + sa.Column("stripe_monthly_price_id", sa.String(), nullable=True), + sa.Column("stripe_yearly_price_id", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_saas_plan_slug", "saas_plan", ["slug"]) + op.create_index("ix_saas_plan_name", "saas_plan", ["name"]) + + _seed_plans() + + # ------------------------------------------------------------------ + # saas_organization + # ------------------------------------------------------------------ + op.create_table( + "saas_organization", + sa.Column("id", _uuid, primary_key=True), + sa.Column("name", sa.String(), nullable=False), + sa.Column("slug", sa.String(), nullable=False, unique=True), + sa.Column( + "owner_id", + _uuid, + sa.ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "plan_id", + _uuid, + sa.ForeignKey("saas_plan.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("is_personal", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column("stripe_customer_id", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_saas_org_slug", "saas_organization", ["slug"]) + op.create_index("ix_saas_org_owner", "saas_organization", ["owner_id"]) + op.create_index("ix_saas_org_stripe_customer", "saas_organization", ["stripe_customer_id"]) + + # ------------------------------------------------------------------ + # saas_invitation (must exist before saas_user_organization FK) + # ------------------------------------------------------------------ + op.create_table( + "saas_invitation", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("email", sa.String(), nullable=False), + sa.Column("role", sa.String(), nullable=False, server_default="member"), + sa.Column( + "invited_by", + _uuid, + sa.ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("token_hash", sa.String(), nullable=False, unique=True), + sa.Column("status", sa.String(), nullable=False, server_default="pending"), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("accepted_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "accepted_by", + _uuid, + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_saas_invitation_org", "saas_invitation", ["org_id"]) + op.create_index("ix_saas_invitation_email", "saas_invitation", ["email"]) + op.create_index("ix_saas_invitation_token_hash", "saas_invitation", ["token_hash"], unique=True) + + # ------------------------------------------------------------------ + # saas_user_organization + # ------------------------------------------------------------------ + op.create_table( + "saas_user_organization", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "user_id", + _uuid, + sa.ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("role", sa.String(), nullable=False, server_default="member"), + sa.Column( + "invitation_id", + _uuid, + sa.ForeignKey("saas_invitation.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("user_id", "org_id", name="uq_saas_user_org"), + ) + op.create_index("ix_saas_user_org_user", "saas_user_organization", ["user_id"]) + op.create_index("ix_saas_user_org_org", "saas_user_organization", ["org_id"]) + + # ------------------------------------------------------------------ + # saas_team + # ------------------------------------------------------------------ + op.create_table( + "saas_team", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("org_id", "name", name="uq_saas_team_name_in_org"), + ) + op.create_index("ix_saas_team_org", "saas_team", ["org_id"]) + + # ------------------------------------------------------------------ + # saas_team_member + # ------------------------------------------------------------------ + op.create_table( + "saas_team_member", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "team_id", + _uuid, + sa.ForeignKey("saas_team.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "user_id", + _uuid, + sa.ForeignKey("user.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("added_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("team_id", "user_id", name="uq_saas_team_member"), + ) + op.create_index("ix_saas_team_member_team", "saas_team_member", ["team_id"]) + op.create_index("ix_saas_team_member_user", "saas_team_member", ["user_id"]) + + # ------------------------------------------------------------------ + # saas_subscription + # ------------------------------------------------------------------ + op.create_table( + "saas_subscription", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + unique=True, + ), + sa.Column( + "plan_id", + _uuid, + sa.ForeignKey("saas_plan.id", ondelete="RESTRICT"), + nullable=False, + ), + sa.Column("status", sa.String(), nullable=False, server_default="trialing"), + sa.Column("stripe_subscription_id", sa.String(), nullable=True, unique=True), + sa.Column("stripe_price_id", sa.String(), nullable=True), + sa.Column("current_period_start", sa.DateTime(timezone=True), nullable=True), + sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True), + sa.Column("cancel_at_period_end", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("trial_end", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_saas_subscription_org", "saas_subscription", ["org_id"]) + op.create_index("ix_saas_subscription_stripe", "saas_subscription", ["stripe_subscription_id"]) + + # ------------------------------------------------------------------ + # saas_usage_record + # ------------------------------------------------------------------ + op.create_table( + "saas_usage_record", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "user_id", + _uuid, + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("metric", sa.String(), nullable=False), + sa.Column("value", sa.Integer(), nullable=False, server_default="1"), + sa.Column("resource_id", sa.String(), nullable=True), + sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index( + "ix_saas_usage_org_metric_time", + "saas_usage_record", + ["org_id", "metric", "recorded_at"], + ) + + # ------------------------------------------------------------------ + # saas_audit_log + # ------------------------------------------------------------------ + op.create_table( + "saas_audit_log", + sa.Column("id", _uuid, primary_key=True), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column( + "user_id", + _uuid, + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("action", sa.String(), nullable=False), + sa.Column("resource_type", sa.String(), nullable=True), + sa.Column("resource_id", sa.String(), nullable=True), + sa.Column("log_metadata", sa.JSON(), nullable=True), + sa.Column("ip_address", sa.String(), nullable=True), + sa.Column("user_agent", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_saas_audit_org_time", "saas_audit_log", ["org_id", "created_at"]) + op.create_index("ix_saas_audit_user_time", "saas_audit_log", ["user_id", "created_at"]) + op.create_index("ix_saas_audit_action", "saas_audit_log", ["action"]) + + +def downgrade() -> None: + # Drop in reverse FK dependency order. + op.drop_table("saas_audit_log") + op.drop_table("saas_usage_record") + op.drop_table("saas_subscription") + op.drop_table("saas_team_member") + op.drop_table("saas_team") + op.drop_table("saas_user_organization") + op.drop_table("saas_invitation") + op.drop_table("saas_organization") + op.drop_table("saas_plan") diff --git a/src/backend/saas/langflow_saas/migrations/versions/002_saas_flow_org.py b/src/backend/saas/langflow_saas/migrations/versions/002_saas_flow_org.py new file mode 100644 index 000000000000..a8c01e77923f --- /dev/null +++ b/src/backend/saas/langflow_saas/migrations/versions/002_saas_flow_org.py @@ -0,0 +1,53 @@ +"""Add saas_flow_org shadow table for org-scoped flow ownership. + +Revision ID: 002saas +Revises: 001saas +Create Date: 2026-04-26 00:00:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import inspect as _sa_inspect + +revision: str = "002saas" +down_revision: str | None = "001saas" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +_uuid = sa.Uuid() + + +def _has_table(name: str) -> bool: + return _sa_inspect(op.get_bind()).has_table(name) + + +def upgrade() -> None: + if _has_table("saas_flow_org"): + return + + op.create_table( + "saas_flow_org", + sa.Column("id", _uuid, primary_key=True), + sa.Column("flow_id", _uuid, nullable=False, unique=True), + sa.Column( + "org_id", + _uuid, + sa.ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "assigned_by", + _uuid, + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("assigned_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_saas_flow_org_flow", "saas_flow_org", ["flow_id"], unique=True) + op.create_index("ix_saas_flow_org_org", "saas_flow_org", ["org_id"]) + + +def downgrade() -> None: + op.drop_table("saas_flow_org") diff --git a/src/backend/saas/langflow_saas/models.py b/src/backend/saas/langflow_saas/models.py new file mode 100644 index 000000000000..697b15eb775c --- /dev/null +++ b/src/backend/saas/langflow_saas/models.py @@ -0,0 +1,425 @@ +"""All SaaS database models. + +Every table is prefixed with ``saas_`` to avoid any collision with +Langflow's own tables. Foreign keys to Langflow's ``user`` table use +ON DELETE CASCADE / SET NULL so removing a Langflow user automatically +cleans up all SaaS artefacts. + +Upgrade safety: these models never import Langflow model *classes* — only +primitive types and ``sqlalchemy``. This means Langflow can rename its +internal model fields without breaking this package. +""" + +from __future__ import annotations + +import enum +from datetime import datetime, timezone +from typing import Any +from uuid import UUID, uuid4 + +import sqlalchemy as sa +from sqlalchemy import JSON, Column, ForeignKey, Index, UniqueConstraint +from sqlmodel import Field, SQLModel + +# --------------------------------------------------------------------------- +# Enumerations +# --------------------------------------------------------------------------- + + +class OrgRole(str, enum.Enum): + OWNER = "owner" + ADMIN = "admin" + MEMBER = "member" + VIEWER = "viewer" + + +class InvitationStatus(str, enum.Enum): + PENDING = "pending" + ACCEPTED = "accepted" + EXPIRED = "expired" + REVOKED = "revoked" + + +class SubscriptionStatus(str, enum.Enum): + TRIALING = "trialing" + ACTIVE = "active" + PAST_DUE = "past_due" + CANCELED = "canceled" + UNPAID = "unpaid" + + +class UsageMetric(str, enum.Enum): + FLOW_EXECUTION = "flow_execution" + API_CALL = "api_call" + STORAGE_BYTES = "storage_bytes" + + +# --------------------------------------------------------------------------- +# Plan (created at deployment time / via admin, not user-facing CRUD) +# --------------------------------------------------------------------------- + + +class Plan(SQLModel, table=True): # type: ignore[call-arg] + """Pricing plan / tier definition. Rows are managed by operators.""" + + __tablename__ = "saas_plan" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str = Field(unique=True, index=True) # "Free", "Pro", "Enterprise" + slug: str = Field(unique=True, index=True) # "free", "pro", "enterprise" + is_active: bool = Field(default=True) + # Quotas (-1 = unlimited) + max_flows: int = Field(default=50) + max_executions_per_day: int = Field(default=1000) + max_members: int = Field(default=5) + max_storage_mb: int = Field(default=500) + max_api_keys: int = Field(default=5) + # Rate limits + rpm_limit: int = Field(default=60, description="Requests per minute for this plan.") + # Billing + price_monthly_cents: int = Field(default=0) + price_yearly_cents: int = Field(default=0) + stripe_monthly_price_id: str | None = Field(default=None) + stripe_yearly_price_id: str | None = Field(default=None) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# Organization +# --------------------------------------------------------------------------- + + +class Organization(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_organization" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str = Field(index=True) + slug: str = Field(unique=True, index=True) + owner_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="CASCADE"), nullable=False, index=True) + ) + plan_id: UUID | None = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_plan.id", ondelete="SET NULL"), nullable=True) + ) + # Personal orgs are created automatically for every user and cannot be + # renamed, shared, or deleted while the user exists. + is_personal: bool = Field(default=False) + is_active: bool = Field(default=True) + stripe_customer_id: str | None = Field(default=None, index=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# UserOrganization (membership + role) +# --------------------------------------------------------------------------- + + +class UserOrganization(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_user_organization" + __table_args__ = (UniqueConstraint("user_id", "org_id", name="uq_saas_user_org"),) + + id: UUID = Field(default_factory=uuid4, primary_key=True) + user_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="CASCADE"), nullable=False, index=True) + ) + org_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_organization.id", ondelete="CASCADE"), nullable=False, index=True) + ) + role: OrgRole = Field(default=OrgRole.MEMBER) + # Tracks which invitation brought this member in; nullable for founders. + invitation_id: UUID | None = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_invitation.id", ondelete="SET NULL"), nullable=True) + ) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# Team (sub-group inside an org) +# --------------------------------------------------------------------------- + + +class Team(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_team" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + org_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_organization.id", ondelete="CASCADE"), nullable=False, index=True) + ) + name: str = Field() + description: str | None = Field(default=None) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + __table_args__ = (UniqueConstraint("org_id", "name", name="uq_saas_team_name_in_org"),) + + +class TeamMember(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_team_member" + __table_args__ = (UniqueConstraint("team_id", "user_id", name="uq_saas_team_member"),) + + id: UUID = Field(default_factory=uuid4, primary_key=True) + team_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_team.id", ondelete="CASCADE"), nullable=False, index=True) + ) + user_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="CASCADE"), nullable=False, index=True) + ) + added_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# Invitation +# --------------------------------------------------------------------------- + + +class Invitation(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_invitation" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + org_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_organization.id", ondelete="CASCADE"), nullable=False, index=True) + ) + email: str = Field(index=True) + role: OrgRole = Field(default=OrgRole.MEMBER) + invited_by: UUID = Field(sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="CASCADE"), nullable=False)) + # HMAC-signed token — the actual secret is derived from the invitation ID + # and SAAS_INVITATION_SECRET so the token is never stored in cleartext. + token_hash: str = Field(unique=True, index=True) + status: InvitationStatus = Field(default=InvitationStatus.PENDING) + expires_at: datetime = Field() + accepted_at: datetime | None = Field(default=None) + accepted_by: UUID | None = Field( + sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="SET NULL"), nullable=True) + ) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# Subscription +# --------------------------------------------------------------------------- + + +class Subscription(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_subscription" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + org_id: UUID = Field( + sa_column=Column( + sa.Uuid(), + ForeignKey("saas_organization.id", ondelete="CASCADE"), + nullable=False, + unique=True, # one active subscription per org + index=True, + ) + ) + plan_id: UUID = Field(sa_column=Column(sa.Uuid(), ForeignKey("saas_plan.id", ondelete="RESTRICT"), nullable=False)) + status: SubscriptionStatus = Field(default=SubscriptionStatus.TRIALING) + stripe_subscription_id: str | None = Field(default=None, unique=True, index=True) + stripe_price_id: str | None = Field(default=None) + current_period_start: datetime | None = Field(default=None) + current_period_end: datetime | None = Field(default=None) + cancel_at_period_end: bool = Field(default=False) + trial_end: datetime | None = Field(default=None) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# UsageRecord (append-only metering log; roll-up queries for quota checks) +# --------------------------------------------------------------------------- + + +class UsageRecord(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_usage_record" + __table_args__ = ( + # Composite index for the quota-check query: + # WHERE org_id=? AND metric=? AND recorded_at >= today + Index("ix_saas_usage_org_metric_time", "org_id", "metric", "recorded_at"), + ) + + id: UUID = Field(default_factory=uuid4, primary_key=True) + org_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_organization.id", ondelete="CASCADE"), nullable=False) + ) + user_id: UUID | None = Field(sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="SET NULL"), nullable=True)) + metric: UsageMetric = Field() + # For FLOW_EXECUTION / API_CALL: value=1 per event. + # For STORAGE_BYTES: value = delta bytes (can be negative for deletions). + value: int = Field(default=1) + # Resource that triggered the usage (e.g. flow_id for executions). + resource_id: str | None = Field(default=None) + recorded_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# AuditLog (immutable append-only; never update or delete rows) +# --------------------------------------------------------------------------- + + +class AuditLog(SQLModel, table=True): # type: ignore[call-arg] + __tablename__ = "saas_audit_log" + __table_args__ = ( + Index("ix_saas_audit_org_time", "org_id", "created_at"), + Index("ix_saas_audit_user_time", "user_id", "created_at"), + ) + + id: UUID = Field(default_factory=uuid4, primary_key=True) + # Nullable so system-level events (e.g. Stripe webhook) don't require a user. + org_id: UUID | None = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_organization.id", ondelete="SET NULL"), nullable=True) + ) + user_id: UUID | None = Field(sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="SET NULL"), nullable=True)) + # Dot-separated action name: "org.created", "member.invited", "subscription.upgraded" + action: str = Field(index=True) + resource_type: str | None = Field(default=None) # "flow", "org", "team", … + resource_id: str | None = Field(default=None) + log_metadata: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) + ip_address: str | None = Field(default=None) + user_agent: str | None = Field(default=None) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# Pydantic response / request schemas (separate from table models) +# --------------------------------------------------------------------------- + + +class PlanRead(SQLModel): + id: UUID + name: str + slug: str + max_flows: int + max_executions_per_day: int + max_members: int + max_storage_mb: int + rpm_limit: int + price_monthly_cents: int + price_yearly_cents: int + is_active: bool + + +class OrganizationCreate(SQLModel): + name: str + slug: str | None = None + + +class OrganizationRead(SQLModel): + id: UUID + name: str + slug: str + owner_id: UUID + is_personal: bool + is_active: bool + created_at: datetime + role: OrgRole | None = None # caller's role — filled at query time + plan: PlanRead | None = None + + +class OrganizationUpdate(SQLModel): + name: str | None = None + slug: str | None = None + + +class MemberRead(SQLModel): + user_id: UUID + username: str + role: OrgRole + joined_at: datetime + + +class InvitationCreate(SQLModel): + email: str + role: OrgRole = OrgRole.MEMBER + + +class InvitationRead(SQLModel): + id: UUID + email: str + role: OrgRole + status: InvitationStatus + expires_at: datetime + created_at: datetime + + +class TeamCreate(SQLModel): + name: str + description: str | None = None + + +class TeamRead(SQLModel): + id: UUID + org_id: UUID + name: str + description: str | None + + +class UsageSummary(SQLModel): + org_id: UUID + executions_today: int + executions_limit: int + flows_count: int + flows_limit: int + storage_mb: float + storage_limit_mb: int + api_calls_today: int + plan_slug: str + + +class SubscriptionRead(SQLModel): + id: UUID + org_id: UUID + status: SubscriptionStatus + plan: PlanRead + current_period_end: datetime | None + cancel_at_period_end: bool + trial_end: datetime | None + + +# --------------------------------------------------------------------------- +# SaaS Alembic version tracking table +# Declared as a SQLModel so it appears in SQLModel.metadata — this prevents +# Langflow's migration drift-checker from flagging it as an unknown table. +# --------------------------------------------------------------------------- + + +class SaasAlembicVersion(SQLModel, table=True): # type: ignore[call-arg] + """Tracks the applied SaaS migration revisions (mirrors alembic_version).""" + + __tablename__ = "saas_alembic_version" + + version_num: str = Field(primary_key=True, max_length=32) + + +# --------------------------------------------------------------------------- +# FlowOrg (shadow table — links Langflow flows to SaaS orgs) +# --------------------------------------------------------------------------- + + +class FlowOrg(SQLModel, table=True): # type: ignore[call-arg] + """Maps a Langflow flow (by its UUID) to the org that owns it. + + Never touches Langflow's ``flow`` table — pure shadow so the SaaS layer + can filter/share flows without modifying core Langflow. One flow belongs + to at most one org at a time. + """ + + __tablename__ = "saas_flow_org" + __table_args__ = (Index("ix_saas_flow_org_org", "org_id"),) + + id: UUID = Field(default_factory=uuid4, primary_key=True) + flow_id: UUID = Field(sa_column=Column(sa.Uuid(), nullable=False, unique=True, index=True)) + org_id: UUID = Field( + sa_column=Column(sa.Uuid(), ForeignKey("saas_organization.id", ondelete="CASCADE"), nullable=False) + ) + assigned_by: UUID | None = Field( + sa_column=Column(sa.Uuid(), ForeignKey("user.id", ondelete="SET NULL"), nullable=True) + ) + assigned_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# SQLModel metadata — exported so Alembic env.py can include it +# --------------------------------------------------------------------------- + +saas_metadata = SQLModel.metadata diff --git a/src/backend/saas/langflow_saas/plugin.py b/src/backend/saas/langflow_saas/plugin.py new file mode 100644 index 000000000000..10d4a0d53add --- /dev/null +++ b/src/backend/saas/langflow_saas/plugin.py @@ -0,0 +1,216 @@ +"""SaaS plugin entry point. + +Langflow calls ``register(app_wrapper)`` at startup via the +``langflow.plugins`` entry-point group declared in pyproject.toml. + +What this function does: + 1. Runs SaaS database migrations (saas_* tables) on every startup — safe + because Alembic is idempotent. + 2. Registers three ASGI middleware classes on the real FastAPI app: + a. RateLimitMiddleware (outermost) + b. TenantContextMiddleware (resolves org context) + c. QuotaEnforcementMiddleware (innermost, checks execution quotas) + 3. Mounts all SaaS REST routes under /api/saas/v1/. + 4. Registers a startup handler that auto-creates personal organisations + for any existing Langflow users who don't yet have one (idempotent). + +Upgrade safety: + This file calls only two things from Langflow: + - app_wrapper.add_middleware() (added in plugin_routes.py) + - app_wrapper.include_router() (already present in plugin_routes.py) + If Langflow changes either of these signatures, update only here. +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger("langflow_saas.plugin") + + +# --------------------------------------------------------------------------- +# Migration helper +# --------------------------------------------------------------------------- + + +def _run_migrations() -> None: + """Run SaaS Alembic migrations synchronously at startup. + + Safe to call on every startup — Alembic no-ops when already up to date. + """ + try: + from pathlib import Path + + from alembic import command + from alembic.config import Config + + ini_path = Path(__file__).parent.parent.parent / "alembic.ini" + if not ini_path.exists(): + # Installed as a wheel: resolve relative to the package. + ini_path = Path(__file__).parent.parent / "alembic.ini" + + if not ini_path.exists(): + logger.warning("langflow-saas: alembic.ini not found at %s — skipping migrations.", ini_path) + return + + alembic_cfg = Config(str(ini_path)) + + # Ensure the migrations folder is discoverable regardless of CWD. + migrations_dir = Path(__file__).parent / "migrations" + alembic_cfg.set_main_option("script_location", str(migrations_dir)) + + from langflow_saas.settings import get_saas_settings + + db_url = get_saas_settings().database_url + alembic_cfg.set_main_option("sqlalchemy.url", db_url) + + command.upgrade(alembic_cfg, "heads") + logger.info("langflow-saas: migrations applied.") + except Exception: + logger.exception("langflow-saas: migration failed — SaaS features may not work correctly.") + + +# --------------------------------------------------------------------------- +# Personal-org auto-creation +# --------------------------------------------------------------------------- + + +async def _ensure_personal_orgs() -> None: + """Create personal orgs for any existing Langflow users that don't have one. + + Idempotent: skips users who already have a personal org. + """ + from langflow_saas.settings import get_saas_settings + + if not get_saas_settings().auto_create_personal_org: + return + + try: + import re + from datetime import datetime, timezone + + from langflow.services.database.models.user.model import User + from langflow.services.deps import session_scope + from sqlmodel import select + + from langflow_saas.models import Organization, OrgRole, UserOrganization + + async with session_scope() as db: + users_result = await db.exec(select(User)) + all_users = users_result.all() + + for user in all_users: + uid = user.id + # Check for existing personal org. + existing_personal = await db.exec( + select(Organization).where( + Organization.owner_id == uid, + Organization.is_personal == True, # noqa: E712 + ) + ) + if existing_personal.first(): + continue + + # Generate a unique slug from username. + base_slug = re.sub(r"[^a-z0-9]+", "-", user.username.lower()).strip("-")[:50] + slug = base_slug + attempt = 0 + while True: + collision = await db.exec(select(Organization).where(Organization.slug == slug)) + if not collision.first(): + break + attempt += 1 + slug = f"{base_slug}-{attempt}" + + now = datetime.now(timezone.utc) + org = Organization( + name=f"{user.username}'s workspace", + slug=slug, + owner_id=uid, + is_personal=True, + created_at=now, + updated_at=now, + ) + db.add(org) + await db.flush() + + membership = UserOrganization(user_id=uid, org_id=org.id, role=OrgRole.OWNER, created_at=now) + db.add(membership) + + await db.commit() + logger.info("langflow-saas: personal orgs ensured for all users.") + except Exception: + logger.exception("langflow-saas: failed to ensure personal orgs.") + + +# --------------------------------------------------------------------------- +# Plugin registration — the ONLY function called by Langflow +# --------------------------------------------------------------------------- + + +def register(app) -> None: # ``app`` is _PluginAppWrapper from plugin_routes.py + """Called by Langflow's load_plugin_routes() at startup. + + Parameters + ---------- + app: + A ``_PluginAppWrapper`` instance that exposes ``include_router`` and + ``add_middleware`` (the latter added by our one-line patch to + plugin_routes.py). + """ + logger.info("langflow-saas: registering SaaS plugin…") + + # 1. Run DB migrations (sync — happens before the ASGI server starts + # accepting requests, so there is no race condition). + _run_migrations() + + # 2. Register middleware. Order matters: add_middleware() inserts at the + # outermost position each time, so the last add_middleware call wraps + # everything. We want: + # QuotaEnforcement (innermost — only fires on execution paths) + # TenantContext (resolves org for all API paths) + # RateLimit (outermost — fast Redis check, no DB) + # + # Because add_middleware() inserts outermost each time, we add them + # in REVERSE order of desired execution: + from langflow_saas.middleware import ( + FlowOwnershipMiddleware, + QuotaEnforcementMiddleware, + RateLimitMiddleware, + TenantContextMiddleware, + UserRegistrationMiddleware, + ) + + # Desired execution order (outermost → innermost): + # RateLimit → UserRegistration → TenantContext → FlowOwnership → QuotaEnforcement + # add_middleware() inserts at the outermost position each call, so we add + # in reverse order: + app.add_middleware(QuotaEnforcementMiddleware) # innermost — quota gate on executions + app.add_middleware(FlowOwnershipMiddleware) # auto-assigns new flows to creator's org + app.add_middleware(TenantContextMiddleware) # resolves org context from JWT + app.add_middleware(UserRegistrationMiddleware) # provisions org on signup + app.add_middleware(RateLimitMiddleware) # outermost — fast Redis check + + # 3. Mount SaaS API routes under /api/saas/v1/. + from langflow_saas.api.router import router as saas_router + + app.include_router(saas_router) + + # 4. Register startup hook for personal-org auto-creation. + # We access the real FastAPI app via the wrapper's private _app attribute + # to use the @app.on_event pattern. This is the only place we touch + # the private attribute; if Langflow ever changes the wrapper, update here. + try: + real_app = app._app # _PluginAppWrapper stores the real FastAPI app here + + @real_app.on_event("startup") + async def _saas_startup(): + await _ensure_personal_orgs() + + except AttributeError: + logger.warning( + "langflow-saas: could not register startup hook " + "(wrapper._app not accessible). Personal orgs will not be auto-created." + ) + + logger.info("langflow-saas: plugin registered. Routes at /api/saas/v1/") diff --git a/src/backend/saas/langflow_saas/services.py b/src/backend/saas/langflow_saas/services.py new file mode 100644 index 000000000000..d22b1e666191 --- /dev/null +++ b/src/backend/saas/langflow_saas/services.py @@ -0,0 +1,476 @@ +"""Service layer for the SaaS plugin. + +Three services: + + AuditService — append-only audit log writer. + EmailService — abstraction over multiple email providers (console, + SMTP, SendGrid, Resend). Provider is selected at startup + from SAAS_EMAIL_PROVIDER. + BillingService — Stripe integration: create/sync subscriptions, handle + webhook events. + +All services are module-level singletons initialised lazily on first use. +""" + +from __future__ import annotations + +import logging +import smtplib +import ssl +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from typing import Any +from uuid import UUID + +logger = logging.getLogger("langflow_saas.services") + + +# =========================================================================== +# Audit Service +# =========================================================================== + + +class AuditService: + """Append-only structured audit log. Never modifies or deletes rows.""" + + async def log( + self, + *, + action: str, + org_id: UUID | None = None, + user_id: UUID | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + log_metadata: dict[str, Any] | None = None, + ip_address: str | None = None, + user_agent: str | None = None, + ) -> None: + try: + from langflow.services.deps import session_scope + + from langflow_saas.models import AuditLog + + entry = AuditLog( + action=action, + org_id=org_id, + user_id=user_id, + resource_type=resource_type, + resource_id=resource_id, + log_metadata=log_metadata, + ip_address=ip_address, + user_agent=user_agent, + ) + async with session_scope() as db: + db.add(entry) + await db.commit() + except Exception: # noqa: BLE001 + # Audit failures must never surface to callers. + logger.warning("Failed to write audit log: action=%s org=%s user=%s", action, org_id, user_id) + + +_audit_service: AuditService | None = None + + +def get_audit_service() -> AuditService: + global _audit_service + if _audit_service is None: + _audit_service = AuditService() + return _audit_service + + +# =========================================================================== +# Email Service +# =========================================================================== + + +class BaseEmailService(ABC): + """Abstract email sender. Implement send_raw() in subclasses.""" + + @abstractmethod + async def send_raw(self, *, to: str, subject: str, html: str, text: str) -> None: ... + + async def send_invitation( + self, + *, + to_email: str, + org_name: str, + inviter_name: str, + role: str, + accept_url: str, + expire_hours: int, + ) -> None: + subject = f"You've been invited to join {org_name} on Langflow" + html = f""" +

You're invited!

+

{inviter_name} has invited you to join {org_name} + as a {role}.

+

This invitation expires in {expire_hours} hours.

+

Accept Invitation

+

Or copy this URL: {accept_url}

+ """ + text = ( + f"{inviter_name} invited you to join {org_name} as {role}.\n" + f"Accept here: {accept_url}\n" + f"Expires in {expire_hours} hours." + ) + await self.send_raw(to=to_email, subject=subject, html=html, text=text) + + async def send_password_reset(self, *, to_email: str, reset_url: str, expire_hours: int) -> None: + subject = "Reset your Langflow password" + html = f""" +

Password Reset Request

+

We received a request to reset your password.

+

This link expires in {expire_hours} hours.

+

Reset Password

+

If you didn't request this, ignore this email.

+ """ + text = f"Reset your password: {reset_url}\nExpires in {expire_hours} hours." + await self.send_raw(to=to_email, subject=subject, html=html, text=text) + + async def send_quota_warning(self, *, to_email: str, org_name: str, metric: str, used: int, limit: int) -> None: + pct = int(used / limit * 100) if limit else 0 + subject = f"[{org_name}] Usage alert: {pct}% of {metric} quota used" + html = f""" +

Usage Alert for {org_name}

+

Your organization has used {used}/{limit} ({pct}%) of its + daily {metric} quota.

+

Upgrade your plan to increase your limits.

+ """ + text = f"[{org_name}] {metric}: {used}/{limit} ({pct}%) used." + await self.send_raw(to=to_email, subject=subject, html=html, text=text) + + +class ConsoleEmailService(BaseEmailService): + """Development stub — prints emails to stdout instead of sending them.""" + + async def send_raw(self, *, to: str, subject: str, html: str, text: str) -> None: + logger.info("=== [EMAIL] To: %s | Subject: %s ===\n%s", to, subject, text) + + +class SMTPEmailService(BaseEmailService): + def __init__(self) -> None: + from langflow_saas.settings import get_saas_settings + + s = get_saas_settings() + self._host = s.smtp_host + self._port = s.smtp_port + self._starttls = s.smtp_starttls + self._user = s.smtp_user + self._password = s.smtp_password.get_secret_value() + self._from = f"{s.email_from_name} <{s.email_from}>" + + async def send_raw(self, *, to: str, subject: str, html: str, text: str) -> None: + import asyncio + + await asyncio.to_thread(self._send_sync, to=to, subject=subject, html=html, text=text) + + def _send_sync(self, *, to: str, subject: str, html: str, text: str) -> None: + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = self._from + msg["To"] = to + msg.attach(MIMEText(text, "plain")) + msg.attach(MIMEText(html, "html")) + + context = ssl.create_default_context() + with smtplib.SMTP(self._host, self._port) as server: + if self._starttls: + server.starttls(context=context) + if self._user: + server.login(self._user, self._password) + server.sendmail(self._from, [to], msg.as_string()) + + +class SendGridEmailService(BaseEmailService): + def __init__(self) -> None: + from langflow_saas.settings import get_saas_settings + + s = get_saas_settings() + self._api_key = s.sendgrid_api_key.get_secret_value() + self._from = s.email_from + self._from_name = s.email_from_name + + async def send_raw(self, *, to: str, subject: str, html: str, text: str) -> None: + import httpx + + payload = { + "personalizations": [{"to": [{"email": to}]}], + "from": {"email": self._from, "name": self._from_name}, + "subject": subject, + "content": [ + {"type": "text/plain", "value": text}, + {"type": "text/html", "value": html}, + ], + } + async with httpx.AsyncClient() as client: + resp = await client.post( + "https://api.sendgrid.com/v3/mail/send", + json=payload, + headers={"Authorization": f"Bearer {self._api_key}"}, + timeout=10, + ) + resp.raise_for_status() + + +class ResendEmailService(BaseEmailService): + def __init__(self) -> None: + from langflow_saas.settings import get_saas_settings + + s = get_saas_settings() + self._api_key = s.resend_api_key.get_secret_value() + self._from = f"{s.email_from_name} <{s.email_from}>" + + async def send_raw(self, *, to: str, subject: str, html: str, text: str) -> None: + import httpx + + payload = {"from": self._from, "to": [to], "subject": subject, "html": html, "text": text} + async with httpx.AsyncClient() as client: + resp = await client.post( + "https://api.resend.com/emails", + json=payload, + headers={"Authorization": f"Bearer {self._api_key}"}, + timeout=10, + ) + resp.raise_for_status() + + +_email_service: BaseEmailService | None = None + + +def get_email_service() -> BaseEmailService: + global _email_service + if _email_service is None: + from langflow_saas.settings import get_saas_settings + + provider = get_saas_settings().email_provider.lower() + _email_service = { + "console": ConsoleEmailService, + "smtp": SMTPEmailService, + "sendgrid": SendGridEmailService, + "resend": ResendEmailService, + }.get(provider, ConsoleEmailService)() + return _email_service + + +# =========================================================================== +# Billing Service (Stripe) +# =========================================================================== + + +class BillingService: + """Stripe billing operations. + + All Stripe calls are gated behind ``settings.billing_enabled`` so the + plugin works without Stripe credentials in development. + """ + + def _stripe(self): + """Return the configured stripe module or raise if not available.""" + import stripe as _stripe + + from langflow_saas.settings import get_saas_settings + + key = get_saas_settings().stripe_secret_key.get_secret_value() + if not key: + raise RuntimeError("SAAS_STRIPE_SECRET_KEY is not set.") + _stripe.api_key = key + return _stripe + + async def get_or_create_customer(self, *, org_id: UUID, org_name: str, email: str) -> str: + """Return Stripe customer_id, creating one if absent.""" + import asyncio + + from langflow.services.deps import session_scope + from sqlmodel import select + + from langflow_saas.models import Organization + + async with session_scope() as db: + result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = result.first() + if org and org.stripe_customer_id: + return org.stripe_customer_id + + # Create new Stripe customer. + stripe = self._stripe() + customer = await asyncio.to_thread( + stripe.Customer.create, + name=org_name, + email=email, + metadata={"langflow_org_id": str(org_id)}, + ) + customer_id: str = customer["id"] + + # Persist the customer_id. + async with session_scope() as db: + result = await db.exec(select(Organization).where(Organization.id == org_id)) + org = result.first() + if org: + org.stripe_customer_id = customer_id + org.updated_at = datetime.now(timezone.utc) + db.add(org) + await db.commit() + + return customer_id + + async def create_checkout_session( + self, + *, + org_id: UUID, + org_name: str, + owner_email: str, + stripe_price_id: str, + success_url: str, + cancel_url: str, + ) -> str: + """Create a Stripe Checkout Session and return the redirect URL.""" + import asyncio + + stripe = self._stripe() + customer_id = await self.get_or_create_customer(org_id=org_id, org_name=org_name, email=owner_email) + session = await asyncio.to_thread( + stripe.checkout.Session.create, + customer=customer_id, + mode="subscription", + line_items=[{"price": stripe_price_id, "quantity": 1}], + success_url=success_url, + cancel_url=cancel_url, + metadata={"langflow_org_id": str(org_id)}, + ) + return session["url"] + + async def handle_webhook(self, *, payload: bytes, sig_header: str) -> dict[str, Any]: + """Verify and process a Stripe webhook event. + + Returns a dict describing what was processed (for logging). + """ + import asyncio + + from langflow_saas.settings import get_saas_settings + + stripe = self._stripe() + webhook_secret = get_saas_settings().stripe_webhook_secret.get_secret_value() + + event = await asyncio.to_thread(stripe.Webhook.construct_event, payload, sig_header, webhook_secret) + + event_type: str = event["type"] + handlers = { + "customer.subscription.created": self._on_subscription_created, + "customer.subscription.updated": self._on_subscription_updated, + "customer.subscription.deleted": self._on_subscription_deleted, + "invoice.payment_failed": self._on_payment_failed, + } + + handler = handlers.get(event_type) + if handler: + await handler(event["data"]["object"]) + return {"processed": True, "event_type": event_type} + + return {"processed": False, "event_type": event_type} + + async def _on_subscription_created(self, subscription: dict) -> None: + await self._upsert_subscription(subscription, status_override=None) + + async def _on_subscription_updated(self, subscription: dict) -> None: + await self._upsert_subscription(subscription, status_override=None) + + async def _on_subscription_deleted(self, subscription: dict) -> None: + await self._upsert_subscription(subscription, status_override="canceled") + + async def _on_payment_failed(self, invoice: dict) -> None: + sub_id = invoice.get("subscription") + if not sub_id: + return + from langflow.services.deps import session_scope + from sqlmodel import select + + from langflow_saas.models import Subscription, SubscriptionStatus + + async with session_scope() as db: + result = await db.exec(select(Subscription).where(Subscription.stripe_subscription_id == sub_id)) + sub = result.first() + if sub: + sub.status = SubscriptionStatus.PAST_DUE + sub.updated_at = datetime.now(timezone.utc) + db.add(sub) + await db.commit() + + async def _upsert_subscription(self, stripe_sub: dict, *, status_override: str | None) -> None: + """Sync a Stripe subscription object into our DB.""" + from langflow.services.deps import session_scope + from sqlmodel import select + + from langflow_saas.models import Organization, Plan, Subscription, SubscriptionStatus + + customer_id: str = stripe_sub["customer"] + stripe_sub_id: str = stripe_sub["id"] + stripe_price_id: str | None = stripe_sub.get("items", {}).get("data", [{}])[0].get("price", {}).get("id") + raw_status = status_override or stripe_sub.get("status", "active") + + try: + stripe_status = SubscriptionStatus(raw_status) + except ValueError: + stripe_status = SubscriptionStatus.ACTIVE + + period_start = stripe_sub.get("current_period_start") + period_end = stripe_sub.get("current_period_end") + trial_end = stripe_sub.get("trial_end") + + async with session_scope() as db: + # Find org by Stripe customer_id. + org_result = await db.exec(select(Organization).where(Organization.stripe_customer_id == customer_id)) + org = org_result.first() + if not org: + logger.warning("Stripe webhook: no org found for customer %s", customer_id) + return + + # Match plan by Stripe price_id. + plan: Plan | None = None + if stripe_price_id: + plan_result = await db.exec( + select(Plan).where( + (Plan.stripe_monthly_price_id == stripe_price_id) + | (Plan.stripe_yearly_price_id == stripe_price_id) + ) + ) + plan = plan_result.first() + + sub_result = await db.exec(select(Subscription).where(Subscription.org_id == org.id)) + sub = sub_result.first() + + now = datetime.now(timezone.utc) + if sub is None: + sub = Subscription( + org_id=org.id, + plan_id=plan.id if plan else org.plan_id, # type: ignore[arg-type] + ) + db.add(sub) + + sub.stripe_subscription_id = stripe_sub_id + sub.stripe_price_id = stripe_price_id + sub.status = stripe_status + sub.current_period_start = datetime.fromtimestamp(period_start, tz=timezone.utc) if period_start else None + sub.current_period_end = datetime.fromtimestamp(period_end, tz=timezone.utc) if period_end else None + sub.trial_end = datetime.fromtimestamp(trial_end, tz=timezone.utc) if trial_end else None + sub.cancel_at_period_end = stripe_sub.get("cancel_at_period_end", False) + sub.updated_at = now + + if plan: + org.plan_id = plan.id + org.updated_at = now + db.add(org) + await db.commit() + + +_billing_service: BillingService | None = None + + +def get_billing_service() -> BillingService: + global _billing_service + if _billing_service is None: + _billing_service = BillingService() + return _billing_service diff --git a/src/backend/saas/langflow_saas/settings.py b/src/backend/saas/langflow_saas/settings.py new file mode 100644 index 000000000000..3acec3382def --- /dev/null +++ b/src/backend/saas/langflow_saas/settings.py @@ -0,0 +1,120 @@ +"""All SaaS configuration, driven entirely by environment variables. + +Every setting has a safe default so the plugin works out-of-the-box in +development with zero extra configuration. In production, set the variables +that are relevant to your deployment (billing, email, Redis, etc.). + +Prefix: SAAS_ (e.g. SAAS_REDIS_URL, SAAS_STRIPE_SECRET_KEY) +""" + +from __future__ import annotations + +import os + +from pydantic import Field, SecretStr +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class SaaSSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="SAAS_", + env_file=".env", + extra="ignore", + populate_by_name=True, + ) + + # ------------------------------------------------------------------ + # Database + # Falls back to Langflow's own DB URL so no extra config is needed + # in single-DB setups. + # ------------------------------------------------------------------ + database_url: str = Field( + default_factory=lambda: os.getenv("LANGFLOW_DATABASE_URL", "sqlite:///./langflow.db"), + description="DB URL for SaaS tables (defaults to LANGFLOW_DATABASE_URL).", + ) + + # ------------------------------------------------------------------ + # Redis (rate limiting + usage counters) + # ------------------------------------------------------------------ + redis_url: str = Field(default="redis://localhost:6379/1") + redis_enabled: bool = Field(default=True) + + # ------------------------------------------------------------------ + # Rate limiting + # ------------------------------------------------------------------ + rate_limit_enabled: bool = Field(default=True) + rate_limit_default_rpm: int = Field(default=60, description="Requests per minute per user (default plan).") + rate_limit_burst_multiplier: int = Field(default=2, description="Burst capacity = rpm * multiplier.") + # Path prefixes to apply rate limiting to. + rate_limit_paths: list[str] = Field(default=["/api/v1/", "/api/v2/", "/api/saas/"]) + + # ------------------------------------------------------------------ + # Email + # ------------------------------------------------------------------ + email_provider: str = Field( + default="console", + description="Email backend: console | smtp | sendgrid | resend", + ) + email_from: str = Field(default="noreply@example.com") + email_from_name: str = Field(default="Langflow") + # SMTP + smtp_host: str = Field(default="localhost") + smtp_port: int = Field(default=587) + smtp_starttls: bool = Field(default=True) + smtp_user: str = Field(default="") + smtp_password: SecretStr = Field(default=SecretStr("")) + # SendGrid + sendgrid_api_key: SecretStr = Field(default=SecretStr("")) + # Resend + resend_api_key: SecretStr = Field(default=SecretStr("")) + + # ------------------------------------------------------------------ + # Billing (Stripe) + # ------------------------------------------------------------------ + billing_enabled: bool = Field(default=False) + stripe_secret_key: SecretStr = Field(default=SecretStr("")) + stripe_webhook_secret: SecretStr = Field(default=SecretStr("")) + stripe_publishable_key: str = Field(default="") + + # ------------------------------------------------------------------ + # Invitations + # ------------------------------------------------------------------ + invitation_expire_hours: int = Field(default=48) + app_base_url: str = Field( + default_factory=lambda: os.getenv("LANGFLOW_BASE_URL", "http://localhost:7860"), + description="Public base URL used in invitation/reset email links.", + ) + invitation_secret: SecretStr = Field( + default_factory=lambda: SecretStr(os.getenv("SAAS_INVITATION_SECRET", "change-me-in-production")), + description="Secret for signing invitation tokens.", + ) + + # ------------------------------------------------------------------ + # Multi-tenancy behaviour + # ------------------------------------------------------------------ + auto_create_personal_org: bool = Field( + default=True, + description="Automatically create a personal org when a new Langflow user is created.", + ) + require_org_header: bool = Field( + default=False, + description="Reject API calls that don't include X-Org-ID when the user belongs to multiple orgs.", + ) + + # ------------------------------------------------------------------ + # Default plan quotas (used before billing is set up) + # ------------------------------------------------------------------ + default_max_flows: int = Field(default=50) + default_max_executions_per_day: int = Field(default=1000) + default_max_members: int = Field(default=5) + default_max_storage_mb: int = Field(default=500) + + +_settings: SaaSSettings | None = None + + +def get_saas_settings() -> SaaSSettings: + global _settings + if _settings is None: + _settings = SaaSSettings() + return _settings diff --git a/src/backend/saas/pyproject.toml b/src/backend/saas/pyproject.toml new file mode 100644 index 000000000000..73762b26e230 --- /dev/null +++ b/src/backend/saas/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "langflow-saas" +version = "0.1.0" +description = "SaaS multi-tenancy plugin for Langflow — organizations, billing, rate-limiting, invitations" +requires-python = ">=3.10,<3.14" +license = "MIT" + +dependencies = [ + # Ties to the same langflow-base major version so APIs stay compatible. + # Bump the lower bound whenever a Langflow upgrade changes an integration + # point (currently only: langflow.services.deps session_scope, + # langflow.services.auth.utils get_current_user_from_access_token, + # and the langflow.plugins entry-point protocol in plugin_routes.py). + "langflow-base>=0.9.0,<1.0.0", + "sqlmodel~=0.0.37", + "pydantic>=2.0.0", + "pydantic-settings>=2.2.0", + "alembic>=1.13.0,<2.0.0", + "redis>=5.0.0,<6.0.0", + "stripe>=8.0.0,<9.0.0", + "httpx>=0.27.0", # for email providers without a heavy SDK + "nanoid>=2.0.0,<3.0.0", + "itsdangerous>=2.0.0", # for signed invitation tokens +] + +[project.entry-points."langflow.plugins"] +# This is the ONLY integration point with Langflow core. +# Langflow calls register(wrapper) at startup via load_plugin_routes(). +saas = "langflow_saas.plugin:register" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["langflow_saas"]