diff --git a/api/spec/src/billing/invoices.tsp b/api/spec/src/billing/invoices.tsp index fbb1c9b5c1..2ce91d242a 100644 --- a/api/spec/src/billing/invoices.tsp +++ b/api/spec/src/billing/invoices.tsp @@ -706,6 +706,7 @@ model InvoicePendingLineCreateResponse { @visibility(Lifecycle.Read) lines: InvoiceLine[]; + // TODO: For the V3 api let's not return the invoice /** * The invoice containing the created lines. */ diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 2743010355..61db1ad9b0 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -17,6 +17,7 @@ type Adapter interface { InvoiceLineAdapter InvoiceSplitLineGroupAdapter InvoiceAdapter + GatheringInvoiceAdapter SequenceAdapter InvoiceAppAdapter CustomerSynchronizationAdapter @@ -74,6 +75,14 @@ type InvoiceAdapter interface { GetInvoiceOwnership(ctx context.Context, input GetInvoiceOwnershipAdapterInput) (GetOwnershipAdapterResponse, error) } +type GatheringInvoiceAdapter interface { + CreateGatheringInvoice(ctx context.Context, input CreateGatheringInvoiceAdapterInput) (GatheringInvoice, error) + UpdateGatheringInvoice(ctx context.Context, input UpdateGatheringInvoiceAdapterInput) error + DeleteGatheringInvoice(ctx context.Context, input DeleteGatheringInvoiceAdapterInput) error + GetGatheringInvoiceById(ctx context.Context, input GetGatheringInvoiceByIdInput) (GatheringInvoice, error) + ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (pagination.Result[GatheringInvoice], error) +} + type InvoiceSplitLineGroupAdapter interface { CreateSplitLineGroup(ctx context.Context, input CreateSplitLineGroupAdapterInput) (SplitLineGroup, error) UpdateSplitLineGroup(ctx context.Context, input UpdateSplitLineGroupInput) (SplitLineGroup, error) diff --git a/openmeter/billing/adapter/gatheringinvoice.go b/openmeter/billing/adapter/gatheringinvoice.go new file mode 100644 index 0000000000..c12ed65d7c --- /dev/null +++ b/openmeter/billing/adapter/gatheringinvoice.go @@ -0,0 +1,415 @@ +package billingadapter + +import ( + "context" + "fmt" + "time" + + "github.com/alpacahq/alpacadecimal" + + "github.com/openmeterio/openmeter/api" + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/convert" + "github.com/openmeterio/openmeter/pkg/framework/entutils" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/sortx" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +var _ billing.GatheringInvoiceAdapter = (*adapter)(nil) + +func (a *adapter) CreateGatheringInvoice(ctx context.Context, input billing.CreateGatheringInvoiceAdapterInput) (billing.GatheringInvoice, error) { + if err := input.Validate(); err != nil { + return billing.GatheringInvoice{}, err + } + + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.GatheringInvoice, error) { + customer := input.Customer + supplier := input.MergedProfile.Supplier + + // Clone the workflow config + clonedWorkflowConfig, err := tx.createWorkflowConfig(ctx, input.Namespace, input.MergedProfile.WorkflowConfig) + if err != nil { + return billing.GatheringInvoice{}, fmt.Errorf("clone workflow config: %w", err) + } + + currentSchemaLevel, err := tx.GetInvoiceDefaultSchemaLevel(ctx) + if err != nil { + return billing.GatheringInvoice{}, fmt.Errorf("get invoice write schema level: %w", err) + } + + createMut := tx.db.BillingInvoice.Create(). + SetNamespace(input.Namespace). + SetMetadata(input.Metadata). + SetCurrency(input.Currency). + SetStatus(billing.StandardInvoiceStatusGathering). + SetSourceBillingProfileID(input.MergedProfile.ID). + SetType(billing.InvoiceTypeStandard). // TODO: Migrate to GatheringInvoiceType once we have the type in the database + SetNumber(input.Number). + SetNillableDescription(input.Description). + SetNillableCollectionAt(input.NextCollectionAt). + SetSchemaLevel(currentSchemaLevel). + // Customer snapshot about usage attribution fields + SetCustomerID(input.Customer.ID). + // TODO: Remove all below this line once we have separate tables for gathering invoices + SetBillingWorkflowConfigID(clonedWorkflowConfig.ID). + SetTaxAppID(input.MergedProfile.Apps.Tax.GetID().ID). + SetInvoicingAppID(input.MergedProfile.Apps.Invoicing.GetID().ID). + SetPaymentAppID(input.MergedProfile.Apps.Payment.GetID().ID). + // Totals + SetAmount(alpacadecimal.Zero). + SetChargesTotal(alpacadecimal.Zero). + SetDiscountsTotal(alpacadecimal.Zero). + SetTaxesTotal(alpacadecimal.Zero). + SetTaxesExclusiveTotal(alpacadecimal.Zero). + SetTaxesInclusiveTotal(alpacadecimal.Zero). + SetTotal(alpacadecimal.Zero). + // Supplier contacts + SetSupplierName(supplier.Name) + + // Customer usage attribution + if usageAttr := mapCustomerUsageAttributionToDB(input.Customer); usageAttr != nil { + createMut = createMut.SetCustomerUsageAttribution(usageAttr) + } + createMut = createMut. + SetCustomerName(customer.Name) + + newInvoice, err := createMut.Save(ctx) + if err != nil { + return billing.GatheringInvoice{}, err + } + + // Let's add required edges for mapping + newInvoice.Edges.BillingWorkflowConfig = clonedWorkflowConfig + + return tx.mapGatheringInvoiceFromDB(ctx, newInvoice, billing.GatheringInvoiceExpands{}) + }) +} + +func (a *adapter) UpdateGatheringInvoice(ctx context.Context, in billing.GatheringInvoice) error { + if err := in.Validate(); err != nil { + return fmt.Errorf("validating gathering invoice: %w", err) + } + + return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { + existingInvoice, err := tx.db.BillingInvoice.Query(). + Where(billinginvoice.ID(in.ID)). + Where(billinginvoice.Namespace(in.Namespace)). + Only(ctx) + if err != nil { + return err + } + + if err := tx.validateUpdateGatheringInvoiceRequest(in, existingInvoice); err != nil { + return err + } + + updateQuery := tx.db.BillingInvoice.UpdateOneID(in.ID). + Where(billinginvoice.Namespace(in.Namespace)). + SetMetadata(in.Metadata). + // Currency is immutable + SetStatus(billing.StandardInvoiceStatusGathering). + ClearStatusDetailsCache(). + // Type is immutable + SetNumber(in.Number). + SetOrClearDescription(in.Description). + ClearDueAt(). + ClearPaymentProcessingEnteredAt(). + ClearDraftUntil(). + ClearIssuedAt(). + ClearDeletedAt(). + ClearSentToCustomerAt(). + ClearQuantitySnapshotedAt(). + // Totals + SetAmount(alpacadecimal.Zero). + SetChargesTotal(alpacadecimal.Zero). + SetDiscountsTotal(alpacadecimal.Zero). + SetTaxesTotal(alpacadecimal.Zero). + SetTaxesExclusiveTotal(alpacadecimal.Zero). + SetTaxesInclusiveTotal(alpacadecimal.Zero). + SetTotal(alpacadecimal.Zero) + + if !in.NextCollectionAt.IsZero() { + updateQuery = updateQuery.SetCollectionAt(in.NextCollectionAt.In(time.UTC)) + } else { + updateQuery = updateQuery.ClearCollectionAt() + } + + updateQuery = updateQuery. + SetPeriodStart(in.ServicePeriod.From.In(time.UTC)). + SetPeriodEnd(in.ServicePeriod.To.In(time.UTC)) + + // Supplier + updateQuery = updateQuery. + SetSupplierName("UNSET"). // Hack until we split the invoices table + SetSupplierAddressCountry("XX"). // Hack until we split the invoices table + ClearSupplierAddressPostalCode(). + ClearSupplierAddressCity(). + ClearSupplierAddressState(). + ClearSupplierAddressLine1(). + ClearSupplierAddressLine2(). + ClearSupplierAddressPhoneNumber() + + // Customer + updateQuery = updateQuery. + // CustomerID is immutable + SetCustomerName("UNSET"). // hack until we split the invoices table + ClearCustomerKey() + + updateQuery = updateQuery. + ClearCustomerAddressCountry(). + ClearCustomerAddressPostalCode(). + ClearCustomerAddressCity(). + ClearCustomerAddressState(). + ClearCustomerAddressLine1(). + ClearCustomerAddressLine2(). + ClearCustomerAddressPhoneNumber() + + // ExternalIDs + updateQuery = updateQuery. + ClearInvoicingAppExternalID(). + ClearPaymentAppExternalID() + + _, err = updateQuery.Save(ctx) + if err != nil { + return err + } + + if in.Lines.IsPresent() { + err := tx.updateGatheringLines(ctx, in.Lines.OrEmpty()) + if err != nil { + return err + } + } + + return nil + }) +} + +func (a *adapter) ListGatheringInvoices(ctx context.Context, input billing.ListGatheringInvoicesInput) (pagination.Result[billing.GatheringInvoice], error) { + if err := input.Validate(); err != nil { + return pagination.Result[billing.GatheringInvoice]{}, err + } + + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (pagination.Result[billing.GatheringInvoice], error) { + query := tx.db.BillingInvoice.Query(). + Where(billinginvoice.NamespaceIn(input.Namespaces...)). + Where(billinginvoice.StatusEQ(billing.StandardInvoiceStatusGathering)) + + if len(input.Customers) > 0 { + query = query.Where(billinginvoice.CustomerIDIn(input.Customers...)) + } + + if len(input.Currencies) > 0 { + query = query.Where(billinginvoice.CurrencyIn(input.Currencies...)) + } + + order := entutils.GetOrdering(sortx.OrderDefault) + if !input.Order.IsDefaultValue() { + order = entutils.GetOrdering(input.Order) + } + + if input.Expand.Has(billing.GatheringInvoiceExpandLines) { + query = query.WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) { + q.WithUsageBasedLine() + }) + } + + switch input.OrderBy { + case api.InvoiceOrderByCustomerName: + query = query.Order(billinginvoice.ByCustomerName(order...)) + case api.InvoiceOrderByIssuedAt: + query = query.Order(billinginvoice.ByIssuedAt(order...)) + case api.InvoiceOrderByPeriodStart: + query = query.Order(billinginvoice.ByPeriodStart(order...)) + case api.InvoiceOrderByStatus: + query = query.Order(billinginvoice.ByStatus(order...)) + case api.InvoiceOrderByUpdatedAt: + query = query.Order(billinginvoice.ByUpdatedAt(order...)) + case api.InvoiceOrderByCreatedAt: + fallthrough + default: + query = query.Order(billinginvoice.ByCreatedAt(order...)) + } + + if !input.IncludeDeleted { + query = query.Where(billinginvoice.DeletedAtIsNil()) + } + + response := pagination.Result[billing.GatheringInvoice]{ + Page: input.Page, + } + + paged, err := query.Paginate(ctx, input.Page) + if err != nil { + return response, err + } + + result := make([]billing.GatheringInvoice, 0, len(paged.Items)) + for _, invoice := range paged.Items { + mapped, err := tx.mapGatheringInvoiceFromDB(ctx, invoice, input.Expand) + if err != nil { + return response, err + } + + result = append(result, mapped) + } + + response.TotalCount = paged.TotalCount + response.Items = result + + return response, nil + }) +} + +func (a *adapter) validateUpdateGatheringInvoiceRequest(req billing.GatheringInvoice, existing *db.BillingInvoice) error { + if req.Currency != existing.Currency { + return billing.ValidationError{ + Err: fmt.Errorf("currency cannot be changed"), + } + } + + if billing.InvoiceTypeStandard != existing.Type { + return billing.ValidationError{ + Err: fmt.Errorf("type cannot be changed"), + } + } + + if req.CustomerID != existing.CustomerID { + return billing.ValidationError{ + Err: fmt.Errorf("customer cannot be changed"), + } + } + + return nil +} + +func (a *adapter) DeleteGatheringInvoice(ctx context.Context, input billing.DeleteGatheringInvoiceAdapterInput) error { + if err := input.Validate(); err != nil { + return fmt.Errorf("validating delete gathering invoice input: %w", err) + } + + return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error { + invoice, err := tx.db.BillingInvoice.Query(). + Where(billinginvoice.ID(input.ID)). + Where(billinginvoice.Namespace(input.Namespace)). + Only(ctx) + if err != nil { + return err + } + + if invoice.Status != billing.StandardInvoiceStatusGathering { + return billing.ValidationError{ + Err: fmt.Errorf("invoice is not a gathering invoice [id=%s]", invoice.ID), + } + } + + if invoice.DeletedAt != nil { + return nil + } + + _, err = tx.db.BillingInvoice.Update(). + Where(billinginvoice.ID(input.ID)). + Where(billinginvoice.Namespace(input.Namespace)). + SetDeletedAt(clock.Now()). + Save(ctx) + if err != nil { + return err + } + + return nil + }) +} + +func (a *adapter) GetGatheringInvoiceById(ctx context.Context, input billing.GetGatheringInvoiceByIdInput) (billing.GatheringInvoice, error) { + if err := input.Validate(); err != nil { + return billing.GatheringInvoice{}, fmt.Errorf("validating get gathering invoice by id input: %w", err) + } + + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.GatheringInvoice, error) { + query := tx.db.BillingInvoice.Query(). + Where(billinginvoice.ID(input.Invoice.ID)). + Where(billinginvoice.Namespace(input.Invoice.Namespace)) + + if input.Expand.Has(billing.GatheringInvoiceExpandLines) { + query = query.WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) { + q.WithUsageBasedLine() + }) + } + + invoice, err := query.Only(ctx) + if err != nil { + if db.IsNotFound(err) { + return billing.GatheringInvoice{}, billing.NotFoundError{ + Err: fmt.Errorf("%w [id=%s]", billing.ErrInvoiceNotFound, input.Invoice.ID), + } + } + + return billing.GatheringInvoice{}, err + } + + return tx.mapGatheringInvoiceFromDB(ctx, invoice, input.Expand) + }) +} + +func (a *adapter) mapGatheringInvoiceFromDB(ctx context.Context, invoice *db.BillingInvoice, expand billing.GatheringInvoiceExpands) (billing.GatheringInvoice, error) { + if invoice.Status != billing.StandardInvoiceStatusGathering { + return billing.GatheringInvoice{}, fmt.Errorf("invoice is not a gathering invoice [id=%s]", invoice.ID) + } + + period := timeutil.ClosedPeriod{} + + if invoice.PeriodStart != nil && invoice.PeriodEnd != nil { + period = timeutil.ClosedPeriod{ + From: invoice.PeriodStart.In(time.UTC), + To: invoice.PeriodEnd.In(time.UTC), + } + } + + res := billing.GatheringInvoice{ + GatheringInvoiceBase: billing.GatheringInvoiceBase{ + ManagedResource: models.ManagedResource{ + NamespacedModel: models.NamespacedModel{ + Namespace: invoice.Namespace, + }, + ManagedModel: models.ManagedModel{ + CreatedAt: invoice.CreatedAt.In(time.UTC), + UpdatedAt: invoice.UpdatedAt.In(time.UTC), + DeletedAt: convert.TimePtrIn(invoice.DeletedAt, time.UTC), + }, + ID: invoice.ID, + Name: invoice.Number, + Description: invoice.Description, + }, + + Metadata: invoice.Metadata, + Number: invoice.Number, + CustomerID: invoice.CustomerID, + Currency: invoice.Currency, + ServicePeriod: period, + NextCollectionAt: invoice.CollectionAt.In(time.UTC), + SchemaLevel: invoice.SchemaLevel, + }, + } + + if expand.Has(billing.GatheringInvoiceExpandLines) { + mappedLines, err := a.mapGatheringInvoiceLinesFromDB(invoice.SchemaLevel, invoice.Edges.BillingInvoiceLines) + if err != nil { + return billing.GatheringInvoice{}, err + } + + // TODO[later]: Implement this once we have proper union type for invoices + // mappedLines, err = a.expandSplitLineHierarchy(ctx, invoice.Namespace, mappedLines) + // if err != nil { + // return billing.StandardInvoice{}, err + // } + + res.Lines = billing.NewGatheringInvoiceLines(mappedLines) + } + + return res, nil +} diff --git a/openmeter/billing/adapter/gatheringlines.go b/openmeter/billing/adapter/gatheringlines.go new file mode 100644 index 0000000000..84a4eb8597 --- /dev/null +++ b/openmeter/billing/adapter/gatheringlines.go @@ -0,0 +1,255 @@ +package billingadapter + +import ( + "context" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "github.com/alpacahq/alpacadecimal" + "github.com/oklog/ulid/v2" + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceusagebasedlineconfig" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/convert" + "github.com/openmeterio/openmeter/pkg/entitydiff" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +type gatheringLineDiff struct { + Line entitydiff.Diff[*billing.GatheringLine] +} + +func diffGatheringInvoiceLines(lines billing.GatheringLines) (gatheringLineDiff, error) { + dbState := []*billing.GatheringLine{} + for _, line := range lines { + if line.DBState != nil { + dbState = append(dbState, line.DBState) + } + } + + linePtrs := lo.Map(lines, func(_ billing.GatheringLine, idx int) *billing.GatheringLine { + return &lines[idx] + }) + + diff := gatheringLineDiff{} + + err := entitydiff.DiffByID(entitydiff.DiffByIDInput[*billing.GatheringLine]{ + DBState: dbState, + ExpectedState: linePtrs, + HandleDelete: func(item *billing.GatheringLine) error { + diff.Line.NeedsDelete(item) + return nil + }, + HandleCreate: func(item *billing.GatheringLine) error { + diff.Line.NeedsCreate(item) + return nil + }, + HandleUpdate: func(item entitydiff.DiffUpdate[*billing.GatheringLine]) error { + diff.Line.NeedsUpdate(item) + return nil + }, + }) + if err != nil { + return gatheringLineDiff{}, err + } + + return diff, nil +} + +func (a *adapter) updateGatheringLines(ctx context.Context, lines billing.GatheringLines) error { + diff, err := diffGatheringInvoiceLines(lines) + if err != nil { + return err + } + + err = upsertWithOptions(ctx, a.db, diff.Line, upsertInput[*billing.GatheringLine, *db.BillingInvoiceUsageBasedLineConfigCreate]{ + Create: func(tx *db.Client, line *billing.GatheringLine) (*db.BillingInvoiceUsageBasedLineConfigCreate, error) { + if line.UBPConfigID == "" { + line.UBPConfigID = ulid.Make().String() + } + + create := tx.BillingInvoiceUsageBasedLineConfig.Create(). + SetNamespace(line.Namespace). + SetPriceType(line.Price.Type()). + SetPrice(lo.ToPtr(line.Price)). + SetFeatureKey(line.FeatureKey). + SetID(line.UBPConfigID) + + return create, nil + }, + UpsertItems: func(ctx context.Context, tx *db.Client, items []*db.BillingInvoiceUsageBasedLineConfigCreate) error { + return tx.BillingInvoiceUsageBasedLineConfig. + CreateBulk(items...). + OnConflict( + sql.ConflictColumns(billinginvoiceusagebasedlineconfig.FieldID), + sql.ResolveWithNewValues(), + ).Exec(ctx) + }, + }) + if err != nil { + return fmt.Errorf("creating usage based line configs: %w", err) + } + + invoiceLineUpsertConfig := upsertInput[*billing.GatheringLine, *db.BillingInvoiceLineCreate]{ + Create: func(tx *db.Client, line *billing.GatheringLine) (*db.BillingInvoiceLineCreate, error) { + if line.ID == "" { + line.ID = ulid.Make().String() + } + + create := tx.BillingInvoiceLine.Create(). + SetID(line.ID). + SetNamespace(line.Namespace). + SetInvoiceID(line.InvoiceID). + SetPeriodStart(line.ServicePeriod.From.In(time.UTC)). + SetPeriodEnd(line.ServicePeriod.To.In(time.UTC)). + SetNillableSplitLineGroupID(line.SplitLineGroupID). + SetNillableDeletedAt(line.DeletedAt). + SetInvoiceAt(line.InvoiceAt.In(time.UTC)). + SetStatus(billing.InvoiceLineStatusValid). + SetManagedBy(line.ManagedBy). + SetType(billing.InvoiceLineTypeUsageBased). + SetName(line.Name). + SetNillableDescription(line.Description). + SetCurrency(line.Currency). + SetMetadata(line.Metadata). + SetAnnotations(line.Annotations). + SetNillableChildUniqueReferenceID(line.ChildUniqueReferenceID). + // Totals + SetAmount(alpacadecimal.Zero). + SetChargesTotal(alpacadecimal.Zero). + SetDiscountsTotal(alpacadecimal.Zero). + SetTaxesTotal(alpacadecimal.Zero). + SetTaxesInclusiveTotal(alpacadecimal.Zero). + SetTaxesExclusiveTotal(alpacadecimal.Zero). + SetTotal(alpacadecimal.Zero) + + if line.Subscription != nil { + create = create.SetSubscriptionID(line.Subscription.SubscriptionID). + SetSubscriptionPhaseID(line.Subscription.PhaseID). + SetSubscriptionItemID(line.Subscription.ItemID). + SetSubscriptionBillingPeriodFrom(line.Subscription.BillingPeriod.From.In(time.UTC)). + SetSubscriptionBillingPeriodTo(line.Subscription.BillingPeriod.To.In(time.UTC)) + } + + if line.TaxConfig != nil { + create = create.SetTaxConfig(*line.TaxConfig) + } + + if !line.RateCardDiscounts.IsEmpty() { + create = create.SetRatecardDiscounts(lo.ToPtr(line.RateCardDiscounts)) + } + + create = create. + SetUsageBasedLineID(line.UBPConfigID) + + return create, nil + }, + UpsertItems: func(ctx context.Context, tx *db.Client, items []*db.BillingInvoiceLineCreate) error { + return tx.BillingInvoiceLine. + CreateBulk(items...). + OnConflict(sql.ConflictColumns(billinginvoiceline.FieldID), + sql.ResolveWithNewValues(), + sql.ResolveWith(func(u *sql.UpdateSet) { + u.SetIgnore(billinginvoiceline.FieldCreatedAt) + })). + UpdateChildUniqueReferenceID(). + Exec(ctx) + }, + MarkDeleted: func(ctx context.Context, line *billing.GatheringLine) (*billing.GatheringLine, error) { + line.DeletedAt = lo.ToPtr(clock.Now().In(time.UTC)) + return line, nil + }, + } + + if err := upsertWithOptions(ctx, a.db, diff.Line, invoiceLineUpsertConfig); err != nil { + return fmt.Errorf("creating lines: %w", err) + } + + return nil +} + +func (a *adapter) mapGatheringInvoiceLinesFromDB(schemaLevel int, dbLines []*db.BillingInvoiceLine) ([]billing.GatheringLine, error) { + return slicesx.MapWithErr(dbLines, func(dbLine *db.BillingInvoiceLine) (billing.GatheringLine, error) { + return a.mapGatheringInvoiceLineFromDB(schemaLevel, dbLine) + }) +} + +func (a *adapter) mapGatheringInvoiceLineFromDB(schemaLevel int, dbLine *db.BillingInvoiceLine) (billing.GatheringLine, error) { + if dbLine.Type != billing.InvoiceLineTypeUsageBased { + return billing.GatheringLine{}, fmt.Errorf("only usage based lines can be gathering invoice lines [line_id=%s]", dbLine.ID) + } + + ubpLine := dbLine.Edges.UsageBasedLine + if ubpLine == nil { + return billing.GatheringLine{}, fmt.Errorf("usage based line data is missing [line_id=%s]", dbLine.ID) + } + + line := billing.GatheringLine{ + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: dbLine.Namespace, + ID: dbLine.ID, + CreatedAt: dbLine.CreatedAt.In(time.UTC), + UpdatedAt: dbLine.UpdatedAt.In(time.UTC), + DeletedAt: convert.TimePtrIn(dbLine.DeletedAt, time.UTC), + Name: dbLine.Name, + Description: dbLine.Description, + }), + + Metadata: dbLine.Metadata, + Annotations: dbLine.Annotations, + InvoiceID: dbLine.InvoiceID, + ManagedBy: dbLine.ManagedBy, + + ServicePeriod: timeutil.ClosedPeriod{ + From: dbLine.PeriodStart.In(time.UTC), + To: dbLine.PeriodEnd.In(time.UTC), + }, + + SplitLineGroupID: dbLine.SplitLineGroupID, + ChildUniqueReferenceID: dbLine.ChildUniqueReferenceID, + + InvoiceAt: dbLine.InvoiceAt.In(time.UTC), + + Currency: dbLine.Currency, + + TaxConfig: lo.EmptyableToPtr(dbLine.TaxConfig), + RateCardDiscounts: lo.FromPtr(dbLine.RatecardDiscounts), + + UBPConfigID: ubpLine.ID, + FeatureKey: lo.FromPtr(ubpLine.FeatureKey), + Price: lo.FromPtr(ubpLine.Price), + }, + } + + if dbLine.SubscriptionID != nil && dbLine.SubscriptionPhaseID != nil && dbLine.SubscriptionItemID != nil { + line.Subscription = &billing.SubscriptionReference{ + SubscriptionID: *dbLine.SubscriptionID, + PhaseID: *dbLine.SubscriptionPhaseID, + ItemID: *dbLine.SubscriptionItemID, + } + if dbLine.SubscriptionBillingPeriodFrom != nil && + dbLine.SubscriptionBillingPeriodTo != nil { + line.Subscription.BillingPeriod = timeutil.ClosedPeriod{ + From: dbLine.SubscriptionBillingPeriodFrom.In(time.UTC), + To: dbLine.SubscriptionBillingPeriodTo.In(time.UTC), + } + } + } + + cloned, err := line.WithoutDBState() + if err != nil { + return billing.GatheringLine{}, fmt.Errorf("cloning line: %w", err) + } + + line.DBState = lo.ToPtr(cloned) + + return line, nil +} diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index 7fe5da3c8e..05a99dd206 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -289,11 +289,6 @@ func (a *adapter) CreateInvoice(ctx context.Context, input billing.CreateInvoice return billing.CreateInvoiceAdapterRespone{}, fmt.Errorf("clone workflow config: %w", err) } - workflowConfig := mapWorkflowConfigToDB(input.Profile.WorkflowConfig, clonedWorkflowConfig.ID) - - // Force cloning of the workflow - workflowConfig.ID = "" - currentSchemaLevel, err := tx.GetInvoiceDefaultSchemaLevel(ctx) if err != nil { return billing.CreateInvoiceAdapterRespone{}, fmt.Errorf("get invoice write schema level: %w", err) diff --git a/openmeter/billing/adapter/profile.go b/openmeter/billing/adapter/profile.go index ac6956c911..bc2b0a937d 100644 --- a/openmeter/billing/adapter/profile.go +++ b/openmeter/billing/adapter/profile.go @@ -431,22 +431,6 @@ func mapProfileFromDB(dbProfile *db.BillingProfile) (*billing.AdapterGetProfileR }, nil } -func mapWorkflowConfigToDB(wc billing.WorkflowConfig, id string) *db.BillingWorkflowConfig { - return &db.BillingWorkflowConfig{ - ID: id, - - CollectionAlignment: wc.Collection.Alignment, - AnchoredAlignmentDetail: wc.Collection.AnchoredAlignmentDetail, - LineCollectionPeriod: wc.Collection.Interval.ISOString(), - InvoiceAutoAdvance: wc.Invoicing.AutoAdvance, - InvoiceDraftPeriod: wc.Invoicing.DraftPeriod.ISOString(), - InvoiceDueAfter: wc.Invoicing.DueAfter.ISOString(), - InvoiceCollectionMethod: wc.Payment.CollectionMethod, - TaxEnabled: wc.Tax.Enabled, - TaxEnforced: wc.Tax.Enforced, - } -} - func mapWorkflowConfigFromDB(dbWC *db.BillingWorkflowConfig) (billing.WorkflowConfig, error) { collectionInterval, err := dbWC.LineCollectionPeriod.Parse() if err != nil { diff --git a/openmeter/billing/derived.gen.go b/openmeter/billing/derived.gen.go index 7229322b79..790be53f89 100644 --- a/openmeter/billing/derived.gen.go +++ b/openmeter/billing/derived.gen.go @@ -4,7 +4,6 @@ package billing import ( models "github.com/openmeterio/openmeter/pkg/models" - timeutil "github.com/openmeterio/openmeter/pkg/timeutil" ) // deriveEqualDetailedLineBase returns whether this and that are equal. @@ -76,7 +75,7 @@ func deriveEqualLineBase(this, that *StandardLineBase) bool { return (this == nil && that == nil) || this != nil && that != nil && deriveEqual(&this.ManagedResource, &that.ManagedResource) && - deriveEqual_2(this.Metadata, that.Metadata) && + this.Metadata.Equal(that.Metadata) && this.Annotations.Equal(that.Annotations) && this.ManagedBy == that.ManagedBy && this.InvoiceID == that.InvoiceID && @@ -87,9 +86,9 @@ func deriveEqualLineBase(this, that *StandardLineBase) bool { ((this.SplitLineGroupID == nil && that.SplitLineGroupID == nil) || (this.SplitLineGroupID != nil && that.SplitLineGroupID != nil && *(this.SplitLineGroupID) == *(that.SplitLineGroupID))) && ((this.ChildUniqueReferenceID == nil && that.ChildUniqueReferenceID == nil) || (this.ChildUniqueReferenceID != nil && that.ChildUniqueReferenceID != nil && *(this.ChildUniqueReferenceID) == *(that.ChildUniqueReferenceID))) && this.TaxConfig.Equal(that.TaxConfig) && - deriveEqual_3(&this.RateCardDiscounts, &that.RateCardDiscounts) && + deriveEqual_2(&this.RateCardDiscounts, &that.RateCardDiscounts) && this.ExternalIDs.Equal(that.ExternalIDs) && - deriveEqual_4(this.Subscription, that.Subscription) && + deriveEqual_3(this.Subscription, that.Subscription) && deriveEqual_(&this.Totals, &that.Totals) } @@ -140,47 +139,19 @@ func deriveEqual_1(this, that *DiscountReason) bool { } // deriveEqual_2 returns whether this and that are equal. -func deriveEqual_2(this, that map[string]string) bool { - if this == nil || that == nil { - return this == nil && that == nil - } - if len(this) != len(that) { - return false - } - for k, v := range this { - thatv, ok := that[k] - if !ok { - return false - } - if !(v == thatv) { - return false - } - } - return true -} - -// deriveEqual_3 returns whether this and that are equal. -func deriveEqual_3(this, that *Discounts) bool { +func deriveEqual_2(this, that *Discounts) bool { return (this == nil && that == nil) || this != nil && that != nil && ((this.Percentage == nil && that.Percentage == nil) || (this.Percentage != nil && that.Percentage != nil && (*(this.Percentage)).Equal(*(that.Percentage)))) && ((this.Usage == nil && that.Usage == nil) || (this.Usage != nil && that.Usage != nil && (*(this.Usage)).Equal(*(that.Usage)))) } -// deriveEqual_4 returns whether this and that are equal. -func deriveEqual_4(this, that *SubscriptionReference) bool { +// deriveEqual_3 returns whether this and that are equal. +func deriveEqual_3(this, that *SubscriptionReference) bool { return (this == nil && that == nil) || this != nil && that != nil && this.SubscriptionID == that.SubscriptionID && this.PhaseID == that.PhaseID && this.ItemID == that.ItemID && - deriveEqual_5(&this.BillingPeriod, &that.BillingPeriod) -} - -// deriveEqual_5 returns whether this and that are equal. -func deriveEqual_5(this, that *timeutil.ClosedPeriod) bool { - return (this == nil && that == nil) || - this != nil && that != nil && - this.From.Equal(that.From) && - this.To.Equal(that.To) + this.BillingPeriod.Equal(that.BillingPeriod) } diff --git a/openmeter/billing/eventsgathering.go b/openmeter/billing/eventsgathering.go new file mode 100644 index 0000000000..1bb1c09568 --- /dev/null +++ b/openmeter/billing/eventsgathering.go @@ -0,0 +1,30 @@ +package billing + +import "github.com/openmeterio/openmeter/openmeter/event/metadata" + +type GatheringInvoiceCreatedEvent struct { + Invoice GatheringInvoice `json:"gatheringInvoice"` +} + +func (e GatheringInvoiceCreatedEvent) Validate() error { + return e.Invoice.Validate() +} + +func NewGatheringInvoiceCreatedEvent(invoice GatheringInvoice) GatheringInvoiceCreatedEvent { + return GatheringInvoiceCreatedEvent{Invoice: invoice} +} + +func (e GatheringInvoiceCreatedEvent) EventName() string { + return metadata.GetEventName(metadata.EventType{ + Subsystem: EventSubsystem, + Name: "gathering.invoice.created", + Version: "v1", + }) +} + +func (e GatheringInvoiceCreatedEvent) EventMetadata() metadata.EventMetadata { + return metadata.EventMetadata{ + Source: metadata.ComposeResourcePath(e.Invoice.Namespace, metadata.EntityGatheringInvoice, e.Invoice.ID), + Subject: metadata.ComposeResourcePath(e.Invoice.Namespace, metadata.EntityCustomer, e.Invoice.CustomerID), + } +} diff --git a/openmeter/billing/gatheringinvoice.go b/openmeter/billing/gatheringinvoice.go index f2444216dd..6947599f90 100644 --- a/openmeter/billing/gatheringinvoice.go +++ b/openmeter/billing/gatheringinvoice.go @@ -3,16 +3,421 @@ package billing import ( "errors" "fmt" + "slices" + "time" + "github.com/samber/lo" + "github.com/samber/mo" + + "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/sortx" + timeutil "github.com/openmeterio/openmeter/pkg/timeutil" +) + +type GatheringInvoiceBase struct { + models.ManagedResource + + Metadata models.Metadata `json:"metadata"` + + Number string `json:"number"` + CustomerID string `json:"customerID"` + Currency currencyx.Code `json:"currency"` + ServicePeriod timeutil.ClosedPeriod `json:"servicePeriod"` + + NextCollectionAt time.Time `json:"nextCollectionAt"` + + SchemaLevel int `json:"schemaLevel"` +} + +func (g GatheringInvoiceBase) Validate() error { + var errs []error + + if g.Name == "" { + errs = append(errs, errors.New("name is required")) + } + + if g.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + + if err := g.Currency.Validate(); err != nil { + errs = append(errs, err) + } + + if err := g.ServicePeriod.Validate(); err != nil { + errs = append(errs, err) + } + + if g.SchemaLevel == 0 { + errs = append(errs, errors.New("schema level is required")) + } + + return errors.Join(errs...) +} + +type GatheringInvoice struct { + GatheringInvoiceBase `json:",inline"` + + // Entities external to the invoice entity + Lines GatheringInvoiceLines `json:"lines,omitempty"` + + // TODO[later]: implement this once we have a lineservice capable of operating on + // these lines too. + AvailableActions *GatheringInvoiceAvailableActions `json:"availableActions,omitempty"` +} + +func (g GatheringInvoice) WithoutDBState() (GatheringInvoice, error) { + clone, err := g.Clone() + if err != nil { + return GatheringInvoice{}, fmt.Errorf("cloning invoice: %w", err) + } + + clone.Lines, err = clone.Lines.MapWithErr(func(l GatheringLine) (GatheringLine, error) { + return l.WithoutDBState() + }) + if err != nil { + return GatheringInvoice{}, fmt.Errorf("cloning lines: %w", err) + } + + return clone, nil +} + +func (g GatheringInvoice) InvoiceID() InvoiceID { + return InvoiceID{ + Namespace: g.Namespace, + ID: g.ID, + } +} + +func (g GatheringInvoice) Validate() error { + var errs []error + + if err := g.GatheringInvoiceBase.Validate(); err != nil { + errs = append(errs, err) + } + + if err := g.Lines.Validate(); err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) +} + +func (g *GatheringInvoice) SortLines() { + if !g.Lines.IsPresent() { + return + } + + g.Lines.Sort() +} + +func (g GatheringInvoice) Clone() (GatheringInvoice, error) { + clone := g + + clone.Metadata = g.Metadata.Clone() + + clonedLines, err := clone.Lines.MapWithErr(func(l GatheringLine) (GatheringLine, error) { + return l.Clone() + }) + if err != nil { + return GatheringInvoice{}, fmt.Errorf("cloning lines: %w", err) + } + + clone.Lines = clonedLines + + return clone, nil +} + +type GatheringInvoiceExpand string + +func (e GatheringInvoiceExpand) Validate() error { + if slices.Contains(GatheringInvoiceExpandValues, e) { + return nil + } + + return fmt.Errorf("invalid gathering invoice expand: %s", e) +} + +const ( + GatheringInvoiceExpandLines GatheringInvoiceExpand = "lines" + GatheringInvoiceExpandAvailableActions GatheringInvoiceExpand = "availableActions" ) +var GatheringInvoiceExpandValues = []GatheringInvoiceExpand{ + GatheringInvoiceExpandLines, + GatheringInvoiceExpandAvailableActions, +} + +type GatheringInvoiceExpands []GatheringInvoiceExpand + +func (e GatheringInvoiceExpands) Validate() error { + for _, expand := range e { + if err := expand.Validate(); err != nil { + return err + } + } + return nil +} + +func (e GatheringInvoiceExpands) Has(expand GatheringInvoiceExpand) bool { + return slices.Contains(e, expand) +} + +func (e GatheringInvoiceExpands) With(expand GatheringInvoiceExpand) GatheringInvoiceExpands { + return append(e, expand) +} + +type GatheringInvoiceAvailableActions struct { + CanBeInvoiced bool `json:"canBeInvoiced"` +} + +type GatheringLines []GatheringLine + +type GatheringInvoiceLines struct { + mo.Option[GatheringLines] +} + +func (l GatheringInvoiceLines) Validate() error { + if l.IsAbsent() { + return nil + } + + return errors.Join( + lo.Map(l.OrEmpty(), func(l GatheringLine, _ int) error { + err := l.Validate() + if err != nil { + return fmt.Errorf("line[%s]: %w", l.ID, err) + } + return nil + })..., + ) +} + +func (l *GatheringInvoiceLines) Sort() { + if l.IsAbsent() { + return + } + + lines := l.OrEmpty() + slices.SortFunc(lines, func(a, b GatheringLine) int { + return a.CreatedAt.Compare(b.CreatedAt) + }) + + l.Option = mo.Some(lines) +} + +func (l GatheringInvoiceLines) NonDeletedLineCount() int { + return lo.CountBy(l.OrEmpty(), func(l GatheringLine) bool { + return l.DeletedAt == nil + }) +} + +func (l GatheringInvoiceLines) Map(fn func(GatheringLine) GatheringLine) GatheringInvoiceLines { + res, _ := l.MapWithErr(func(gl GatheringLine) (GatheringLine, error) { + return fn(gl), nil + }) + + return res +} + +func (l GatheringInvoiceLines) MapWithErr(fn func(GatheringLine) (GatheringLine, error)) (GatheringInvoiceLines, error) { + if l.IsAbsent() { + return l, nil + } + + out, err := slicesx.MapWithErr(l.OrEmpty(), fn) + if err != nil { + return l, err + } + + return GatheringInvoiceLines{ + Option: mo.Some(GatheringLines(out)), + }, nil +} + +func (l *GatheringInvoiceLines) Append(lines ...GatheringLine) { + l.Option = mo.Some(append(l.OrEmpty(), lines...)) +} + +func NewGatheringInvoiceLines(children []GatheringLine) GatheringInvoiceLines { + return GatheringInvoiceLines{ + Option: mo.Some(GatheringLines(children)), + } +} + +type GatheringLineBase struct { + models.ManagedResource + + Metadata models.Metadata `json:"metadata"` + Annotations models.Annotations `json:"annotations"` + ManagedBy InvoiceLineManagedBy `json:"managedBy"` + InvoiceID string `json:"invoiceID"` + + Currency currencyx.Code `json:"currency"` + ServicePeriod timeutil.ClosedPeriod `json:"period"` + InvoiceAt time.Time `json:"invoiceAt"` + Price productcatalog.Price `json:"price"` + FeatureKey string `json:"featureKey"` + + TaxConfig *productcatalog.TaxConfig `json:"taxOverrides,omitempty"` + RateCardDiscounts Discounts `json:"rateCardDiscounts,omitempty"` + + ChildUniqueReferenceID *string `json:"childUniqueReferenceID,omitempty"` + Subscription *SubscriptionReference `json:"subscription,omitempty"` + SplitLineGroupID *string `json:"splitLineGroupID,omitempty"` + + // TODO: Remove once we have dedicated db field for gathering invoice lines + UBPConfigID string `json:"ubpConfigID"` +} + +func (i GatheringLineBase) Validate() error { + var errs []error + + if i.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + + if i.Name == "" { + errs = append(errs, errors.New("name is required")) + } + + if err := i.ServicePeriod.Validate(); err != nil { + errs = append(errs, fmt.Errorf("service period: %w", err)) + } + + if i.InvoiceAt.IsZero() { + errs = append(errs, errors.New("invoice at is required")) + } + + if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } + + if !slices.Contains(InvoiceLineManagedBy("").Values(), string(i.ManagedBy)) { + errs = append(errs, fmt.Errorf("invalid managed by %s", i.ManagedBy)) + } + + if i.Subscription != nil { + if err := i.Subscription.Validate(); err != nil { + errs = append(errs, fmt.Errorf("subscription: %w", err)) + } + } + + if i.TaxConfig != nil { + if err := i.TaxConfig.Validate(); err != nil { + errs = append(errs, fmt.Errorf("tax config: %w", err)) + } + } + + if err := i.Price.Validate(); err != nil { + errs = append(errs, fmt.Errorf("price: %w", err)) + } + + if err := i.RateCardDiscounts.ValidateForPrice(&i.Price); err != nil { + errs = append(errs, fmt.Errorf("rate card discounts: %w", err)) + } + + if i.ChildUniqueReferenceID != nil && *i.ChildUniqueReferenceID == "" { + errs = append(errs, errors.New("child unique reference id is required")) + } + + if i.Price.Type() != productcatalog.FlatPriceType && i.FeatureKey == "" { + errs = append(errs, errors.New("feature key is required for non-flat prices")) + } + + return errors.Join(errs...) +} + +func (i *GatheringLineBase) NormalizeValues() error { + i.ServicePeriod = i.ServicePeriod.Truncate(streaming.MinimumWindowSizeDuration) + i.InvoiceAt = i.InvoiceAt.Truncate(streaming.MinimumWindowSizeDuration) + + if err := setDefaultPaymentTermForFlatPrice(&i.Price); err != nil { + return fmt.Errorf("setting default payment term for flat price: %w", err) + } + + return nil +} + +func (i GatheringLineBase) Clone() (GatheringLineBase, error) { + var err error + + out := i + + out.Annotations, err = i.Annotations.Clone() + if err != nil { + return GatheringLineBase{}, fmt.Errorf("cloning annotations: %w", err) + } + + out.Metadata = i.Metadata.Clone() + + if i.TaxConfig != nil { + out.TaxConfig = &productcatalog.TaxConfig{} + *out.TaxConfig = *i.TaxConfig + } + + if i.Subscription != nil { + out.Subscription = &SubscriptionReference{} + *out.Subscription = *i.Subscription + } + + return out, nil +} + +// TODO: rename to GatheringLine +type GatheringLine struct { + GatheringLineBase `json:",inline"` + + DBState *GatheringLine `json:"-"` +} + +func (g GatheringLine) Clone() (GatheringLine, error) { + base, err := g.GatheringLineBase.Clone() + if err != nil { + return GatheringLine{}, fmt.Errorf("cloning line base: %w", err) + } + + return GatheringLine{ + GatheringLineBase: base, + DBState: g.DBState, + }, nil +} + +func (g GatheringLine) WithoutDBState() (GatheringLine, error) { + clone, err := g.Clone() + if err != nil { + return GatheringLine{}, fmt.Errorf("cloning line: %w", err) + } + + clone.DBState = nil + return clone, nil +} + +func (g GatheringLine) WithNormalizedValues() (GatheringLine, error) { + clone, err := g.Clone() + if err != nil { + return GatheringLine{}, fmt.Errorf("cloning line: %w", err) + } + + if err := clone.GatheringLineBase.NormalizeValues(); err != nil { + return GatheringLine{}, fmt.Errorf("normalizing line values: %w", err) + } + + return clone, nil +} + type CreatePendingInvoiceLinesInput struct { Customer customer.CustomerID `json:"customer"` Currency currencyx.Code `json:"currency"` - Lines []*StandardLine `json:"lines"` + Lines []GatheringLine `json:"lines"` } func (c CreatePendingInvoiceLinesInput) Validate() error { @@ -38,14 +443,6 @@ func (c CreatePendingInvoiceLinesInput) Validate() error { errs = append(errs, fmt.Errorf("line.%d: invoice ID is not allowed for pending lines", id)) } - if len(line.DetailedLines) > 0 { - errs = append(errs, fmt.Errorf("line.%d: detailed lines are not allowed for pending lines", id)) - } - - if line.ParentLineID != nil { - errs = append(errs, fmt.Errorf("line.%d: parent line ID is not allowed for pending lines", id)) - } - if line.SplitLineGroupID != nil { errs = append(errs, fmt.Errorf("line.%d: split line group ID is not allowed for pending lines", id)) } @@ -55,7 +452,145 @@ func (c CreatePendingInvoiceLinesInput) Validate() error { } type CreatePendingInvoiceLinesResult struct { - Lines []*StandardLine - Invoice StandardInvoice + Lines []GatheringLine + Invoice GatheringInvoice IsInvoiceNew bool } + +type CreateGatheringInvoiceAdapterInput struct { + Namespace string + Number string + Currency currencyx.Code + Metadata map[string]string + + Description *string + NextCollectionAt *time.Time + + // TODO[later]: This should be just a CustomerID once we have split the invoices table + Customer customer.Customer + MergedProfile Profile +} + +func (c CreateGatheringInvoiceAdapterInput) Validate() error { + var errs []error + + if c.Namespace == "" { + errs = append(errs, errors.New("namespace is required")) + } + + if err := c.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } + + if c.Number == "" { + errs = append(errs, errors.New("number is required")) + } + + if err := c.Customer.Validate(); err != nil { + errs = append(errs, fmt.Errorf("customer: %w", err)) + } + + if err := c.MergedProfile.Validate(); err != nil { + errs = append(errs, fmt.Errorf("merged profile: %w", err)) + } + + return errors.Join(errs...) +} + +type DeleteGatheringInvoiceAdapterInput = InvoiceID + +type UpdateGatheringInvoiceAdapterInput = GatheringInvoice + +type ListGatheringInvoicesInput struct { + pagination.Page + + Namespaces []string + Customers []string + Currencies []currencyx.Code + OrderBy api.InvoiceOrderBy + Order sortx.Order + IncludeDeleted bool + Expand GatheringInvoiceExpands +} + +func (i ListGatheringInvoicesInput) Validate() error { + var errs []error + + if err := i.Page.Validate(); err != nil { + errs = append(errs, fmt.Errorf("page: %w", err)) + } + + if len(i.Namespaces) == 0 { + errs = append(errs, errors.New("namespaces is required")) + } + + for _, expand := range i.Expand { + if err := expand.Validate(); err != nil { + errs = append(errs, fmt.Errorf("expand: %w", err)) + } + } + + return errors.Join(errs...) +} + +func NewFlatFeeGatheringLine(input NewFlatFeeLineInput, opts ...usageBasedLineOption) GatheringLine { + ubpOptions := usageBasedLineOptions{} + + for _, opt := range opts { + opt(&ubpOptions) + } + + return GatheringLine{ + GatheringLineBase: GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: input.Namespace, + ID: input.ID, + CreatedAt: input.CreatedAt, + UpdatedAt: input.UpdatedAt, + Name: input.Name, + Description: input.Description, + }), + ServicePeriod: timeutil.ClosedPeriod{ + From: input.Period.Start, + To: input.Period.End, + }, + InvoiceAt: input.InvoiceAt, + InvoiceID: input.InvoiceID, + + Metadata: input.Metadata, + Annotations: input.Annotations, + + ManagedBy: lo.CoalesceOrEmpty(input.ManagedBy, SystemManagedLine), + + Currency: input.Currency, + RateCardDiscounts: input.RateCardDiscounts, + Price: lo.FromPtr( + productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: input.PerUnitAmount, + PaymentTerm: input.PaymentTerm, + }), + ), + FeatureKey: ubpOptions.featureKey, + }, + } +} + +type GetGatheringInvoiceByIdInput struct { + Invoice InvoiceID + Expand GatheringInvoiceExpands +} + +func (i GetGatheringInvoiceByIdInput) Validate() error { + var errs []error + + if err := i.Invoice.Validate(); err != nil { + errs = append(errs, fmt.Errorf("invoice: %w", err)) + } + + for _, expand := range i.Expand { + if err := expand.Validate(); err != nil { + errs = append(errs, fmt.Errorf("expand: %w", err)) + } + } + return errors.Join(errs...) +} diff --git a/openmeter/billing/httpdriver/gatheringinvoice.go b/openmeter/billing/httpdriver/gatheringinvoice.go new file mode 100644 index 0000000000..0b8ef4bd8b --- /dev/null +++ b/openmeter/billing/httpdriver/gatheringinvoice.go @@ -0,0 +1,151 @@ +package httpdriver + +import ( + "fmt" + + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/api" + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/customer" + productcataloghttp "github.com/openmeterio/openmeter/openmeter/productcatalog/http" + "github.com/openmeterio/openmeter/pkg/convert" + "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +func MapGatheringInvoiceToAPI(invoice billing.GatheringInvoice, customer *customer.Customer, profile billing.Profile) (api.Invoice, error) { + var err error + + invoiceCustomer := billing.InvoiceCustomer{ + Key: customer.Key, + Name: customer.Name, + } + + statusDetails := api.InvoiceStatusDetails{ + Failed: false, + Immutable: false, + ExtendedStatus: string(billing.StandardInvoiceStatusGathering), + } + + if invoice.AvailableActions != nil && invoice.AvailableActions.CanBeInvoiced { + statusDetails.AvailableActions = api.InvoiceAvailableActions{ + Invoice: &api.InvoiceAvailableActionInvoiceDetails{}, + } + } + + // Sort the lines to make the response more consistent (internally we don't care about the order) + invoice.SortLines() + + out := api.Invoice{ + Id: invoice.ID, + + CreatedAt: invoice.CreatedAt, + UpdatedAt: invoice.UpdatedAt, + DeletedAt: invoice.DeletedAt, + CollectionAt: lo.ToPtr(invoice.NextCollectionAt), + Period: mapServicePeriodToAPI(invoice.ServicePeriod), + + Currency: string(invoice.Currency), + Customer: mapInvoiceCustomerToAPI(invoiceCustomer), + + Number: invoice.Number, + Description: invoice.Description, + Metadata: convert.MapToPointer(invoice.Metadata), + + Status: api.InvoiceStatus(billing.StandardInvoiceStatusGathering), + StatusDetails: statusDetails, + Supplier: api.BillingParty{}, + Totals: api.InvoiceTotals{}, + Type: api.InvoiceType(billing.StandardInvoiceStatusCategoryGathering), + } + + workflowConfig, err := mapWorkflowConfigSettingsToAPI(profile.WorkflowConfig) + if err != nil { + return api.Invoice{}, fmt.Errorf("failed to map workflow config to API: %w", err) + } + + out.Workflow = api.InvoiceWorkflowSettings{ + SourceBillingProfileId: profile.ID, + Workflow: workflowConfig, + } + + outLines, err := slicesx.MapWithErr(invoice.Lines.OrEmpty(), func(line billing.GatheringLine) (api.InvoiceLine, error) { + mappedLine, err := mapGatheringInvoiceLineToAPI(line) + if err != nil { + return api.InvoiceLine{}, fmt.Errorf("failed to map billing line[%s] to API: %w", line.ID, err) + } + + return mappedLine, nil + }) + if err != nil { + return api.Invoice{}, err + } + + if len(outLines) > 0 { + out.Lines = &outLines + } + + return out, nil +} + +func mapServicePeriodToAPI(p timeutil.ClosedPeriod) *api.Period { + if lo.IsEmpty(p) { + return nil + } + + return &api.Period{ + From: p.From, + To: p.To, + } +} + +func mapGatheringInvoiceLineToAPI(line billing.GatheringLine) (api.InvoiceLine, error) { + price, err := productcataloghttp.FromRateCardUsageBasedPrice(line.Price) + if err != nil { + return api.InvoiceLine{}, fmt.Errorf("failed to map price: %w", err) + } + + invoiceLine := api.InvoiceLine{ + Type: api.InvoiceLineTypeUsageBased, + Id: line.ID, + + CreatedAt: line.CreatedAt, + DeletedAt: line.DeletedAt, + UpdatedAt: line.UpdatedAt, + InvoiceAt: line.InvoiceAt, + + Currency: string(line.Currency), + Status: api.InvoiceLineStatusValid, + + Description: line.Description, + Name: line.Name, + ManagedBy: api.InvoiceLineManagedBy(line.ManagedBy), + + Invoice: &api.InvoiceReference{ + Id: line.InvoiceID, + }, + + Metadata: convert.MapToPointer(line.Metadata), + Period: api.Period{ + From: line.ServicePeriod.From, + To: line.ServicePeriod.To, + }, + + TaxConfig: mapTaxConfigToAPI(line.TaxConfig), + + FeatureKey: lo.EmptyableToPtr(line.FeatureKey), + + Price: lo.ToPtr(price), + + RateCard: &api.InvoiceUsageBasedRateCard{ + TaxConfig: mapTaxConfigToAPI(line.TaxConfig), + Price: lo.ToPtr(price), + FeatureKey: lo.EmptyableToPtr(line.FeatureKey), + }, + + Subscription: mapSubscriptionReferencesToAPI(line.Subscription), + } + + return invoiceLine, nil +} diff --git a/openmeter/billing/httpdriver/invoice_test.go b/openmeter/billing/httpdriver/invoice_test.go index 34b6e67bb1..d6163cbe3a 100644 --- a/openmeter/billing/httpdriver/invoice_test.go +++ b/openmeter/billing/httpdriver/invoice_test.go @@ -55,8 +55,8 @@ func (s *InvoicingTestSuite) TestGatheringInvoiceSerialization() { res, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: cust.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine( + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine( billing.NewFlatFeeLineInput{ Namespace: namespace, Period: billing.Period{Start: now, End: now.Add(time.Hour * 24)}, diff --git a/openmeter/billing/httpdriver/invoiceline.go b/openmeter/billing/httpdriver/invoiceline.go index 320bd5307f..b61b9d1923 100644 --- a/openmeter/billing/httpdriver/invoiceline.go +++ b/openmeter/billing/httpdriver/invoiceline.go @@ -24,6 +24,7 @@ import ( "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/set" "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/timeutil" ) var _ InvoiceLineHandler = (*handler)(nil) @@ -64,8 +65,8 @@ func (h *handler) CreatePendingLine() CreatePendingLineHandler { } } - lineEntities, err := slicesx.MapWithErr(req.Lines, func(line api.InvoicePendingLineCreate) (*billing.StandardLine, error) { - return mapCreateLineToEntity(line, ns) + lineEntities, err := slicesx.MapWithErr(req.Lines, func(line api.InvoicePendingLineCreate) (billing.GatheringLine, error) { + return mapCreateGatheringLineToEntity(line, ns) }) if err != nil { return CreatePendingLineRequest{}, billing.ValidationError{ @@ -88,16 +89,32 @@ func (h *handler) CreatePendingLine() CreatePendingLineHandler { return CreatePendingLineResponse{}, fmt.Errorf("failed to create invoice lines: %w", err) } + if res == nil { + return CreatePendingLineResponse{}, fmt.Errorf("create pending invoice lines result is nil") + } + out := CreatePendingLineResponse{ IsInvoiceNew: res.IsInvoiceNew, } - out.Invoice, err = MapInvoiceToAPI(res.Invoice) + // TODO: For the V3 api let's not return the invoice + mergedProfile, err := h.service.GetCustomerOverride(ctx, billing.GetCustomerOverrideInput{ + Customer: request.Customer, + Expand: billing.CustomerOverrideExpand{ + Customer: true, + Apps: true, + }, + }) + if err != nil { + return CreatePendingLineResponse{}, fmt.Errorf("failed to get customer override: %w", err) + } + + out.Invoice, err = MapGatheringInvoiceToAPI(res.Invoice, mergedProfile.Customer, mergedProfile.MergedProfile) if err != nil { return CreatePendingLineResponse{}, fmt.Errorf("failed to map invoice: %w", err) } - out.Lines, err = slicesx.MapWithErr(res.Lines, mapInvoiceLineToAPI) + out.Lines, err = slicesx.MapWithErr(res.Lines, mapGatheringInvoiceLineToAPI) if err != nil { return CreatePendingLineResponse{}, fmt.Errorf("failed to map lines: %w", err) } @@ -151,6 +168,46 @@ func mapCreateLineToEntity(line api.InvoicePendingLineCreate, ns string) (*billi }, nil } +func mapCreateGatheringLineToEntity(line api.InvoicePendingLineCreate, ns string) (billing.GatheringLine, error) { + rateCardParsed, err := mapAndValidateInvoiceLineRateCardDeprecatedFields(invoiceLineRateCardItems{ + RateCard: line.RateCard, + Price: line.Price, + TaxConfig: line.TaxConfig, + FeatureKey: line.FeatureKey, + }) + if err != nil { + return billing.GatheringLine{}, fmt.Errorf("failed to map usage based line: %w", err) + } + + if rateCardParsed.Price == nil { + return billing.GatheringLine{}, fmt.Errorf("price is nil [line=%s]", line.Name) + } + + return billing.GatheringLine{ + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: ns, + Name: line.Name, + Description: line.Description, + }), + + Metadata: lo.FromPtrOr(line.Metadata, map[string]string{}), + ManagedBy: billing.ManuallyManagedLine, + + ServicePeriod: timeutil.ClosedPeriod{ + From: line.Period.From, + To: line.Period.To, + }, + + InvoiceAt: line.InvoiceAt, + TaxConfig: rateCardParsed.TaxConfig, + RateCardDiscounts: rateCardParsed.Discounts, + Price: lo.FromPtr(rateCardParsed.Price), + FeatureKey: rateCardParsed.FeatureKey, + }, + }, nil +} + func mapTaxConfigToEntity(tc *api.TaxConfig) *productcatalog.TaxConfig { if tc == nil { return nil @@ -636,7 +693,7 @@ func mergeLineFromInvoiceLineReplaceUpdate(existing *billing.StandardLine, line } } - existing.Metadata = lo.FromPtrOr(line.Metadata, existing.Metadata) + existing.Metadata = lo.FromPtrOr(line.Metadata, api.Metadata(existing.Metadata)) existing.Name = line.Name existing.Description = line.Description diff --git a/openmeter/billing/invoiceline.go b/openmeter/billing/invoiceline.go index f8535cca5c..9ae4cc14c8 100644 --- a/openmeter/billing/invoiceline.go +++ b/openmeter/billing/invoiceline.go @@ -5,6 +5,7 @@ import ( "time" "github.com/openmeterio/openmeter/pkg/models" + timeutil "github.com/openmeterio/openmeter/pkg/timeutil" ) type LineID models.NamespacedID @@ -79,6 +80,13 @@ func (p Period) Duration() time.Duration { return p.End.Sub(p.Start) } +func (p Period) ToClosedPeriod() timeutil.ClosedPeriod { + return timeutil.ClosedPeriod{ + From: p.Start, + To: p.End, + } +} + type GetLinesForSubscriptionInput struct { Namespace string SubscriptionID string diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index c3ec521c4f..c1746bea20 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -5,6 +5,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/app" "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/pkg/pagination" ) type Service interface { @@ -13,6 +14,7 @@ type Service interface { InvoiceLineService SplitLineGroupService InvoiceService + GatheringInvoiceService SequenceService LockableService @@ -43,8 +45,6 @@ type CustomerOverrideService interface { } type InvoiceLineService interface { - // CreatePendingInvoiceLines creates pending invoice lines for a customer, if the lines are zero valued, the response is nil - CreatePendingInvoiceLines(ctx context.Context, input CreatePendingInvoiceLinesInput) (*CreatePendingInvoiceLinesResult, error) GetLinesForSubscription(ctx context.Context, input GetLinesForSubscriptionInput) ([]LineOrHierarchy, error) // SnapshotLineQuantity returns an updated line with the quantity snapshoted from meters // the invoice is used as contextual information to the call. @@ -84,6 +84,13 @@ type InvoiceService interface { RecalculateGatheringInvoices(ctx context.Context, input RecalculateGatheringInvoicesInput) error } +type GatheringInvoiceService interface { + // CreatePendingInvoiceLines creates pending invoice lines for a customer, if the lines are zero valued, the response is nil + CreatePendingInvoiceLines(ctx context.Context, input CreatePendingInvoiceLinesInput) (*CreatePendingInvoiceLinesResult, error) + + ListGatheringInvoices(ctx context.Context, input ListGatheringInvoicesInput) (pagination.Result[GatheringInvoice], error) +} + type SequenceService interface { GenerateInvoiceSequenceNumber(ctx context.Context, in SequenceGenerationInput, def SequenceDefinition) (string, error) } diff --git a/openmeter/billing/service/gatheringinvoice.go b/openmeter/billing/service/gatheringinvoice.go new file mode 100644 index 0000000000..7bbdf6a697 --- /dev/null +++ b/openmeter/billing/service/gatheringinvoice.go @@ -0,0 +1,21 @@ +package billingservice + +import ( + "context" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/pkg/framework/transaction" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +var _ billing.GatheringInvoiceService = (*Service)(nil) + +func (s *Service) ListGatheringInvoices(ctx context.Context, input billing.ListGatheringInvoicesInput) (pagination.Result[billing.GatheringInvoice], error) { + if err := input.Validate(); err != nil { + return pagination.Result[billing.GatheringInvoice]{}, err + } + + return transaction.Run(ctx, s.adapter, func(ctx context.Context) (pagination.Result[billing.GatheringInvoice], error) { + return s.adapter.ListGatheringInvoices(ctx, input) + }) +} diff --git a/openmeter/billing/service/gatheringinvoicependinglines.go b/openmeter/billing/service/gatheringinvoicependinglines.go index ceb48eec80..e3b1663792 100644 --- a/openmeter/billing/service/gatheringinvoicependinglines.go +++ b/openmeter/billing/service/gatheringinvoicependinglines.go @@ -869,7 +869,7 @@ func (s *Service) moveLinesToInvoice(ctx context.Context, in moveLinesToInvoiceI // - the invoice is updated to the database func (s *Service) updateGatheringInvoice(ctx context.Context, invoice billing.StandardInvoice) (billing.StandardInvoice, error) { // Let's update the invoice's state - if err := s.invoiceCalculator.CalculateGatheringInvoice(&invoice); err != nil { + if err := s.invoiceCalculator.CalculateLegacyGatheringInvoice(&invoice); err != nil { return billing.StandardInvoice{}, fmt.Errorf("calculating gathering invoice: %w", err) } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index 75674d5f9a..cde6324989 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -574,7 +574,7 @@ func (s *Service) UpdateInvoice(ctx context.Context, input billing.UpdateInvoice return billing.StandardInvoice{}, fmt.Errorf("normalizing lines: %w", err) } - if err := s.invoiceCalculator.CalculateGatheringInvoice(&invoice); err != nil { + if err := s.invoiceCalculator.CalculateLegacyGatheringInvoice(&invoice); err != nil { return billing.StandardInvoice{}, fmt.Errorf("calculating invoice[%s]: %w", invoice.ID, err) } @@ -893,7 +893,7 @@ func (s *Service) RecalculateGatheringInvoices(ctx context.Context, input billin for _, invoice := range gatheringInvoices.Items { var err error - if err = s.invoiceCalculator.CalculateGatheringInvoice(&invoice); err != nil { + if err = s.invoiceCalculator.CalculateLegacyGatheringInvoice(&invoice); err != nil { return fmt.Errorf("calculating gathering invoice: %w", err) } diff --git a/openmeter/billing/service/invoicecalc/calculator.go b/openmeter/billing/service/invoicecalc/calculator.go index ccd39411f4..bb91199342 100644 --- a/openmeter/billing/service/invoicecalc/calculator.go +++ b/openmeter/billing/service/invoicecalc/calculator.go @@ -7,7 +7,8 @@ import ( ) type invoiceCalculatorsByType struct { - GatheringInvoice []Calculation + LegacyGatheringInvoice []Calculation + GatheringInvoice []GatheringCalculation GatheringInvoiceWithLiveData []Calculation Invoice []Calculation } @@ -19,32 +20,39 @@ var InvoiceCalculations = invoiceCalculatorsByType{ WithNoDependencies(CalculateDueAt), WithNoDependencies(UpsertDiscountCorrelationIDs), RecalculateDetailedLinesAndTotals, - WithNoDependencies(CalculateInvoicePeriod), + WithNoDependencies(CalculateStandardInvoiceServicePeriod), WithNoDependencies(SnapshotTaxConfigIntoLines), }, - GatheringInvoice: []Calculation{ + LegacyGatheringInvoice: []Calculation{ WithNoDependencies(UpsertDiscountCorrelationIDs), - WithNoDependencies(GatheringInvoiceCollectionAt), - WithNoDependencies(CalculateInvoicePeriod), + WithNoDependencies(LegacyGatheringInvoiceCollectionAt), + WithNoDependencies(CalculateStandardInvoiceServicePeriod), + }, + GatheringInvoice: []GatheringCalculation{ + UpsertGatheringInvoiceDiscountCorrelationIDs, + GatheringInvoiceCollectionAt, + CalculateGatheringInvoiceServicePeriod, }, // Calculations that should be running on a gathering invoice to populate line items GatheringInvoiceWithLiveData: []Calculation{ WithNoDependencies(UpsertDiscountCorrelationIDs), - WithNoDependencies(GatheringInvoiceCollectionAt), + WithNoDependencies(LegacyGatheringInvoiceCollectionAt), RecalculateDetailedLinesAndTotals, - WithNoDependencies(CalculateInvoicePeriod), + WithNoDependencies(CalculateStandardInvoiceServicePeriod), WithNoDependencies(SnapshotTaxConfigIntoLines), WithNoDependencies(FillGatheringDetailedLineMeta), }, } type ( - Calculation func(*billing.StandardInvoice, CalculatorDependencies) error + Calculation func(*billing.StandardInvoice, CalculatorDependencies) error + GatheringCalculation func(*billing.GatheringInvoice) error ) type Calculator interface { Calculate(*billing.StandardInvoice, CalculatorDependencies) error - CalculateGatheringInvoice(*billing.StandardInvoice) error + CalculateLegacyGatheringInvoice(*billing.StandardInvoice) error + CalculateGatheringInvoice(*billing.GatheringInvoice) error CalculateGatheringInvoiceWithLiveData(*billing.StandardInvoice, CalculatorDependencies) error } @@ -78,12 +86,12 @@ func (c *calculator) applyCalculations(invoice *billing.StandardInvoice, calcula billing.ValidationComponentOpenMeter) } -func (c *calculator) CalculateGatheringInvoice(invoice *billing.StandardInvoice) error { +func (c *calculator) CalculateLegacyGatheringInvoice(invoice *billing.StandardInvoice) error { if invoice.Status != billing.StandardInvoiceStatusGathering { return errors.New("invoice is not a gathering invoice") } - return c.applyCalculations(invoice, InvoiceCalculations.GatheringInvoice, CalculatorDependencies{}) + return c.applyCalculations(invoice, InvoiceCalculations.LegacyGatheringInvoice, CalculatorDependencies{}) } func (c *calculator) CalculateGatheringInvoiceWithLiveData(invoice *billing.StandardInvoice, deps CalculatorDependencies) error { @@ -94,6 +102,20 @@ func (c *calculator) CalculateGatheringInvoiceWithLiveData(invoice *billing.Stan return c.applyCalculations(invoice, InvoiceCalculations.GatheringInvoiceWithLiveData, deps) } +func (c *calculator) CalculateGatheringInvoice(invoice *billing.GatheringInvoice) error { + var errs []error + + // Note: GatheringInvoice has no ValidationIssues, so we should just return the error as is + for _, calc := range InvoiceCalculations.GatheringInvoice { + err := calc(invoice) + if err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + func WithNoDependencies(cb func(inv *billing.StandardInvoice) error) Calculation { return func(inv *billing.StandardInvoice, _ CalculatorDependencies) error { return cb(inv) diff --git a/openmeter/billing/service/invoicecalc/collectionat.go b/openmeter/billing/service/invoicecalc/collectionat.go index 45cdc3d6d8..693065507c 100644 --- a/openmeter/billing/service/invoicecalc/collectionat.go +++ b/openmeter/billing/service/invoicecalc/collectionat.go @@ -10,7 +10,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" ) -func GatheringInvoiceCollectionAt(i *billing.StandardInvoice) error { +func LegacyGatheringInvoiceCollectionAt(i *billing.StandardInvoice) error { i.CollectionAt = nil if !i.Lines.IsPresent() { @@ -31,6 +31,27 @@ func GatheringInvoiceCollectionAt(i *billing.StandardInvoice) error { return nil } +func GatheringInvoiceCollectionAt(i *billing.GatheringInvoice) error { + i.NextCollectionAt = time.Time{} + + if !i.Lines.IsPresent() { + return errors.New("lines must be expanded") + } + + for _, line := range i.Lines.OrEmpty() { + if i.NextCollectionAt.IsZero() { + i.NextCollectionAt = line.InvoiceAt + continue + } + + if i.NextCollectionAt.After(line.InvoiceAt) { + i.NextCollectionAt = line.InvoiceAt + } + } + + return nil +} + func StandardInvoiceCollectionAt(i *billing.StandardInvoice) error { if !i.Lines.IsPresent() { return errors.New("lines must be expanded") diff --git a/openmeter/billing/service/invoicecalc/discounts.go b/openmeter/billing/service/invoicecalc/discounts.go index 67b1027e6f..ff6302a3b5 100644 --- a/openmeter/billing/service/invoicecalc/discounts.go +++ b/openmeter/billing/service/invoicecalc/discounts.go @@ -22,6 +22,26 @@ func UpsertDiscountCorrelationIDs(invoice *billing.StandardInvoice) error { return nil } +func UpsertGatheringInvoiceDiscountCorrelationIDs(invoice *billing.GatheringInvoice) error { + lines, err := invoice.Lines.MapWithErr(func(line billing.GatheringLine) (billing.GatheringLine, error) { + updatedDiscounts, err := ensureDiscountCorrelationIDs(line.RateCardDiscounts) + if err != nil { + return billing.GatheringLine{}, err + } + + line.RateCardDiscounts = updatedDiscounts + + return line, nil + }) + if err != nil { + return err + } + + invoice.Lines = lines + + return nil +} + func ensureDiscountCorrelationIDs(discounts billing.Discounts) (billing.Discounts, error) { if discounts.Percentage != nil { corrID, err := generateDiscountCorrelationID(discounts.Percentage.CorrelationID) diff --git a/openmeter/billing/service/invoicecalc/mock.go b/openmeter/billing/service/invoicecalc/mock.go index 32afbd6bee..0739b198d3 100644 --- a/openmeter/billing/service/invoicecalc/mock.go +++ b/openmeter/billing/service/invoicecalc/mock.go @@ -19,11 +19,14 @@ type mockCalculator struct { calculateResult mo.Option[error] calculateResultCalled bool - calculateGatheringInvoiceResult mo.Option[error] - calculateGatheringInvoiceResultCalled bool + calculateLegacyGatheringInvoiceResult mo.Option[error] + calculateLegacyGatheringInvoiceResultCalled bool calculateGatheringInvoiceWithLiveDataResult mo.Option[error] calculateGatheringInvoiceWithLiveDataResultCalled bool + + calculateGatheringInvoiceResult mo.Option[error] + calculateGatheringInvoiceResultCalled bool } func (m *mockCalculator) Calculate(i *billing.StandardInvoice, deps CalculatorDependencies) error { @@ -41,10 +44,10 @@ func (m *mockCalculator) Calculate(i *billing.StandardInvoice, deps CalculatorDe billing.ValidationComponentOpenMeter) } -func (m *mockCalculator) CalculateGatheringInvoice(i *billing.StandardInvoice) error { - m.calculateGatheringInvoiceResultCalled = true +func (m *mockCalculator) CalculateLegacyGatheringInvoice(i *billing.StandardInvoice) error { + m.calculateLegacyGatheringInvoiceResultCalled = true - res := m.calculateGatheringInvoiceResult.MustGet() + res := m.calculateLegacyGatheringInvoiceResult.MustGet() // This simulates the same behavior as the calculate method for the original // implementation. This way the mock can be used to inject calculation errors @@ -71,18 +74,30 @@ func (m *mockCalculator) CalculateGatheringInvoiceWithLiveData(i *billing.Standa billing.ValidationComponentOpenMeter) } +func (m *mockCalculator) CalculateGatheringInvoice(i *billing.GatheringInvoice) error { + m.calculateGatheringInvoiceResultCalled = true + + res := m.calculateGatheringInvoiceResult.MustGet() + + return res +} + func (m *mockCalculator) OnCalculate(err error) { m.calculateResult = mo.Some(err) } -func (m *mockCalculator) OnCalculateGatheringInvoice(err error) { - m.calculateGatheringInvoiceResult = mo.Some(err) +func (m *mockCalculator) OnCalculateLegacyGatheringInvoice(err error) { + m.calculateLegacyGatheringInvoiceResult = mo.Some(err) } func (m *mockCalculator) OnCalculateGatheringInvoiceWithLiveData(err error) { m.calculateGatheringInvoiceWithLiveDataResult = mo.Some(err) } +func (m *mockCalculator) OnCalculateGatheringInvoice(err error) { + m.calculateGatheringInvoiceResult = mo.Some(err) +} + func (m *mockCalculator) AssertExpectations(t *testing.T) { t.Helper() @@ -90,6 +105,10 @@ func (m *mockCalculator) AssertExpectations(t *testing.T) { t.Errorf("expected Calculate to be called") } + if m.calculateLegacyGatheringInvoiceResult.IsPresent() && !m.calculateLegacyGatheringInvoiceResultCalled { + t.Errorf("expected CalculateLegacyGatheringInvoice to be called") + } + if m.calculateGatheringInvoiceResult.IsPresent() && !m.calculateGatheringInvoiceResultCalled { t.Errorf("expected CalculateGatheringInvoice to be called") } @@ -107,6 +126,9 @@ func (m *mockCalculator) Reset(t *testing.T) { m.calculateResult = mo.None[error]() m.calculateResultCalled = false + m.calculateLegacyGatheringInvoiceResult = mo.None[error]() + m.calculateLegacyGatheringInvoiceResultCalled = false + m.calculateGatheringInvoiceResult = mo.None[error]() m.calculateGatheringInvoiceResultCalled = false @@ -133,7 +155,20 @@ func (m *MockableInvoiceCalculator) Calculate(i *billing.StandardInvoice, deps C return outErr } -func (m *MockableInvoiceCalculator) CalculateGatheringInvoice(i *billing.StandardInvoice) error { +func (m *MockableInvoiceCalculator) CalculateLegacyGatheringInvoice(i *billing.StandardInvoice) error { + outErr := m.upstream.CalculateLegacyGatheringInvoice(i) + + if m.mock != nil { + err := m.mock.CalculateLegacyGatheringInvoice(i) + if err != nil { + outErr = errors.Join(outErr, err) + } + } + + return outErr +} + +func (m *MockableInvoiceCalculator) CalculateGatheringInvoice(i *billing.GatheringInvoice) error { outErr := m.upstream.CalculateGatheringInvoice(i) if m.mock != nil { diff --git a/openmeter/billing/service/invoicecalc/period.go b/openmeter/billing/service/invoicecalc/period.go index 82df1235bf..44c5157ee9 100644 --- a/openmeter/billing/service/invoicecalc/period.go +++ b/openmeter/billing/service/invoicecalc/period.go @@ -1,9 +1,14 @@ package invoicecalc -import "github.com/openmeterio/openmeter/openmeter/billing" +import ( + "github.com/samber/lo" -// CalculateInvoicePeriod calculates the period of the invoice based on the lines. -func CalculateInvoicePeriod(invoice *billing.StandardInvoice) error { + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +// CalculateStandardInvoiceServicePeriod calculates the period of the invoice based on the lines. +func CalculateStandardInvoiceServicePeriod(invoice *billing.StandardInvoice) error { var period *billing.Period for _, line := range invoice.Lines.OrEmpty() { @@ -32,3 +37,30 @@ func CalculateInvoicePeriod(invoice *billing.StandardInvoice) error { return nil } + +func CalculateGatheringInvoiceServicePeriod(invoice *billing.GatheringInvoice) error { + var period timeutil.ClosedPeriod + + for _, line := range invoice.Lines.OrEmpty() { + if line.DeletedAt != nil { + continue + } + + if lo.IsEmpty(period) { + period = line.ServicePeriod + continue + } + + if line.ServicePeriod.From.Before(period.From) { + period.From = line.ServicePeriod.From + } + + if line.ServicePeriod.To.After(period.To) { + period.To = line.ServicePeriod.To + } + } + + invoice.ServicePeriod = period + + return nil +} diff --git a/openmeter/billing/service/stdinvoiceline.go b/openmeter/billing/service/stdinvoiceline.go index 8dc9d116c8..74c14f3b53 100644 --- a/openmeter/billing/service/stdinvoiceline.go +++ b/openmeter/billing/service/stdinvoiceline.go @@ -21,6 +21,7 @@ import ( var _ billing.InvoiceLineService = (*Service)(nil) +// TODO[later]: Move this to gatheringinvoice.go func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.CreatePendingInvoiceLinesInput) (*billing.CreatePendingInvoiceLinesResult, error) { for i := range input.Lines { input.Lines[i].Namespace = input.Customer.Namespace @@ -44,8 +45,8 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C if !maxPeriodEnd.IsZero() { var errs []error for _, line := range input.Lines { - if line.Period.End.After(maxPeriodEnd) { - errs = append(errs, fmt.Errorf("line[%s]: line period end[%s] is after customer deleted at[%s]", line.ID, line.Period.End, maxPeriodEnd)) + if line.ServicePeriod.To.After(maxPeriodEnd) { + errs = append(errs, fmt.Errorf("line[%s]: line period end[%s] is after customer deleted at[%s]", line.ID, line.ServicePeriod.To, maxPeriodEnd)) } } @@ -58,7 +59,9 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C return transcationForInvoiceManipulation(ctx, s, input.Customer, func(ctx context.Context) (*billing.CreatePendingInvoiceLinesResult, error) { if len(input.Lines) == 0 { - return nil, nil + return nil, billing.ValidationError{ + Err: fmt.Errorf("no lines provided"), + } } // let's resolve the customer's settings @@ -78,12 +81,9 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C return nil, fmt.Errorf("upserting gathering invoice: %w", err) } - if gatheringInvoiceUpsertResult.Invoice == nil { - return nil, fmt.Errorf("gathering invoice is nil") - } - gatheringInvoice := *gatheringInvoiceUpsertResult.Invoice + gatheringInvoice := gatheringInvoiceUpsertResult.Invoice - linesToCreate, err := slicesx.MapWithErr(input.Lines, func(l *billing.StandardLine) (*billing.StandardLine, error) { + linesToCreate, err := slicesx.MapWithErr(input.Lines, func(l billing.GatheringLine) (billing.GatheringLine, error) { l.Namespace = input.Customer.Namespace l.Currency = input.Currency @@ -94,11 +94,11 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C normalizedLine, err := l.WithNormalizedValues() if err != nil { - return nil, fmt.Errorf("normalizing line[%s]: %w", l.ID, err) + return billing.GatheringLine{}, fmt.Errorf("normalizing line[%s]: %w", l.ID, err) } if err := normalizedLine.Validate(); err != nil { - return nil, fmt.Errorf("validating line[%s]: %w", l.ID, err) + return billing.GatheringLine{}, fmt.Errorf("validating line[%s]: %w", l.ID, err) } return normalizedLine, nil @@ -118,22 +118,17 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C return nil, fmt.Errorf("calculating invoice[%s]: %w", gatheringInvoiceID, err) } - gatheringInvoice, err = s.adapter.UpdateInvoice(ctx, gatheringInvoice) + err = s.adapter.UpdateGatheringInvoice(ctx, gatheringInvoice) if err != nil { return nil, fmt.Errorf("failed to update invoice[%s]: %w", gatheringInvoiceID, err) } - gatheringInvoice, err = s.resolveWorkflowApps(ctx, gatheringInvoice) - if err != nil { - return nil, fmt.Errorf("error resolving workflow apps for invoice [%s]: %w", gatheringInvoiceID, err) - } - // Let's resolve the created lines from the final invoice - invoiceLinesByID := lo.SliceToMap(gatheringInvoice.Lines.OrEmpty(), func(l *billing.StandardLine) (string, *billing.StandardLine) { + invoiceLinesByID := lo.SliceToMap(gatheringInvoice.Lines.OrEmpty(), func(l billing.GatheringLine) (string, billing.GatheringLine) { return l.ID, l }) - finalLines := []*billing.StandardLine{} + finalLines := []billing.GatheringLine{} for _, line := range linesToCreate { if line, ok := invoiceLinesByID[line.ID]; ok { finalLines = append(finalLines, line) @@ -142,13 +137,10 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C // Publish system event for newly created invoices if gatheringInvoiceUpsertResult.IsInvoiceNew { - event, err := billing.NewStandardInvoiceCreatedEvent(gatheringInvoice) - if err != nil { - return nil, fmt.Errorf("creating event: %w", err) - } + event := billing.NewGatheringInvoiceCreatedEvent(gatheringInvoice) if err := s.publisher.Publish(ctx, event); err != nil { - return nil, fmt.Errorf("publishing invoice[%s] created event: %w", gatheringInvoiceID, err) + return nil, fmt.Errorf("publishing gathering invoice[%s] created event: %w", gatheringInvoiceID, err) } } @@ -161,27 +153,24 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C } type upsertGatheringInvoiceForCurrencyResponse struct { - Invoice *billing.StandardInvoice + Invoice billing.GatheringInvoice IsInvoiceNew bool } func (s *Service) upsertGatheringInvoiceForCurrency(ctx context.Context, currency currencyx.Code, customerProfile billing.CustomerOverrideWithDetails) (*upsertGatheringInvoiceForCurrencyResponse, error) { // We would want to stage a pending invoice Line - pendingInvoiceList, err := s.adapter.ListInvoices(ctx, billing.ListInvoicesInput{ + pendingInvoiceList, err := s.adapter.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ Page: pagination.Page{ PageNumber: 1, PageSize: 10, }, - Customers: []string{customerProfile.Customer.ID}, - Namespaces: []string{customerProfile.Customer.Namespace}, - ExtendedStatuses: []billing.StandardInvoiceStatus{billing.StandardInvoiceStatusGathering}, - Currencies: []currencyx.Code{currency}, - OrderBy: api.InvoiceOrderByCreatedAt, - Order: sortx.OrderAsc, - IncludeDeleted: true, - Expand: billing.InvoiceExpand{ - Lines: true, - }, + Customers: []string{customerProfile.Customer.ID}, + Namespaces: []string{customerProfile.Customer.Namespace}, + Currencies: []currencyx.Code{currency}, + OrderBy: api.InvoiceOrderByCreatedAt, + Order: sortx.OrderAsc, + IncludeDeleted: true, + Expand: []billing.GatheringInvoiceExpand{billing.GatheringInvoiceExpandLines}, }) if err != nil { return nil, fmt.Errorf("fetching gathering invoices: %w", err) @@ -200,21 +189,19 @@ func (s *Service) upsertGatheringInvoiceForCurrency(ctx context.Context, currenc } // Create a new invoice - invoice, err := s.adapter.CreateInvoice(ctx, billing.CreateInvoiceAdapterInput{ - Namespace: customerProfile.Customer.Namespace, - Customer: lo.FromPtr(customerProfile.Customer), - Profile: customerProfile.MergedProfile, - Number: invoiceNumber, - Currency: currency, - Status: billing.StandardInvoiceStatusGathering, - Type: billing.InvoiceTypeStandard, + invoice, err := s.adapter.CreateGatheringInvoice(ctx, billing.CreateGatheringInvoiceAdapterInput{ + Namespace: customerProfile.Customer.Namespace, + Customer: lo.FromPtr(customerProfile.Customer), + Number: invoiceNumber, + Currency: currency, + MergedProfile: customerProfile.MergedProfile, }) if err != nil { return nil, fmt.Errorf("creating invoice: %w", err) } return &upsertGatheringInvoiceForCurrencyResponse{ - Invoice: &invoice, + Invoice: invoice, IsInvoiceNew: true, }, nil } @@ -226,7 +213,7 @@ func (s *Service) upsertGatheringInvoiceForCurrency(ctx context.Context, currenc // If the invoice was deleted, but has non-deleted lines, we need to delete those lines to prevent // them from reappearing in the recreated gathering invoice. if invoice.Lines.NonDeletedLineCount() > 0 { - invoice.Lines = invoice.Lines.Map(func(l *billing.StandardLine) *billing.StandardLine { + invoice.Lines = invoice.Lines.Map(func(l billing.GatheringLine) billing.GatheringLine { if l.DeletedAt == nil { l.DeletedAt = lo.ToPtr(clock.Now()) } @@ -236,14 +223,26 @@ func (s *Service) upsertGatheringInvoiceForCurrency(ctx context.Context, currenc invoiceID := invoice.ID - invoice, err = s.adapter.UpdateInvoice(ctx, invoice) + err = s.adapter.UpdateGatheringInvoice(ctx, invoice) if err != nil { return nil, fmt.Errorf("restoring deleted invoice[id=%s]: %w", invoiceID, err) } + + // We need to refetch the invoice to get all included lines + invoice, err = s.adapter.GetGatheringInvoiceById(ctx, billing.GetGatheringInvoiceByIdInput{ + Invoice: billing.InvoiceID{ + ID: invoiceID, + Namespace: customerProfile.Customer.Namespace, + }, + Expand: billing.GatheringInvoiceExpands{billing.GatheringInvoiceExpandLines}, + }) + if err != nil { + return nil, fmt.Errorf("refetching invoice: %w", err) + } } return &upsertGatheringInvoiceForCurrencyResponse{ - Invoice: &invoice, + Invoice: invoice, }, nil } diff --git a/openmeter/billing/stdinvoiceline.go b/openmeter/billing/stdinvoiceline.go index d79c56b149..44631e121a 100644 --- a/openmeter/billing/stdinvoiceline.go +++ b/openmeter/billing/stdinvoiceline.go @@ -22,7 +22,7 @@ import ( type StandardLineBase struct { models.ManagedResource - Metadata map[string]string `json:"metadata,omitempty"` + Metadata models.Metadata `json:"metadata,omitempty"` Annotations models.Annotations `json:"annotations,omitempty"` ManagedBy InvoiceLineManagedBy `json:"managedBy"` @@ -186,6 +186,47 @@ func (i StandardLine) LineID() LineID { } } +// ToGatheringLineBase converts the standard line to a gathering line base. +// This is temporary until the full gathering invoice functionality is split. +func (i StandardLine) ToGatheringLineBase() (GatheringLineBase, error) { + if i.UsageBased == nil { + return GatheringLineBase{}, errors.New("usage based line is required") + } + + if i.UsageBased.Price == nil { + return GatheringLineBase{}, errors.New("usage based line price is required") + } + + clonedMetadata := i.Metadata.Clone() + + clonedAnnotations, err := i.Annotations.Clone() + if err != nil { + return GatheringLineBase{}, fmt.Errorf("cloning annotations: %w", err) + } + + return GatheringLineBase{ + ManagedResource: i.ManagedResource, + Metadata: clonedMetadata, + Annotations: clonedAnnotations, + ManagedBy: i.ManagedBy, + InvoiceID: i.InvoiceID, + Currency: i.Currency, + ServicePeriod: timeutil.ClosedPeriod{ + From: i.Period.Start, + To: i.Period.End, + }, + InvoiceAt: i.InvoiceAt, + Price: lo.FromPtr(i.UsageBased.Price), + FeatureKey: i.UsageBased.FeatureKey, + TaxConfig: i.TaxConfig, + RateCardDiscounts: i.RateCardDiscounts, + ChildUniqueReferenceID: i.ChildUniqueReferenceID, + Subscription: i.Subscription, + SplitLineGroupID: i.SplitLineGroupID, + UBPConfigID: i.UsageBased.ConfigID, + }, nil +} + type StandardLineEditFunction func(*StandardLine) // CloneWithoutDependencies returns a clone of the line without any external dependencies. Could be used @@ -372,22 +413,31 @@ func (i StandardLine) WithNormalizedValues() (*StandardLine, error) { out.Period = out.Period.Truncate(streaming.MinimumWindowSizeDuration) out.InvoiceAt = out.InvoiceAt.Truncate(streaming.MinimumWindowSizeDuration) - if out.UsageBased.Price.Type() == productcatalog.FlatPriceType { - // Let's apply the default inAdvance payment term for flat prices - flatPrice, err := out.UsageBased.Price.AsFlat() - if err != nil { - return nil, fmt.Errorf("converting price to flat price: %w", err) - } - - if flatPrice.PaymentTerm == "" { - flatPrice.PaymentTerm = productcatalog.InAdvancePaymentTerm - out.UsageBased.Price = productcatalog.NewPriceFrom(flatPrice) - } + if err := setDefaultPaymentTermForFlatPrice(out.UsageBased.Price); err != nil { + return nil, fmt.Errorf("setting default payment term for flat price: %w", err) } return out, nil } +func setDefaultPaymentTermForFlatPrice(price *productcatalog.Price) error { + if price.Type() != productcatalog.FlatPriceType { + return nil + } + + flatPrice, err := price.AsFlat() + if err != nil { + return err + } + + if flatPrice.PaymentTerm == "" { + flatPrice.PaymentTerm = productcatalog.InAdvancePaymentTerm + *price = lo.FromPtr(productcatalog.NewPriceFrom(flatPrice)) + } + + return nil +} + // DissacociateChildren removes the Children both from the DBState and the current line, so that the // line can be safely persisted/managed without the children. // diff --git a/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go b/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go index 32c36dbc5a..b9190565da 100644 --- a/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go +++ b/openmeter/billing/worker/subscriptionsync/service/invoiceupdate.go @@ -14,6 +14,7 @@ import ( "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/slicesx" ) type InvoiceUpdater struct { @@ -156,10 +157,22 @@ func (u *InvoiceUpdater) provisionUpcomingLines(ctx context.Context, customerID }) for currency, lines := range linesByCurrency { - _, err := u.billingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ + gatheringLines, err := slicesx.MapWithErr(lines, func(l *billing.StandardLine) (billing.GatheringLine, error) { + base, err := l.ToGatheringLineBase() + if err != nil { + return billing.GatheringLine{}, err + } + + return billing.GatheringLine{GatheringLineBase: base}, nil + }) + if err != nil { + return fmt.Errorf("converting lines to gathering lines: %w", err) + } + + _, err = u.billingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customerID, Currency: currency, - Lines: lines, + Lines: gatheringLines, }) if err != nil { return fmt.Errorf("creating pending invoice lines: %w", err) diff --git a/openmeter/billing/worker/subscriptionsync/service/sync_test.go b/openmeter/billing/worker/subscriptionsync/service/sync_test.go index 700a624cfc..c47d11dd5a 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync_test.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync_test.go @@ -432,34 +432,32 @@ func (s *SubscriptionHandlerTestSuite) TestUncollectableCollection() { apiRequestsTotalFeature := s.SetupApiRequestsTotalFeature(ctx, namespace) defer apiRequestsTotalFeature.Cleanup() - lineServicePeriod := billing.Period{ - Start: lo.Must(time.Parse(time.RFC3339, "2025-01-01T00:00:00Z")), - End: lo.Must(time.Parse(time.RFC3339, "2025-01-02T00:00:00Z")), + lineServicePeriod := timeutil.ClosedPeriod{ + From: lo.Must(time.Parse(time.RFC3339, "2025-01-01T00:00:00Z")), + To: lo.Must(time.Parse(time.RFC3339, "2025-01-02T00:00:00Z")), } - clock.SetTime(lineServicePeriod.Start) + clock.SetTime(lineServicePeriod.From) defer clock.ResetTime() pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - unit", }), - Period: lineServicePeriod, - InvoiceAt: lineServicePeriod.End, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom( + ServicePeriod: lineServicePeriod, + InvoiceAt: lineServicePeriod.To, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom( productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(1), }, - ), + )), }, }, }, diff --git a/openmeter/event/metadata/resourcepath.go b/openmeter/event/metadata/resourcepath.go index 6327120a73..0d758460de 100644 --- a/openmeter/event/metadata/resourcepath.go +++ b/openmeter/event/metadata/resourcepath.go @@ -12,6 +12,7 @@ const ( EntitySubscriptionAddon = "subscriptionAddon" EntityInvoice = "invoice" EntityCustomer = "customer" + EntityGatheringInvoice = "gatheringInvoice" EntitySubjectKey = "subjectKey" EntityGrant = "grant" EntityApp = "app" diff --git a/openmeter/server/server_test.go b/openmeter/server/server_test.go index 694257368c..e70bd897a7 100644 --- a/openmeter/server/server_test.go +++ b/openmeter/server/server_test.go @@ -1488,10 +1488,6 @@ func (n NoopBillingService) GetInvoiceByID(ctx context.Context, input billing.Ge return billing.StandardInvoice{}, nil } -func (n NoopBillingService) InvoicePendingLines(ctx context.Context, input billing.InvoicePendingLinesInput) ([]billing.StandardInvoice, error) { - return []billing.StandardInvoice{}, nil -} - func (n NoopBillingService) AdvanceInvoice(ctx context.Context, input billing.AdvanceInvoiceInput) (billing.StandardInvoice, error) { return billing.StandardInvoice{}, nil } @@ -1528,6 +1524,15 @@ func (n NoopBillingService) RecalculateGatheringInvoices(ctx context.Context, in return nil } +// GatheringInvoiceService methods +func (n NoopBillingService) InvoicePendingLines(ctx context.Context, input billing.InvoicePendingLinesInput) ([]billing.StandardInvoice, error) { + return []billing.StandardInvoice{}, nil +} + +func (n NoopBillingService) ListGatheringInvoices(ctx context.Context, input billing.ListGatheringInvoicesInput) (pagination.Result[billing.GatheringInvoice], error) { + return pagination.Result[billing.GatheringInvoice]{}, nil +} + // SequenceService methods func (n NoopBillingService) GenerateInvoiceSequenceNumber(ctx context.Context, in billing.SequenceGenerationInput, def billing.SequenceDefinition) (string, error) { return "", nil diff --git a/pkg/models/metadata.go b/pkg/models/metadata.go index aa7c6e4488..22606cdea8 100644 --- a/pkg/models/metadata.go +++ b/pkg/models/metadata.go @@ -34,6 +34,10 @@ func (m Metadata) Merge(d Metadata) Metadata { return r } +func (m Metadata) Clone() Metadata { + return maps.Clone(m) +} + func NewMetadata[T ~map[string]string](m T) Metadata { return Metadata(m) } diff --git a/pkg/timeutil/closedperiod.go b/pkg/timeutil/closedperiod.go index 1702c40671..a78e26dc5a 100644 --- a/pkg/timeutil/closedperiod.go +++ b/pkg/timeutil/closedperiod.go @@ -98,3 +98,14 @@ func (p ClosedPeriod) Validate() error { return nil } + +func (p ClosedPeriod) Truncate(resolution time.Duration) ClosedPeriod { + return ClosedPeriod{ + From: p.From.Truncate(resolution), + To: p.To.Truncate(resolution), + } +} + +func (p ClosedPeriod) Equal(other ClosedPeriod) bool { + return p.From.Equal(other.From) && p.To.Equal(other.To) +} diff --git a/test/app/custominvoicing/invocing_test.go b/test/app/custominvoicing/invocing_test.go index 75762ab13d..07c3f61fc4 100644 --- a/test/app/custominvoicing/invocing_test.go +++ b/test/app/custominvoicing/invocing_test.go @@ -24,6 +24,7 @@ import ( "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/datetime" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" billingtest "github.com/openmeterio/openmeter/test/billing" ) @@ -166,8 +167,8 @@ func (s *CustomInvoicingTestSuite) TestInvoicingFlowHooksEnabled() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.HUF), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: issueAt, @@ -179,17 +180,15 @@ func (s *CustomInvoicingTestSuite) TestInvoicingFlowHooksEnabled() { PaymentTerm: productcatalog.InAdvancePaymentTerm, }), { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "Test item - HUF", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, InvoiceAt: issueAt, ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -204,7 +203,7 @@ func (s *CustomInvoicingTestSuite) TestInvoicingFlowHooksEnabled() { }, }, }, - }), + })), FeatureKey: "test", }, }, @@ -333,8 +332,8 @@ func (s *CustomInvoicingTestSuite) TestInvoicingFlowPaymentStatusOnly() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.HUF), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: issueAt, diff --git a/test/app/stripe/invoice_test.go b/test/app/stripe/invoice_test.go index 61677ab83d..d570f26619 100644 --- a/test/app/stripe/invoice_test.go +++ b/test/app/stripe/invoice_test.go @@ -36,6 +36,7 @@ import ( "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/datetime" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" billingtest "github.com/openmeterio/openmeter/test/billing" ) @@ -300,76 +301,68 @@ func (s *StripeInvoiceTestSuite) TestComplexInvoice() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { // Covered case: Discount caused by maximum amount - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.flatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.flatPerUnit.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), Commitments: productcatalog.Commitments{ MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(2000)), }, }), + ), }, }, { // Covered case: Very small per unit amount, high quantity, rounding to two decimal places - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - AI Usecase", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.aiFlatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.aiFlatPerUnit.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(0.00000075), - }), + })), }, }, { // Covered case: Flat line represented as UBP item - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per any usage", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.flatPerUsage.Key, - Price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.flatPerUsage.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ Amount: alpacadecimal.NewFromFloat(100), PaymentTerm: productcatalog.InArrearsPaymentTerm, - }), - Quantity: lo.ToPtr(alpacadecimal.NewFromFloat(1)), + })), }, }, { // Covered case: Multiple lines per item, tier boundary is fractional - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered graduated", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredGraduated.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredGraduated.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -390,22 +383,20 @@ func (s *StripeInvoiceTestSuite) TestComplexInvoice() { }, }, }, - }), + })), }, }, { // Covered case: minimum amount charges - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered volume", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredVolume.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredVolume.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.VolumeTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -429,7 +420,7 @@ func (s *StripeInvoiceTestSuite) TestComplexInvoice() { Commitments: productcatalog.Commitments{ MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(3000)), }, - }), + })), }, }, }, @@ -1109,21 +1100,19 @@ func (s *StripeInvoiceTestSuite) TestEmptyInvoiceGenerationZeroUsage() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: flatPerUnitFeature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: flatPerUnitFeature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(0), - }), + })), }, }, }, @@ -1293,8 +1282,8 @@ func (s *StripeInvoiceTestSuite) TestSendInvoice() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: periodStart, Name: "Flat fee", diff --git a/test/billing/collection_test.go b/test/billing/collection_test.go index ffa98b832c..cc8bde1ef1 100644 --- a/test/billing/collection_test.go +++ b/test/billing/collection_test.go @@ -17,6 +17,7 @@ import ( "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/datetime" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" ) type CollectionTestSuite struct { @@ -79,33 +80,29 @@ func (s *CollectionTestSuite) TestCollectionFlow() { res, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)})), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - volume", }), - Period: billing.Period{Start: periodStart, End: period2End}, - InvoiceAt: period2End, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: period2End}, + InvoiceAt: period2End, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.VolumeTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -117,7 +114,7 @@ func (s *CollectionTestSuite) TestCollectionFlow() { UnitPrice: &productcatalog.PriceTierUnitPrice{Amount: alpacadecimal.NewFromFloat(0.5)}, }, }, - }), + })), }, }, }, @@ -128,8 +125,8 @@ func (s *CollectionTestSuite) TestCollectionFlow() { gatheringInvoiceID = res.Invoice.InvoiceID() // Validate collection_at calculation - s.NotNil(res.Invoice.CollectionAt) - s.Equal(periodEnd, *res.Invoice.CollectionAt, "collection_at should be the min of the invoice_at of the lines") + s.NotNil(res.Invoice.NextCollectionAt) + s.Equal(periodEnd, res.Invoice.NextCollectionAt, "collection_at should be the min of the invoice_at of the lines") }) // Given a gatherting invoice exists @@ -234,12 +231,12 @@ func (s *CollectionTestSuite) TestCollectionFlowWithFlatFeeOnly() { tcs := []struct { name string namespace string - line *billing.StandardLine + line billing.GatheringLine }{ { name: "flat fee only", namespace: "ns-collection-flow-flat-fee", - line: billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + line: billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: periodStart, Name: "Flat fee", @@ -251,7 +248,7 @@ func (s *CollectionTestSuite) TestCollectionFlowWithFlatFeeOnly() { { name: "ubp flat fee only", namespace: "ns-collection-flow-ubp-flat-fee", - line: billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + line: billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: periodStart, Name: "Flat fee", @@ -283,11 +280,11 @@ func (s *CollectionTestSuite) TestCollectionFlowWithFlatFeeOnly() { pendingLineResult, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{tc.line}, + Lines: []billing.GatheringLine{tc.line}, }) s.NoError(err) s.Len(pendingLineResult.Lines, 1) - s.NotNil(pendingLineResult.Invoice.CollectionAt) + s.NotNil(pendingLineResult.Invoice.NextCollectionAt) // When clock.SetTime(periodStart.Add(time.Hour * 1)) @@ -337,26 +334,22 @@ func (s *CollectionTestSuite) TestCollectionFlowWithFlatFeeEditing() { pendingLineResult, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ - ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ - Name: "UBP - unit", - }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{Name: "UBP - unit"}), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: *productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), }, }, }, }) s.NoError(err) s.Len(pendingLineResult.Lines, 1) - s.NotNil(pendingLineResult.Invoice.CollectionAt) + s.NotNil(pendingLineResult.Invoice.NextCollectionAt) clock.SetTime(periodEnd.Add(time.Hour * 1)) invoices, err := s.BillingService.InvoicePendingLines(ctx, billing.InvoicePendingLinesInput{ @@ -448,17 +441,15 @@ func (s *CollectionTestSuite) TestAnchoredAlignment_SetsCollectionAtToNextAnchor _, err = s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{Name: "UBP - unit"}), - Period: billing.Period{Start: periodStart, End: periodEnd}, + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, InvoiceAt: periodEnd, ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: s.SetupApiRequestsTotalFeature(ctx, namespace).Feature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), + FeatureKey: s.SetupApiRequestsTotalFeature(ctx, namespace).Feature.Key, + Price: *productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), }, }, }, @@ -516,26 +507,22 @@ func (s *CollectionTestSuite) TestCollectionFlowWithUBPEditingExtendingCollectio pendingLineResult, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ - ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ - Name: "UBP - unit", - }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{Name: "UBP - unit"}), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: *productcatalog.NewPriceFrom(productcatalog.UnitPrice{Amount: alpacadecimal.NewFromFloat(1)}), }, }, }, }) s.NoError(err) s.Len(pendingLineResult.Lines, 1) - s.NotNil(pendingLineResult.Invoice.CollectionAt) + s.NotNil(pendingLineResult.Invoice.NextCollectionAt) clock.SetTime(periodEnd.Add(time.Hour * 1)) invoices, err := s.BillingService.InvoicePendingLines(ctx, billing.InvoicePendingLinesInput{ diff --git a/test/billing/discount_test.go b/test/billing/discount_test.go index 50fcec8469..e6f97d022e 100644 --- a/test/billing/discount_test.go +++ b/test/billing/discount_test.go @@ -18,6 +18,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" ) type DiscountsTestSuite struct { @@ -87,14 +88,14 @@ func (s *DiscountsTestSuite) TestCorrelationIDHandling() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Namespace: namespace, Name: "Test item1", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, InvoiceAt: periodEnd, @@ -108,12 +109,10 @@ func (s *DiscountsTestSuite) TestCorrelationIDHandling() { }, }, }, - }, - UsageBased: &billing.UsageBasedLine{ FeatureKey: featureFlatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), - }), + })), }, }, }, @@ -284,14 +283,14 @@ func (s *DiscountsTestSuite) TestUnitDiscountProgressiveBilling() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Namespace: namespace, Name: "Test item1", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, InvoiceAt: periodEnd, @@ -305,12 +304,10 @@ func (s *DiscountsTestSuite) TestUnitDiscountProgressiveBilling() { }, }, }, - }, - UsageBased: &billing.UsageBasedLine{ FeatureKey: featureFlatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), - }), + })), }, }, }, diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 0d15f7b4f1..5f3e7e2800 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -33,6 +33,7 @@ import ( "github.com/openmeterio/openmeter/pkg/datetime" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" + "github.com/openmeterio/openmeter/pkg/timeutil" ) type InvoicingTestSuite struct { @@ -118,10 +119,10 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { // Given we have a default profile for the namespace - billingProfile := s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID()) + _ = s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID()) - var items []*billing.StandardLine - var HUFItem *billing.StandardLine + var items []billing.GatheringLine + var HUFItem billing.GatheringLine s.T().Run("CreateInvoiceItems", func(t *testing.T) { // When we create invoice items @@ -130,8 +131,8 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Namespace: namespace, Period: billing.Period{Start: periodStart, End: periodEnd}, @@ -164,8 +165,8 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.HUF), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: issueAt, @@ -177,17 +178,15 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { PaymentTerm: productcatalog.InAdvancePaymentTerm, }), { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "Test item - HUF", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, InvoiceAt: issueAt, ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -202,7 +201,7 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { }, }, }, - }), + })), FeatureKey: "test", }, }, @@ -211,28 +210,27 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { // Then we should have the items created require.NoError(s.T(), err) - items = []*billing.StandardLine{usdItem, res.Lines[0], res.Lines[1]} + items = []billing.GatheringLine{usdItem, res.Lines[0], res.Lines[1]} // Then we should have an usd invoice automatically created - usdInvoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ + usdInvoices, err := s.BillingService.ListGatheringInvoices(ctx, billing.ListGatheringInvoicesInput{ Page: pagination.Page{ PageNumber: 1, PageSize: 10, }, - Namespaces: []string{namespace}, - Customers: []string{customerEntity.ID}, - Expand: billing.InvoiceExpandAll, - ExtendedStatuses: []billing.StandardInvoiceStatus{billing.StandardInvoiceStatusGathering}, - Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, + Namespaces: []string{namespace}, + Customers: []string{customerEntity.ID}, + Expand: []billing.GatheringInvoiceExpand{billing.GatheringInvoiceExpandLines}, + Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, }) require.NoError(s.T(), err) require.Len(s.T(), usdInvoices.Items, 1) usdInvoice := usdInvoices.Items[0] usdInvoiceLine := usdInvoice.Lines.MustGet()[0] - expectedUSDLine := &billing.StandardLine{ - StandardLineBase: billing.StandardLineBase{ + expectedUSDLine := billing.GatheringLine{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ ID: items[0].ID, Namespace: namespace, @@ -241,7 +239,7 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { UpdatedAt: usdInvoiceLine.UpdatedAt.In(time.UTC), }), - Period: billing.Period{Start: periodStart.Truncate(time.Microsecond), End: periodEnd.Truncate(time.Microsecond)}, + ServicePeriod: timeutil.ClosedPeriod{From: periodStart.Truncate(time.Microsecond), To: periodEnd.Truncate(time.Microsecond)}, InvoiceID: usdInvoice.ID, InvoiceAt: issueAt.In(time.UTC), @@ -256,78 +254,48 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { "string_key": "value", "float_key": 1.0, }, - }, - UsageBased: &billing.UsageBasedLine{ - ConfigID: usdInvoiceLine.UsageBased.ConfigID, - Price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + UBPConfigID: usdInvoiceLine.UBPConfigID, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ Amount: alpacadecimal.NewFromFloat(100), PaymentTerm: productcatalog.InAdvancePaymentTerm, - }), + })), }, } // Let's make sure that the workflow config is cloned - expectedInvoice := billing.StandardInvoice{ - StandardInvoiceBase: billing.StandardInvoiceBase{ - Namespace: namespace, - ID: usdInvoice.ID, + expectedInvoice := billing.GatheringInvoice{ + GatheringInvoiceBase: billing.GatheringInvoiceBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: namespace, + ID: usdInvoice.ID, + Name: "GATHER-TECU-USD-1", + CreatedAt: usdInvoice.CreatedAt, + UpdatedAt: usdInvoice.UpdatedAt, + }), - Type: billing.InvoiceTypeStandard, - Number: "GATHER-TECU-USD-1", - Currency: currencyx.Code(currency.USD), - Status: billing.StandardInvoiceStatusGathering, - Period: &billing.Period{Start: periodStart.Truncate(time.Second), End: periodEnd.Truncate(time.Second)}, - - CreatedAt: usdInvoice.CreatedAt, - UpdatedAt: usdInvoice.UpdatedAt, - - Workflow: billing.InvoiceWorkflow{ - Config: billing.WorkflowConfig{ - Collection: billingProfile.WorkflowConfig.Collection, - Invoicing: billingProfile.WorkflowConfig.Invoicing, - Payment: billingProfile.WorkflowConfig.Payment, - Tax: billingProfile.WorkflowConfig.Tax, - }, - SourceBillingProfileID: billingProfile.ID, - AppReferences: *billingProfile.AppReferences, - Apps: billingProfile.Apps, - }, + Number: "GATHER-TECU-USD-1", + Currency: currencyx.Code(currency.USD), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart.Truncate(time.Second), To: periodEnd.Truncate(time.Second)}, // The customer snapshot - Customer: billing.InvoiceCustomer{ - // Usage attribution fields - Key: customerEntity.Key, - CustomerID: customerEntity.ID, - UsageAttribution: &streaming.CustomerUsageAttribution{ - ID: customerEntity.ID, - Key: customerEntity.Key, - SubjectKeys: customerEntity.UsageAttribution.SubjectKeys, - }, - - // Other fields - Name: customerEntity.Name, - BillingAddress: customerEntity.BillingAddress, - }, - Supplier: billingProfile.Supplier, + CustomerID: customerEntity.ID, SchemaLevel: billingadapter.DefaultInvoiceWriteSchemaLevel, }, - Lines: billing.NewStandardInvoiceLines([]*billing.StandardLine{expectedUSDLine}), - - ExpandedFields: billing.InvoiceExpandAll, + Lines: billing.NewGatheringInvoiceLines([]billing.GatheringLine{expectedUSDLine}), } s.NoError(invoicecalc.GatheringInvoiceCollectionAt(&expectedInvoice)) ExpectJSONEqual(s.T(), - expectedInvoice.RemoveMetaForCompare(), - usdInvoice.RemoveMetaForCompare()) + lo.Must(expectedInvoice.WithoutDBState()), + lo.Must(usdInvoice.WithoutDBState())) require.Len(s.T(), items, 3) // Validate that the create returns the expected items items[0].CreatedAt = expectedUSDLine.CreatedAt items[0].UpdatedAt = expectedUSDLine.UpdatedAt - require.Equal(s.T(), items[0].RemoveMetaForCompare(), expectedUSDLine.RemoveMetaForCompare()) + require.Equal(s.T(), lo.Must(items[0].WithoutDBState()), lo.Must(expectedUSDLine.WithoutDBState())) require.NotEmpty(s.T(), items[1].ID) HUFItem = items[1] @@ -465,8 +433,8 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Namespace: namespace, Period: billing.Period{Start: periodStart, End: periodEnd}, @@ -483,7 +451,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { PerUnitAmount: alpacadecimal.NewFromFloat(100), PaymentTerm: productcatalog.InAdvancePaymentTerm, }), - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Namespace: namespace, Period: billing.Period{Start: periodStart, End: periodEnd}, @@ -641,8 +609,8 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Name: "Test item1", Namespace: namespace, Period: billing.Period{Start: periodStart, End: periodEnd}, @@ -1043,6 +1011,7 @@ func (s *InvoicingTestSuite) TestInvoicingFlowErrorHandling() { // Given that the app will return a validation error mockApp.OnValidateStandardInvoice(billing.NewValidationError("test1", "validation error")) calcMock.OnCalculate(nil) + calcMock.OnCalculateLegacyGatheringInvoice(nil) calcMock.OnCalculateGatheringInvoice(nil) // When we create a draft invoice @@ -1275,6 +1244,7 @@ func (s *InvoicingTestSuite) TestInvoicingFlowErrorHandling() { mockApp.OnFinalizeStandardInvoice(nil) calcMock.OnCalculate(nil) calcMock.OnCalculateGatheringInvoice(nil) + calcMock.OnCalculateLegacyGatheringInvoice(nil) // When we create a draft invoice invoice := s.CreateDraftInvoice(s.T(), ctx, DraftInvoiceInput{ @@ -1526,55 +1496,48 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.flatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.flatPerUnit.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), Commitments: productcatalog.Commitments{ MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(2000)), }, - }), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per any usage", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - Price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ Amount: alpacadecimal.NewFromFloat(100), PaymentTerm: productcatalog.InArrearsPaymentTerm, - }), - Quantity: lo.ToPtr(alpacadecimal.NewFromFloat(1)), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered graduated", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredGraduated.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredGraduated.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -1595,21 +1558,19 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { }, }, }, - }), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered volume", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredVolume.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredVolume.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.VolumeTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -1633,7 +1594,7 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { Commitments: productcatalog.Commitments{ MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(3000)), }, - }), + })), }, }, }, @@ -1651,11 +1612,8 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { // The flat fee line should not be truncated require.Equal(s.T(), - billing.Period{ - Start: truncatedPeriodStart, - End: truncatedPeriodEnd, - }, - lines.flatFee.Period, + timeutil.ClosedPeriod{From: truncatedPeriodStart, To: truncatedPeriodEnd}, + lines.flatFee.ServicePeriod, "period should not be truncated", ) require.Equal(s.T(), @@ -1665,13 +1623,10 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { ) // The pending invoice items should be truncated to 1 min resolution (start => up to next, end down to previous) - for _, line := range []*billing.StandardLine{lines.flatPerUnit, lines.tieredGraduated, lines.tieredVolume} { + for _, line := range []billing.GatheringLine{lines.flatPerUnit, lines.tieredGraduated, lines.tieredVolume} { require.Equal(s.T(), - billing.Period{ - Start: testutils.GetRFC3339Time(s.T(), "2024-09-02T12:13:14Z"), - End: testutils.GetRFC3339Time(s.T(), "2024-09-03T12:13:14Z"), - }, - line.Period, + timeutil.ClosedPeriod{From: testutils.GetRFC3339Time(s.T(), "2024-09-02T12:13:14Z"), To: testutils.GetRFC3339Time(s.T(), "2024-09-03T12:13:14Z")}, + line.ServicePeriod, "period should be truncated to 1 min resolution", ) @@ -2045,7 +2000,7 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { s.NotNil(tieredGraduated.SplitLineHierarchy) tieredGraduatedHierarchy := tieredGraduated.SplitLineHierarchy - require.True(s.T(), tieredGraduatedHierarchy.Group.ServicePeriod.Equal(lines.tieredGraduated.Period)) + require.True(s.T(), tieredGraduatedHierarchy.Group.ServicePeriod.ToClosedPeriod().Equal(lines.tieredGraduated.ServicePeriod)) require.Len(s.T(), tieredGraduatedHierarchy.Lines, 3, "there should be to child lines [id=%s]", tieredGraduatedHierarchy.Group.ID) require.True(s.T(), tieredGraduatedHierarchy.Lines[0].Line.Period.Equal(billing.Period{ Start: truncatedPeriodStart, @@ -2236,15 +2191,15 @@ func (s *InvoicingTestSuite) TestUBPProgressiveInvoicing() { for _, line := range []*billing.StandardLine{flatPerUnit, tieredGraduated} { require.True(s.T(), expectedPeriod.Equal(line.Period), "period should be changed for the line items") } - require.True(s.T(), tieredVolume.Period.Equal(lines.tieredVolume.Period), "period should be unchanged for the tiered volume line") - require.True(s.T(), flatFee.Period.Equal(lines.flatFee.Period), "period should be unchanged for the flat line") + require.True(s.T(), tieredVolume.Period.ToClosedPeriod().Equal(lines.tieredVolume.ServicePeriod), "period should be unchanged for the tiered volume line") + require.True(s.T(), flatFee.Period.ToClosedPeriod().Equal(lines.flatFee.ServicePeriod), "period should be unchanged for the flat line") // Let's validate the output of the split itself: no new split should have occurred s.sortedSplitLineGroupChildren(tieredGraduated) tieredGraduatedHierarchy := tieredGraduated.SplitLineHierarchy s.NotNil(tieredGraduatedHierarchy) - require.True(s.T(), tieredGraduatedHierarchy.Group.ServicePeriod.Equal(lines.tieredGraduated.Period)) + require.True(s.T(), tieredGraduatedHierarchy.Group.ServicePeriod.ToClosedPeriod().Equal(lines.tieredGraduated.ServicePeriod)) require.Len(s.T(), tieredGraduatedHierarchy.Lines, 3, "there should be to child lines [id=%s]", tieredGraduatedHierarchy.Group.ID) require.True(s.T(), tieredGraduatedHierarchy.Lines[0].Line.Period.Equal(billing.Period{ Start: truncatedPeriodStart, @@ -2411,26 +2366,24 @@ func (s *InvoicingTestSuite) TestUBPGraduatingFlatFeeTier1() { // Given we have a default profile for the namespace s.ProvisionBillingProfile(ctx, namespace, sandboxApp.GetID(), WithProgressiveBilling()) - var pendingLine *billing.StandardLine + var pendingLine billing.GatheringLine s.Run("create pending invoice items", func() { // When we create pending invoice items pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered graduated", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredGraduated.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredGraduated.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -2457,7 +2410,7 @@ func (s *InvoicingTestSuite) TestUBPGraduatingFlatFeeTier1() { }, }, }, - }), + })), }, }, }, @@ -2732,55 +2685,48 @@ func (s *InvoicingTestSuite) TestUBPNonProgressiveInvoicing() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.flatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.flatPerUnit.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), Commitments: productcatalog.Commitments{ MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(2000)), }, - }), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per any usage", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - Price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ Amount: alpacadecimal.NewFromFloat(100), PaymentTerm: productcatalog.InArrearsPaymentTerm, - }), - Quantity: lo.ToPtr(alpacadecimal.NewFromFloat(1)), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered graduated", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredGraduated.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredGraduated.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -2801,21 +2747,19 @@ func (s *InvoicingTestSuite) TestUBPNonProgressiveInvoicing() { }, }, }, - }), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - Tiered volume", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: features.tieredVolume.Key, - Price: productcatalog.NewPriceFrom(productcatalog.TieredPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: features.tieredVolume.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.TieredPrice{ Mode: productcatalog.VolumeTieredPrice, Tiers: []productcatalog.PriceTier{ { @@ -2839,7 +2783,7 @@ func (s *InvoicingTestSuite) TestUBPNonProgressiveInvoicing() { Commitments: productcatalog.Commitments{ MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(3000)), }, - }), + })), }, }, }, @@ -2856,13 +2800,10 @@ func (s *InvoicingTestSuite) TestUBPNonProgressiveInvoicing() { } // The pending invoice items should be truncated to 1 min resolution (start => up to next, end down to previous) - for _, line := range []*billing.StandardLine{lines.flatPerUnit, lines.tieredGraduated, lines.tieredVolume, lines.flatFee} { + for _, line := range []billing.GatheringLine{lines.flatPerUnit, lines.tieredGraduated, lines.tieredVolume, lines.flatFee} { require.Equal(s.T(), - billing.Period{ - Start: truncatedPeriodStart, - End: truncatedPeriodEnd, - }, - line.Period, + timeutil.ClosedPeriod{From: truncatedPeriodStart, To: truncatedPeriodEnd}, + line.ServicePeriod, "period should be truncated to 1 min resolution", ) @@ -3072,10 +3013,10 @@ func (s *InvoicingTestSuite) sortedSplitLineGroupChildren(line *billing.Standard } type ubpPendingLines struct { - flatPerUnit *billing.StandardLine - flatFee *billing.StandardLine - tieredGraduated *billing.StandardLine - tieredVolume *billing.StandardLine + flatPerUnit billing.GatheringLine + flatFee billing.GatheringLine + tieredGraduated billing.GatheringLine + tieredVolume billing.GatheringLine } type ubpFeatures struct { @@ -3243,24 +3184,22 @@ func (s *InvoicingTestSuite) TestGatheringInvoiceRecalculation() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: flatPerUnitFeature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: flatPerUnitFeature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), Commitments: productcatalog.Commitments{ MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(2000)), }, - }), + })), }, }, }, @@ -3410,21 +3349,19 @@ func (s *InvoicingTestSuite) TestEmptyInvoiceGenerationZeroUsage() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: flatPerUnitFeature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: flatPerUnitFeature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(0), - }), + })), }, }, }, @@ -3530,21 +3467,19 @@ func (s *InvoicingTestSuite) TestEmptyInvoiceGenerationZeroPrice() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - FLAT per unit", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: flatPerUnitFeature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: flatPerUnitFeature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(0), - }), + })), }, }, }, @@ -3686,19 +3621,17 @@ func (s *InvoicingTestSuite) TestProgressiveBillLate() { pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - volume", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom( + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom( productcatalog.TieredPrice{ Mode: productcatalog.VolumeTieredPrice, Tiers: []productcatalog.PriceTier{ @@ -3715,8 +3648,7 @@ func (s *InvoicingTestSuite) TestProgressiveBillLate() { }, }, }, - }, - ), + })), }, }, }, @@ -3781,19 +3713,17 @@ func (s *InvoicingTestSuite) TestProgressiveBillingOverride() { pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - volume", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom( + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom( productcatalog.TieredPrice{ Mode: productcatalog.VolumeTieredPrice, Tiers: []productcatalog.PriceTier{ @@ -3811,25 +3741,23 @@ func (s *InvoicingTestSuite) TestProgressiveBillingOverride() { }, }, }, - ), + )), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - unit", }), - Period: billing.Period{Start: periodStart, End: periodStart.Add(24 * time.Hour)}, - InvoiceAt: periodStart.Add(24 * time.Hour), - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom( + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodStart.Add(24 * time.Hour)}, + InvoiceAt: periodStart.Add(24 * time.Hour), + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom( productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(1), }, - ), + )), }, }, }, @@ -3893,19 +3821,17 @@ func (s *InvoicingTestSuite) TestSortLines() { pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Name: "UBP - volume", }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: apiRequestsTotalFeature.Feature.Key, - Price: productcatalog.NewPriceFrom( + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: apiRequestsTotalFeature.Feature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom( productcatalog.TieredPrice{ Mode: productcatalog.GraduatedTieredPrice, Tiers: []productcatalog.PriceTier{ @@ -3932,7 +3858,7 @@ func (s *InvoicingTestSuite) TestSortLines() { MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100000)), }, }, - ), + )), }, }, }, @@ -4022,8 +3948,8 @@ func (s *InvoicingTestSuite) TestGatheringInvoicePeriodPersisting() { pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: periodStart, Name: "Flat fee", @@ -4054,8 +3980,8 @@ func (s *InvoicingTestSuite) TestGatheringInvoicePeriodPersisting() { pendingLines, err = s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: newPeriodStart, End: newPeriodEnd}, InvoiceAt: newPeriodStart, Name: "Flat fee", @@ -4133,8 +4059,8 @@ func (s *InvoicingTestSuite) TestCreatePendingInvoiceLinesForDeletedCustomers() pendingLines, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: periodStart, End: periodEnd}, InvoiceAt: periodStart, Name: "Flat fee", @@ -4170,12 +4096,11 @@ func (s *InvoicingTestSuite) TestCreatePendingInvoiceLinesForDeletedCustomers() pendingLines, err = s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ - Period: billing.Period{Start: clock.Now(), End: clock.Now().Add(time.Hour * 24)}, - InvoiceAt: clock.Now(), - Name: "Flat fee", - + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ + Period: billing.Period{Start: clock.Now(), End: clock.Now().Add(time.Hour * 24)}, + InvoiceAt: clock.Now(), + Name: "Flat fee", PerUnitAmount: alpacadecimal.NewFromFloat(10), PaymentTerm: productcatalog.InAdvancePaymentTerm, }), @@ -4275,24 +4200,22 @@ func (s *InvoicingTestSuite) TestSnapshotQuantityInvalidDatabaseState() { billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.ManagedResource{ NamespacedModel: models.NamespacedModel{ Namespace: namespace, }, Name: "UBP - snapshot", }, - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - }, - UsageBased: &billing.UsageBasedLine{ - FeatureKey: snapshotFeature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + FeatureKey: snapshotFeature.Key, + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(1), - }), + })), }, }, }, diff --git a/test/billing/schemamigration_test.go b/test/billing/schemamigration_test.go index aa0db6083a..28b72134a9 100644 --- a/test/billing/schemamigration_test.go +++ b/test/billing/schemamigration_test.go @@ -22,6 +22,7 @@ import ( "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" ) type SchemaMigrationTestSuite struct { @@ -113,17 +114,17 @@ func (s *SchemaMigrationTestSuite) TestSchemaLevel1Migration() { _, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: customerEntity.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Namespace: namespace, Name: lineNameDeletedDetailed, }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - Currency: currencyx.Code(currency.USD), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + Currency: currencyx.Code(currency.USD), RateCardDiscounts: billing.Discounts{ Percentage: &billing.PercentageDiscount{ PercentageDiscount: productcatalog.PercentageDiscount{ @@ -131,24 +132,22 @@ func (s *SchemaMigrationTestSuite) TestSchemaLevel1Migration() { }, }, }, - }, - UsageBased: &billing.UsageBasedLine{ FeatureKey: featureFlatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), - }), + })), }, }, { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Namespace: namespace, Name: lineNameActiveDetailed, }), - Period: billing.Period{Start: periodStart, End: periodEnd}, - InvoiceAt: periodEnd, - ManagedBy: billing.ManuallyManagedLine, - Currency: currencyx.Code(currency.USD), + ServicePeriod: timeutil.ClosedPeriod{From: periodStart, To: periodEnd}, + InvoiceAt: periodEnd, + ManagedBy: billing.ManuallyManagedLine, + Currency: currencyx.Code(currency.USD), RateCardDiscounts: billing.Discounts{ Percentage: &billing.PercentageDiscount{ PercentageDiscount: productcatalog.PercentageDiscount{ @@ -156,12 +155,10 @@ func (s *SchemaMigrationTestSuite) TestSchemaLevel1Migration() { }, }, }, - }, - UsageBased: &billing.UsageBasedLine{ FeatureKey: featureFlatPerUnit.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), - }), + })), }, }, }, diff --git a/test/billing/suite.go b/test/billing/suite.go index d4ca7a74e0..fbd496b5a6 100644 --- a/test/billing/suite.go +++ b/test/billing/suite.go @@ -383,8 +383,8 @@ func (s *BaseSuite) CreateGatheringInvoice(t *testing.T, ctx context.Context, in billing.CreatePendingInvoiceLinesInput{ Customer: in.Customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine( + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine( billing.NewFlatFeeLineInput{ Namespace: namespace, Period: billing.Period{Start: periodStart, End: periodEnd}, @@ -399,7 +399,7 @@ func (s *BaseSuite) CreateGatheringInvoice(t *testing.T, ctx context.Context, in PaymentTerm: productcatalog.InArrearsPaymentTerm, }, ), - billing.NewFlatFeeLine( + billing.NewFlatFeeGatheringLine( billing.NewFlatFeeLineInput{ Namespace: namespace, Period: billing.Period{Start: periodStart, End: periodEnd}, diff --git a/test/billing/tax_test.go b/test/billing/tax_test.go index ce109ed840..7cfd908327 100644 --- a/test/billing/tax_test.go +++ b/test/billing/tax_test.go @@ -19,6 +19,7 @@ import ( "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" ) type InvoicingTaxTestSuite struct { @@ -203,14 +204,14 @@ func (s *InvoicingTaxTestSuite) TestLineSplittingRetainsTaxConfig() { billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ { - StandardLineBase: billing.StandardLineBase{ + GatheringLineBase: billing.GatheringLineBase{ ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ Namespace: namespace, Name: "Test item - USD", }), - Period: billing.Period{Start: now, End: now.Add(time.Hour * 24)}, + ServicePeriod: timeutil.ClosedPeriod{From: now, To: now.Add(time.Hour * 24)}, InvoiceAt: now.Add(time.Hour * 24), ManagedBy: billing.ManuallyManagedLine, @@ -220,15 +221,13 @@ func (s *InvoicingTaxTestSuite) TestLineSplittingRetainsTaxConfig() { Metadata: map[string]string{ "key": "value", }, - }, - UsageBased: &billing.UsageBasedLine{ FeatureKey: flatPerUnitFeature.Key, - Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Price: lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.UnitPrice{ Amount: alpacadecimal.NewFromFloat(100), Commitments: productcatalog.Commitments{ MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(2000)), }, - }), + })), }, }, }, @@ -272,8 +271,8 @@ func (s *InvoicingTaxTestSuite) generateDraftInvoice(ctx context.Context, custom billing.CreatePendingInvoiceLinesInput{ Customer: customer.GetID(), Currency: currencyx.Code(currency.USD), - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{Start: now, End: now.Add(time.Hour * 24)}, InvoiceAt: now, diff --git a/test/billing/ubpflatfee_test.go b/test/billing/ubpflatfee_test.go index 5e3b0db962..0665d47165 100644 --- a/test/billing/ubpflatfee_test.go +++ b/test/billing/ubpflatfee_test.go @@ -46,10 +46,11 @@ func (s *UBPFlatFeeLineTestSuite) TestPendingLineCreation() { } s.Run("should create a pending line", func() { - lineIn := billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + lineIn := billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: period, InvoiceAt: period.End, + Currency: "USD", Name: "test in arrears", PerUnitAmount: alpacadecimal.NewFromInt(100), PaymentTerm: productcatalog.InArrearsPaymentTerm, @@ -58,7 +59,7 @@ func (s *UBPFlatFeeLineTestSuite) TestPendingLineCreation() { res, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: cust.GetID(), Currency: "USD", - Lines: []*billing.StandardLine{ + Lines: []billing.GatheringLine{ lineIn, }, }) @@ -67,18 +68,20 @@ func (s *UBPFlatFeeLineTestSuite) TestPendingLineCreation() { s.NotNil(res) line := res.Lines[0] - expected := lineIn.Clone() + expected, err := lineIn.Clone() + s.NoError(err) // Let's add fields coming from the line creation + expected.Namespace = cust.Namespace expected.InvoiceID = res.Invoice.ID expected.ID = line.ID expected.CreatedAt = line.CreatedAt expected.UpdatedAt = line.UpdatedAt - expected.UsageBased.ConfigID = line.UsageBased.ConfigID + expected.UBPConfigID = line.UBPConfigID ExpectJSONEqual(s.T(), - expected.RemoveCircularReferences().RemoveMetaForCompare(), - line.RemoveCircularReferences().RemoveMetaForCompare()) + lo.Must(expected.WithoutDBState()), + lo.Must(line.WithoutDBState())) }) // Given the line on gathering invoice is created @@ -156,8 +159,8 @@ func (s *UBPFlatFeeLineTestSuite) TestPercentageDiscount() { _, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: cust.GetID(), Currency: "USD", - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: period, InvoiceAt: period.End, @@ -235,8 +238,8 @@ func (s *UBPFlatFeeLineTestSuite) TestValidations() { _, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: cust.GetID(), Currency: "USD", - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: period, InvoiceAt: period.End, @@ -261,8 +264,8 @@ func (s *UBPFlatFeeLineTestSuite) TestValidations() { _, err := s.BillingService.CreatePendingInvoiceLines(ctx, billing.CreatePendingInvoiceLinesInput{ Customer: cust.GetID(), Currency: "USD", - Lines: []*billing.StandardLine{ - billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + Lines: []billing.GatheringLine{ + billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ Period: billing.Period{ Start: period.Start, End: period.Start, diff --git a/test/customer/subject.go b/test/customer/subject.go index bdaecd8b1d..1d8d6984ac 100644 --- a/test/customer/subject.go +++ b/test/customer/subject.go @@ -375,7 +375,7 @@ func (s *CustomerHandlerTestSuite) TestMultiSubjectIntegrationFlow(ctx context.C t.Cleanup(clock.ResetTime) periodStart := now - pendingLine := billing.NewFlatFeeLine(billing.NewFlatFeeLineInput{ + pendingLine := billing.NewFlatFeeGatheringLine(billing.NewFlatFeeLineInput{ ID: ulid.Make().String(), CreatedAt: now, UpdatedAt: now, @@ -393,7 +393,7 @@ func (s *CustomerHandlerTestSuite) TestMultiSubjectIntegrationFlow(ctx context.C ID: createdCustomer.ID, }, Currency: currencyx.Code("USD"), - Lines: []*billing.StandardLine{pendingLine}, + Lines: []billing.GatheringLine{pendingLine}, }) require.NoError(t, err, "creating pending invoice lines should succeed") require.NotNil(t, result)