diff --git a/openmeter/billing/adapter/gatheringinvoice.go b/openmeter/billing/adapter/gatheringinvoice.go index ffa71a98ea..52fac8dc6f 100644 --- a/openmeter/billing/adapter/gatheringinvoice.go +++ b/openmeter/billing/adapter/gatheringinvoice.go @@ -422,6 +422,8 @@ func (a *adapter) mapGatheringInvoiceFromDB(ctx context.Context, invoice *db.Bil NextCollectionAt: invoice.CollectionAt.In(time.UTC), SchemaLevel: invoice.SchemaLevel, }, + + Expands: expand, } if expand.Has(billing.GatheringInvoiceExpandLines) { diff --git a/openmeter/billing/gatheringinvoice.go b/openmeter/billing/gatheringinvoice.go index 7e9c19b8d1..00fdb9ab7d 100644 --- a/openmeter/billing/gatheringinvoice.go +++ b/openmeter/billing/gatheringinvoice.go @@ -75,7 +75,7 @@ type GatheringInvoice struct { // these lines too. AvailableActions *GatheringInvoiceAvailableActions `json:"availableActions,omitempty"` - SplitLineHierarchy *SplitLineHierarchy `json:"splitLineHierarchy,omitempty"` + Expands GatheringInvoiceExpands `json:"expands,omitempty"` } func (g GatheringInvoice) WithoutDBState() (GatheringInvoice, error) { @@ -158,6 +158,7 @@ func (g GatheringInvoice) Clone() (GatheringInvoice, error) { } clone.Lines = clonedLines + clone.Expands = g.Expands.Clone() return clone, nil } @@ -681,6 +682,16 @@ func (g GatheringLine) AsNewStandardLine(invoiceID string) (*StandardLine, error subscription = g.Subscription.Clone() } + var splitLineHierarchy *SplitLineHierarchy + if g.SplitLineHierarchy != nil { + clonedSHierarchy, err := g.SplitLineHierarchy.Clone() + if err != nil { + return nil, fmt.Errorf("cloning split line hierarchy: %w", err) + } + + splitLineHierarchy = lo.ToPtr(clonedSHierarchy) + } + convertedLine := &StandardLine{ StandardLineBase: StandardLineBase{ ManagedResource: g.ManagedResource, @@ -708,6 +719,8 @@ func (g GatheringLine) AsNewStandardLine(invoiceID string) (*StandardLine, error FeatureKey: g.FeatureKey, }, + SplitLineHierarchy: splitLineHierarchy, + DBState: nil, // We don't want to reuse the state from the gathering line (so let's make it explicit) } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index e7ce556e0f..12a482f90c 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -166,13 +166,35 @@ func (s *Service) calculateGatheringInvoiceAsStandardInvoice(ctx context.Context return nil, fmt.Errorf("creating standard invoice from gathering invoice: %w", err) } - wasLinesAbsent := invoice.Lines.IsAbsent() + wasLinesPresent := invoice.Lines.IsPresent() - if wasLinesAbsent { + shouldReloadLines := !wasLinesPresent + + if !invoice.Expands.Has(billing.GatheringInvoiceExpandSplitLineHierarchy) && wasLinesPresent { + // If the invoice has lines and the splitline hierarchy is not expanded, we need to check if there are any progressive billed lines + // and reload the invoice as price calculations depend on the presence of the split line hierarchy. + + progressiveBilledLineCount := lo.CountBy(invoice.Lines.OrEmpty(), func(line billing.GatheringLine) bool { + if line.DeletedAt != nil { + return false + } + + return line.SplitLineGroupID != nil + }) + + if progressiveBilledLineCount > 0 { + shouldReloadLines = true + } + } + + // If the gathering invoice has no splitline hierarchy expanded we need to reload the invoice so that the price calculations can + // properly proceed. + if shouldReloadLines { // Let's reload the whole invoice with lines expanded invoice, err = s.adapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ Invoice: in.Invoice.GetInvoiceID(), - Expand: billing.GatheringInvoiceExpandAll, + Expand: billing.GatheringInvoiceExpandAll. + With(billing.GatheringInvoiceExpandSplitLineHierarchy), }) if err != nil { return nil, fmt.Errorf("fetching gathering invoice: %w", err) @@ -223,7 +245,7 @@ func (s *Service) calculateGatheringInvoiceAsStandardInvoice(ctx context.Context return nil, fmt.Errorf("calculating invoice: %w", err) } - if wasLinesAbsent { + if !wasLinesPresent { // If the original user intent was to not to receive the lines, let's not send them out.Lines = billing.StandardInvoiceLines{} } else { diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 21e0bbdcc3..c22e75de75 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -283,6 +283,8 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { }, Lines: billing.NewGatheringInvoiceLines([]billing.GatheringLine{expectedUSDLine}), + + Expands: []billing.GatheringInvoiceExpand{billing.GatheringInvoiceExpandLines}, } s.NoError(invoicecalc.GatheringInvoiceCollectionAt(&expectedInvoice))