@@ -7,13 +7,15 @@ import {
77 UsagePeriod ,
88} from "@acme/db/schema" ;
99import { TRPCError , type TRPCRouterRecord } from "@trpc/server" ;
10+ import { gte } from "drizzle-orm" ;
1011import { z } from "zod/v4" ;
11- import { getActiveSubscription , getFreePlan } from "../shared/subscription.js" ;
1212import type { TRPCContext } from "../trpc.js" ;
1313import { publicProcedure } from "../trpc.js" ;
1414
1515/**
16- * Get or create the current usage period for a user
16+ * Get the current usage period aggregate for a user
17+ * Returns a complete aggregate with period, plan, and subscription data
18+ * Periods are created by webhooks (pro) and cron jobs (free)
1719 */
1820async function getCurrentUsagePeriod ( {
1921 ctx,
@@ -22,156 +24,63 @@ async function getCurrentUsagePeriod({
2224 ctx : TRPCContext ;
2325 userId : string ;
2426} ) {
25- return await ctx . db . transaction ( async ( tx ) => {
26- // First, check if user has an active subscription
27- const activeSubscription = await getActiveSubscription ( {
28- ctx,
29- userId,
30- } ) ;
27+ const now = new Date ( ) . toISOString ( ) ;
3128
32- if ( activeSubscription ) {
33- // Validate subscription has required fields
34- if (
35- ! activeSubscription . periodStart ||
36- ! activeSubscription . periodEnd ||
37- ! activeSubscription . plan
38- ) {
39- throw new TRPCError ( {
40- code : "INTERNAL_SERVER_ERROR" ,
41- message : "Invalid subscription data" ,
42- } ) ;
43- }
44-
45- // Pro user: try to find existing usage period for current subscription period
46- let usagePeriod = await tx . query . UsagePeriod . findFirst ( {
47- where : and (
48- eq ( UsagePeriod . user_id , userId ) ,
49- eq ( UsagePeriod . subscription_id , activeSubscription . id ) ,
50- eq ( UsagePeriod . period_start , activeSubscription . periodStart ) ,
51- eq ( UsagePeriod . period_end , activeSubscription . periodEnd ) ,
52- ) ,
53- } ) ;
54-
55- if ( ! usagePeriod ) {
56- // Create usage period for subscription period
57- const [ newUsagePeriod ] = await tx
58- . insert ( UsagePeriod )
59- . values ( {
60- period_end : activeSubscription . periodEnd ,
61- period_start : activeSubscription . periodStart ,
62- plan_id : activeSubscription . plan . id ,
63- subscription_id : activeSubscription . id ,
64- user_id : userId ,
65- } )
66- . returning ( ) ;
67-
68- usagePeriod = newUsagePeriod ;
69- }
70-
71- if ( ! usagePeriod ) {
72- throw new TRPCError ( {
73- code : "INTERNAL_SERVER_ERROR" ,
74- message : "Failed to create or retrieve usage period for subscription" ,
75- } ) ;
76- }
77-
78- return usagePeriod ;
79- } else {
80- // For free users, we will use a monthly usage period (1st of each month) for simplicity
81- const now = new Date ( ) ;
82- const monthStart = new Date ( now . getFullYear ( ) , now . getMonth ( ) , 1 ) ;
83- // The end of the month is the last day of the month at 23:59:59.999
84- const monthEnd = new Date (
85- now . getFullYear ( ) ,
86- now . getMonth ( ) + 1 ,
87- 0 ,
88- 23 ,
89- 59 ,
90- 59 ,
91- 999 ,
92- ) ;
93-
94- // Try to find existing usage period for current month
95- let usagePeriod = await tx . query . UsagePeriod . findFirst ( {
96- where : and (
97- eq ( UsagePeriod . user_id , userId ) ,
98- eq ( UsagePeriod . period_start , monthStart ) ,
99- eq ( UsagePeriod . period_end , monthEnd ) ,
100- ) ,
101- } ) ;
102-
103- if ( ! usagePeriod ) {
104- // Get the free plan for free users
105- const freePlan = await getFreePlan ( ctx ) ;
106-
107- // Create usage period for free user
108- const [ newUsagePeriod ] = await tx
109- . insert ( UsagePeriod )
110- . values ( {
111- period_end : monthEnd ,
112- period_start : monthStart ,
113- plan_id : freePlan ?. id ,
114- subscription_id : null , // Free users don't have a subscription_id
115- user_id : userId ,
116- } )
117- . returning ( ) ;
118-
119- usagePeriod = newUsagePeriod ;
120- }
29+ const usagePeriodAggregate = await ctx . db . query . UsagePeriod . findFirst ( {
30+ orderBy : ( fields , { desc } ) => [
31+ desc ( fields . subscription_id ) ,
32+ desc ( fields . created_at ) ,
33+ ] ,
34+ where : and (
35+ eq ( UsagePeriod . user_id , userId ) ,
36+ lte ( UsagePeriod . period_start , now ) ,
37+ gte ( UsagePeriod . period_end , now ) ,
38+ ) ,
39+ with : {
40+ plan : true ,
41+ subscription : {
42+ with : {
43+ plan : true ,
44+ } ,
45+ } ,
46+ usageAggregate : true ,
47+ } ,
48+ } ) ;
12149
122- if ( ! usagePeriod ) {
123- throw new TRPCError ( {
124- code : "INTERNAL_SERVER_ERROR" ,
125- message : "Failed to create or retrieve usage period for free user" ,
126- } ) ;
127- }
50+ if ( ! usagePeriodAggregate ) {
51+ throw new TRPCError ( {
52+ code : "PRECONDITION_FAILED" ,
53+ message :
54+ "No active usage period found. Usage periods are created automatically via webhooks and scheduled jobs." ,
55+ } ) ;
56+ }
12857
129- return usagePeriod ;
130- }
131- } ) ;
58+ return usagePeriodAggregate ;
13259}
13360
13461export const usageRouter = {
13562 checkUsage : publicProcedure
13663 . input ( z . object ( { user_id : z . string ( ) } ) )
13764 . query ( async ( { ctx, input } ) => {
13865 try {
139- const activeSubscription = await getActiveSubscription ( {
66+ const usagePeriodAggregate = await getCurrentUsagePeriod ( {
14067 ctx,
14168 userId : input . user_id ,
14269 } ) ;
14370
144- const usagePeriod = await getCurrentUsagePeriod ( {
145- ctx,
146- userId : input . user_id ,
147- } ) ;
148-
149- let quotaUsd : number ;
150- if ( activeSubscription ?. plan ) {
151- quotaUsd = activeSubscription . plan . quota / 100 ;
152- } else {
153- const freePlan = await getFreePlan ( ctx ) ;
154- if ( ! freePlan ) {
155- throw new TRPCError ( {
156- code : "INTERNAL_SERVER_ERROR" ,
157- message : "No free plan found" ,
158- } ) ;
159- }
160- quotaUsd = freePlan . quota / 100 ; // Convert cents to dollars
71+ const { plan } = usagePeriodAggregate ;
72+ if ( ! plan ) {
73+ throw new TRPCError ( {
74+ code : "INTERNAL_SERVER_ERROR" ,
75+ message : "Usage period missing plan data" ,
76+ } ) ;
16177 }
16278
163- // Get current usage for the period
164- let currentUsageUsd = 0 ;
165- const usageAggregate = await ctx . db . query . UsageAggregate . findFirst ( {
166- where : and (
167- eq ( UsageAggregate . user_id , input . user_id ) ,
168- eq ( UsageAggregate . usage_period_id , usagePeriod . id ) ,
169- ) ,
170- } ) ;
79+ const quotaUsd = plan . quota / 100 ;
17180
172- if ( usageAggregate ) {
173- currentUsageUsd = Number . parseFloat ( usageAggregate . total_cost ) ;
174- }
81+ const currentUsageUsd = usagePeriodAggregate . usageAggregate
82+ ? Number . parseFloat ( usagePeriodAggregate . usageAggregate . total_cost )
83+ : 0 ;
17584
17685 const canUse = currentUsageUsd < quotaUsd ;
17786 const remainingQuotaUsd = quotaUsd - currentUsageUsd ;
@@ -181,8 +90,8 @@ export const usageRouter = {
18190 currentUsageUsd,
18291 quotaUsd,
18392 remainingQuotaUsd : Math . max ( 0 , remainingQuotaUsd ) ,
184- subscriptionType : activeSubscription ? "pro" : "free" ,
185- usagePeriodId : usagePeriod . id ,
93+ subscriptionType : usagePeriodAggregate . subscription ? "pro" : "free" ,
94+ usagePeriodId : usagePeriodAggregate . id ,
18695 } ;
18796 } catch ( error ) {
18897 console . error ( "Error checking usage:" , error ) ;
@@ -226,34 +135,40 @@ export const usageRouter = {
226135 return { message : "Already processed" , success : true } ;
227136 }
228137
229- // Get or create the current usage period
230138 const usagePeriod = await getCurrentUsagePeriod ( {
231139 ctx,
232140 userId : input . user_id ,
233141 } ) ;
234142
235- // Calculate total cost for this usage event
236- let totalCost = 0 ;
143+ const eventDate = usageEvent . created_at
144+ ? usageEvent . created_at
145+ : new Date ( ) . toISOString ( ) ;
146+
147+ // Fetch all pricing for this model upfront
148+ const allPricing = await tx
149+ . select ( )
150+ . from ( ModelPricing )
151+ . where (
152+ and (
153+ eq ( ModelPricing . model_id , usageEvent . model_id ) ,
154+ eq ( ModelPricing . model_provider , usageEvent . model_provider ) ,
155+ lte ( ModelPricing . effective_date , eventDate ) ,
156+ ) ,
157+ )
158+ . orderBy ( desc ( ModelPricing . effective_date ) ) ;
159+
160+ // Create lookup map
161+ const pricingMap = new Map < string , ( typeof allPricing ) [ 0 ] > ( ) ;
162+ for ( const price of allPricing ) {
163+ if ( ! pricingMap . has ( price . unit_type ) ) {
164+ pricingMap . set ( price . unit_type , price ) ;
165+ }
166+ }
237167
168+ // Calculate total cost using in-memory pricing lookups
169+ let totalCost = 0 ;
238170 for ( const metric of usageEvent . metrics ) {
239- // Get current pricing for this model and unit type
240- const eventDate = usageEvent . created_at
241- ? new Date ( usageEvent . created_at )
242- : new Date ( ) ;
243- const pricing = await tx
244- . select ( )
245- . from ( ModelPricing )
246- . where (
247- and (
248- eq ( ModelPricing . model_id , usageEvent . model_id ) ,
249- eq ( ModelPricing . model_provider , usageEvent . model_provider ) ,
250- eq ( ModelPricing . unit_type , metric . unit ) ,
251- lte ( ModelPricing . effective_date , eventDate ) ,
252- ) ,
253- )
254- . orderBy ( desc ( ModelPricing . effective_date ) )
255- . limit ( 1 )
256- . then ( ( results ) => results [ 0 ] || null ) ;
171+ const pricing = pricingMap . get ( metric . unit ) ;
257172
258173 if ( pricing ) {
259174 const cost =
@@ -277,7 +192,6 @@ export const usageRouter = {
277192 . onConflictDoUpdate ( {
278193 set : {
279194 total_cost : sql `${ UsageAggregate . total_cost } + ${ totalCost . toFixed ( 6 ) } ` ,
280- updated_at : new Date ( ) ,
281195 } ,
282196 target : [ UsageAggregate . user_id , UsageAggregate . usage_period_id ] ,
283197 } ) ;
0 commit comments