Skip to content

Commit c416f52

Browse files
Merge pull request #62 from hyperledger/dbsqltest
Expose the mock provider for other packages to use in tests
2 parents 28ad317 + d7f3de7 commit c416f52

12 files changed

+703
-233
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"Infof",
2626
"Kaleido",
2727
"microservices",
28+
"Nillable",
2829
"Nowarn",
2930
"openapi",
3031
"passwordfile",

pkg/dbsql/crud.go

Lines changed: 215 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"database/sql"
2222
"fmt"
23+
"reflect"
2324
"strings"
2425

2526
sq "github.com/Masterminds/squirrel"
@@ -29,6 +30,12 @@ import (
2930
"github.com/hyperledger/firefly-common/pkg/log"
3031
)
3132

33+
const (
34+
ColumnID = "id"
35+
ColumnCreated = "created"
36+
ColumnUpdated = "updated"
37+
)
38+
3239
type ChangeEventType int
3340

3441
const (
@@ -45,17 +52,58 @@ const (
4552
UpsertOptimizationExisting
4653
)
4754

55+
type GetOption int
56+
57+
const (
58+
FailIfNotFound GetOption = iota
59+
)
60+
4861
type PostCompletionHook func()
4962

50-
type WithID interface {
63+
type Resource interface {
5164
GetID() *fftypes.UUID
65+
SetCreated(*fftypes.FFTime)
66+
SetUpdated(*fftypes.FFTime)
5267
}
5368

54-
type CrudBase[T WithID] struct {
55-
DB *Database
56-
Table string
57-
Columns []string
58-
FilterFieldMap map[string]string
69+
type ResourceBase struct {
70+
ID *fftypes.UUID `ffstruct:"ResourceBase" json:"id"`
71+
Created *fftypes.FFTime `ffstruct:"ResourceBase" json:"created"`
72+
Updated *fftypes.FFTime `ffstruct:"ResourceBase" json:"updated"`
73+
}
74+
75+
func (r *ResourceBase) GetID() *fftypes.UUID {
76+
return r.ID
77+
}
78+
79+
func (r *ResourceBase) SetCreated(t *fftypes.FFTime) {
80+
r.Created = t
81+
}
82+
83+
func (r *ResourceBase) SetUpdated(t *fftypes.FFTime) {
84+
r.Updated = t
85+
}
86+
87+
type CRUD[T Resource] interface {
88+
Validate()
89+
Upsert(ctx context.Context, inst T, optimization UpsertOptimization, hooks ...PostCompletionHook) (created bool, err error)
90+
InsertMany(ctx context.Context, instances []T, allowPartialSuccess bool, hooks ...PostCompletionHook) (err error)
91+
Insert(ctx context.Context, inst T, hooks ...PostCompletionHook) (err error)
92+
Replace(ctx context.Context, inst T, hooks ...PostCompletionHook) (err error)
93+
GetByID(ctx context.Context, id *fftypes.UUID, getOpts ...GetOption) (inst T, err error)
94+
GetMany(ctx context.Context, filter ffapi.Filter) (instances []T, fr *ffapi.FilterResult, err error)
95+
Update(ctx context.Context, id *fftypes.UUID, update ffapi.Update, hooks ...PostCompletionHook) (err error)
96+
UpdateSparse(ctx context.Context, sparseUpdate T, hooks ...PostCompletionHook) (err error)
97+
UpdateMany(ctx context.Context, filter ffapi.Filter, update ffapi.Update, hooks ...PostCompletionHook) (err error)
98+
Delete(ctx context.Context, id *fftypes.UUID, hooks ...PostCompletionHook) (err error)
99+
}
100+
101+
type CrudBase[T Resource] struct {
102+
DB *Database
103+
Table string
104+
Columns []string
105+
FilterFieldMap map[string]string
106+
ImmutableColumns []string
59107

60108
NilValue func() T // nil value typed to T
61109
NewInstance func() T
@@ -69,6 +117,57 @@ type CrudBase[T WithID] struct {
69117
ReadQueryModifier func(sq.SelectBuilder) sq.SelectBuilder
70118
}
71119

120+
// Validate checks things that must be true about a CRUD collection using this framework.
121+
// Intended for use in the unit tests of microservices (will exercise all the functions of the CrudBase):
122+
// - the mandatory columns exist - id/created/updated
123+
// - no column has the same name as the sequence column for the DB
124+
// - a unique pointer is returned for each field column
125+
// - the immutable columns exist
126+
// - the other functions return valid data
127+
func (c *CrudBase[T]) Validate() {
128+
inst := c.NewInstance()
129+
if isNil(inst) {
130+
panic("NewInstance() value must not be nil")
131+
}
132+
ptrs := map[string]interface{}{}
133+
fieldMap := map[string]bool{
134+
// Mandatory column checks
135+
ColumnID: false,
136+
ColumnCreated: false,
137+
ColumnUpdated: false,
138+
}
139+
for _, col := range c.Columns {
140+
if ok, set := fieldMap[col]; ok && set {
141+
panic(fmt.Sprintf("%s is a duplicated column", col))
142+
}
143+
144+
fieldMap[col] = true
145+
if col == c.DB.sequenceColumn {
146+
panic(fmt.Sprintf("cannot have column named '%s'", c.DB.sequenceColumn))
147+
}
148+
fieldPtr := c.GetFieldPtr(inst, col)
149+
ptrVal := reflect.ValueOf(fieldPtr)
150+
if ptrVal.Kind() != reflect.Ptr || !isNil(ptrVal.Elem().Interface()) {
151+
panic(fmt.Sprintf("field %s does not seem to be a pointer type - prevents null-check for PATCH semantics functioning", col))
152+
}
153+
ptrs[col] = fieldPtr
154+
}
155+
for col, set := range fieldMap {
156+
if !set {
157+
panic(fmt.Sprintf("mandatory column '%s' must be included in column list", col))
158+
}
159+
}
160+
if !isNil(c.NilValue()) {
161+
panic("NilValue() value must be nil")
162+
}
163+
if isNil(c.ScopedFilter()) {
164+
panic("ScopedFilter() value must not be nil")
165+
}
166+
if !isNil(c.GetFieldPtr(inst, fftypes.NewUUID().String())) {
167+
panic("GetFieldPtr() must return nil for unknown column")
168+
}
169+
}
170+
72171
func (c *CrudBase[T]) idFilter(id *fftypes.UUID) sq.Eq {
73172
filter := c.ScopedFilter()
74173
if c.ReadTableAlias != "" {
@@ -79,11 +178,27 @@ func (c *CrudBase[T]) idFilter(id *fftypes.UUID) sq.Eq {
79178
return filter
80179
}
81180

82-
func (c *CrudBase[T]) attemptReplace(ctx context.Context, tx *TXWrapper, inst T) (int64, error) {
83-
update := sq.Update(c.Table)
181+
func (c *CrudBase[T]) buildUpdateList(_ context.Context, update sq.UpdateBuilder, inst T, includeNil bool) sq.UpdateBuilder {
182+
colLoop:
84183
for _, col := range c.Columns {
85-
update = update.Set(col, c.GetFieldPtr(inst, col))
184+
for _, immutable := range append(c.ImmutableColumns, ColumnID, ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) {
185+
if col == immutable {
186+
continue colLoop
187+
}
188+
}
189+
value := c.getFieldValue(inst, col)
190+
if includeNil || !isNil(value) {
191+
update = update.Set(col, value)
192+
}
86193
}
194+
update = update.Set(ColumnUpdated, fftypes.Now())
195+
return update
196+
}
197+
198+
func (c *CrudBase[T]) updateFromInstance(ctx context.Context, tx *TXWrapper, inst T, includeNil bool) (int64, error) {
199+
update := sq.Update(c.Table)
200+
inst.SetUpdated(fftypes.Now())
201+
update = c.buildUpdateList(ctx, update, inst, includeNil)
87202
update = update.Where(c.idFilter(inst.GetID()))
88203
return c.DB.UpdateTx(ctx, c.Table, tx,
89204
update,
@@ -94,11 +209,19 @@ func (c *CrudBase[T]) attemptReplace(ctx context.Context, tx *TXWrapper, inst T)
94209
})
95210
}
96211

212+
func (c *CrudBase[T]) getFieldValue(inst T, col string) interface{} {
213+
// Validate() will have checked this is safe for microservices (as long as they use that at build time in their UTs)
214+
return reflect.ValueOf(c.GetFieldPtr(inst, col)).Elem().Interface()
215+
}
216+
97217
func (c *CrudBase[T]) attemptInsert(ctx context.Context, tx *TXWrapper, inst T, requestConflictEmptyResult bool) (err error) {
218+
now := fftypes.Now()
219+
inst.SetCreated(now)
220+
inst.SetUpdated(now)
98221
insert := sq.Insert(c.Table).Columns(c.Columns...)
99222
values := make([]interface{}, len(c.Columns))
100223
for i, col := range c.Columns {
101-
values[i] = c.GetFieldPtr(inst, col)
224+
values[i] = c.getFieldValue(inst, col)
102225
}
103226
insert = insert.Values(values...)
104227
_, err = c.DB.InsertTxExt(ctx, c.Table, tx, insert,
@@ -110,10 +233,10 @@ func (c *CrudBase[T]) attemptInsert(ctx context.Context, tx *TXWrapper, inst T,
110233
return err
111234
}
112235

113-
func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOptimization, hooks ...PostCompletionHook) (err error) {
236+
func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOptimization, hooks ...PostCompletionHook) (created bool, err error) {
114237
ctx, tx, autoCommit, err := c.DB.BeginOrUseTx(ctx)
115238
if err != nil {
116-
return err
239+
return false, err
117240
}
118241
defer c.DB.RollbackTx(ctx, tx, autoCommit)
119242

@@ -124,8 +247,9 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt
124247
if optimization == UpsertOptimizationNew {
125248
opErr := c.attemptInsert(ctx, tx, inst, true /* we want a failure here we can progress past */)
126249
optimized = opErr == nil
250+
created = optimized
127251
} else if optimization == UpsertOptimizationExisting {
128-
rowsAffected, opErr := c.attemptReplace(ctx, tx, inst)
252+
rowsAffected, opErr := c.updateFromInstance(ctx, tx, inst, true /* full replace */)
129253
optimized = opErr == nil && rowsAffected == 1
130254
}
131255

@@ -137,20 +261,21 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt
137261
Where(c.idFilter(inst.GetID())),
138262
)
139263
if err != nil {
140-
return err
264+
return false, err
141265
}
142266
existing := msgRows.Next()
143267
msgRows.Close()
144268

145269
if existing {
146270
// Replace the existing one
147-
if _, err = c.attemptReplace(ctx, tx, inst); err != nil {
148-
return err
271+
if _, err = c.updateFromInstance(ctx, tx, inst, true /* full replace */); err != nil {
272+
return false, err
149273
}
150274
} else {
151275
// Get a useful error out of an insert attempt
276+
created = true
152277
if err = c.attemptInsert(ctx, tx, inst, false); err != nil {
153-
return err
278+
return false, err
154279
}
155280
}
156281
}
@@ -159,7 +284,7 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt
159284
tx.AddPostCommitHook(hook)
160285
}
161286

162-
return c.DB.CommitTx(ctx, tx, autoCommit)
287+
return created, c.DB.CommitTx(ctx, tx, autoCommit)
163288
}
164289

165290
func (c *CrudBase[T]) InsertMany(ctx context.Context, instances []T, allowPartialSuccess bool, hooks ...PostCompletionHook) (err error) {
@@ -174,7 +299,7 @@ func (c *CrudBase[T]) InsertMany(ctx context.Context, instances []T, allowPartia
174299
for _, inst := range instances {
175300
values := make([]interface{}, len(c.Columns))
176301
for i, col := range c.Columns {
177-
values[i] = c.GetFieldPtr(inst, col)
302+
values[i] = c.getFieldValue(inst, col)
178303
}
179304
insert = insert.Values(values...)
180305
}
@@ -234,7 +359,7 @@ func (c *CrudBase[T]) Replace(ctx context.Context, inst T, hooks ...PostCompleti
234359
}
235360
defer c.DB.RollbackTx(ctx, tx, autoCommit)
236361

237-
rowsAffected, err := c.attemptReplace(ctx, tx, inst)
362+
rowsAffected, err := c.updateFromInstance(ctx, tx, inst, true /* full replace */)
238363
if err != nil {
239364
return err
240365
} else if rowsAffected < 1 {
@@ -282,14 +407,26 @@ func (c *CrudBase[T]) getReadCols() (tableFrom string, cols, readCols []string)
282407
return tableFrom, cols, readCols
283408
}
284409

285-
func (c *CrudBase[T]) GetByID(ctx context.Context, id *fftypes.UUID) (inst T, err error) {
410+
func (c *CrudBase[T]) GetByID(ctx context.Context, id *fftypes.UUID, getOpts ...GetOption) (inst T, err error) {
411+
412+
failNotFound := false
413+
for _, o := range getOpts {
414+
switch o {
415+
case FailIfNotFound:
416+
failNotFound = true
417+
default:
418+
return c.NilValue(), i18n.NewError(ctx, i18n.MsgDBUnknownGetOption, o)
419+
}
420+
}
421+
286422
tableFrom, cols, readCols := c.getReadCols()
287423
query := sq.Select(readCols...).
288424
From(tableFrom).
289425
Where(c.idFilter(id))
290426
if c.ReadQueryModifier != nil {
291427
query = c.ReadQueryModifier(query)
292428
}
429+
293430
rows, _, err := c.DB.Query(ctx, c.Table, query)
294431
if err != nil {
295432
return c.NilValue(), err
@@ -298,6 +435,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id *fftypes.UUID) (inst T, er
298435

299436
if !rows.Next() {
300437
log.L(ctx).Debugf("%s '%s' not found", c.Table, id)
438+
if failNotFound {
439+
return c.NilValue(), i18n.NewError(ctx, i18n.Msg404NoResult)
440+
}
301441
return c.NilValue(), nil
302442
}
303443

@@ -344,6 +484,59 @@ func (c *CrudBase[T]) Update(ctx context.Context, id *fftypes.UUID, update ffapi
344484
}, update, true, hooks...)
345485
}
346486

487+
// Thanks to the testify/assert package for this
488+
func isNil(o interface{}) bool {
489+
if o == nil {
490+
return true
491+
}
492+
493+
containsKind := func(kinds []reflect.Kind, kind reflect.Kind) bool {
494+
for i := 0; i < len(kinds); i++ {
495+
if kind == kinds[i] {
496+
return true
497+
}
498+
}
499+
return false
500+
}
501+
502+
value := reflect.ValueOf(o)
503+
kind := value.Kind()
504+
isNillableKind := containsKind(
505+
[]reflect.Kind{
506+
reflect.Chan, reflect.Func,
507+
reflect.Interface, reflect.Map,
508+
reflect.Ptr, reflect.Slice},
509+
kind)
510+
511+
if isNillableKind && value.IsNil() {
512+
return true
513+
}
514+
515+
return false
516+
}
517+
518+
func (c *CrudBase[T]) UpdateSparse(ctx context.Context, sparseUpdate T, hooks ...PostCompletionHook) (err error) {
519+
ctx, tx, autoCommit, err := c.DB.BeginOrUseTx(ctx)
520+
if err != nil {
521+
return err
522+
}
523+
defer c.DB.RollbackTx(ctx, tx, autoCommit)
524+
525+
updateCount, err := c.updateFromInstance(ctx, tx, sparseUpdate, false /* only non-nil fields */)
526+
if err != nil {
527+
return err
528+
}
529+
if updateCount < 1 {
530+
return i18n.NewError(ctx, i18n.MsgDBNoRowsAffected)
531+
}
532+
533+
for _, hook := range hooks {
534+
tx.AddPostCommitHook(hook)
535+
}
536+
537+
return c.DB.CommitTx(ctx, tx, autoCommit)
538+
}
539+
347540
func (c *CrudBase[T]) UpdateMany(ctx context.Context, filter ffapi.Filter, update ffapi.Update, hooks ...PostCompletionHook) (err error) {
348541
return c.attemptUpdate(ctx, func(query sq.UpdateBuilder) (sq.UpdateBuilder, error) {
349542
return c.DB.FilterUpdate(ctx, query, filter, c.FilterFieldMap)
@@ -361,6 +554,7 @@ func (c *CrudBase[T]) attemptUpdate(ctx context.Context, filterFn func(sq.Update
361554
if err == nil {
362555
query, err = filterFn(query)
363556
}
557+
query = query.Set(ColumnUpdated, fftypes.Now())
364558
if err != nil {
365559
return err
366560
}

0 commit comments

Comments
 (0)