diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 61db1ad9b0..53fe596f3f 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -81,6 +81,8 @@ type GatheringInvoiceAdapter interface { DeleteGatheringInvoice(ctx context.Context, input DeleteGatheringInvoiceAdapterInput) error GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error) ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (pagination.Result[GatheringInvoice], error) + + HardDeleteGatheringInvoiceLines(ctx context.Context, invoiceID InvoiceID, lineIDs []string) error } type InvoiceSplitLineGroupAdapter interface { @@ -88,6 +90,7 @@ type InvoiceSplitLineGroupAdapter interface { UpdateSplitLineGroup(ctx context.Context, input UpdateSplitLineGroupInput) (SplitLineGroup, error) DeleteSplitLineGroup(ctx context.Context, input DeleteSplitLineGroupInput) error GetSplitLineGroup(ctx context.Context, input GetSplitLineGroupInput) (SplitLineHierarchy, error) + GetSplitLineGroupHeaders(ctx context.Context, input GetSplitLineGroupHeadersInput) (SplitLineGroupHeaders, error) } type SequenceAdapter interface { diff --git a/openmeter/billing/adapter/gatheringinvoice.go b/openmeter/billing/adapter/gatheringinvoice.go index c12ed65d7c..52a935f9e8 100644 --- a/openmeter/billing/adapter/gatheringinvoice.go +++ b/openmeter/billing/adapter/gatheringinvoice.go @@ -11,6 +11,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/convert" "github.com/openmeterio/openmeter/pkg/framework/entutils" @@ -121,7 +122,7 @@ func (a *adapter) UpdateGatheringInvoice(ctx context.Context, in billing.Gatheri ClearPaymentProcessingEnteredAt(). ClearDraftUntil(). ClearIssuedAt(). - ClearDeletedAt(). + SetOrClearDeletedAt(convert.SafeToUTC(in.DeletedAt)). ClearSentToCustomerAt(). ClearQuantitySnapshotedAt(). // Totals @@ -139,9 +140,16 @@ func (a *adapter) UpdateGatheringInvoice(ctx context.Context, in billing.Gatheri updateQuery = updateQuery.ClearCollectionAt() } - updateQuery = updateQuery. - SetPeriodStart(in.ServicePeriod.From.In(time.UTC)). - SetPeriodEnd(in.ServicePeriod.To.In(time.UTC)) + // Clear period when the invoice is soft-deleted + if in.DeletedAt != nil { + updateQuery = updateQuery. + ClearPeriodStart(). + ClearPeriodEnd() + } else { + updateQuery = updateQuery. + SetPeriodStart(in.ServicePeriod.From.In(time.UTC)). + SetPeriodEnd(in.ServicePeriod.To.In(time.UTC)) + } // Supplier updateQuery = updateQuery. @@ -215,6 +223,9 @@ func (a *adapter) ListGatheringInvoices(ctx context.Context, input billing.ListG if input.Expand.Has(billing.GatheringInvoiceExpandLines) { query = query.WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) { + if !input.Expand.Has(billing.GatheringInvoiceExpandDeletedLines) { + q = q.Where(billinginvoiceline.DeletedAtIsNil()) + } q.WithUsageBasedLine() }) } @@ -337,6 +348,9 @@ func (a *adapter) GetGatheringInvoiceById(ctx context.Context, input billing.Get if input.Expand.Has(billing.GatheringInvoiceExpandLines) { query = query.WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) { + if !input.Expand.Has(billing.GatheringInvoiceExpandDeletedLines) { + q = q.Where(billinginvoiceline.DeletedAtIsNil()) + } q.WithUsageBasedLine() }) } diff --git a/openmeter/billing/adapter/gatheringlines.go b/openmeter/billing/adapter/gatheringlines.go index 84a4eb8597..4d0a3825d7 100644 --- a/openmeter/billing/adapter/gatheringlines.go +++ b/openmeter/billing/adapter/gatheringlines.go @@ -12,16 +12,94 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceusagebasedlineconfig" "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/convert" "github.com/openmeterio/openmeter/pkg/entitydiff" + "github.com/openmeterio/openmeter/pkg/framework/entutils" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/slicesx" "github.com/openmeterio/openmeter/pkg/timeutil" ) +func (a *adapter) HardDeleteGatheringInvoiceLines(ctx context.Context, invoiceID billing.InvoiceID, lineIDs []string) error { + if err := invoiceID.Validate(); err != nil { + return fmt.Errorf("validating invoice ID: %w", err) + } + + if len(lineIDs) == 0 { + return nil + } + + return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { + // Let's validate the delete + invoiceHeader, err := tx.db.BillingInvoice.Query(). + Select(billinginvoice.FieldStatus, billinginvoice.FieldNamespace, billinginvoice.FieldCurrency). + Where(billinginvoice.ID(invoiceID.ID)). + Where(billinginvoice.Namespace(invoiceID.Namespace)). + Only(ctx) + if err != nil { + return err + } + + if invoiceHeader.Status != billing.StandardInvoiceStatusGathering { + return fmt.Errorf("invoice is not a gathering invoice [id=%s, namespace=%s, currency=%s]", invoiceID.ID, invoiceID.Namespace, invoiceHeader.Currency) + } + + // Let's determine the usage based line configs to delete + existingLines, err := tx.db.BillingInvoiceLine.Query(). + Where(billinginvoiceline.InvoiceID(invoiceID.ID)). + Where(billinginvoiceline.Namespace(invoiceID.Namespace)). + Where(billinginvoiceline.IDIn(lineIDs...)). + WithUsageBasedLine(). + All(ctx) + if err != nil { + return err + } + + usageBasedLineConfigIDs, err := slicesx.MapWithErr(existingLines, func(line *db.BillingInvoiceLine) (string, error) { + if line.Edges.UsageBasedLine == nil { + return "", fmt.Errorf("usage based line is missing [line_id=%s]", line.ID) + } + + return line.Edges.UsageBasedLine.ID, nil + }) + if err != nil { + return err + } + + nrDeleted, err := tx.db.BillingInvoiceLine.Delete(). + Where(billinginvoiceline.InvoiceID(invoiceID.ID)). + Where(billinginvoiceline.Namespace(invoiceID.Namespace)). + Where(billinginvoiceline.IDIn(lineIDs...)). + Exec(ctx) + if err != nil { + return err + } + + if nrDeleted != len(lineIDs) { + // Note: this causes a rollback of the transaction + return fmt.Errorf("failed to hard delete all gathering invoice lines [deleted=%d, linesToDelete=%d]", nrDeleted, len(lineIDs)) + } + + nrDeleted, err = tx.db.BillingInvoiceUsageBasedLineConfig.Delete(). + Where(billinginvoiceusagebasedlineconfig.IDIn(usageBasedLineConfigIDs...)). + Where(billinginvoiceusagebasedlineconfig.Namespace(invoiceID.Namespace)). + Exec(ctx) + if err != nil { + return err + } + + if nrDeleted != len(usageBasedLineConfigIDs) { + return fmt.Errorf("failed to hard delete all usage based line configs [deleted=%d, configsToDelete=%d]", nrDeleted, len(usageBasedLineConfigIDs)) + } + + return nil + }) +} + type gatheringLineDiff struct { Line entitydiff.Diff[*billing.GatheringLine] } diff --git a/openmeter/billing/adapter/invoicelinesplitgroup.go b/openmeter/billing/adapter/invoicelinesplitgroup.go index bdf711bd6f..a35c97413e 100644 --- a/openmeter/billing/adapter/invoicelinesplitgroup.go +++ b/openmeter/billing/adapter/invoicelinesplitgroup.go @@ -16,6 +16,8 @@ import ( "github.com/openmeterio/openmeter/pkg/timeutil" ) +var _ billing.InvoiceSplitLineGroupAdapter = (*adapter)(nil) + func (a *adapter) CreateSplitLineGroup(ctx context.Context, input billing.CreateSplitLineGroupAdapterInput) (billing.SplitLineGroup, error) { if err := input.Validate(); err != nil { return billing.SplitLineGroup{}, billing.ValidationError{ @@ -295,3 +297,30 @@ func (a *adapter) fetchAllSplitLineGroups(ctx context.Context, namespace string, return a.mapSplitLineHierarchyFromDB(ctx, dbSplitLineGroup) }) } + +func (a *adapter) GetSplitLineGroupHeaders(ctx context.Context, input billing.GetSplitLineGroupHeadersInput) (billing.SplitLineGroupHeaders, error) { + if err := input.Validate(); err != nil { + return billing.SplitLineGroupHeaders{}, billing.ValidationError{ + Err: err, + } + } + + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.SplitLineGroupHeaders, error) { + dbSplitLineGroups, err := tx.db.BillingInvoiceSplitLineGroup.Query(). + Where(billinginvoicesplitlinegroup.Namespace(input.Namespace)). + Where(billinginvoicesplitlinegroup.IDIn(input.SplitLineGroupIDs...)). + All(ctx) + if err != nil { + return billing.SplitLineGroupHeaders{}, err + } + + splitLineGroups, err := slicesx.MapWithErr(dbSplitLineGroups, func(dbSplitLineGroup *db.BillingInvoiceSplitLineGroup) (billing.SplitLineGroup, error) { + return a.mapSplitLineGroupFromDB(dbSplitLineGroup) + }) + if err != nil { + return billing.SplitLineGroupHeaders{}, err + } + + return splitLineGroups, nil + }) +} diff --git a/openmeter/billing/gatheringinvoice.go b/openmeter/billing/gatheringinvoice.go index 6947599f90..20bebce36e 100644 --- a/openmeter/billing/gatheringinvoice.go +++ b/openmeter/billing/gatheringinvoice.go @@ -147,11 +147,13 @@ func (e GatheringInvoiceExpand) Validate() error { const ( GatheringInvoiceExpandLines GatheringInvoiceExpand = "lines" + GatheringInvoiceExpandDeletedLines GatheringInvoiceExpand = "deletedLines" GatheringInvoiceExpandAvailableActions GatheringInvoiceExpand = "availableActions" ) var GatheringInvoiceExpandValues = []GatheringInvoiceExpand{ GatheringInvoiceExpandLines, + GatheringInvoiceExpandDeletedLines, GatheringInvoiceExpandAvailableActions, } @@ -180,6 +182,18 @@ type GatheringInvoiceAvailableActions struct { type GatheringLines []GatheringLine +func (l GatheringLines) Validate() error { + return errors.Join( + lo.Map(l, func(l GatheringLine, _ int) error { + err := l.Validate() + if err != nil { + return fmt.Errorf("line[%s]: %w", l.ID, err) + } + return nil + })..., + ) +} + type GatheringInvoiceLines struct { mo.Option[GatheringLines] } @@ -189,15 +203,7 @@ func (l GatheringInvoiceLines) Validate() error { return nil } - return errors.Join( - lo.Map(l.OrEmpty(), func(l GatheringLine, _ int) error { - err := l.Validate() - if err != nil { - return fmt.Errorf("line[%s]: %w", l.ID, err) - } - return nil - })..., - ) + return l.OrEmpty().Validate() } func (l *GatheringInvoiceLines) Sort() { @@ -246,6 +252,54 @@ func (l *GatheringInvoiceLines) Append(lines ...GatheringLine) { l.Option = mo.Some(append(l.OrEmpty(), lines...)) } +func (l GatheringInvoiceLines) GetReferencedFeatureKeys() ([]string, error) { + if l.IsAbsent() { + return nil, nil + } + + keys := make([]string, 0, len(l.OrEmpty())) + for _, line := range l.OrEmpty() { + if line.FeatureKey == "" { + continue + } + + keys = append(keys, line.FeatureKey) + } + + return lo.Uniq(keys), nil +} + +func (l GatheringInvoiceLines) GetByID(id string) (GatheringLine, bool) { + if l.IsAbsent() { + return GatheringLine{}, false + } + + lines := l.OrEmpty() + for _, line := range lines { + if line.ID == id { + return line, true + } + } + + return GatheringLine{}, false +} + +func (l *GatheringInvoiceLines) SetByID(line GatheringLine) error { + if l.IsAbsent() { + return fmt.Errorf("lines are absent") + } + + lines := l.OrEmpty() + for i := range lines { + if lines[i].ID == line.ID { + lines[i] = line + return nil + } + } + + return fmt.Errorf("line[%s]: line not found", line.ID) +} + func NewGatheringInvoiceLines(children []GatheringLine) GatheringInvoiceLines { return GatheringInvoiceLines{ Option: mo.Some(GatheringLines(children)), @@ -371,6 +425,30 @@ func (i GatheringLineBase) Clone() (GatheringLineBase, error) { return out, nil } +func (i GatheringLineBase) GetFeatureKey() string { + return i.FeatureKey +} + +func (i GatheringLineBase) GetServicePeriod() timeutil.ClosedPeriod { + return i.ServicePeriod +} + +func (i GatheringLineBase) GetPrice() *productcatalog.Price { + return &i.Price +} + +func (i GatheringLineBase) GetID() string { + return i.ID +} + +func (i GatheringLineBase) GetInvoiceAt() time.Time { + return i.InvoiceAt +} + +func (i GatheringLineBase) GetSplitLineGroupID() *string { + return i.SplitLineGroupID +} + // TODO: rename to GatheringLine type GatheringLine struct { GatheringLineBase `json:",inline"` @@ -390,6 +468,26 @@ func (g GatheringLine) Clone() (GatheringLine, error) { }, nil } +func (i GatheringLine) CloneForCreate(edits ...func(*GatheringLine)) (GatheringLine, error) { + clone, err := i.Clone() + if err != nil { + return GatheringLine{}, fmt.Errorf("cloning line: %w", err) + } + + clone.ID = "" + clone.UBPConfigID = "" + clone.CreatedAt = time.Time{} + clone.UpdatedAt = time.Time{} + clone.DeletedAt = nil + clone.DBState = nil + + for _, edit := range edits { + edit(&clone) + } + + return clone, nil +} + func (g GatheringLine) WithoutDBState() (GatheringLine, error) { clone, err := g.Clone() if err != nil { @@ -516,8 +614,10 @@ type ListGatheringInvoicesInput struct { func (i ListGatheringInvoicesInput) Validate() error { var errs []error - if err := i.Page.Validate(); err != nil { - errs = append(errs, fmt.Errorf("page: %w", err)) + if !lo.IsEmpty(i.Page) { + if err := i.Page.Validate(); err != nil { + errs = append(errs, fmt.Errorf("page: %w", err)) + } } if len(i.Namespaces) == 0 { diff --git a/openmeter/billing/invoicelinesplitgroup.go b/openmeter/billing/invoicelinesplitgroup.go index d951845c45..0e3dcbeec5 100644 --- a/openmeter/billing/invoicelinesplitgroup.go +++ b/openmeter/billing/invoicelinesplitgroup.go @@ -344,3 +344,20 @@ func (i LineOrHierarchy) ServicePeriod() Period { return Period{} } + +type GetSplitLineGroupHeadersInput struct { + Namespace string + SplitLineGroupIDs []string +} + +type SplitLineGroupHeaders = []SplitLineGroup + +func (i GetSplitLineGroupHeadersInput) Validate() error { + var errs []error + + if i.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + + return errors.Join(errs...) +} diff --git a/openmeter/billing/service/featuremeter.go b/openmeter/billing/service/featuremeter.go index 6f6449776d..bb1b1e97a8 100644 --- a/openmeter/billing/service/featuremeter.go +++ b/openmeter/billing/service/featuremeter.go @@ -12,32 +12,25 @@ import ( "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" ) -func (s *Service) resolveFeatureMeters(ctx context.Context, lines billing.StandardLines) (billing.FeatureMeters, error) { - namespaces := lo.Uniq(lo.Map(lines, func(line *billing.StandardLine, _ int) string { - return line.Namespace - })) +type linesFeatureGetter interface { + GetReferencedFeatureKeys() ([]string, error) +} - if len(namespaces) != 1 { - return nil, fmt.Errorf("all lines must be in the same namespace") +func (s *Service) resolveFeatureMeters(ctx context.Context, namespace string, lines linesFeatureGetter) (billing.FeatureMeters, error) { + if namespace == "" { + return nil, fmt.Errorf("namespace is required") } - namespace := namespaces[0] + featureKeys, err := lines.GetReferencedFeatureKeys() + if err != nil { + return nil, fmt.Errorf("getting referenced feature keys: %w", err) + } - featuresToResolve := lo.Uniq( - lo.Filter( - lo.Map(lines, func(line *billing.StandardLine, _ int) string { - // Never happens, as StandardLine is always a usage based line, but until we migrate to a new table let's keep it here - if line.UsageBased == nil { - return "" - } + if len(featureKeys) == 0 { + return billing.FeatureMeters{}, nil + } - return line.UsageBased.FeatureKey - }), - func(featureKey string, _ int) bool { - return featureKey != "" - }, - ), - ) + featuresToResolve := lo.Uniq(featureKeys) // Let's resolve the features features, err := s.featureService.ListFeatures(ctx, feature.ListFeaturesParams{ diff --git a/openmeter/billing/service/gatheringinvoicependinglines.go b/openmeter/billing/service/gatheringinvoicependinglines.go index d0e239e1ff..81424bb3cf 100644 --- a/openmeter/billing/service/gatheringinvoicependinglines.go +++ b/openmeter/billing/service/gatheringinvoicependinglines.go @@ -14,9 +14,13 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/service/lineservice" "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/timeutil" ) // InvoicePendingLines invoices the pending lines for the customer. @@ -76,12 +80,11 @@ func (s *Service) InvoicePendingLines(ctx context.Context, input billing.Invoice asOf := lo.FromPtrOr(input.AsOf, clock.Now()) // let's fetch the existing gathering invoices for the customer - existingGatheringInvoices, err := s.ListInvoices(ctx, billing.ListInvoicesInput{ - Namespaces: []string{input.Customer.Namespace}, - Customers: []string{input.Customer.ID}, - ExtendedStatuses: []billing.StandardInvoiceStatus{billing.StandardInvoiceStatusGathering}, - Expand: billing.InvoiceExpand{ - Lines: true, + existingGatheringInvoices, err := s.adapter.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ + Namespaces: []string{input.Customer.Namespace}, + Customers: []string{input.Customer.ID}, + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, }, }) if err != nil { @@ -97,7 +100,7 @@ func (s *Service) InvoicePendingLines(ctx context.Context, input billing.Invoice } } - invoicesByCurrency := lo.SliceToMap(existingGatheringInvoices.Items, func(i billing.StandardInvoice) (currencyx.Code, gatheringInvoiceWithFeatureMeters) { + invoicesByCurrency := lo.SliceToMap(existingGatheringInvoices.Items, func(i billing.GatheringInvoice) (currencyx.Code, gatheringInvoiceWithFeatureMeters) { return i.Currency, gatheringInvoiceWithFeatureMeters{ Invoice: i, } @@ -105,7 +108,7 @@ func (s *Service) InvoicePendingLines(ctx context.Context, input billing.Invoice // Let's resolve the feature meters for each gathering invoice line for downstream calculations. for currency, gatheringInvoiceWithCurrency := range invoicesByCurrency { - featureMeters, err := s.resolveFeatureMeters(ctx, invoicesByCurrency[currency].Invoice.Lines.OrEmpty()) + featureMeters, err := s.resolveFeatureMeters(ctx, input.Customer.Namespace, invoicesByCurrency[currency].Invoice.Lines) if err != nil { return nil, fmt.Errorf("resolving feature meters: %w", err) } @@ -183,12 +186,12 @@ func (s *Service) InvoicePendingLines(ctx context.Context, input billing.Invoice }) } -type gatheringLineWithBillablePeriod = lineservice.LineWithBillablePeriod[*billing.StandardLine] +type gatheringLineWithBillablePeriod = lineservice.LineWithBillablePeriod[billing.GatheringLine] type handleInvoicePendingLinesForCurrencyInput struct { Currency currencyx.Code Customer customer.Customer - GatheringInvoice billing.StandardInvoice + GatheringInvoice billing.GatheringInvoice FeatureMeters billing.FeatureMeters InScopeLines []gatheringLineWithBillablePeriod EffectiveBillingProfile billing.Profile @@ -236,7 +239,7 @@ func (s *Service) handleInvoicePendingLinesForCurrency(ctx context.Context, in h return nil, fmt.Errorf("lines to bill is nil") } - if len(prepareResults.LineIDsToBill) == 0 { + if len(prepareResults.LinesToBill) == 0 { return nil, billing.ValidationError{ Err: billing.ErrInvoiceCreateNoLines, } @@ -245,15 +248,6 @@ func (s *Service) handleInvoicePendingLinesForCurrency(ctx context.Context, in h gatheringInvoice = prepareResults.GatheringInvoice // Step 2: Let's create the standard invoice and move the lines to the new invoice. - linesToBill := lo.Filter(gatheringInvoice.Lines.OrEmpty(), func(line *billing.StandardLine, _ int) bool { - return slices.Contains(prepareResults.LineIDsToBill, line.ID) - }) - - if len(linesToBill) != len(prepareResults.LineIDsToBill) { - return nil, fmt.Errorf("lines to associate[%d] must contain the same number of lines as lines to bill[%d]", len(linesToBill), len(prepareResults.LineIDsToBill)) - } - - // Let's create the invoice and associate the lines to it // Invariant: // - new invoice: initial calculations are done and persisted to the database // - gathering invoice: lines that have been associated to the new invoice are removed from the gathering invoice @@ -262,25 +256,22 @@ func (s *Service) handleInvoicePendingLinesForCurrency(ctx context.Context, in h Currency: in.Currency, GatheringInvoice: gatheringInvoice, FeatureMeters: in.FeatureMeters, - Lines: linesToBill, + Lines: prepareResults.LinesToBill, EffectiveBillingProfile: in.EffectiveBillingProfile, }) if err != nil { return nil, fmt.Errorf("creating standard invoice and associating lines: %w", err) } - gatheringInvoice = createStandardInvoiceResult.GatheringInvoice - - _, err = s.updateGatheringInvoice(ctx, gatheringInvoice) - if err != nil { - return nil, fmt.Errorf("updating gathering invoice: %w", err) + if createStandardInvoiceResult == nil { + return nil, fmt.Errorf("created invoice is nil") } - return &createStandardInvoiceResult.CreatedInvoice, nil + return createStandardInvoiceResult, nil } type gatheringInvoiceWithFeatureMeters struct { - Invoice billing.StandardInvoice + Invoice billing.GatheringInvoice FeatureMeters billing.FeatureMeters } @@ -304,7 +295,7 @@ func (s *Service) gatherInScopeLines(ctx context.Context, in gatherInScopeLineIn for currency, invoice := range in.GatheringInvoicesByCurrency { linesWithResolvedPeriods, err := lineservice.GetLinesWithBillablePeriods( - lineservice.GetLinesWithBillablePeriodsInput[*billing.StandardLine]{ + lineservice.GetLinesWithBillablePeriodsInput[billing.GatheringLine]{ AsOf: in.AsOf, ProgressiveBilling: in.ProgressiveBilling, Lines: invoice.Invoice.Lines.OrEmpty(), @@ -323,7 +314,7 @@ func (s *Service) gatherInScopeLines(ctx context.Context, in gatherInScopeLineIn // 2. the line does not need to be split but it's invoiceAt is after the line's period end, when the line is technically billable, but // from the user's perspective as they are not requesting progressive billing we should not include it on the invoice. - linesWithResolvedPeriods = lo.Filter(linesWithResolvedPeriods, func(line lineservice.LineWithBillablePeriod[*billing.StandardLine], _ int) bool { + linesWithResolvedPeriods = lo.Filter(linesWithResolvedPeriods, func(line lineservice.LineWithBillablePeriod[billing.GatheringLine], _ int) bool { invoiceAtTruncated := line.Line.InvoiceAt.Truncate(streaming.MinimumWindowSizeDuration) return invoiceAtTruncated.Before(asOfTruncated) || invoiceAtTruncated.Equal(asOfTruncated) @@ -410,10 +401,16 @@ func (s *Service) hasInvoicableLines(ctx context.Context, in hasInvoicableLinesI return false, err } + // TODO: Remove once we have the Union type for generic invoice queries + gatheringInvoice, err := convertStandardInvoiceToGatheringInvoice(in.Invoice) + if err != nil { + return false, fmt.Errorf("converting standard invoice to gathering invoice: %w", err) + } + inScopeLines, err := s.gatherInScopeLines(ctx, gatherInScopeLineInput{ GatheringInvoicesByCurrency: map[currencyx.Code]gatheringInvoiceWithFeatureMeters{ - in.Invoice.Currency: { - Invoice: in.Invoice, + gatheringInvoice.Currency: { + Invoice: gatheringInvoice, FeatureMeters: in.FeatureMeters, }, }, @@ -432,8 +429,84 @@ func (s *Service) hasInvoicableLines(ctx context.Context, in hasInvoicableLinesI return len(res) > 0, nil } +func convertStandardInvoiceToGatheringInvoice(invoice billing.StandardInvoice) (billing.GatheringInvoice, error) { + // TODO: Remove once we have the Union type for generic invoice queries + + lines := billing.GatheringInvoiceLines{} + if invoice.Lines.IsPresent() { + gatheringLines, err := slicesx.MapWithErr(invoice.Lines.OrEmpty(), func(l *billing.StandardLine) (billing.GatheringLine, error) { + if l.UsageBased == nil { + return billing.GatheringLine{}, fmt.Errorf("usage based line is required") + } + + if l.UsageBased.Price == nil { + return billing.GatheringLine{}, fmt.Errorf("usage based line price is required") + } + + return billing.GatheringLine{ + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: l.ManagedResource, + Metadata: l.Metadata, + Annotations: l.Annotations, + ManagedBy: l.ManagedBy, + InvoiceID: l.InvoiceID, + Currency: l.Currency, + ServicePeriod: timeutil.ClosedPeriod{ + From: l.Period.Start, + To: l.Period.End, + }, + InvoiceAt: l.InvoiceAt, + Price: lo.FromPtr(l.UsageBased.Price), + FeatureKey: l.UsageBased.FeatureKey, + TaxConfig: l.TaxConfig, + RateCardDiscounts: l.RateCardDiscounts, + ChildUniqueReferenceID: l.ChildUniqueReferenceID, + Subscription: l.Subscription, + + UBPConfigID: l.UsageBased.ConfigID, + SplitLineGroupID: l.SplitLineGroupID, + }, + }, nil + }) + if err != nil { + return billing.GatheringInvoice{}, fmt.Errorf("mapping lines: %w", err) + } + + lines = billing.NewGatheringInvoiceLines(gatheringLines) + } + + return billing.GatheringInvoice{ + GatheringInvoiceBase: billing.GatheringInvoiceBase{ + ManagedResource: models.ManagedResource{ + NamespacedModel: models.NamespacedModel{ + Namespace: invoice.Namespace, + }, + ManagedModel: models.ManagedModel{ + CreatedAt: invoice.CreatedAt, + UpdatedAt: invoice.UpdatedAt, + DeletedAt: invoice.DeletedAt, + }, + ID: invoice.ID, + Name: invoice.Number, + Description: invoice.Description, + }, + Metadata: invoice.Metadata, + Number: invoice.Number, + CustomerID: invoice.Customer.CustomerID, + Currency: invoice.Currency, + ServicePeriod: timeutil.ClosedPeriod{ + From: invoice.Period.Start, + To: invoice.Period.End, + }, + NextCollectionAt: lo.FromPtrOr(invoice.CollectionAt, clock.Now()), + SchemaLevel: invoice.SchemaLevel, + }, + Lines: lines, + }, nil +} + type prepareLinesToBillInput struct { - GatheringInvoice billing.StandardInvoice + GatheringInvoice billing.GatheringInvoice FeatureMeters billing.FeatureMeters InScopeLines []gatheringLineWithBillablePeriod } @@ -441,10 +514,6 @@ type prepareLinesToBillInput struct { func (i prepareLinesToBillInput) Validate() error { var errs []error - if i.GatheringInvoice.Status != billing.StandardInvoiceStatusGathering { - errs = append(errs, fmt.Errorf("gathering invoice is not in gathering status")) - } - if i.GatheringInvoice.Lines.IsAbsent() { errs = append(errs, fmt.Errorf("gathering invoice must have lines expanded")) } @@ -467,8 +536,8 @@ func (i prepareLinesToBillInput) Validate() error { } type prepareLinesToBillResult struct { - LineIDsToBill []string - GatheringInvoice billing.StandardInvoice + LinesToBill billing.GatheringLines + GatheringInvoice billing.GatheringInvoice } // prepareLinesToBill prepares the lines to be billed from the gathering invoice, if needed @@ -480,14 +549,14 @@ func (s *Service) prepareLinesToBill(ctx context.Context, input prepareLinesToBi gatheringInvoice := input.GatheringInvoice - invoiceLines := make([]*billing.StandardLine, 0, len(input.InScopeLines)) + invoiceLines := make([]billing.GatheringLine, 0, len(input.InScopeLines)) wasSplit := false for _, line := range input.InScopeLines { - if !line.Line.Period.ToClosedPeriod().Equal(line.BillablePeriod) { + if !line.Line.ServicePeriod.Equal(line.BillablePeriod) { // We need to split the line into multiple lines - if !line.Line.Period.Start.Equal(line.BillablePeriod.From) { - return nil, fmt.Errorf("line[%s]: line period start[%s] is not equal to billable period start[%s]", line.Line.ID, line.Line.Period.Start, line.BillablePeriod.From) + if !line.Line.ServicePeriod.From.Equal(line.BillablePeriod.From) { + return nil, fmt.Errorf("line[%s]: line period start[%s] is not equal to billable period start[%s]", line.Line.ID, line.Line.ServicePeriod.From, line.BillablePeriod.From) } splitLine, err := s.splitGatheringInvoiceLine(ctx, splitGatheringInvoiceLineInput{ @@ -500,12 +569,14 @@ func (s *Service) prepareLinesToBill(ctx context.Context, input prepareLinesToBi return nil, fmt.Errorf("line[%s]: splitting line: %w", line.Line.ID, err) } - if splitLine.PreSplitAtLine == nil || splitLine.PreSplitAtLine.DeletedAt != nil { - if splitLine.PreSplitAtLine != nil { - wasSplit = true - } + if splitLine.PreSplitAtLine.DeletedAt != nil { + wasSplit = true - s.logger.WarnContext(ctx, "pre split line is nil, we are not creating empty lines", "line", line.Line.ID, "period_start", line.Line.Period.Start, "period_end", line.Line.Period.End) + s.logger.WarnContext(ctx, "pre split line is nil, skipping collection", + "line", line.Line.ID, + "original_period_start", line.Line.ServicePeriod.From, + "original_period_end", line.Line.ServicePeriod.To, + "split_at", line.BillablePeriod.To) continue } @@ -519,22 +590,32 @@ func (s *Service) prepareLinesToBill(ctx context.Context, input prepareLinesToBi if wasSplit { // Let's update the gathering invoice to contain the new lines that we have split - updatedInvoice, err := s.adapter.UpdateInvoice(ctx, gatheringInvoice) + err := s.adapter.UpdateGatheringInvoice(ctx, gatheringInvoice) if err != nil { return nil, fmt.Errorf("updating gathering invoice: %w", err) } + updatedInvoice, err := s.adapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: gatheringInvoice.InvoiceID(), + Expand: billing.GatheringInvoiceExpands{ + billing.GatheringInvoiceExpandLines, + }, + }) + if err != nil { + return nil, fmt.Errorf("getting gathering invoice: %w", err) + } + gatheringInvoice = updatedInvoice } return &prepareLinesToBillResult{ - LineIDsToBill: lo.Map(invoiceLines, func(l *billing.StandardLine, _ int) string { return l.ID }), + LinesToBill: invoiceLines, GatheringInvoice: gatheringInvoice, }, nil } type splitGatheringInvoiceLineInput struct { - GatheringInvoice billing.StandardInvoice + GatheringInvoice billing.GatheringInvoice FeatureMeters billing.FeatureMeters LineID string SplitAt time.Time @@ -543,10 +624,6 @@ type splitGatheringInvoiceLineInput struct { func (i splitGatheringInvoiceLineInput) Validate() error { var errs []error - if i.GatheringInvoice.Status != billing.StandardInvoiceStatusGathering { - errs = append(errs, fmt.Errorf("gathering invoice is not in gathering status")) - } - if i.LineID == "" { errs = append(errs, fmt.Errorf("line ID is required")) } @@ -567,9 +644,9 @@ func (i splitGatheringInvoiceLineInput) Validate() error { } type splitGatheringInvoiceLineResult struct { - PreSplitAtLine *billing.StandardLine - PostSplitAtLine *billing.StandardLine - GatheringInvoice billing.StandardInvoice + PreSplitAtLine billing.GatheringLine + PostSplitAtLine billing.GatheringLine + GatheringInvoice billing.GatheringInvoice } // splitGatheringInvoiceLine splits a gathering invoice line into two lines, one will be from the @@ -586,11 +663,12 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri gatheringInvoice := in.GatheringInvoice - line := gatheringInvoice.Lines.GetByID(in.LineID) - if line == nil { + line, found := gatheringInvoice.Lines.GetByID(in.LineID) + if !found { return res, fmt.Errorf("line[%s]: line not found in gathering invoice", in.LineID) } - if !line.Period.Contains(in.SplitAt) { + + if !line.ServicePeriod.Contains(in.SplitAt) { return res, fmt.Errorf("line[%s]: splitAt is not within the line period", line.ID) } @@ -603,7 +681,7 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri Name: line.Name, Description: line.Description, - ServicePeriod: line.Period, + ServicePeriod: billing.Period{Start: line.ServicePeriod.From, End: line.ServicePeriod.To}, RatecardDiscounts: line.RateCardDiscounts, TaxConfig: line.TaxConfig, }, @@ -612,8 +690,8 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri Currency: line.Currency, - Price: line.UsageBased.Price, - FeatureKey: lo.EmptyableToPtr(line.UsageBased.FeatureKey), + Price: lo.ToPtr(line.Price), + FeatureKey: lo.EmptyableToPtr(line.FeatureKey), Subscription: line.Subscription, }) @@ -624,15 +702,21 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri splitLineGroupID = splitLineGroup.ID } else { splitLineGroupID = lo.FromPtr(line.SplitLineGroupID) + if splitLineGroupID == "" { + return res, fmt.Errorf("split line group id is empty") + } } - // We have alredy split the line once, we just need to create a new line and update the existing line - postSplitAtLine := line.CloneWithoutDependencies(func(l *billing.StandardLine) { - l.Period.Start = in.SplitAt + // We have already split the line once, we just need to create a new line and update the existing line + postSplitAtLine, err := line.CloneForCreate(func(l *billing.GatheringLine) { + l.ServicePeriod.From = in.SplitAt l.SplitLineGroupID = lo.ToPtr(splitLineGroupID) l.ChildUniqueReferenceID = nil }) + if err != nil { + return res, fmt.Errorf("cloning post split line: %w", err) + } postSplitAtLineEmpty, err := lineservice.IsPeriodEmptyConsideringTruncations(postSplitAtLine) if err != nil { @@ -648,7 +732,7 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri } // Let's update the original line to only contain the period up to the splitAt time - line.Period.End = in.SplitAt + line.ServicePeriod.To = in.SplitAt line.InvoiceAt = in.SplitAt line.SplitLineGroupID = lo.ToPtr(splitLineGroupID) line.ChildUniqueReferenceID = nil @@ -669,6 +753,10 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri } } + if err := gatheringInvoice.Lines.SetByID(preSplitAtLine); err != nil { + return res, fmt.Errorf("setting pre split line: %w", err) + } + return splitGatheringInvoiceLineResult{ PreSplitAtLine: preSplitAtLine, PostSplitAtLine: postSplitAtLine, @@ -679,9 +767,9 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri type createStandardInvoiceFromGatheringLinesInput struct { Customer customer.Customer Currency currencyx.Code - GatheringInvoice billing.StandardInvoice + GatheringInvoice billing.GatheringInvoice FeatureMeters billing.FeatureMeters - Lines billing.StandardLines + Lines billing.GatheringLines EffectiveBillingProfile billing.Profile } @@ -717,16 +805,11 @@ func (in createStandardInvoiceFromGatheringLinesInput) Validate() error { return errors.Join(errs...) } -type createStandardInvoiceFromGatheringLinesResult struct { - CreatedInvoice billing.StandardInvoice - GatheringInvoice billing.StandardInvoice -} - // createStandardInvoiceFromGatheringLines creates a standard invoice from the gathering invoice lines. // Invariant: // - the standard invoice is in draft.created state, and is calculated and persisted to the database -// - the gathering invoice's lines are removed, but not persisted to the database -func (s *Service) createStandardInvoiceFromGatheringLines(ctx context.Context, in createStandardInvoiceFromGatheringLinesInput) (*createStandardInvoiceFromGatheringLinesResult, error) { +// - the gathering invoice's lines are deleted, and persisted to the database +func (s *Service) createStandardInvoiceFromGatheringLines(ctx context.Context, in createStandardInvoiceFromGatheringLinesInput) (*billing.StandardInvoice, error) { if err := in.Validate(); err != nil { return nil, fmt.Errorf("validating input: %w", err) } @@ -767,24 +850,40 @@ func (s *Service) createStandardInvoiceFromGatheringLines(ctx context.Context, i return nil, fmt.Errorf("error resolving workflow apps for invoice [%s]: %w", invoiceID, err) } - moveResults, err := s.moveLinesToInvoice(ctx, moveLinesToInvoiceInput{ - SourceGatheringInvoice: in.GatheringInvoice, - TargetInvoice: invoice, - FeatureMeters: in.FeatureMeters, - LineIDsToMove: lo.Map(in.Lines, func(l *billing.StandardLine, _ int) string { return l.ID }), + convertResults, err := s.convertGatheringLinesToStandardLines(ctx, convertGatheringLinesToStandardLinesInput{ + TargetInvoice: invoice, + FeatureMeters: in.FeatureMeters, + GatheringLinesToConvert: in.Lines, }) if err != nil { return nil, fmt.Errorf("moving lines to invoice: %w", err) } - // Let's create the sub lines as per the meters (we are not setting the QuantitySnapshotedAt field just now, to signal that this is not the final snapshot) - if err := s.snapshotLineQuantitiesInParallel(ctx, invoice.Customer, moveResults.LinesAssociated, in.FeatureMeters); err != nil { + invoice = convertResults.TargetInvoice + + // Let's first update the gathering invoice to make sure deleted lines are synced, as the standard invoice will have expanded split line hierarchies + // and we need to make sure that gathering invoice lines that are already yielded the standard invoice lines are excluded from the split line hierarchy. + // + // Note: this is a hack, on the long term we need to have a Charge type that encapsulates all of this logic. + err = s.removeLinesFromGatheringInvoice(ctx, in.GatheringInvoice, in.Lines) + if err != nil { + return nil, fmt.Errorf("updating gathering invoice: %w", err) + } + + // Prerequisite: we should have the split line group headers expanded so that snapshotting can determine if the preLine + // queries are needed. + if err := s.resolveSplitLineGroupHeadersForLines(ctx, in.Customer.Namespace, convertResults.LinesAssociated); err != nil { + return nil, fmt.Errorf("resolving split line group headers for lines: %w", err) + } + + // Let's snapshot the quantities for the lines that we have converted to standard lines so that calculations can be performed + if err := s.snapshotLineQuantitiesInParallel(ctx, invoice.Customer, convertResults.LinesAssociated, in.FeatureMeters); err != nil { return nil, fmt.Errorf("snapshotting lines: %w", err) } - // Let's persist the target invoice as the state machine always reloads the invoice from the database to make + // Let's persist the snapshotted values to the database as the state machine always reloads the invoice from the database to make // sure we don't have any manual modifications inside the invoice structure. - _, err = s.updateInvoice(ctx, moveResults.TargetInvoice) + invoice, err = s.updateInvoice(ctx, invoice) if err != nil { return nil, fmt.Errorf("updating target invoice: %w", err) } @@ -829,28 +928,16 @@ func (s *Service) createStandardInvoiceFromGatheringLines(ctx context.Context, i return nil, fmt.Errorf("activating invoice: %w", err) } - return &createStandardInvoiceFromGatheringLinesResult{ - CreatedInvoice: invoice, - GatheringInvoice: moveResults.GatheringInvoice, - }, nil + return &invoice, nil } -type moveLinesToInvoiceInput struct { - SourceGatheringInvoice billing.StandardInvoice - FeatureMeters billing.FeatureMeters - TargetInvoice billing.StandardInvoice - LineIDsToMove []string +type convertGatheringLinesToStandardLinesInput struct { + FeatureMeters billing.FeatureMeters + TargetInvoice billing.StandardInvoice + GatheringLinesToConvert billing.GatheringLines } -func (in moveLinesToInvoiceInput) Validate() error { - if err := in.SourceGatheringInvoice.Validate(); err != nil { - return fmt.Errorf("source gathering invoice: %w", err) - } - - if in.SourceGatheringInvoice.Status != billing.StandardInvoiceStatusGathering { - return fmt.Errorf("source gathering invoice must be in gathering status") - } - +func (in convertGatheringLinesToStandardLinesInput) Validate() error { if err := in.TargetInvoice.Validate(); err != nil { return fmt.Errorf("target invoice: %w", err) } @@ -859,20 +946,26 @@ func (in moveLinesToInvoiceInput) Validate() error { return fmt.Errorf("target invoice must be in draft created status") } - if len(in.LineIDsToMove) == 0 { + if len(in.GatheringLinesToConvert) == 0 { return fmt.Errorf("line IDs to move is required") } - if in.TargetInvoice.Currency != in.SourceGatheringInvoice.Currency { - return fmt.Errorf("target invoice currency must be the same as source gathering invoice currency") - } - if in.TargetInvoice.ID == "" { return fmt.Errorf("target invoice ID is required") } - if in.TargetInvoice.Namespace != in.SourceGatheringInvoice.Namespace { - return fmt.Errorf("target invoice namespace must be the same as source gathering invoice namespace") + for _, line := range in.GatheringLinesToConvert { + if err := line.Validate(); err != nil { + return fmt.Errorf("validating gathering line: %w", err) + } + + if line.Currency != in.TargetInvoice.Currency { + return fmt.Errorf("gathering line[%s]: currency[%s] is not equal to target invoice currency[%s]", line.ID, line.Currency, in.TargetInvoice.Currency) + } + + if line.Namespace != in.TargetInvoice.Namespace { + return fmt.Errorf("gathering line[%s]: namespace[%s] is not equal to target invoice namespace[%s]", line.ID, line.Namespace, in.TargetInvoice.Namespace) + } } if in.FeatureMeters == nil { @@ -882,73 +975,127 @@ func (in moveLinesToInvoiceInput) Validate() error { return nil } -type moveLinesToInvoiceResult struct { - GatheringInvoice billing.StandardInvoice - TargetInvoice billing.StandardInvoice - LinesAssociated billing.StandardLines +type convertGatheringLinesToStandardLinesResult struct { + TargetInvoice billing.StandardInvoice + LinesAssociated billing.StandardLines } -// moveLinesToInvoice moves the lines from the source gathering invoice to the target invoice, invariants: -// - the source gathering invoice is updated by removing the lines that have been moved to the target invoice -// - the target invoice is updated by adding the lines that have been moved from the source gathering invoice -// - neither invoices are saved to the database, they are returned as is -func (s *Service) moveLinesToInvoice(ctx context.Context, in moveLinesToInvoiceInput) (*moveLinesToInvoiceResult, error) { +// convertGatheringLinesToStandardLines converts the gathering lines to standard lines and adds them to the target invoice. +// Invariants: +// - the target invoice is updated by adding the standard lines that have been converted from the gathering lines +// - no database changes are made +func (s *Service) convertGatheringLinesToStandardLines(ctx context.Context, in convertGatheringLinesToStandardLinesInput) (*convertGatheringLinesToStandardLinesResult, error) { if err := in.Validate(); err != nil { return nil, fmt.Errorf("validating input: %w", err) } - srcInvoice := in.SourceGatheringInvoice - dstInvoice := in.TargetInvoice - - // Let's find the lines to move from the source gathering invoice - linesToMove := lo.Filter(srcInvoice.Lines.OrEmpty(), func(line *billing.StandardLine, _ int) bool { - return slices.Contains(in.LineIDsToMove, line.ID) - }) + newStandardLines, err := slicesx.MapWithErr(in.GatheringLinesToConvert, func(gatheringLine billing.GatheringLine) (*billing.StandardLine, error) { + newStandardLine, err := convertGatheringLineToNewStandardLine(gatheringLine, in.TargetInvoice.ID) + if err != nil { + return nil, fmt.Errorf("converting gathering line to new standard line: %w", err) + } - for _, line := range linesToMove { - if line.Currency != dstInvoice.Currency { - return nil, fmt.Errorf("line[%s]: currency[%s] is not equal to target invoice currency[%s]", line.ID, line.Currency, dstInvoice.Currency) + if err := newStandardLine.Validate(); err != nil { + return nil, fmt.Errorf("validating new standard line: %w", err) } + + return newStandardLine, nil + }) + if err != nil { + return nil, fmt.Errorf("converting gathering lines to standard lines: %w", err) } - if len(linesToMove) != len(in.LineIDsToMove) { - return nil, fmt.Errorf("lines to move[%d] must contain the same number of lines as line IDs to move[%d]", len(linesToMove), len(in.LineIDsToMove)) + // Let's add the lines to the target invoice + in.TargetInvoice.Lines.Append(newStandardLines...) + + return &convertGatheringLinesToStandardLinesResult{ + TargetInvoice: in.TargetInvoice, + LinesAssociated: newStandardLines, + }, nil +} + +func convertGatheringLineToNewStandardLine(line billing.GatheringLine, invoiceID string) (*billing.StandardLine, error) { + clonedAnnotations, err := line.Annotations.Clone() + if err != nil { + return nil, fmt.Errorf("cloning annotations: %w", err) } - if err := linesToMove.Validate(); err != nil { - return nil, fmt.Errorf("validating lines to move: %w", err) + var taxConfig *productcatalog.TaxConfig + if line.TaxConfig != nil { + taxConfig = lo.ToPtr(line.TaxConfig.Clone()) } - // Let's set the invoice ID of the lines to move to the target invoice ID - for _, line := range linesToMove { - line.InvoiceID = dstInvoice.ID + var subscription *billing.SubscriptionReference + if line.Subscription != nil { + subscription = line.Subscription.Clone() } - // Let's add the lines to the target invoice - dstInvoice.Lines.Append(linesToMove...) + convertedLine := &billing.StandardLine{ + StandardLineBase: billing.StandardLineBase{ + ManagedResource: line.ManagedResource, + Metadata: line.Metadata.Clone(), + Annotations: clonedAnnotations, + ManagedBy: line.ManagedBy, + InvoiceID: invoiceID, + Currency: line.Currency, - // Let's remove the lines from the source gathering invoice - for _, line := range linesToMove { - if !srcInvoice.Lines.RemoveByID(line.ID) { - return nil, fmt.Errorf("line[%s] not found in source gathering invoice", line.ID) - } + Period: billing.Period{ + Start: line.ServicePeriod.From, + End: line.ServicePeriod.To, + }, + InvoiceAt: line.InvoiceAt, + + TaxConfig: taxConfig, + RateCardDiscounts: line.RateCardDiscounts.Clone(), + ChildUniqueReferenceID: line.ChildUniqueReferenceID, + Subscription: subscription, + SplitLineGroupID: line.SplitLineGroupID, + }, + UsageBased: &billing.UsageBasedLine{ + Price: lo.ToPtr(line.Price), + FeatureKey: line.FeatureKey, + }, + + DBState: nil, // We don't want to reuse the state from the gathering line (so let's make it explicit) } - return &moveLinesToInvoiceResult{ - GatheringInvoice: srcInvoice, - TargetInvoice: dstInvoice, - LinesAssociated: linesToMove, - }, nil + return convertedLine, nil } // updateGatheringInvoice updates the gathering invoice's state and if it contains no lines, it will be deleted. // Invariant: // - the invoice is recalculated // - the invoice is updated to the database -func (s *Service) updateGatheringInvoice(ctx context.Context, invoice billing.StandardInvoice) (billing.StandardInvoice, error) { +func (s *Service) removeLinesFromGatheringInvoice(ctx context.Context, invoice billing.GatheringInvoice, linesToRemove billing.GatheringLines) error { + lineIDsToRemove := lo.Map(linesToRemove, func(l billing.GatheringLine, _ int) string { return l.ID }) + + nrLinesRemoved := 0 + invoiceLinesWithoutRemovedLines := lo.Filter(invoice.Lines.OrEmpty(), func(l billing.GatheringLine, _ int) bool { + if slices.Contains(lineIDsToRemove, l.ID) { + nrLinesRemoved++ + return false + } + + return true + }) + + // This makes sure that all the IDs are present on the gathering invoice before invoking the hard delete. + if nrLinesRemoved != len(lineIDsToRemove) { + return fmt.Errorf("lines to remove[%d] must contain the same number of lines as line IDs to remove[%d]", nrLinesRemoved, len(lineIDsToRemove)) + } + + invoice.Lines = billing.NewGatheringInvoiceLines(invoiceLinesWithoutRemovedLines) + + // We need to hard delete the lines from the gathering invoice as now the standard lines are taking their place with the same IDs. + // If we would soft-delete the lines, all downstream services would assume that the line was deleted due to synchronization and + // would recreate it. + if err := s.adapter.HardDeleteGatheringInvoiceLines(ctx, invoice.InvoiceID(), lineIDsToRemove); err != nil { + return fmt.Errorf("hard deleting gathering invoice lines: %w", err) + } + // Let's update the invoice's state - if err := s.invoiceCalculator.CalculateLegacyGatheringInvoice(&invoice); err != nil { - return billing.StandardInvoice{}, fmt.Errorf("calculating gathering invoice: %w", err) + if err := s.invoiceCalculator.CalculateGatheringInvoice(&invoice); err != nil { + return fmt.Errorf("calculating gathering invoice: %w", err) } // The gathering invoice has no lines => delete the invoice @@ -956,10 +1103,54 @@ func (s *Service) updateGatheringInvoice(ctx context.Context, invoice billing.St invoice.DeletedAt = lo.ToPtr(clock.Now()) } - invoice, err := s.adapter.UpdateInvoice(ctx, invoice) + err := s.adapter.UpdateGatheringInvoice(ctx, invoice) + if err != nil { + return fmt.Errorf("updating gathering invoice: %w", err) + } + + return nil +} + +// resolveSplitLineGroupHeadersForLines resolves the split line group headers for the given lines. +// Warning: this will not fetch the lines from the database, so only use this if you are sure that +// only the headers are needed. (e.g. don't use it for invoice calculations or usage discounts +// will be off) +func (s *Service) resolveSplitLineGroupHeadersForLines(ctx context.Context, ns string, lines billing.StandardLines) error { + splitLineGroupIDs := lo.Uniq( + lo.Filter( + lo.Map(lines, func(line *billing.StandardLine, _ int) string { return lo.FromPtr(line.SplitLineGroupID) }), + func(id string, _ int) bool { return id != "" }, + ), + ) + + if len(splitLineGroupIDs) == 0 { + return nil + } + + splitLineGroupHeaders, err := s.adapter.GetSplitLineGroupHeaders(ctx, billing.GetSplitLineGroupHeadersInput{ + Namespace: ns, + SplitLineGroupIDs: splitLineGroupIDs, + }) if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("updating gathering invoice: %w", err) + return fmt.Errorf("getting split line group headers: %w", err) } - return invoice, nil + splitLineGroupHeadersByID := lo.SliceToMap(splitLineGroupHeaders, func(header billing.SplitLineGroup) (string, billing.SplitLineGroup) { return header.ID, header }) + + for idx := range lines { + if lines[idx].SplitLineGroupID == nil { + continue + } + + splitLineGroupHeader, ok := splitLineGroupHeadersByID[lo.FromPtr(lines[idx].SplitLineGroupID)] + if !ok { + return fmt.Errorf("split line group header not found for line[%s]: id[%s]", lines[idx].ID, lo.FromPtr(lines[idx].SplitLineGroupID)) + } + + lines[idx].SplitLineHierarchy = &billing.SplitLineHierarchy{ + Group: splitLineGroupHeader, + } + } + + return nil } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index d9c9d2b6ba..9158d578e9 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -160,7 +160,7 @@ func (s *Service) recalculateGatheringInvoice(ctx context.Context, in recalculat return invoice, fmt.Errorf("customer profile is nil") } - featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Lines.OrEmpty()) + featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Namespace, invoice.Lines) if err != nil { return invoice, fmt.Errorf("resolving feature meters: %w", err) } @@ -459,7 +459,7 @@ func (s *Service) executeTriggerOnInvoice(ctx context.Context, invoiceID billing } } - featureMeters, err := s.resolveFeatureMeters(ctx, sm.Invoice.Lines.OrEmpty()) + featureMeters, err := s.resolveFeatureMeters(ctx, sm.Invoice.Namespace, sm.Invoice.Lines) if err != nil { return fmt.Errorf("resolving feature meters: %w", err) } @@ -588,7 +588,7 @@ func (s *Service) UpdateInvoice(ctx context.Context, input billing.UpdateInvoice } } - featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Lines.OrEmpty()) + featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Namespace, invoice.Lines) if err != nil { return billing.StandardInvoice{}, fmt.Errorf("resolving feature meters: %w", err) } @@ -790,7 +790,7 @@ func (s Service) SimulateInvoice(ctx context.Context, input billing.SimulateInvo } } - featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Lines.OrEmpty()) + featureMeters, err := s.resolveFeatureMeters(ctx, input.Namespace, invoice.Lines) if err != nil { return billing.StandardInvoice{}, fmt.Errorf("resolving feature meters: %w", err) } diff --git a/openmeter/billing/service/quantitysnapshot.go b/openmeter/billing/service/quantitysnapshot.go index 7397fe7e65..7932c8d53e 100644 --- a/openmeter/billing/service/quantitysnapshot.go +++ b/openmeter/billing/service/quantitysnapshot.go @@ -24,7 +24,7 @@ func (s *Service) SnapshotLineQuantity(ctx context.Context, input billing.Snapsh } } - featureMeters, err := s.resolveFeatureMeters(ctx, billing.StandardLines{input.Line}) + featureMeters, err := s.resolveFeatureMeters(ctx, input.Invoice.Namespace, billing.StandardLines{input.Line}) if err != nil { return nil, fmt.Errorf("line[%s]: %w", input.Line.ID, err) } diff --git a/openmeter/billing/service/stdinvoicestate.go b/openmeter/billing/service/stdinvoicestate.go index 1aeb95c5a9..03639e9077 100644 --- a/openmeter/billing/service/stdinvoicestate.go +++ b/openmeter/billing/service/stdinvoicestate.go @@ -21,6 +21,7 @@ import ( type InvoiceStateMachine struct { Invoice billing.StandardInvoice + NeedsDBSave bool Calculator invoicecalc.Calculator StateMachine *stateless.StateMachine Logger *slog.Logger @@ -85,7 +86,10 @@ func allocateStateMachine() *InvoiceStateMachine { Permit(billing.TriggerFailed, billing.StandardInvoiceStatusDraftInvalid). Permit(billing.TriggerDelete, billing.StandardInvoiceStatusDeleteInProgress). Permit(billing.TriggerUpdated, billing.StandardInvoiceStatusDraftUpdating). - OnActive(out.calculateInvoice) + OnActive(allOf( + out.calculateInvoice, + out.requireDBSave, // so that any new detailed lines have IDs + )) stateMachine.Configure(billing.StandardInvoiceStatusDraftWaitingForCollection). Permit( @@ -106,6 +110,7 @@ func allocateStateMachine() *InvoiceStateMachine { allOf( out.snapshotQuantityAsNeeded, out.calculateInvoice, + out.requireDBSave, // so that any new detailed lines have IDs ), ) @@ -118,6 +123,7 @@ func allocateStateMachine() *InvoiceStateMachine { allOf( out.calculateInvoice, out.validateDraftInvoice, + out.requireDBSave, // Due to the calculation, new detailed lines may be added ), ) @@ -134,6 +140,7 @@ func allocateStateMachine() *InvoiceStateMachine { OnActive(allOf( out.calculateInvoice, out.validateDraftInvoice, + out.requireDBSave, // Due to the calculation, new detailed lines may be added )) stateMachine.Configure(billing.StandardInvoiceStatusDraftInvalid). @@ -158,7 +165,9 @@ func allocateStateMachine() *InvoiceStateMachine { ). Permit(billing.TriggerDelete, billing.StandardInvoiceStatusDeleteInProgress). Permit(billing.TriggerFailed, billing.StandardInvoiceStatusDraftSyncFailed). - OnActive(out.syncDraftInvoice) + OnActive(allOf( + out.syncDraftInvoice, + )) stateMachine.Configure(billing.StandardInvoiceStatusDraftSyncFailed). Permit(billing.TriggerRetry, billing.StandardInvoiceStatusDraftValidating). @@ -217,7 +226,9 @@ func allocateStateMachine() *InvoiceStateMachine { ). Permit(billing.TriggerFailed, billing.StandardInvoiceStatusIssuingSyncFailed). Permit(billing.TriggerDelete, billing.StandardInvoiceStatusDeleteInProgress). - OnActive(out.finalizeInvoice) + OnActive(allOf( + out.finalizeInvoice, + )) stateMachine.Configure(billing.StandardInvoiceStatusIssuingSyncFailed). Permit(billing.TriggerDelete, billing.StandardInvoiceStatusDeleteInProgress). @@ -285,6 +296,7 @@ func (s *Service) WithInvoiceStateMachine(ctx context.Context, invoice billing.S sm.FSNamespaceLockdown = s.fsNamespaceLockdown // Stateless doesn't store any state in the state machine, so it's fine to reuse the state machine itself sm.Invoice = invoice + sm.NeedsDBSave = false sm.Calculator = s.invoiceCalculator sm.Service = s @@ -295,6 +307,7 @@ func (s *Service) WithInvoiceStateMachine(ctx context.Context, invoice billing.S sm.Logger = nil sm.Publisher = nil sm.FSNamespaceLockdown = nil + sm.NeedsDBSave = false invoiceStateMachineCache.Put(sm) }() @@ -406,6 +419,12 @@ func (m *InvoiceStateMachine) calculateAvailableActionDetails(ctx context.Contex }, nil } +func (m *InvoiceStateMachine) requireDBSave(ctx context.Context) error { + m.NeedsDBSave = true + + return nil +} + func (m *InvoiceStateMachine) AdvanceUntilStateStable(ctx context.Context) error { for { preAdvanceState, err := billing.NewEventStandardInvoice(m.Invoice) @@ -431,6 +450,16 @@ func (m *InvoiceStateMachine) AdvanceUntilStateStable(ctx context.Context) error return fmt.Errorf("cannot transition to the next status [current_status=%s]: %w", m.Invoice.Status, err) } + if m.NeedsDBSave { + updatedInvoice, err := m.Service.updateInvoice(ctx, m.Invoice) + if err != nil { + return fmt.Errorf("error updating invoice: %w", err) + } + + m.NeedsDBSave = false + m.Invoice = updatedInvoice + } + // Let's emit an event for the transition event, err := billing.NewStandardInvoiceUpdatedEvent(m.Invoice, preAdvanceState) if err != nil { @@ -606,7 +635,7 @@ func (m *InvoiceStateMachine) validateDraftInvoice(ctx context.Context) error { } func (m *InvoiceStateMachine) calculateInvoice(ctx context.Context) error { - featureMeters, err := m.Service.resolveFeatureMeters(ctx, m.Invoice.Lines.OrEmpty()) + featureMeters, err := m.Service.resolveFeatureMeters(ctx, m.Invoice.Namespace, m.Invoice.Lines) if err != nil { return fmt.Errorf("resolving feature meters: %w", err) } @@ -744,7 +773,7 @@ func (m *InvoiceStateMachine) snapshotQuantityAsNeeded(ctx context.Context) erro return nil } - featureMeters, err := m.Service.resolveFeatureMeters(ctx, m.Invoice.Lines.OrEmpty()) + featureMeters, err := m.Service.resolveFeatureMeters(ctx, m.Invoice.Namespace, m.Invoice.Lines) if err != nil { return fmt.Errorf("resolving feature meters: %w", err) } diff --git a/openmeter/billing/stdinvoice.go b/openmeter/billing/stdinvoice.go index 36b1b788ba..b2ca61fd6f 100644 --- a/openmeter/billing/stdinvoice.go +++ b/openmeter/billing/stdinvoice.go @@ -525,6 +525,14 @@ func (c *StandardInvoiceLines) RemoveByID(id string) bool { return true } +func (c StandardInvoiceLines) GetReferencedFeatureKeys() ([]string, error) { + if c.IsAbsent() { + return nil, nil + } + + return c.OrEmpty().GetReferencedFeatureKeys() +} + type StandardInvoiceAvailableActions struct { Advance *StandardInvoiceAvailableActionDetails `json:"advance,omitempty"` Approve *StandardInvoiceAvailableActionDetails `json:"approve,omitempty"` diff --git a/openmeter/billing/stdinvoiceline.go b/openmeter/billing/stdinvoiceline.go index 8e34f26ee3..9f34d85562 100644 --- a/openmeter/billing/stdinvoiceline.go +++ b/openmeter/billing/stdinvoiceline.go @@ -158,6 +158,15 @@ func (i SubscriptionReference) Validate() error { return errors.Join(errs...) } +func (i SubscriptionReference) Clone() *SubscriptionReference { + return &SubscriptionReference{ + SubscriptionID: i.SubscriptionID, + PhaseID: i.PhaseID, + ItemID: i.ItemID, + BillingPeriod: i.BillingPeriod, + } +} + type LineExternalIDs struct { Invoicing string `json:"invoicing,omitempty"` } @@ -691,6 +700,24 @@ func (c *StandardLines) Sort() { } } +func (c StandardLines) GetReferencedFeatureKeys() ([]string, error) { + out := make([]string, 0, len(c)) + + for _, line := range c { + if line.UsageBased == nil { + return nil, fmt.Errorf("usage based line is required") + } + + if line.UsageBased.FeatureKey == "" { + continue + } + + out = append(out, line.UsageBased.FeatureKey) + } + + return lo.Uniq(out), nil +} + func (i StandardLine) SetDiscountExternalIDs(externalIDs map[string]string) []string { foundIDs := []string{} diff --git a/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go b/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go index adc831b224..bbdffdeb9c 100644 --- a/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go +++ b/openmeter/billing/worker/subscriptionsync/service/suitebase_test.go @@ -164,6 +164,11 @@ func (s *SuiteBase) expectNoGatheringInvoice(ctx context.Context, namespace stri }) s.NoError(err) + if len(invoices.Items) > 0 { + for _, invoice := range invoices.Items { + s.DebugDumpInvoice(fmt.Sprintf("unexpected gathering invoice[%s]", invoice.ID), invoice) + } + } s.Len(invoices.Items, 0) } diff --git a/openmeter/billing/worker/subscriptionsync/service/sync.go b/openmeter/billing/worker/subscriptionsync/service/sync.go index afa6e193c1..422bfce34c 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync.go @@ -498,7 +498,7 @@ func (s *Service) hierarchyHasAnnotation(hierarchy *billing.SplitLineHierarchy, // The correction can only happen if the last line the progressively billed group is in scope for the period correction for _, line := range hierarchy.Lines { - if line.Line.Period.End.Equal(servicePeriod.End) { + if line.Line.Period.End.Equal(servicePeriod.End) && line.Line.DeletedAt == nil { return s.lineHasAnnotation(line.Line, annotation) } } diff --git a/openmeter/billing/worker/subscriptionsync/service/sync_test.go b/openmeter/billing/worker/subscriptionsync/service/sync_test.go index a6b2bbd889..c002db1c50 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync_test.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync_test.go @@ -3688,6 +3688,7 @@ func (s *SubscriptionHandlerTestSuite) TestSplitLineManualDeleteSync() { draftInvoice := draftInvoices[0] s.DebugDumpInvoice("draft invoice", draftInvoice) + s.DebugDumpInvoice("gathering invoice - after invoicing", s.gatheringInvoice(ctx, s.Namespace, s.Customer.ID)) var updatedLine *billing.StandardLine editedInvoice, err := s.BillingService.UpdateInvoice(ctx, billing.UpdateInvoiceInput{ diff --git a/test/billing/adapter_test.go b/test/billing/adapter_test.go index 72550ff73e..29a1823a47 100644 --- a/test/billing/adapter_test.go +++ b/test/billing/adapter_test.go @@ -17,9 +17,14 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" billingadapter "github.com/openmeterio/openmeter/openmeter/billing/adapter" "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceusagebasedlineconfig" "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" + "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" ) type BillingAdapterTestSuite struct { @@ -617,3 +622,241 @@ func (s *BillingAdapterTestSuite) findAmountDiscountByDescription(discounts []bi s.T().Fatalf("discount not found: %s", description) return billing.AmountLineDiscountManaged{} } + +func (s *BillingAdapterTestSuite) TestHardDeleteGatheringInvoiceLines() { + ctx := s.T().Context() + namespace := s.GetUniqueNamespace("ns-adapter-hard-delete-gathering-invoice-lines") + featureKey := "in-advance-payment" + + var customerEntity *customer.Customer + + s.Run("Given a customer and default billing profile exists", func() { + sandboxApp := s.InstallSandboxApp(s.T(), namespace) + s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID()) + + customerEntity = s.CreateTestCustomer(namespace, "test-customer") + s.NotNil(customerEntity) + + _ = lo.Must(s.FeatureService.CreateFeature(ctx, feature.CreateFeatureInputs{ + Namespace: namespace, + Name: featureKey, + Key: featureKey, + })) + }) + + var gatheringInvoice billing.GatheringInvoice + s.Run("Given a gathering invoice with two lines", func() { + periodStart := time.Now().Add(-time.Hour) + periodEnd := time.Now().Add(time.Hour) + + res, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ + Customer: customerEntity.GetID(), + Currency: currencyx.Code(currency.USD), + Lines: []billing.GatheringLine{ + { + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: namespace, + Name: "Test line 1", + }), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + Currency: currencyx.Code(currency.USD), + RateCardDiscounts: billing.Discounts{ + Percentage: &billing.PercentageDiscount{ + PercentageDiscount: productcatalog.PercentageDiscount{ + Percentage: models.NewPercentage(10), + }, + }, + }, + FeatureKey: featureKey, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + })), + }, + }, + { + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: namespace, + Name: "Test line 2", + }), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + Currency: currencyx.Code(currency.USD), + RateCardDiscounts: billing.Discounts{ + Percentage: &billing.PercentageDiscount{ + PercentageDiscount: productcatalog.PercentageDiscount{ + Percentage: models.NewPercentage(10), + }, + }, + }, + FeatureKey: featureKey, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + })), + }, + }, + }, + }) + s.NoError(err) + + gatheringInvoice = res.Invoice + }) + + var ( + deletedLine billing.GatheringLine + gatheringInvoiceWithDeletedLine billing.GatheringInvoice + ) + s.Run("When we hard delete one of the lines", func() { + deletedLine = gatheringInvoice.Lines.OrEmpty()[0] + err := s.BillingAdapter.HardDeleteGatheringInvoiceLines(ctx, gatheringInvoice.InvoiceID(), []string{deletedLine.ID}) + s.NoError(err) + + gatheringInvoice, err = s.BillingAdapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: gatheringInvoice.InvoiceID(), + Expand: billing.GatheringInvoiceExpands{billing.GatheringInvoiceExpandLines}, + }) + s.NoError(err) + gatheringInvoiceWithDeletedLine = gatheringInvoice + }) + + s.Run("Then the gathering invoice has only one line", func() { + s.Len(gatheringInvoiceWithDeletedLine.Lines.OrEmpty(), 1) + s.NotEqual(deletedLine.ID, gatheringInvoice.Lines.OrEmpty()[0].ID) + }) + + s.Run("Then the deleted line's usage based config is also deleted", func() { + s.NotEmpty(deletedLine.UBPConfigID) + + _, err := s.DBClient.BillingInvoiceUsageBasedLineConfig.Query(). + Where(billinginvoiceusagebasedlineconfig.ID(deletedLine.UBPConfigID)). + Only(ctx) + + s.Error(err) + s.True(db.IsNotFound(err)) + }) +} + +func (s *BillingAdapterTestSuite) TestHardDeleteGatheringInvoiceLinesNegative() { + ctx := s.T().Context() + namespace := s.GetUniqueNamespace("ns-adapter-hard-delete-gathering-invoice-lines-negative") + featureKey := "in-advance-payment" + + now := lo.Must(time.Parse(time.RFC3339, "2026-01-01T00:00:00Z")) + clock.SetTime(now) + defer clock.ResetTime() + + var customerEntity *customer.Customer + + s.Run("Given a customer and billing profile exists", func() { + sandboxApp := s.InstallSandboxApp(s.T(), namespace) + s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID()) + + customerEntity = s.CreateTestCustomer(namespace, "test-customer") + s.NotNil(customerEntity) + + _ = lo.Must(s.FeatureService.CreateFeature(ctx, feature.CreateFeatureInputs{ + Namespace: namespace, + Name: featureKey, + Key: featureKey, + })) + }) + + var ( + gatheringInvoice billing.GatheringInvoice + standardInvoice billing.StandardInvoice + ) + + s.Run("Given a gathering invoice with two lines", func() { + line1PeriodStart := lo.Must(time.Parse(time.RFC3339, "2026-01-01T00:00:00Z")) + line1PeriodEnd := line1PeriodStart.Add(time.Hour) + + line2PeriodStart := lo.Must(time.Parse(time.RFC3339, "2026-01-01T01:00:00Z")) + line2PeriodEnd := line2PeriodStart.Add(time.Hour) + + createdPendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ + Customer: customerEntity.GetID(), + Currency: currencyx.Code(currency.USD), + Lines: []billing.GatheringLine{ + { + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: namespace, + Name: "Test line 1", + }), + ServicePeriod: timeutil.ClosedPeriod{From: line1PeriodStart, To: line1PeriodEnd}, + InvoiceAt: line1PeriodStart, + ManagedBy: billing.ManuallyManagedLine, + Currency: currencyx.Code(currency.USD), + RateCardDiscounts: billing.Discounts{ + Percentage: &billing.PercentageDiscount{ + PercentageDiscount: productcatalog.PercentageDiscount{ + Percentage: models.NewPercentage(10), + }, + }, + }, + FeatureKey: featureKey, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: productcatalog.InAdvancePaymentTerm, + })), + }, + }, + { + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: namespace, + Name: "Test line 2", + }), + ServicePeriod: timeutil.ClosedPeriod{From: line2PeriodStart, To: line2PeriodEnd}, + InvoiceAt: line2PeriodStart, + ManagedBy: billing.ManuallyManagedLine, + Currency: currencyx.Code(currency.USD), + RateCardDiscounts: billing.Discounts{ + Percentage: &billing.PercentageDiscount{ + PercentageDiscount: productcatalog.PercentageDiscount{ + Percentage: models.NewPercentage(10), + }, + }, + }, + FeatureKey: featureKey, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: productcatalog.InAdvancePaymentTerm, + })), + }, + }, + }, + }) + s.NoError(err) + + standardInvoices, err := s.BillingService.InvoicePendingLines(ctx, billing.InvoicePendingLinesInput{ + Customer: customerEntity.GetID(), + }) + s.NoError(err) + s.Len(standardInvoices, 1) + + standardInvoice = standardInvoices[0] + + gatheringInvoice, err = s.BillingAdapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: createdPendingLines.Invoice.InvoiceID(), + Expand: billing.GatheringInvoiceExpands{billing.GatheringInvoiceExpandLines}, + }) + s.NoError(err) + s.NotNil(gatheringInvoice) + s.Len(gatheringInvoice.Lines.OrEmpty(), 1) + }) + + s.Run("When we try to hard delete a line from the standard invoice then we fail", func() { + err := s.BillingAdapter.HardDeleteGatheringInvoiceLines(ctx, standardInvoice.InvoiceID(), []string{standardInvoice.Lines.OrEmpty()[0].ID}) + s.Error(err) + }) + + s.Run("When we try to delete a line that does not belong to the invoice then we fail", func() { + err := s.BillingAdapter.HardDeleteGatheringInvoiceLines(ctx, gatheringInvoice.InvoiceID(), []string{standardInvoice.Lines.OrEmpty()[0].ID}) + s.Error(err) + }) +} diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 93e091bce3..12ff127a7b 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -470,10 +470,10 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { // Then we should have the items created require.NoError(s.T(), err) require.Len(s.T(), res.Lines, 2) - line1ID := res.Lines[0].ID - line2ID := res.Lines[1].ID - require.NotEmpty(s.T(), line1ID) - require.NotEmpty(s.T(), line2ID) + gatheringLine1 := res.Lines[0] + gatheringLine2 := res.Lines[1] + require.NotEmpty(s.T(), gatheringLine1.ID) + require.NotEmpty(s.T(), gatheringLine2.ID) // Expect that a single gathering invoice has been created require.Equal(s.T(), res.Lines[0].InvoiceID, res.Lines[1].InvoiceID) @@ -535,7 +535,9 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { // Then we should have item1 added to the invoice require.Len(s.T(), invoice[0].Lines.MustGet(), 1) - require.Equal(s.T(), line1ID, invoice[0].Lines.MustGet()[0].ID) + stdInvoiceLine1 := invoice[0].Lines.MustGet()[0] + // The standard invoice line's id must match the gathering line's id + require.Equal(s.T(), stdInvoiceLine1.ID, gatheringLine1.ID) // Then we expect that the gathering invoice is still present, with item2 gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ @@ -545,7 +547,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { require.NoError(s.T(), err) require.Nil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be present") require.Len(s.T(), gatheringInvoice.Lines.MustGet(), 1) - require.Equal(s.T(), line2ID, gatheringInvoice.Lines.MustGet()[0].ID) + require.Equal(s.T(), gatheringLine2.ID, gatheringInvoice.Lines.MustGet()[0].ID) // We expect the freshly generated invoice to be in waiting for auto approval state require.Equal(s.T(), billing.StandardInvoiceStatusDraftWaitingAutoApproval, invoice[0].Status) @@ -567,7 +569,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { ID: customerEntity.ID, Namespace: customerEntity.Namespace, }, - IncludePendingLines: mo.Some([]string{line2ID}), + IncludePendingLines: mo.Some([]string{gatheringLine2.ID}), AsOf: lo.ToPtr(line1IssueAt.Add(time.Minute)), }) @@ -582,7 +584,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { ID: customerEntity.ID, Namespace: customerEntity.Namespace, }, - IncludePendingLines: mo.Some([]string{line2ID}), + IncludePendingLines: mo.Some([]string{gatheringLine2.ID}), AsOf: lo.ToPtr(now), }) @@ -592,7 +594,9 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { // Then we should have item2 added to the invoice require.Len(s.T(), invoice[0].Lines.MustGet(), 1) - require.Equal(s.T(), line2ID, invoice[0].Lines.MustGet()[0].ID) + stdInvoiceLine2 := invoice[0].Lines.MustGet()[0] + // The standard invoice line's id must match the gathering line's id + require.Equal(s.T(), stdInvoiceLine2.ID, gatheringLine2.ID) // Then we expect that the gathering invoice is deleted and empty gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ @@ -600,7 +604,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { Expand: billing.InvoiceExpandAll, }) require.NoError(s.T(), err) - require.NotNil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be present") + require.NotNil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be deleted") require.Len(s.T(), gatheringInvoice.Lines.MustGet(), 0, "deleted gathering invoice is empty") }) @@ -1011,7 +1015,6 @@ func (s *InvoicingTestSuite) TestInvoicingFlowErrorHandling() { // Given that the app will return a validation error mockApp.OnValidateStandardInvoice(billing.NewValidationError("test1", "validation error")) calcMock.OnCalculate(nil) - calcMock.OnCalculateLegacyGatheringInvoice(nil) calcMock.OnCalculateGatheringInvoice(nil) // When we create a draft invoice @@ -1244,7 +1247,6 @@ func (s *InvoicingTestSuite) TestInvoicingFlowErrorHandling() { mockApp.OnFinalizeStandardInvoice(nil) calcMock.OnCalculate(nil) calcMock.OnCalculateGatheringInvoice(nil) - calcMock.OnCalculateLegacyGatheringInvoice(nil) // When we create a draft invoice invoice := s.CreateDraftInvoice(s.T(), ctx, DraftInvoiceInput{ @@ -2078,9 +2080,13 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { for _, line := range i.Lines.OrEmpty() { for _, detailedLine := range line.DetailedLines { + s.NotEmpty(detailedLine.ID, "detailed line id is empty") + out.AddLineExternalID(detailedLine.ID, "final_upsert_"+detailedLine.ID) for _, discount := range detailedLine.AmountDiscounts { + s.NotEmpty(discount.GetID(), "discount id is empty") + out.AddLineDiscountExternalID(discount.GetID(), "final_upsert_"+discount.GetID()) } } @@ -2126,9 +2132,13 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { for _, line := range i.Lines.OrEmpty() { for _, detailedLine := range line.DetailedLines { + s.NotEmpty(detailedLine.ID, "detailed line id is empty") + out.AddLineExternalID(detailedLine.ID, "final_upsert_"+detailedLine.ID) for _, discount := range detailedLine.AmountDiscounts { + s.NotEmpty(discount.GetID(), "discount id is empty") + out.AddLineDiscountExternalID(discount.GetID(), "final_upsert_"+discount.GetID()) } }