diff --git a/apps/web/src/app/api/ai/journl-agent/route.ts b/apps/web/src/app/api/ai/journl-agent/route.ts index c43f005..0b7bedc 100644 --- a/apps/web/src/app/api/ai/journl-agent/route.ts +++ b/apps/web/src/app/api/ai/journl-agent/route.ts @@ -1,3 +1,4 @@ +import { after } from "next/server"; import { journlAgent, setJournlRuntimeContext } from "~/ai/agents/journl-agent"; import type { JournlAgentContext } from "~/ai/agents/journl-agent-context"; import { handler as corsHandler } from "~/app/api/_cors/cors"; @@ -18,30 +19,6 @@ async function handler(req: Request) { const result = await journlAgent.streamVNext(messages, { format: "aisdk", - onFinish: async (result) => { - const modelData = await journlAgent.getModel(); - - const provider = modelData.provider; - const model = modelData.modelId; - - if (result.usage && session.user?.id) { - await api.usage.trackModelUsage({ - metrics: [ - { - quantity: result.usage.promptTokens, - unit: "input_tokens", - }, - { - quantity: result.usage.completionTokens, - unit: "output_tokens", - }, - ], - model_id: model, - model_provider: provider, - user_id: session.user.id, - }); - } - }, runtimeContext: setJournlRuntimeContext({ ...rest.context, user: { @@ -51,6 +28,43 @@ async function handler(req: Request) { } satisfies JournlAgentContext), }); + if (session.user?.id) { + after(async () => { + try { + const fullOutput = await result.getFullOutput(); + const usage = fullOutput.usage; + + if (usage) { + const modelData = await journlAgent.getModel(); + const provider = modelData.provider; + const model = modelData.modelId; + + await api.usage.trackModelUsage({ + metrics: [ + { + quantity: usage.promptTokens || 0, + unit: "input_tokens", + }, + { + quantity: usage.completionTokens || 0, + unit: "output_tokens", + }, + { + quantity: usage.reasoningTokens || 0, + unit: "reasoning_tokens", + }, + ], + model_id: model, + model_provider: provider, + user_id: session.user.id, + }); + } + } catch (error) { + console.error("[usage tracking] error:", error); + } + }); + } + return result.toUIMessageStreamResponse(); } catch (error) { console.error("[api.chat.route] error 👀", error); diff --git a/apps/web/src/app/api/supabase/usage/route.ts b/apps/web/src/app/api/supabase/usage/route.ts new file mode 100644 index 0000000..74b8b63 --- /dev/null +++ b/apps/web/src/app/api/supabase/usage/route.ts @@ -0,0 +1,39 @@ +import { zUsageEventWebhook } from "@acme/db/schema"; +import { NextResponse } from "next/server"; +import { api } from "~/trpc/server"; +import { handler } from "../_lib/webhook-handler"; + +/** + * This webhook processes usage events when they are created or updated. + * Usage events are created when AI features are used (chat, embedding, etc.) + */ +export const POST = handler(zUsageEventWebhook, async (payload) => { + // Skip DELETE events + if (payload.type === "DELETE") { + return NextResponse.json({ success: true }); + } + + // Skip processing if the event is already processed + if (payload.record?.status === "processed") { + return NextResponse.json({ success: true }); + } + + try { + // Process the usage event with the usage period + const result = await api.usage.processUsageEvent({ + usage_event_id: payload.record.id, + user_id: payload.record.user_id, + }); + + return NextResponse.json({ result, success: true }); + } catch (error) { + return NextResponse.json( + { + details: error instanceof Error ? error.message : "Unknown error", + error: "Processing failed", + success: false, + }, + { status: 500 }, + ); + } +}); diff --git a/packages/api/src/api-router/index.ts b/packages/api/src/api-router/index.ts index 2fbcadd..e69b9c8 100644 --- a/packages/api/src/api-router/index.ts +++ b/packages/api/src/api-router/index.ts @@ -2,6 +2,7 @@ import { createTRPCRouter } from "../trpc.js"; import { authRouter } from "./auth.js"; import { documentRouter } from "./document.js"; import { journalRouter } from "./journal.js"; +import { modelPricingRouter } from "./model-pricing.js"; import { notesRouter } from "./notes.js"; import { pagesRouter } from "./pages.js"; import { subscriptionRouter } from "./subscription.js"; @@ -11,6 +12,7 @@ export const apiRouter = createTRPCRouter({ auth: authRouter, document: documentRouter, journal: journalRouter, + modelPricing: modelPricingRouter, notes: notesRouter, pages: pagesRouter, subscription: subscriptionRouter, diff --git a/packages/api/src/api-router/model-pricing.ts b/packages/api/src/api-router/model-pricing.ts new file mode 100644 index 0000000..559e90f --- /dev/null +++ b/packages/api/src/api-router/model-pricing.ts @@ -0,0 +1,100 @@ +import { and, desc, eq, lte } from "@acme/db"; +import { ModelPricing, zInsertModelPricing } from "@acme/db/schema"; +import { TRPCError, type TRPCRouterRecord } from "@trpc/server"; +import { z } from "zod/v4"; +import { publicProcedure } from "../trpc.js"; + +export const modelPricingRouter = { + getAllPricingForModel: publicProcedure + .input( + z.object({ + model_id: z.string(), + model_provider: z.string(), + }), + ) + .query(async ({ ctx, input }) => { + try { + const pricing = await ctx.db.query.ModelPricing.findMany({ + orderBy: [desc(ModelPricing.effective_date)], + where: and( + eq(ModelPricing.model_id, input.model_id), + eq(ModelPricing.model_provider, input.model_provider), + lte(ModelPricing.effective_date, new Date().toISOString()), + ), + }); + + return pricing; + } catch (error) { + console.error( + "Database error in modelPricing.getAllPricingForModel:", + error, + ); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to get model pricing", + }); + } + }), + getCurrentPricing: publicProcedure + .input( + z.object({ + model_id: z.string(), + model_provider: z.string(), + unit_type: z.string(), + }), + ) + .query(async ({ ctx, input }) => { + try { + const pricing = await ctx.db.query.ModelPricing.findFirst({ + orderBy: [desc(ModelPricing.effective_date)], + where: and( + eq(ModelPricing.model_id, input.model_id), + eq(ModelPricing.model_provider, input.model_provider), + eq(ModelPricing.unit_type, input.unit_type), + lte(ModelPricing.effective_date, new Date().toISOString()), + ), + }); + + return pricing; + } catch (error) { + console.error( + "Database error in modelPricing.getCurrentPricing:", + error, + ); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to get current pricing", + }); + } + }), + + upsertPricing: publicProcedure + .input(zInsertModelPricing) + .mutation(async ({ ctx, input }) => { + try { + const [pricing] = await ctx.db + .insert(ModelPricing) + .values(input) + .onConflictDoUpdate({ + set: { + price_per_unit: input.price_per_unit, + }, + target: [ + ModelPricing.model_id, + ModelPricing.model_provider, + ModelPricing.unit_type, + ModelPricing.effective_date, + ], + }) + .returning(); + + return pricing; + } catch (error) { + console.error("Database error in modelPricing.upsertPricing:", error); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to upsert pricing", + }); + } + }), +} satisfies TRPCRouterRecord; diff --git a/packages/api/src/api-router/subscription.ts b/packages/api/src/api-router/subscription.ts index b830aa4..5b94947 100644 --- a/packages/api/src/api-router/subscription.ts +++ b/packages/api/src/api-router/subscription.ts @@ -1,7 +1,6 @@ -import { and, eq, or } from "@acme/db"; -import { Subscription } from "@acme/db/schema"; import type { TRPCRouterRecord } from "@trpc/server"; import { z } from "zod/v4"; +import { getActiveSubscription } from "../shared/subscription"; import { protectedProcedure } from "../trpc"; /** @@ -66,36 +65,8 @@ export const subscriptionRouter = { quota: plan.quota, }); }), - // getActivePlan: protectedProcedure.query(async ({ ctx }) => { - // const activeSubscription = await getActiveSubscription({ ctx }); - // if (!activeSubscription?.priceId) return null; - - // const plan = await ctx.db.query.Plan.findFirst({ - // where: eq(Plan.id, activeSubscription.priceId), - // with: { - // prices: true, - // }, - // }); - - // return plan; - // }), getSubscription: protectedProcedure.query(async ({ ctx }) => { - const subscription = await ctx.db.query.Subscription.findFirst({ - where: and( - eq(Subscription.referenceId, ctx.session.user.id), - or( - eq(Subscription.status, "active"), - eq(Subscription.status, "trialing"), - ), - ), - with: { - plan: { - with: { - price: true, - }, - }, - }, - }); + const subscription = await getActiveSubscription({ ctx }); if (!subscription?.plan) { return null; diff --git a/packages/api/src/api-router/usage.ts b/packages/api/src/api-router/usage.ts index fb15c0a..656a6ea 100644 --- a/packages/api/src/api-router/usage.ts +++ b/packages/api/src/api-router/usage.ts @@ -1,10 +1,231 @@ -import { eq } from "@acme/db"; -import { UsageEvent, UsageEventStatus } from "@acme/db/schema"; +import { and, desc, eq, lte, sql } from "@acme/db"; +import { + ModelPricing, + UsageAggregate, + UsageEvent, + UsageEventStatus, + UsagePeriod, +} from "@acme/db/schema"; import { TRPCError, type TRPCRouterRecord } from "@trpc/server"; +import { gte } from "drizzle-orm"; import { z } from "zod/v4"; +import type { TRPCContext } from "../trpc.js"; import { publicProcedure } from "../trpc.js"; +/** + * Get the current usage period aggregate for a user + * Returns a complete aggregate with period, plan, and subscription data + * Periods are created by webhooks (pro) and cron jobs (free) + */ +async function getCurrentUsagePeriod({ + ctx, + userId, +}: { + ctx: TRPCContext; + userId: string; +}) { + const now = new Date().toISOString(); + + const usagePeriodAggregate = await ctx.db.query.UsagePeriod.findFirst({ + orderBy: (fields, { desc }) => [ + desc(fields.subscription_id), + desc(fields.created_at), + ], + where: and( + eq(UsagePeriod.user_id, userId), + lte(UsagePeriod.period_start, now), + gte(UsagePeriod.period_end, now), + ), + with: { + plan: true, + subscription: { + with: { + plan: true, + }, + }, + usageAggregate: true, + }, + }); + + if (!usagePeriodAggregate) { + throw new TRPCError({ + code: "PRECONDITION_FAILED", + message: + "No active usage period found. Usage periods are created automatically via webhooks and scheduled jobs.", + }); + } + + return usagePeriodAggregate; +} + export const usageRouter = { + checkUsage: publicProcedure + .input(z.object({ user_id: z.string() })) + .query(async ({ ctx, input }) => { + try { + const usagePeriodAggregate = await getCurrentUsagePeriod({ + ctx, + userId: input.user_id, + }); + + const { plan } = usagePeriodAggregate; + if (!plan) { + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Usage period missing plan data", + }); + } + + const quotaUsd = plan.quota / 100; + + const currentUsageUsd = usagePeriodAggregate.usageAggregate + ? Number.parseFloat(usagePeriodAggregate.usageAggregate.total_cost) + : 0; + + const canUse = currentUsageUsd < quotaUsd; + const remainingQuotaUsd = quotaUsd - currentUsageUsd; + + return { + canUse, + currentUsageUsd, + quotaUsd, + remainingQuotaUsd: Math.max(0, remainingQuotaUsd), + subscriptionType: usagePeriodAggregate.subscription ? "pro" : "free", + usagePeriodId: usagePeriodAggregate.id, + }; + } catch (error) { + console.error("Error checking usage:", error); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to check usage", + }); + } + }), + getCurrentUsagePeriod: publicProcedure + .input(z.object({ user_id: z.string() })) + .query(async ({ ctx, input }) => { + return getCurrentUsagePeriod({ + ctx, + userId: input.user_id, + }); + }), + processUsageEvent: publicProcedure + .input( + z.object({ + usage_event_id: z.string(), + user_id: z.string(), + }), + ) + .mutation(async ({ ctx, input }) => { + return await ctx.db.transaction(async (tx) => { + try { + // Get the usage event + const usageEvent = await tx.query.UsageEvent.findFirst({ + where: eq(UsageEvent.id, input.usage_event_id), + }); + + if (!usageEvent) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "Usage event not found", + }); + } + + if (usageEvent.status === "processed") { + return { message: "Already processed", success: true }; + } + + const usagePeriod = await getCurrentUsagePeriod({ + ctx, + userId: input.user_id, + }); + + const eventDate = usageEvent.created_at + ? usageEvent.created_at + : new Date().toISOString(); + + // Fetch all pricing for this model upfront + const allPricing = await tx + .select() + .from(ModelPricing) + .where( + and( + eq(ModelPricing.model_id, usageEvent.model_id), + eq(ModelPricing.model_provider, usageEvent.model_provider), + lte(ModelPricing.effective_date, eventDate), + ), + ) + .orderBy(desc(ModelPricing.effective_date)); + + // Create lookup map + const pricingMap = new Map(); + for (const price of allPricing) { + if (!pricingMap.has(price.unit_type)) { + pricingMap.set(price.unit_type, price); + } + } + + // Calculate total cost using in-memory pricing lookups + let totalCost = 0; + for (const metric of usageEvent.metrics) { + const pricing = pricingMap.get(metric.unit); + + if (pricing) { + const cost = + Number.parseFloat(pricing.price_per_unit) * metric.quantity; + totalCost += cost; + } else { + console.warn( + `No pricing found for model ${usageEvent.model_id} (${usageEvent.model_provider}) unit ${metric.unit}`, + ); + } + } + + // Update or create usage aggregate + await tx + .insert(UsageAggregate) + .values({ + total_cost: totalCost.toFixed(6), + usage_period_id: usagePeriod.id, + user_id: usageEvent.user_id, + }) + .onConflictDoUpdate({ + set: { + total_cost: sql`${UsageAggregate.total_cost} + ${totalCost.toFixed(6)}`, + }, + target: [UsageAggregate.user_id, UsageAggregate.usage_period_id], + }); + + // Mark usage event as processed and store the calculated cost + await tx + .update(UsageEvent) + .set({ + status: "processed", + total_cost: totalCost.toFixed(6), + }) + .where(eq(UsageEvent.id, input.usage_event_id)); + + return { + cost_added: totalCost, + success: true, + usage_period_id: usagePeriod.id, + }; + } catch (error) { + console.error("Database error in usage.processUsageEvent:", error); + + // Mark usage event as failed within the transaction + await tx + .update(UsageEvent) + .set({ status: "failed" }) + .where(eq(UsageEvent.id, input.usage_event_id)); + + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to process usage event", + }); + } + }); + }), trackModelUsage: publicProcedure .input( z.object({ diff --git a/packages/api/src/shared/subscription.ts b/packages/api/src/shared/subscription.ts new file mode 100644 index 0000000..7ec07a8 --- /dev/null +++ b/packages/api/src/shared/subscription.ts @@ -0,0 +1,47 @@ +import { and, eq, or } from "@acme/db"; +import { Subscription } from "@acme/db/schema"; +import type { TRPCContext } from "../trpc"; + +/** + * Get the free plan from the database + */ +export async function getFreePlan(ctx: TRPCContext) { + return ctx.db.query.Plan.findFirst({ + where: (plans, { eq, and }) => + and(eq(plans.active, true), eq(plans.name, "free")), + }); +} + +/** + * Get the active subscription for a user + */ +export async function getActiveSubscription({ + ctx, + userId, +}: { + ctx: TRPCContext; + userId?: string; +}) { + const userIdToUse = userId ?? ctx.session?.user?.id; + + if (!userIdToUse) { + return null; + } + + return ctx.db.query.Subscription.findFirst({ + where: and( + eq(Subscription.referenceId, userIdToUse), + or( + eq(Subscription.status, "active"), + eq(Subscription.status, "trialing"), + ), + ), + with: { + plan: { + with: { + price: true, + }, + }, + }, + }); +} diff --git a/packages/auth/src/index.ts b/packages/auth/src/index.ts index b9121a9..9746ceb 100644 --- a/packages/auth/src/index.ts +++ b/packages/auth/src/index.ts @@ -7,6 +7,10 @@ import { oAuthProxy, organization } from "better-auth/plugins"; import { eq } from "drizzle-orm"; import { stripeClient } from "./stripe-client"; import { handleStripeWebhookEvent } from "./stripe-webhooks"; +import { + createInitialUsagePeriodForUser, + createUsagePeriodForSubscription, +} from "./usage/usage-period-lifecycle"; export function initAuth(options: { appName: string; @@ -38,6 +42,9 @@ export function initAuth(options: { plugins: [ stripe({ createCustomerOnSignUp: true, + onCustomerCreate: async ({ user }) => { + await createInitialUsagePeriodForUser(user.id); + }, onEvent: handleStripeWebhookEvent, schema: { subscription: { @@ -112,3 +119,5 @@ export function initAuth(options: { export type Auth = ReturnType; export type Session = Auth["$Infer"]["Session"]; + +export { createInitialUsagePeriodForUser, createUsagePeriodForSubscription }; diff --git a/packages/auth/src/stripe/plan-handler.ts b/packages/auth/src/stripe/plan-handler.ts index db04959..39f6abf 100644 --- a/packages/auth/src/stripe/plan-handler.ts +++ b/packages/auth/src/stripe/plan-handler.ts @@ -4,19 +4,31 @@ import { eq } from "drizzle-orm"; import type Stripe from "stripe"; async function upsertPlan(product: Stripe.Product) { - const data: InsertPlan = { + const quota = Number(product.metadata.quota) || 0; + + const insertData: InsertPlan = { active: product.active, description: product.description, displayName: product.name, id: product.id, + metadata: product.metadata, name: product.name.toLowerCase(), - quota: Number(product.metadata.quota), + quota, }; - return db.insert(Plan).values(data).onConflictDoUpdate({ - set: data, - target: Plan.id, - }); + return db + .insert(Plan) + .values(insertData) + .onConflictDoUpdate({ + set: { + active: insertData.active, + description: insertData.description, + displayName: insertData.displayName, + metadata: insertData.metadata, + quota: insertData.quota, + }, + target: Plan.id, + }); } async function deletePlan(productId: string) { diff --git a/packages/auth/src/stripe/price-handler.ts b/packages/auth/src/stripe/price-handler.ts index 9dffbc5..12a0526 100644 --- a/packages/auth/src/stripe/price-handler.ts +++ b/packages/auth/src/stripe/price-handler.ts @@ -4,30 +4,47 @@ import { eq } from "drizzle-orm"; import type Stripe from "stripe"; async function upsertPrice(price: Stripe.Price) { - if (!price.unit_amount || !price.recurring) { - throw new Error(`Price data is missing for price: ${price.id}`); + if (!price.recurring) { + throw new Error( + `Price data is missing recurring info for price: ${price.id}`, + ); } - const data: InsertPrice = { + const planId = + typeof price.product === "string" ? price.product : price.product.id; + + const insertData: InsertPrice = { active: price.active, currency: price.currency, id: price.id, lookupKey: price.lookup_key, + metadata: price.metadata, nickname: price.nickname, - planId: - typeof price.product === "string" ? price.product : price.product.id, + planId, recurring: { interval: price.recurring.interval, intervalCount: price.recurring.interval_count, }, type: price.type, - unitAmount: price.unit_amount, + unitAmount: price.unit_amount ?? 0, }; - return db.insert(Price).values(data).onConflictDoUpdate({ - set: data, - target: Price.id, - }); + return db + .insert(Price) + .values(insertData) + .onConflictDoUpdate({ + set: { + active: insertData.active, + currency: insertData.currency, + lookupKey: insertData.lookupKey, + metadata: insertData.metadata, + nickname: insertData.nickname, + recurring: insertData.recurring, + type: insertData.type, + unitAmount: insertData.unitAmount, + }, + target: Price.id, + }); } async function deletePrice(priceId: string) { diff --git a/packages/auth/src/stripe/subscription-handler.ts b/packages/auth/src/stripe/subscription-handler.ts index 822b4d8..4bdc25d 100644 --- a/packages/auth/src/stripe/subscription-handler.ts +++ b/packages/auth/src/stripe/subscription-handler.ts @@ -3,6 +3,7 @@ import { Subscription } from "@acme/db/schema"; import { eq } from "drizzle-orm"; import type Stripe from "stripe"; import { stripeClient } from "../stripe-client"; +import { createUsagePeriodForSubscription } from "../usage/usage-period-lifecycle"; export async function handleSubscriptionEvents( event: @@ -62,11 +63,34 @@ export async function handleSubscriptionEvents( const [stripeSubscription] = subscriptions.data; if (!stripeSubscription) return; + const firstItem = stripeSubscription.items.data[0]; + if (!firstItem) { + console.error("No subscription items found"); + return; + } + + const periodStart = firstItem.current_period_start + ? new Date(firstItem.current_period_start * 1000) + : undefined; + const periodEnd = firstItem.current_period_end + ? new Date(firstItem.current_period_end * 1000) + : undefined; + await db .update(Subscription) .set({ cancelAtPeriodEnd: stripeSubscription.cancel_at_period_end, + periodEnd, + periodStart, status: stripeSubscription.status, }) .where(eq(Subscription.stripeSubscriptionId, stripeSubscription.id)); + + const updatedSubscription = await db.query.Subscription.findFirst({ + where: eq(Subscription.stripeSubscriptionId, stripeSubscription.id), + }); + + if (updatedSubscription && updatedSubscription.status === "active") { + await createUsagePeriodForSubscription(updatedSubscription); + } } diff --git a/packages/auth/src/usage/usage-period-lifecycle.ts b/packages/auth/src/usage/usage-period-lifecycle.ts new file mode 100644 index 0000000..ce153c4 --- /dev/null +++ b/packages/auth/src/usage/usage-period-lifecycle.ts @@ -0,0 +1,116 @@ +import { db } from "@acme/db/client"; +import { Plan, type Subscription, UsagePeriod } from "@acme/db/schema"; +import { and, eq, gte, isNull } from "drizzle-orm"; + +function get30DayPeriod(startDate: Date = new Date()) { + const periodStart = new Date(startDate); + const periodEnd = new Date(startDate); + periodEnd.setDate(periodEnd.getDate() + 30); + periodEnd.setMilliseconds(periodEnd.getMilliseconds() - 1); + return { periodEnd, periodStart }; +} + +export async function createInitialUsagePeriodForUser(userId: string) { + const freePlan = await db.query.Plan.findFirst({ + where: eq(Plan.name, "free"), + }); + + if (!freePlan) { + console.error("No free plan found, cannot create usage period for user"); + return; + } + + const { periodStart, periodEnd } = get30DayPeriod(); + + await db + .insert(UsagePeriod) + .values({ + period_end: periodEnd.toISOString(), + period_start: periodStart.toISOString(), + plan_id: freePlan.id, + subscription_id: null, + user_id: userId, + }) + .onConflictDoNothing({ + target: [ + UsagePeriod.user_id, + UsagePeriod.period_start, + UsagePeriod.period_end, + ], + }); +} + +export async function createUsagePeriodForSubscription( + subscription: Subscription, +) { + if (!subscription.referenceId) { + console.error( + "Subscription missing referenceId (userId), cannot create usage period", + ); + return; + } + + if (!subscription.periodStart || !subscription.periodEnd) { + console.error( + "Subscription missing period dates, cannot create usage period", + ); + return; + } + + if (!subscription.planName) { + console.error("Subscription missing planName, cannot create usage period"); + return; + } + + const plan = await db.query.Plan.findFirst({ + where: eq(Plan.name, subscription.planName), + }); + + if (!plan) { + console.error( + `Plan ${subscription.planName} not found, cannot create usage period`, + ); + return; + } + + // Handle free-to-pro upgrades: trim overlapping free periods + // Example: User on free (Jan 1-31), upgrades Jan 20 + // Result: Free period (Jan 1-19), Pro period (Jan 20-Feb 20) + const overlappingPeriods = await db.query.UsagePeriod.findMany({ + where: and( + eq(UsagePeriod.user_id, subscription.referenceId), + isNull(UsagePeriod.subscription_id), + gte(UsagePeriod.period_end, subscription.periodStart.toISOString()), + ), + }); + + // Trim each overlapping free period to end 1ms before pro period starts + // Preserves usage data and aggregates tied to the free period + for (const period of overlappingPeriods) { + const newEndDate = new Date(subscription.periodStart.getTime() - 1); + await db + .update(UsagePeriod) + .set({ + period_end: newEndDate.toISOString(), + }) + .where(eq(UsagePeriod.id, period.id)); + } + + // Create new pro period linked to subscription + await db + .insert(UsagePeriod) + .values({ + period_end: subscription.periodEnd.toISOString(), + period_start: subscription.periodStart.toISOString(), + plan_id: plan.id, + subscription_id: subscription.id, + user_id: subscription.referenceId, + }) + .onConflictDoNothing({ + target: [ + UsagePeriod.user_id, + UsagePeriod.period_start, + UsagePeriod.period_end, + ], + }); +} diff --git a/packages/db/src/billing/plan.schema.ts b/packages/db/src/billing/plan.schema.ts index bfdb8c9..443ed62 100644 --- a/packages/db/src/billing/plan.schema.ts +++ b/packages/db/src/billing/plan.schema.ts @@ -3,6 +3,7 @@ import { boolean, index, integer, + jsonb, pgTable, text, timestamp, @@ -22,11 +23,13 @@ export const Plan = pgTable( description: varchar("description", { length: TEXT_LIMITS.DESCRIPTION }), active: boolean("active").default(true).notNull(), quota: integer("quota").notNull(), + metadata: jsonb("metadata").$type>().default({}), created_at: timestamp("created_at") .$defaultFn(() => /* @__PURE__ */ new Date()) .notNull(), updated_at: timestamp("updated_at") .$defaultFn(() => /* @__PURE__ */ new Date()) + .$onUpdateFn(() => sql`now()`) .notNull(), }, (t) => [index("plan_name_lower").on(sql`lower(${t.name})`)], diff --git a/packages/db/src/billing/price.schema.ts b/packages/db/src/billing/price.schema.ts index 4a0d65d..5397ce3 100644 --- a/packages/db/src/billing/price.schema.ts +++ b/packages/db/src/billing/price.schema.ts @@ -1,4 +1,4 @@ -import { relations } from "drizzle-orm"; +import { relations, sql } from "drizzle-orm"; import { boolean, integer, @@ -33,11 +33,13 @@ export const Price = pgTable( }).notNull(), active: boolean("active").default(true).notNull(), lookupKey: varchar("lookup_key", { length: TEXT_LIMITS.LOOKUP_KEY }), + metadata: jsonb("metadata").$type>().default({}), createdAt: timestamp("created_at") .$defaultFn(() => /* @__PURE__ */ new Date()) .notNull(), updatedAt: timestamp("updated_at") .$defaultFn(() => /* @__PURE__ */ new Date()) + .$onUpdateFn(() => sql`now()`) .notNull(), }, (t) => [unique("price_plan_id_active").on(t.planId, t.active)], diff --git a/packages/db/src/schema.ts b/packages/db/src/schema.ts index 44109e4..16b8e47 100644 --- a/packages/db/src/schema.ts +++ b/packages/db/src/schema.ts @@ -12,4 +12,7 @@ export * from "./core/document-embedding.schema.js"; export * from "./core/document-embedding-task.schema.js"; export * from "./core/journal-entry.schema.js"; export * from "./core/page.schema.js"; +export * from "./usage/model-pricing.schema.js"; +export * from "./usage/usage-aggregate.schema.js"; export * from "./usage/usage-event.schema.js"; +export * from "./usage/usage-period.schema.js"; diff --git a/packages/db/src/usage/model-pricing.schema.ts b/packages/db/src/usage/model-pricing.schema.ts new file mode 100644 index 0000000..7c74151 --- /dev/null +++ b/packages/db/src/usage/model-pricing.schema.ts @@ -0,0 +1,56 @@ +import { sql } from "drizzle-orm"; +import { decimal, index, pgTable, unique, varchar } from "drizzle-orm/pg-core"; +import { createInsertSchema, createSelectSchema } from "drizzle-zod"; +import { TEXT_LIMITS } from "../constants/resource-limits.js"; + +export const ModelPricing = pgTable( + "model_pricing", + (t) => ({ + id: t.uuid().notNull().primaryKey().defaultRandom(), + model_id: varchar("model_id", { length: TEXT_LIMITS.MODEL_ID }).notNull(), + model_provider: varchar("model_provider", { + length: TEXT_LIMITS.MODEL_PROVIDER, + }).notNull(), + unit_type: varchar("unit_type", { length: 50 }).notNull(), // e.g., "input_tokens", "output_tokens", "reasoning_tokens", "requests" + price_per_unit: decimal("price_per_unit", { + precision: 12, + scale: 8, + }).notNull(), // High precision for token pricing (e.g., $0.00000150 per token) + effective_date: t + .timestamp({ mode: "string", withTimezone: true }) + .notNull() + .defaultNow(), + created_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull(), + updated_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull() + .$onUpdateFn(() => sql`now()`), + }), + (t) => [ + // Ensure unique pricing per model, provider, unit type, and effective date + unique("model_pricing_unique").on( + t.model_id, + t.model_provider, + t.unit_type, + t.effective_date, + ), + // Optimize queries by model and provider + index("model_pricing_model_provider").on(t.model_id, t.model_provider), + // Optimize queries by effective date for getting current pricing + index("model_pricing_effective_date").on(t.effective_date), + ], +); + +export type ModelPricing = typeof ModelPricing.$inferSelect; + +export const zInsertModelPricing = createInsertSchema(ModelPricing).omit({ + created_at: true, + id: true, + updated_at: true, +}); + +export const zModelPricing = createSelectSchema(ModelPricing); diff --git a/packages/db/src/usage/usage-aggregate.schema.ts b/packages/db/src/usage/usage-aggregate.schema.ts new file mode 100644 index 0000000..18b97ac --- /dev/null +++ b/packages/db/src/usage/usage-aggregate.schema.ts @@ -0,0 +1,53 @@ +import { sql } from "drizzle-orm"; +import { decimal, index, pgTable, text, unique } from "drizzle-orm/pg-core"; +import { createInsertSchema, createSelectSchema } from "drizzle-zod"; +import { user } from "../auth/user.schema.js"; +import { UsagePeriod } from "./usage-period.schema.js"; + +export const UsageAggregate = pgTable( + "usage_aggregate", + (t) => ({ + id: t.uuid().notNull().primaryKey().defaultRandom(), + user_id: text() + .notNull() + .references(() => user.id), + usage_period_id: t + .uuid() + .notNull() + .references(() => UsagePeriod.id), + + // Aggregated cost for the entire period + total_cost: decimal("total_cost", { precision: 10, scale: 6 }) + .notNull() + .default("0"), + + created_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull(), + updated_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull() + .$onUpdateFn(() => sql`now()`), + }), + (t) => [ + // One aggregate per user per period + unique("usage_aggregate_user_period_unique").on( + t.user_id, + t.usage_period_id, + ), + // Optimize queries by user and period + index("usage_aggregate_user_period").on(t.user_id, t.usage_period_id), + ], +); + +export type UsageAggregate = typeof UsageAggregate.$inferSelect; + +export const zInsertUsageAggregate = createInsertSchema(UsageAggregate).omit({ + created_at: true, + id: true, + updated_at: true, +}); + +export const zUsageAggregate = createSelectSchema(UsageAggregate); diff --git a/packages/db/src/usage/usage-event.schema.ts b/packages/db/src/usage/usage-event.schema.ts index 16bb374..8ffec8c 100644 --- a/packages/db/src/usage/usage-event.schema.ts +++ b/packages/db/src/usage/usage-event.schema.ts @@ -1,14 +1,15 @@ import { sql } from "drizzle-orm"; import { check, + decimal, jsonb, pgEnum, pgTable, text, - timestamp, varchar, } from "drizzle-orm/pg-core"; import { createInsertSchema, createSelectSchema } from "drizzle-zod"; +import { z } from "zod/v4"; import { JSONB_LIMITS, TEXT_LIMITS } from "../constants/resource-limits.js"; import { user } from "../schema.js"; @@ -39,10 +40,18 @@ export const UsageEvent = pgTable( >() .notNull(), status: UsageEventStatus().notNull().default("pending"), - created_at: timestamp().defaultNow(), - updated_at: timestamp() + total_cost: decimal("total_cost", { precision: 10, scale: 6 }) + .notNull() + .default("0"), + created_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull(), + updated_at: t + .timestamp({ mode: "string", withTimezone: true }) .defaultNow() - .$onUpdateFn(() => new Date()), + .notNull() + .$onUpdateFn(() => sql`now()`), }), (t) => [ // Resource protection constraints for JSONB fields @@ -66,3 +75,13 @@ export const zInsertUsageEvent = createInsertSchema(UsageEvent).omit({ }); export const zUsageEvent = createSelectSchema(UsageEvent); + +// Webhook-compatible schema that accepts string timestamps and handles decimal fields +export const zUsageEventWebhook = zUsageEvent.extend({ + created_at: z.string().optional(), + updated_at: z.string().optional(), + total_cost: z + .union([z.string(), z.number()]) + .optional() + .transform((val) => (val !== undefined ? String(val) : "0")), +}); diff --git a/packages/db/src/usage/usage-period.schema.ts b/packages/db/src/usage/usage-period.schema.ts new file mode 100644 index 0000000..5b9bb9b --- /dev/null +++ b/packages/db/src/usage/usage-period.schema.ts @@ -0,0 +1,75 @@ +import { relations, sql } from "drizzle-orm"; +import { index, pgTable, text, unique } from "drizzle-orm/pg-core"; +import { createInsertSchema, createSelectSchema } from "drizzle-zod"; +import { user } from "../auth/user.schema.js"; +import { Plan } from "../billing/plan.schema.js"; +import { Subscription } from "../billing/subscription.schema.js"; +import { UsageAggregate } from "./usage-aggregate.schema.js"; + +export const UsagePeriod = pgTable( + "usage_period", + (t) => ({ + id: t.uuid().notNull().primaryKey().defaultRandom(), + user_id: text() + .notNull() + .references(() => user.id), + plan_id: text().references(() => Plan.id), + subscription_id: text().references(() => Subscription.id), + period_start: t.timestamp({ mode: "string", withTimezone: true }).notNull(), + period_end: t.timestamp({ mode: "string", withTimezone: true }).notNull(), + created_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull(), + updated_at: t + .timestamp({ mode: "string", withTimezone: true }) + .defaultNow() + .notNull() + .$onUpdateFn(() => sql`now()`), + }), + (t) => [ + // Prevent duplicate periods for the same user and timeframe + unique("usage_period_user_period").on( + t.user_id, + t.period_start, + t.period_end, + ), + // Optimize queries by user and date ranges + index("usage_period_user_dates").on( + t.user_id, + t.period_start, + t.period_end, + ), + // Optimize subscription-based queries + index("usage_period_subscription").on(t.subscription_id), + ], +); + +export const UsagePeriodRelations = relations(UsagePeriod, ({ one }) => ({ + plan: one(Plan, { + fields: [UsagePeriod.plan_id], + references: [Plan.id], + }), + subscription: one(Subscription, { + fields: [UsagePeriod.subscription_id], + references: [Subscription.id], + }), + user: one(user, { + fields: [UsagePeriod.user_id], + references: [user.id], + }), + usageAggregate: one(UsageAggregate, { + fields: [UsagePeriod.id, UsagePeriod.user_id], + references: [UsageAggregate.usage_period_id, UsageAggregate.user_id], + }), +})); + +export type UsagePeriod = typeof UsagePeriod.$inferSelect; + +export const zInsertUsagePeriod = createInsertSchema(UsagePeriod).omit({ + created_at: true, + id: true, + updated_at: true, +}); + +export const zUsagePeriod = createSelectSchema(UsagePeriod);