@@ -6,12 +6,17 @@ import (
66 "time"
77
88 "github.com/alpacahq/alpacadecimal"
9+ "github.com/openmeterio/openmeter/api"
910 "github.com/openmeterio/openmeter/openmeter/billing"
1011 "github.com/openmeterio/openmeter/openmeter/ent/db"
12+ "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice"
13+ "github.com/openmeterio/openmeter/pkg/clock"
1114 "github.com/openmeterio/openmeter/pkg/convert"
1215 "github.com/openmeterio/openmeter/pkg/framework/entutils"
1316 "github.com/openmeterio/openmeter/pkg/models"
17+ "github.com/openmeterio/openmeter/pkg/pagination"
1418 "github.com/openmeterio/openmeter/pkg/slicesx"
19+ "github.com/openmeterio/openmeter/pkg/sortx"
1520 "github.com/openmeterio/openmeter/pkg/timeutil"
1621 "github.com/samber/lo"
1722)
@@ -88,11 +93,265 @@ func (a *adapter) CreateGatheringInvoice(ctx context.Context, input billing.Crea
8893 // Let's add required edges for mapping
8994 newInvoice .Edges .BillingWorkflowConfig = clonedWorkflowConfig
9095
91- return tx .mapGatheringInvoiceFromDB (ctx , newInvoice , billing .InvoiceExpandAll )
96+ return tx .mapGatheringInvoiceFromDB (ctx , newInvoice , billing.GatheringInvoiceExpands {} )
9297 })
9398}
9499
95- func (a * adapter ) mapGatheringInvoiceFromDB (ctx context.Context , invoice * db.BillingInvoice , expand billing.InvoiceExpand ) (billing.GatheringInvoice , error ) {
100+ func (a * adapter ) UpdateGatheringInvoice (ctx context.Context , in billing.GatheringInvoice ) error {
101+ if err := in .Validate (); err != nil {
102+ return fmt .Errorf ("validating gathering invoice: %w" , err )
103+ }
104+
105+ return entutils .TransactingRepoWithNoValue (ctx , a , func (ctx context.Context , tx * adapter ) error {
106+ existingInvoice , err := tx .db .BillingInvoice .Query ().
107+ Where (billinginvoice .ID (in .ID )).
108+ Where (billinginvoice .Namespace (in .Namespace )).
109+ Only (ctx )
110+ if err != nil {
111+ return err
112+ }
113+
114+ if err := tx .validateUpdateGatheringInvoiceRequest (in , existingInvoice ); err != nil {
115+ return err
116+ }
117+
118+ updateQuery := tx .db .BillingInvoice .UpdateOneID (in .ID ).
119+ Where (billinginvoice .Namespace (in .Namespace )).
120+ SetMetadata (in .Metadata ).
121+ // Currency is immutable
122+ SetStatus (billing .StandardInvoiceStatusGathering ).
123+ ClearStatusDetailsCache ().
124+ // Type is immutable
125+ SetNumber (in .Number ).
126+ SetOrClearDescription (in .Description ).
127+ ClearDueAt ().
128+ SetCollectionAt (in .NextCollectionAt .In (time .UTC )).
129+ ClearPaymentProcessingEnteredAt ().
130+ ClearDraftUntil ().
131+ ClearIssuedAt ().
132+ ClearDeletedAt ().
133+ ClearSentToCustomerAt ().
134+ ClearQuantitySnapshotedAt ().
135+ // Totals
136+ SetAmount (alpacadecimal .Zero ).
137+ SetChargesTotal (alpacadecimal .Zero ).
138+ SetDiscountsTotal (alpacadecimal .Zero ).
139+ SetTaxesTotal (alpacadecimal .Zero ).
140+ SetTaxesExclusiveTotal (alpacadecimal .Zero ).
141+ SetTaxesInclusiveTotal (alpacadecimal .Zero ).
142+ SetTotal (alpacadecimal .Zero )
143+
144+ updateQuery = updateQuery .
145+ SetPeriodStart (in .ServicePeriod .From .In (time .UTC )).
146+ SetPeriodEnd (in .ServicePeriod .To .In (time .UTC ))
147+
148+ // Supplier
149+ updateQuery = updateQuery .
150+ SetSupplierName ("UNSET" ). // Hack until we split the invoices table
151+ ClearSupplierAddressCountry ().
152+ ClearSupplierAddressPostalCode ().
153+ ClearSupplierAddressCity ().
154+ ClearSupplierAddressState ().
155+ ClearSupplierAddressLine1 ().
156+ ClearSupplierAddressLine2 ().
157+ ClearSupplierAddressPhoneNumber ()
158+
159+ // Customer
160+ updateQuery = updateQuery .
161+ // CustomerID is immutable
162+ SetCustomerName ("UNSET" ). // hack until we split the invoices table
163+ ClearCustomerKey ()
164+
165+ updateQuery = updateQuery .
166+ ClearCustomerAddressCountry ().
167+ ClearCustomerAddressPostalCode ().
168+ ClearCustomerAddressCity ().
169+ ClearCustomerAddressState ().
170+ ClearCustomerAddressLine1 ().
171+ ClearCustomerAddressLine2 ().
172+ ClearCustomerAddressPhoneNumber ()
173+
174+ // ExternalIDs
175+ updateQuery = updateQuery .
176+ ClearInvoicingAppExternalID ().
177+ ClearPaymentAppExternalID ()
178+
179+ _ , err = updateQuery .Save (ctx )
180+ if err != nil {
181+ return err
182+ }
183+
184+ if in .Lines .IsPresent () {
185+ err := a .updateGatheringLines (ctx , in .Lines .OrEmpty ())
186+ if err != nil {
187+ return err
188+ }
189+ }
190+
191+ return nil
192+ })
193+ }
194+
195+ func (a * adapter ) ListGatheringInvoices (ctx context.Context , input billing.ListGatheringInvoicesInput ) (pagination.Result [billing.GatheringInvoice ], error ) {
196+ if err := input .Validate (); err != nil {
197+ return pagination.Result [billing.GatheringInvoice ]{}, err
198+ }
199+
200+ return entutils .TransactingRepo (ctx , a , func (ctx context.Context , tx * adapter ) (pagination.Result [billing.GatheringInvoice ], error ) {
201+ query := tx .db .BillingInvoice .Query ().
202+ Where (billinginvoice .NamespaceIn (input .Namespaces ... ))
203+
204+ if len (input .Customers ) > 0 {
205+ query = query .Where (billinginvoice .CustomerIDIn (input .Customers ... ))
206+ }
207+
208+ if len (input .Currencies ) > 0 {
209+ query = query .Where (billinginvoice .CurrencyIn (input .Currencies ... ))
210+ }
211+
212+ order := entutils .GetOrdering (sortx .OrderDefault )
213+ if ! input .Order .IsDefaultValue () {
214+ order = entutils .GetOrdering (input .Order )
215+ }
216+
217+ if input .Expand .Has (billing .GatheringInvoiceExpandLines ) {
218+ query = query .WithBillingInvoiceLines (func (q * db.BillingInvoiceLineQuery ) {
219+ q .WithUsageBasedLine ()
220+ })
221+ }
222+
223+ switch input .OrderBy {
224+ case api .InvoiceOrderByCustomerName :
225+ query = query .Order (billinginvoice .ByCustomerName (order ... ))
226+ case api .InvoiceOrderByIssuedAt :
227+ query = query .Order (billinginvoice .ByIssuedAt (order ... ))
228+ case api .InvoiceOrderByPeriodStart :
229+ query = query .Order (billinginvoice .ByPeriodStart (order ... ))
230+ case api .InvoiceOrderByStatus :
231+ query = query .Order (billinginvoice .ByStatus (order ... ))
232+ case api .InvoiceOrderByUpdatedAt :
233+ query = query .Order (billinginvoice .ByUpdatedAt (order ... ))
234+ case api .InvoiceOrderByCreatedAt :
235+ fallthrough
236+ default :
237+ query = query .Order (billinginvoice .ByCreatedAt (order ... ))
238+ }
239+
240+ if ! input .IncludeDeleted {
241+ query = query .Where (billinginvoice .DeletedAtIsNil ())
242+ }
243+
244+ response := pagination.Result [billing.GatheringInvoice ]{
245+ Page : input .Page ,
246+ }
247+
248+ paged , err := query .Paginate (ctx , input .Page )
249+ if err != nil {
250+ return response , err
251+ }
252+
253+ result := make ([]billing.GatheringInvoice , 0 , len (paged .Items ))
254+ for _ , invoice := range paged .Items {
255+ mapped , err := tx .mapGatheringInvoiceFromDB (ctx , invoice , input .Expand )
256+ if err != nil {
257+ return response , err
258+ }
259+
260+ result = append (result , mapped )
261+ }
262+
263+ response .TotalCount = paged .TotalCount
264+ response .Items = result
265+
266+ return response , nil
267+ })
268+ }
269+
270+ func (a * adapter ) validateUpdateGatheringInvoiceRequest (req billing.GatheringInvoice , existing * db.BillingInvoice ) error {
271+ if req .Currency != existing .Currency {
272+ return billing.ValidationError {
273+ Err : fmt .Errorf ("currency cannot be changed" ),
274+ }
275+ }
276+
277+ if billing .InvoiceTypeStandard != existing .Type {
278+ return billing.ValidationError {
279+ Err : fmt .Errorf ("type cannot be changed" ),
280+ }
281+ }
282+
283+ if req .CustomerID != existing .CustomerID {
284+ return billing.ValidationError {
285+ Err : fmt .Errorf ("customer cannot be changed" ),
286+ }
287+ }
288+
289+ return nil
290+ }
291+
292+ func (a * adapter ) DeleteGatheringInvoice (ctx context.Context , input billing.DeleteGatheringInvoiceAdapterInput ) error {
293+ if err := input .Validate (); err != nil {
294+ return fmt .Errorf ("validating delete gathering invoice input: %w" , err )
295+ }
296+
297+ return entutils .TransactingRepoWithNoValue (ctx , a , func (ctx context.Context , tx * adapter ) error {
298+ invoice , err := tx .db .BillingInvoice .Query ().
299+ Where (billinginvoice .ID (input .ID )).
300+ Where (billinginvoice .Namespace (input .Namespace )).
301+ Only (ctx )
302+ if err != nil {
303+ return err
304+ }
305+
306+ if invoice .Status != billing .StandardInvoiceStatusGathering {
307+ return billing.ValidationError {
308+ Err : fmt .Errorf ("invoice is not a gathering invoice [id=%s]" , invoice .ID ),
309+ }
310+ }
311+
312+ if invoice .DeletedAt != nil {
313+ return nil
314+ }
315+
316+ _ , err = tx .db .BillingInvoice .Update ().
317+ Where (billinginvoice .ID (input .ID )).
318+ Where (billinginvoice .Namespace (input .Namespace )).
319+ SetDeletedAt (clock .Now ()).
320+ Save (ctx )
321+ if err != nil {
322+ return err
323+ }
324+
325+ return nil
326+ })
327+ }
328+
329+ func (a * adapter ) GetGatheringInvoiceById (ctx context.Context , input billing.GetGatheringInvoiceByIdInput ) (billing.GatheringInvoice , error ) {
330+ if err := input .Validate (); err != nil {
331+ return billing.GatheringInvoice {}, fmt .Errorf ("validating get gathering invoice by id input: %w" , err )
332+ }
333+
334+ return entutils .TransactingRepo (ctx , a , func (ctx context.Context , tx * adapter ) (billing.GatheringInvoice , error ) {
335+ query := tx .db .BillingInvoice .Query ().
336+ Where (billinginvoice .ID (input .Invoice .ID )).
337+ Where (billinginvoice .Namespace (input .Invoice .Namespace ))
338+
339+ if input .Expand .Has (billing .GatheringInvoiceExpandLines ) {
340+ query = query .WithBillingInvoiceLines (func (q * db.BillingInvoiceLineQuery ) {
341+ q .WithUsageBasedLine ()
342+ })
343+ }
344+
345+ invoice , err := query .Only (ctx )
346+ if err != nil {
347+ return billing.GatheringInvoice {}, err
348+ }
349+
350+ return tx .mapGatheringInvoiceFromDB (ctx , invoice , input .Expand )
351+ })
352+ }
353+
354+ func (a * adapter ) mapGatheringInvoiceFromDB (ctx context.Context , invoice * db.BillingInvoice , expand billing.GatheringInvoiceExpands ) (billing.GatheringInvoice , error ) {
96355 if invoice .Status != billing .StandardInvoiceStatusGathering {
97356 return billing.GatheringInvoice {}, fmt .Errorf ("invoice is not a gathering invoice [id=%s]" , invoice .ID )
98357 }
@@ -132,7 +391,7 @@ func (a *adapter) mapGatheringInvoiceFromDB(ctx context.Context, invoice *db.Bil
132391 },
133392 }
134393
135- if expand .Lines {
394+ if expand .Has ( billing . GatheringInvoiceExpandLines ) {
136395 mappedLines , err := a .mapGatheringInvoiceLinesFromDB (invoice .SchemaLevel , invoice .Edges .BillingInvoiceLines )
137396 if err != nil {
138397 return billing.GatheringInvoice {}, err
@@ -167,39 +426,41 @@ func (a *adapter) mapGatheringInvoiceLineFromDB(schemaLevel int, dbLine *db.Bill
167426 }
168427
169428 line := billing.GatheringLine {
170- ManagedResource : models .NewManagedResource (models.ManagedResourceInput {
171- Namespace : dbLine .Namespace ,
172- ID : dbLine .ID ,
173- CreatedAt : dbLine .CreatedAt .In (time .UTC ),
174- UpdatedAt : dbLine .UpdatedAt .In (time .UTC ),
175- DeletedAt : convert .TimePtrIn (dbLine .DeletedAt , time .UTC ),
176- Name : dbLine .Name ,
177- Description : dbLine .Description ,
178- }),
179-
180- Metadata : dbLine .Metadata ,
181- Annotations : dbLine .Annotations ,
182- InvoiceID : dbLine .InvoiceID ,
183- ManagedBy : dbLine .ManagedBy ,
184-
185- ServicePeriod : timeutil.ClosedPeriod {
186- From : dbLine .PeriodStart .In (time .UTC ),
187- To : dbLine .PeriodEnd .In (time .UTC ),
188- },
429+ GatheringLineBase : billing.GatheringLineBase {
430+ ManagedResource : models .NewManagedResource (models.ManagedResourceInput {
431+ Namespace : dbLine .Namespace ,
432+ ID : dbLine .ID ,
433+ CreatedAt : dbLine .CreatedAt .In (time .UTC ),
434+ UpdatedAt : dbLine .UpdatedAt .In (time .UTC ),
435+ DeletedAt : convert .TimePtrIn (dbLine .DeletedAt , time .UTC ),
436+ Name : dbLine .Name ,
437+ Description : dbLine .Description ,
438+ }),
439+
440+ Metadata : dbLine .Metadata ,
441+ Annotations : dbLine .Annotations ,
442+ InvoiceID : dbLine .InvoiceID ,
443+ ManagedBy : dbLine .ManagedBy ,
444+
445+ ServicePeriod : timeutil.ClosedPeriod {
446+ From : dbLine .PeriodStart .In (time .UTC ),
447+ To : dbLine .PeriodEnd .In (time .UTC ),
448+ },
189449
190- SplitLineGroupID : dbLine .SplitLineGroupID ,
191- ChildUniqueReferenceID : dbLine .ChildUniqueReferenceID ,
450+ SplitLineGroupID : dbLine .SplitLineGroupID ,
451+ ChildUniqueReferenceID : dbLine .ChildUniqueReferenceID ,
192452
193- InvoiceAt : dbLine .InvoiceAt .In (time .UTC ),
453+ InvoiceAt : dbLine .InvoiceAt .In (time .UTC ),
194454
195- Currency : dbLine .Currency ,
455+ Currency : dbLine .Currency ,
196456
197- TaxConfig : lo .EmptyableToPtr (dbLine .TaxConfig ),
198- RateCardDiscounts : lo .FromPtr (dbLine .RatecardDiscounts ),
457+ TaxConfig : lo .EmptyableToPtr (dbLine .TaxConfig ),
458+ RateCardDiscounts : lo .FromPtr (dbLine .RatecardDiscounts ),
199459
200- UBPConfigID : ubpLine .ID ,
201- FeatureKey : lo .FromPtr (ubpLine .FeatureKey ),
202- Price : lo .FromPtr (ubpLine .Price ),
460+ UBPConfigID : ubpLine .ID ,
461+ FeatureKey : lo .FromPtr (ubpLine .FeatureKey ),
462+ Price : lo .FromPtr (ubpLine .Price ),
463+ },
203464 }
204465
205466 if dbLine .SubscriptionID != nil && dbLine .SubscriptionPhaseID != nil && dbLine .SubscriptionItemID != nil {
0 commit comments