Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions openmeter/billing/adapter/gatheringinvoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (a *adapter) CreateGatheringInvoice(ctx context.Context, input billing.Crea
// Let's add required edges for mapping
newInvoice.Edges.BillingWorkflowConfig = clonedWorkflowConfig

return tx.mapGatheringInvoiceFromDB(ctx, newInvoice, billing.GatheringInvoiceExpands{})
return tx.mapGatheringInvoiceFromDB(newInvoice, billing.GatheringInvoiceExpands{})
})
}

Expand Down Expand Up @@ -257,7 +257,7 @@ func (a *adapter) ListGatheringInvoices(ctx context.Context, input billing.ListG

result := make([]billing.GatheringInvoice, 0, len(paged.Items))
for _, invoice := range paged.Items {
mapped, err := tx.mapGatheringInvoiceFromDB(ctx, invoice, input.Expand)
mapped, err := tx.mapGatheringInvoiceFromDB(invoice, input.Expand)
if err != nil {
return response, err
}
Expand Down Expand Up @@ -337,8 +337,8 @@ func (a *adapter) expandGatheringInvoiceLines(q *db.BillingInvoiceQuery, expand
q = q.Where(billinginvoiceline.DeletedAtIsNil())
}
q = q.
Where(billinginvoiceline.TypeEQ(billing.InvoiceLineTypeUsageBased)). // Only include usage based lines (there are some detailed lines existing for gathering invoices)
Where(billinginvoiceline.ParentLineIDIsNil()) // Only include top-level lines (there are some detailed lines existing for gathering invoices)
Where(billinginvoiceline.TypeEQ(billing.InvoiceLineAdapterTypeUsageBased)). // Only include usage based lines (there are some detailed lines existing for gathering invoices)
Where(billinginvoiceline.ParentLineIDIsNil()) // Only include top-level lines (there are some detailed lines existing for gathering invoices)
q.WithUsageBasedLine()
})
}
Expand Down Expand Up @@ -368,11 +368,11 @@ func (a *adapter) GetGatheringInvoiceById(ctx context.Context, input billing.Get
return billing.GatheringInvoice{}, err
}

return tx.mapGatheringInvoiceFromDB(ctx, invoice, input.Expand)
return tx.mapGatheringInvoiceFromDB(invoice, input.Expand)
})
}

func (a *adapter) mapGatheringInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoice, expand billing.GatheringInvoiceExpands) (billing.GatheringInvoice, error) {
func (a *adapter) mapGatheringInvoiceFromDB(invoice *db.BillingInvoice, expand billing.GatheringInvoiceExpands) (billing.GatheringInvoice, error) {
if invoice.Status != billing.StandardInvoiceStatusGathering {
return billing.GatheringInvoice{}, fmt.Errorf("invoice is not a gathering invoice [id=%s]", invoice.ID)
}
Expand Down
4 changes: 2 additions & 2 deletions openmeter/billing/adapter/gatheringlines.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (a *adapter) updateGatheringLines(ctx context.Context, lines billing.Gather
SetInvoiceAt(line.InvoiceAt.In(time.UTC)).
SetStatus(billing.InvoiceLineStatusValid).
SetManagedBy(line.ManagedBy).
SetType(billing.InvoiceLineTypeUsageBased).
SetType(billing.InvoiceLineAdapterTypeUsageBased).
SetName(line.Name).
SetNillableDescription(line.Description).
SetCurrency(line.Currency).
Expand Down Expand Up @@ -260,7 +260,7 @@ func (a *adapter) mapGatheringInvoiceLinesFromDB(schemaLevel int, dbLines []*db.
}

func (a *adapter) mapGatheringInvoiceLineFromDB(schemaLevel int, dbLine *db.BillingInvoiceLine) (billing.GatheringLine, error) {
if dbLine.Type != billing.InvoiceLineTypeUsageBased {
if dbLine.Type != billing.InvoiceLineAdapterTypeUsageBased {
return billing.GatheringLine{}, fmt.Errorf("only usage based lines can be gathering invoice lines [line_id=%s]", dbLine.ID)
}

Expand Down
4 changes: 2 additions & 2 deletions openmeter/billing/adapter/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ func (a *adapter) GetInvoiceOwnership(ctx context.Context, in billing.GetInvoice
})
}

func (a *adapter) mapStandardInvoiceBaseFromDB(ctx context.Context, invoice *db.BillingInvoice) billing.StandardInvoiceBase {
func (a *adapter) mapStandardInvoiceBaseFromDB(invoice *db.BillingInvoice) billing.StandardInvoiceBase {
return billing.StandardInvoiceBase{
ID: invoice.ID,
Namespace: invoice.Namespace,
Expand Down Expand Up @@ -697,7 +697,7 @@ func (a *adapter) mapStandardInvoiceBaseFromDB(ctx context.Context, invoice *db.
}

func (a *adapter) mapStandardInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoice, expand billing.InvoiceExpand) (billing.StandardInvoice, error) {
base := a.mapStandardInvoiceBaseFromDB(ctx, invoice)
base := a.mapStandardInvoiceBaseFromDB(invoice)

res := billing.StandardInvoice{
StandardInvoiceBase: base,
Expand Down
71 changes: 57 additions & 14 deletions openmeter/billing/adapter/invoicelinesplitgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ func (a *adapter) GetSplitLineGroup(ctx context.Context, input billing.GetSplitL
).
WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) {
a.expandLineItems(q)
q.WithBillingInvoice()
q.WithBillingInvoice(func(q *db.BillingInvoiceQuery) {
q.WithBillingWorkflowConfig()
})
}).
First(ctx)
if err != nil {
Expand Down Expand Up @@ -209,17 +211,7 @@ func (a *adapter) mapSplitLineHierarchyFromDB(ctx context.Context, dbSplitLineGr
return empty, err
}

mappedLines, err := slicesx.MapWithErr(dbSplitLineGroup.Edges.BillingInvoiceLines, func(dbLine *db.BillingInvoiceLine) (billing.LineWithInvoiceHeader, error) {
line, err := a.mapStandardInvoiceLineWithoutReferences(dbLine)
if err != nil {
return billing.LineWithInvoiceHeader{}, err
}

return billing.LineWithInvoiceHeader{
Line: line,
Invoice: a.mapStandardInvoiceBaseFromDB(ctx, dbLine.Edges.BillingInvoice),
}, nil
})
mappedLines, err := a.mapSplitLineHierarchyLinesFromDB(ctx, dbSplitLineGroup.Edges.BillingInvoiceLines)
if err != nil {
return empty, err
}
Expand All @@ -230,6 +222,55 @@ func (a *adapter) mapSplitLineHierarchyFromDB(ctx context.Context, dbSplitLineGr
}, nil
}

func (a *adapter) mapSplitLineHierarchyLinesFromDB(ctx context.Context, dbLines []*db.BillingInvoiceLine) ([]billing.LineWithInvoiceHeader, error) {
return slicesx.MapWithErr(dbLines, func(dbLine *db.BillingInvoiceLine) (billing.LineWithInvoiceHeader, error) {
if dbLine.Edges.BillingInvoice == nil {
return billing.LineWithInvoiceHeader{}, fmt.Errorf("billing invoice must be expanded when mapping split line hierarchy lines [id=%s]", dbLine.ID)
}

switch dbLine.Edges.BillingInvoice.Status {
case billing.StandardInvoiceStatusGathering:
return a.mapSplitLineHierarchyGatheringLineFromDB(dbLine)
default:
return a.mapSplitLineHierarchyStandardLineFromDB(ctx, dbLine)
}
})
}
Comment thread
turip marked this conversation as resolved.

func (a *adapter) mapSplitLineHierarchyStandardLineFromDB(ctx context.Context, dbLine *db.BillingInvoiceLine) (billing.LineWithInvoiceHeader, error) {
line, err := a.mapStandardInvoiceLineWithoutReferences(dbLine)
if err != nil {
return billing.LineWithInvoiceHeader{}, err
}

invoice, err := a.mapStandardInvoiceFromDB(ctx, dbLine.Edges.BillingInvoice, billing.InvoiceExpand{})
if err != nil {
return billing.LineWithInvoiceHeader{}, err
}

return billing.NewLineWithInvoiceHeader(billing.StandardLineWithInvoiceHeader{
Line: line,
Invoice: invoice,
}), nil
}

func (a *adapter) mapSplitLineHierarchyGatheringLineFromDB(dbLine *db.BillingInvoiceLine) (billing.LineWithInvoiceHeader, error) {
line, err := a.mapGatheringInvoiceLineFromDB(dbLine.Edges.BillingInvoice.SchemaLevel, dbLine)
if err != nil {
return billing.LineWithInvoiceHeader{}, err
}

invoice, err := a.mapGatheringInvoiceFromDB(dbLine.Edges.BillingInvoice, billing.GatheringInvoiceExpands{})
if err != nil {
return billing.LineWithInvoiceHeader{}, err
}

return billing.NewLineWithInvoiceHeader(billing.GatheringLineWithInvoiceHeader{
Line: line,
Invoice: invoice,
}), nil
}

// expandSplitLineHierarchy expands the given lines with their progressive line hierarchy
// This is done by fetching all the lines that are children of the given lines parent lines and then building
// the hierarchy.
Expand Down Expand Up @@ -257,7 +298,7 @@ func (a *adapter) expandSplitLineHierarchy(ctx context.Context, namespace string
hierarchyByLineID := map[string]*billing.SplitLineHierarchy{}
for _, splitLineGroup := range splitLineGroups {
for _, line := range splitLineGroup.Lines {
hierarchyByLineID[line.Line.ID] = &splitLineGroup
hierarchyByLineID[line.Line.GetID()] = &splitLineGroup
}
}

Expand Down Expand Up @@ -285,7 +326,9 @@ func (a *adapter) fetchAllSplitLineGroups(ctx context.Context, namespace string,
).
WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) {
a.expandLineItems(q)
q.WithBillingInvoice() // TODO[later]: we can consider loading this in a separate query, might be more efficient
q.WithBillingInvoice(func(q *db.BillingInvoiceQuery) {
q.WithBillingWorkflowConfig()
}) // TODO[later]: we can consider loading this in a separate query, might be more efficient
})

dbSplitLineGroups, err := query.All(ctx)
Expand Down
37 changes: 20 additions & 17 deletions openmeter/billing/adapter/stdinvoicelinediff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("existing line hierarchy, no changes", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

lineDiff, err := diffInvoiceLines(base)
require.NoError(t, err)
Expand All @@ -100,7 +100,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("existing line hierarchy, one child line is deleted", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

require.True(t, removeDetailedLineByID(base[1], "2.1"), "child line 2.1 should be removed")

Expand All @@ -120,7 +120,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("existing line hierarchy, one child line is changed", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

getDetailedLineByID(base[1], "2.1").Quantity = alpacadecimal.NewFromFloat(10)

Expand All @@ -137,7 +137,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("existing line hierarchy, one parent line is changed", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

base[1].UsageBased.Quantity = lo.ToPtr(alpacadecimal.NewFromFloat(10))

Expand All @@ -151,11 +151,11 @@ func TestInvoiceLineDiffing(t *testing.T) {
}, lineDiff)
})

t.Run("a line is updated in the existing line hieararchy", func(t *testing.T) {
t.Run("a line is updated in the existing line hierarchy", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

// ID change should tirgger a delete/update
// ID change should trigger a delete/update
changedLine := getDetailedLineByID(base[1], "2.1")
changedLine.ID = ""
changedLine.Description = lo.ToPtr("2.3")
Expand All @@ -182,7 +182,7 @@ func TestInvoiceLineDiffing(t *testing.T) {
// Discount handling
t.Run("existing line hierarchy, one discount is deleted", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

getDetailedLineByID(base[1], "2.1").AmountDiscounts = nil

Expand All @@ -200,7 +200,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("existing line hierarchy, one discount is changed", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

getDetailedLineByID(base[1], "2.1").AmountDiscounts[0].Amount = alpacadecimal.NewFromFloat(20)

Expand All @@ -218,7 +218,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("existing line hierarchy, one discount is added/old one is removed", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

discounts := getDetailedLineByID(base[1], "2.1").AmountDiscounts

Expand All @@ -241,7 +241,7 @@ func TestInvoiceLineDiffing(t *testing.T) {
// DeletedAt handling
t.Run("support for detailed lines being deleted using deletedAt", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

getDetailedLineByID(base[1], "2.1").DeletedAt = lo.ToPtr(clock.Now())

Expand All @@ -261,7 +261,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("support for parent lines with children being deleted using deletedAt", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

base[1].DeletedAt = lo.ToPtr(clock.Now())

Expand All @@ -283,7 +283,7 @@ func TestInvoiceLineDiffing(t *testing.T) {

t.Run("support for parent lines without children being deleted using deletedAt", func(t *testing.T) {
base := cloneLines(template)
snapshotAsDBState(base)
snapshotAsDBState(t, base)

base[0].DeletedAt = lo.ToPtr(clock.Now())

Expand All @@ -300,7 +300,7 @@ func TestInvoiceLineDiffing(t *testing.T) {
t.Run("deleted, changed lines are not triggering updates", func(t *testing.T) {
base := cloneLines(template)
base[0].DeletedAt = lo.ToPtr(clock.Now())
snapshotAsDBState(base)
snapshotAsDBState(t, base)
base[0].Description = lo.ToPtr("test")

lineDiff, err := diffInvoiceLines(base)
Expand Down Expand Up @@ -362,14 +362,17 @@ func requireDiff(t *testing.T, expected lineDiffExpectation, actual invoiceLineD

func cloneLines(lines []*billing.StandardLine) []*billing.StandardLine {
return lo.Map(lines, func(line *billing.StandardLine, _ int) *billing.StandardLine {
return line.Clone()
return lo.Must(line.Clone())
})
}

// snapshotAsDBState saves the current state of the lines as if they were in the database
func snapshotAsDBState(lines []*billing.StandardLine) {
func snapshotAsDBState(t *testing.T, lines []*billing.StandardLine) {
t.Helper()

for _, line := range lines {
line.SaveDBSnapshot()
err := line.SaveDBSnapshot()
require.NoError(t, err)
}
}

Expand Down
8 changes: 5 additions & 3 deletions openmeter/billing/adapter/stdinvoicelinemapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ func (a *adapter) mapStandardInvoiceLinesFromDB(schemaLevelByInvoiceID map[strin
}
}

line.SaveDBSnapshot()
if err := line.SaveDBSnapshot(); err != nil {
return nil, fmt.Errorf("saving DB snapshot [id=%s]: %w", line.GetID(), err)
}

lines = append(lines, line)
}
Expand Down Expand Up @@ -109,8 +111,8 @@ func (a *adapter) mapStandardInvoiceLineWithoutReferences(dbLine *db.BillingInvo
}
}

if dbLine.Type != billing.InvoiceLineTypeUsageBased {
return invoiceLine, fmt.Errorf("only usage based lines can be top level lines [line_id=%s]", dbLine.ID)
if dbLine.Type != billing.InvoiceLineAdapterTypeUsageBased {
return nil, fmt.Errorf("only usage based lines can be top level lines [line_id=%s]", dbLine.ID)
}

ubpLine := dbLine.Edges.UsageBasedLine
Expand Down
Loading
Loading