@@ -20,6 +20,7 @@ import (
20
20
"context"
21
21
"database/sql"
22
22
"fmt"
23
+ "reflect"
23
24
"strings"
24
25
25
26
sq "github.com/Masterminds/squirrel"
@@ -29,6 +30,12 @@ import (
29
30
"github.com/hyperledger/firefly-common/pkg/log"
30
31
)
31
32
33
+ const (
34
+ ColumnID = "id"
35
+ ColumnCreated = "created"
36
+ ColumnUpdated = "updated"
37
+ )
38
+
32
39
type ChangeEventType int
33
40
34
41
const (
@@ -45,17 +52,58 @@ const (
45
52
UpsertOptimizationExisting
46
53
)
47
54
55
+ type GetOption int
56
+
57
+ const (
58
+ FailIfNotFound GetOption = iota
59
+ )
60
+
48
61
type PostCompletionHook func ()
49
62
50
- type WithID interface {
63
+ type Resource interface {
51
64
GetID () * fftypes.UUID
65
+ SetCreated (* fftypes.FFTime )
66
+ SetUpdated (* fftypes.FFTime )
52
67
}
53
68
54
- type CrudBase [T WithID ] struct {
55
- DB * Database
56
- Table string
57
- Columns []string
58
- FilterFieldMap map [string ]string
69
+ type ResourceBase struct {
70
+ ID * fftypes.UUID `ffstruct:"ResourceBase" json:"id"`
71
+ Created * fftypes.FFTime `ffstruct:"ResourceBase" json:"created"`
72
+ Updated * fftypes.FFTime `ffstruct:"ResourceBase" json:"updated"`
73
+ }
74
+
75
+ func (r * ResourceBase ) GetID () * fftypes.UUID {
76
+ return r .ID
77
+ }
78
+
79
+ func (r * ResourceBase ) SetCreated (t * fftypes.FFTime ) {
80
+ r .Created = t
81
+ }
82
+
83
+ func (r * ResourceBase ) SetUpdated (t * fftypes.FFTime ) {
84
+ r .Updated = t
85
+ }
86
+
87
+ type CRUD [T Resource ] interface {
88
+ Validate ()
89
+ Upsert (ctx context.Context , inst T , optimization UpsertOptimization , hooks ... PostCompletionHook ) (created bool , err error )
90
+ InsertMany (ctx context.Context , instances []T , allowPartialSuccess bool , hooks ... PostCompletionHook ) (err error )
91
+ Insert (ctx context.Context , inst T , hooks ... PostCompletionHook ) (err error )
92
+ Replace (ctx context.Context , inst T , hooks ... PostCompletionHook ) (err error )
93
+ GetByID (ctx context.Context , id * fftypes.UUID , getOpts ... GetOption ) (inst T , err error )
94
+ GetMany (ctx context.Context , filter ffapi.Filter ) (instances []T , fr * ffapi.FilterResult , err error )
95
+ Update (ctx context.Context , id * fftypes.UUID , update ffapi.Update , hooks ... PostCompletionHook ) (err error )
96
+ UpdateSparse (ctx context.Context , sparseUpdate T , hooks ... PostCompletionHook ) (err error )
97
+ UpdateMany (ctx context.Context , filter ffapi.Filter , update ffapi.Update , hooks ... PostCompletionHook ) (err error )
98
+ Delete (ctx context.Context , id * fftypes.UUID , hooks ... PostCompletionHook ) (err error )
99
+ }
100
+
101
+ type CrudBase [T Resource ] struct {
102
+ DB * Database
103
+ Table string
104
+ Columns []string
105
+ FilterFieldMap map [string ]string
106
+ ImmutableColumns []string
59
107
60
108
NilValue func () T // nil value typed to T
61
109
NewInstance func () T
@@ -69,6 +117,57 @@ type CrudBase[T WithID] struct {
69
117
ReadQueryModifier func (sq.SelectBuilder ) sq.SelectBuilder
70
118
}
71
119
120
+ // Validate checks things that must be true about a CRUD collection using this framework.
121
+ // Intended for use in the unit tests of microservices (will exercise all the functions of the CrudBase):
122
+ // - the mandatory columns exist - id/created/updated
123
+ // - no column has the same name as the sequence column for the DB
124
+ // - a unique pointer is returned for each field column
125
+ // - the immutable columns exist
126
+ // - the other functions return valid data
127
+ func (c * CrudBase [T ]) Validate () {
128
+ inst := c .NewInstance ()
129
+ if isNil (inst ) {
130
+ panic ("NewInstance() value must not be nil" )
131
+ }
132
+ ptrs := map [string ]interface {}{}
133
+ fieldMap := map [string ]bool {
134
+ // Mandatory column checks
135
+ ColumnID : false ,
136
+ ColumnCreated : false ,
137
+ ColumnUpdated : false ,
138
+ }
139
+ for _ , col := range c .Columns {
140
+ if ok , set := fieldMap [col ]; ok && set {
141
+ panic (fmt .Sprintf ("%s is a duplicated column" , col ))
142
+ }
143
+
144
+ fieldMap [col ] = true
145
+ if col == c .DB .sequenceColumn {
146
+ panic (fmt .Sprintf ("cannot have column named '%s'" , c .DB .sequenceColumn ))
147
+ }
148
+ fieldPtr := c .GetFieldPtr (inst , col )
149
+ ptrVal := reflect .ValueOf (fieldPtr )
150
+ if ptrVal .Kind () != reflect .Ptr || ! isNil (ptrVal .Elem ().Interface ()) {
151
+ panic (fmt .Sprintf ("field %s does not seem to be a pointer type - prevents null-check for PATCH semantics functioning" , col ))
152
+ }
153
+ ptrs [col ] = fieldPtr
154
+ }
155
+ for col , set := range fieldMap {
156
+ if ! set {
157
+ panic (fmt .Sprintf ("mandatory column '%s' must be included in column list" , col ))
158
+ }
159
+ }
160
+ if ! isNil (c .NilValue ()) {
161
+ panic ("NilValue() value must be nil" )
162
+ }
163
+ if isNil (c .ScopedFilter ()) {
164
+ panic ("ScopedFilter() value must not be nil" )
165
+ }
166
+ if ! isNil (c .GetFieldPtr (inst , fftypes .NewUUID ().String ())) {
167
+ panic ("GetFieldPtr() must return nil for unknown column" )
168
+ }
169
+ }
170
+
72
171
func (c * CrudBase [T ]) idFilter (id * fftypes.UUID ) sq.Eq {
73
172
filter := c .ScopedFilter ()
74
173
if c .ReadTableAlias != "" {
@@ -79,11 +178,27 @@ func (c *CrudBase[T]) idFilter(id *fftypes.UUID) sq.Eq {
79
178
return filter
80
179
}
81
180
82
- func (c * CrudBase [T ]) attemptReplace ( ctx context.Context , tx * TXWrapper , inst T ) ( int64 , error ) {
83
- update := sq . Update ( c . Table )
181
+ func (c * CrudBase [T ]) buildUpdateList ( _ context.Context , update sq. UpdateBuilder , inst T , includeNil bool ) sq. UpdateBuilder {
182
+ colLoop:
84
183
for _ , col := range c .Columns {
85
- update = update .Set (col , c .GetFieldPtr (inst , col ))
184
+ for _ , immutable := range append (c .ImmutableColumns , ColumnID , ColumnCreated , ColumnUpdated , c .DB .sequenceColumn ) {
185
+ if col == immutable {
186
+ continue colLoop
187
+ }
188
+ }
189
+ value := c .getFieldValue (inst , col )
190
+ if includeNil || ! isNil (value ) {
191
+ update = update .Set (col , value )
192
+ }
86
193
}
194
+ update = update .Set (ColumnUpdated , fftypes .Now ())
195
+ return update
196
+ }
197
+
198
+ func (c * CrudBase [T ]) updateFromInstance (ctx context.Context , tx * TXWrapper , inst T , includeNil bool ) (int64 , error ) {
199
+ update := sq .Update (c .Table )
200
+ inst .SetUpdated (fftypes .Now ())
201
+ update = c .buildUpdateList (ctx , update , inst , includeNil )
87
202
update = update .Where (c .idFilter (inst .GetID ()))
88
203
return c .DB .UpdateTx (ctx , c .Table , tx ,
89
204
update ,
@@ -94,11 +209,19 @@ func (c *CrudBase[T]) attemptReplace(ctx context.Context, tx *TXWrapper, inst T)
94
209
})
95
210
}
96
211
212
+ func (c * CrudBase [T ]) getFieldValue (inst T , col string ) interface {} {
213
+ // Validate() will have checked this is safe for microservices (as long as they use that at build time in their UTs)
214
+ return reflect .ValueOf (c .GetFieldPtr (inst , col )).Elem ().Interface ()
215
+ }
216
+
97
217
func (c * CrudBase [T ]) attemptInsert (ctx context.Context , tx * TXWrapper , inst T , requestConflictEmptyResult bool ) (err error ) {
218
+ now := fftypes .Now ()
219
+ inst .SetCreated (now )
220
+ inst .SetUpdated (now )
98
221
insert := sq .Insert (c .Table ).Columns (c .Columns ... )
99
222
values := make ([]interface {}, len (c .Columns ))
100
223
for i , col := range c .Columns {
101
- values [i ] = c .GetFieldPtr (inst , col )
224
+ values [i ] = c .getFieldValue (inst , col )
102
225
}
103
226
insert = insert .Values (values ... )
104
227
_ , err = c .DB .InsertTxExt (ctx , c .Table , tx , insert ,
@@ -110,10 +233,10 @@ func (c *CrudBase[T]) attemptInsert(ctx context.Context, tx *TXWrapper, inst T,
110
233
return err
111
234
}
112
235
113
- func (c * CrudBase [T ]) Upsert (ctx context.Context , inst T , optimization UpsertOptimization , hooks ... PostCompletionHook ) (err error ) {
236
+ func (c * CrudBase [T ]) Upsert (ctx context.Context , inst T , optimization UpsertOptimization , hooks ... PostCompletionHook ) (created bool , err error ) {
114
237
ctx , tx , autoCommit , err := c .DB .BeginOrUseTx (ctx )
115
238
if err != nil {
116
- return err
239
+ return false , err
117
240
}
118
241
defer c .DB .RollbackTx (ctx , tx , autoCommit )
119
242
@@ -124,8 +247,9 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt
124
247
if optimization == UpsertOptimizationNew {
125
248
opErr := c .attemptInsert (ctx , tx , inst , true /* we want a failure here we can progress past */ )
126
249
optimized = opErr == nil
250
+ created = optimized
127
251
} else if optimization == UpsertOptimizationExisting {
128
- rowsAffected , opErr := c .attemptReplace (ctx , tx , inst )
252
+ rowsAffected , opErr := c .updateFromInstance (ctx , tx , inst , true /* full replace */ )
129
253
optimized = opErr == nil && rowsAffected == 1
130
254
}
131
255
@@ -137,20 +261,21 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt
137
261
Where (c .idFilter (inst .GetID ())),
138
262
)
139
263
if err != nil {
140
- return err
264
+ return false , err
141
265
}
142
266
existing := msgRows .Next ()
143
267
msgRows .Close ()
144
268
145
269
if existing {
146
270
// Replace the existing one
147
- if _ , err = c .attemptReplace (ctx , tx , inst ); err != nil {
148
- return err
271
+ if _ , err = c .updateFromInstance (ctx , tx , inst , true /* full replace */ ); err != nil {
272
+ return false , err
149
273
}
150
274
} else {
151
275
// Get a useful error out of an insert attempt
276
+ created = true
152
277
if err = c .attemptInsert (ctx , tx , inst , false ); err != nil {
153
- return err
278
+ return false , err
154
279
}
155
280
}
156
281
}
@@ -159,7 +284,7 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt
159
284
tx .AddPostCommitHook (hook )
160
285
}
161
286
162
- return c .DB .CommitTx (ctx , tx , autoCommit )
287
+ return created , c .DB .CommitTx (ctx , tx , autoCommit )
163
288
}
164
289
165
290
func (c * CrudBase [T ]) InsertMany (ctx context.Context , instances []T , allowPartialSuccess bool , hooks ... PostCompletionHook ) (err error ) {
@@ -174,7 +299,7 @@ func (c *CrudBase[T]) InsertMany(ctx context.Context, instances []T, allowPartia
174
299
for _ , inst := range instances {
175
300
values := make ([]interface {}, len (c .Columns ))
176
301
for i , col := range c .Columns {
177
- values [i ] = c .GetFieldPtr (inst , col )
302
+ values [i ] = c .getFieldValue (inst , col )
178
303
}
179
304
insert = insert .Values (values ... )
180
305
}
@@ -234,7 +359,7 @@ func (c *CrudBase[T]) Replace(ctx context.Context, inst T, hooks ...PostCompleti
234
359
}
235
360
defer c .DB .RollbackTx (ctx , tx , autoCommit )
236
361
237
- rowsAffected , err := c .attemptReplace (ctx , tx , inst )
362
+ rowsAffected , err := c .updateFromInstance (ctx , tx , inst , true /* full replace */ )
238
363
if err != nil {
239
364
return err
240
365
} else if rowsAffected < 1 {
@@ -282,14 +407,26 @@ func (c *CrudBase[T]) getReadCols() (tableFrom string, cols, readCols []string)
282
407
return tableFrom , cols , readCols
283
408
}
284
409
285
- func (c * CrudBase [T ]) GetByID (ctx context.Context , id * fftypes.UUID ) (inst T , err error ) {
410
+ func (c * CrudBase [T ]) GetByID (ctx context.Context , id * fftypes.UUID , getOpts ... GetOption ) (inst T , err error ) {
411
+
412
+ failNotFound := false
413
+ for _ , o := range getOpts {
414
+ switch o {
415
+ case FailIfNotFound :
416
+ failNotFound = true
417
+ default :
418
+ return c .NilValue (), i18n .NewError (ctx , i18n .MsgDBUnknownGetOption , o )
419
+ }
420
+ }
421
+
286
422
tableFrom , cols , readCols := c .getReadCols ()
287
423
query := sq .Select (readCols ... ).
288
424
From (tableFrom ).
289
425
Where (c .idFilter (id ))
290
426
if c .ReadQueryModifier != nil {
291
427
query = c .ReadQueryModifier (query )
292
428
}
429
+
293
430
rows , _ , err := c .DB .Query (ctx , c .Table , query )
294
431
if err != nil {
295
432
return c .NilValue (), err
@@ -298,6 +435,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id *fftypes.UUID) (inst T, er
298
435
299
436
if ! rows .Next () {
300
437
log .L (ctx ).Debugf ("%s '%s' not found" , c .Table , id )
438
+ if failNotFound {
439
+ return c .NilValue (), i18n .NewError (ctx , i18n .Msg404NoResult )
440
+ }
301
441
return c .NilValue (), nil
302
442
}
303
443
@@ -344,6 +484,59 @@ func (c *CrudBase[T]) Update(ctx context.Context, id *fftypes.UUID, update ffapi
344
484
}, update , true , hooks ... )
345
485
}
346
486
487
+ // Thanks to the testify/assert package for this
488
+ func isNil (o interface {}) bool {
489
+ if o == nil {
490
+ return true
491
+ }
492
+
493
+ containsKind := func (kinds []reflect.Kind , kind reflect.Kind ) bool {
494
+ for i := 0 ; i < len (kinds ); i ++ {
495
+ if kind == kinds [i ] {
496
+ return true
497
+ }
498
+ }
499
+ return false
500
+ }
501
+
502
+ value := reflect .ValueOf (o )
503
+ kind := value .Kind ()
504
+ isNillableKind := containsKind (
505
+ []reflect.Kind {
506
+ reflect .Chan , reflect .Func ,
507
+ reflect .Interface , reflect .Map ,
508
+ reflect .Ptr , reflect .Slice },
509
+ kind )
510
+
511
+ if isNillableKind && value .IsNil () {
512
+ return true
513
+ }
514
+
515
+ return false
516
+ }
517
+
518
+ func (c * CrudBase [T ]) UpdateSparse (ctx context.Context , sparseUpdate T , hooks ... PostCompletionHook ) (err error ) {
519
+ ctx , tx , autoCommit , err := c .DB .BeginOrUseTx (ctx )
520
+ if err != nil {
521
+ return err
522
+ }
523
+ defer c .DB .RollbackTx (ctx , tx , autoCommit )
524
+
525
+ updateCount , err := c .updateFromInstance (ctx , tx , sparseUpdate , false /* only non-nil fields */ )
526
+ if err != nil {
527
+ return err
528
+ }
529
+ if updateCount < 1 {
530
+ return i18n .NewError (ctx , i18n .MsgDBNoRowsAffected )
531
+ }
532
+
533
+ for _ , hook := range hooks {
534
+ tx .AddPostCommitHook (hook )
535
+ }
536
+
537
+ return c .DB .CommitTx (ctx , tx , autoCommit )
538
+ }
539
+
347
540
func (c * CrudBase [T ]) UpdateMany (ctx context.Context , filter ffapi.Filter , update ffapi.Update , hooks ... PostCompletionHook ) (err error ) {
348
541
return c .attemptUpdate (ctx , func (query sq.UpdateBuilder ) (sq.UpdateBuilder , error ) {
349
542
return c .DB .FilterUpdate (ctx , query , filter , c .FilterFieldMap )
@@ -361,6 +554,7 @@ func (c *CrudBase[T]) attemptUpdate(ctx context.Context, filterFn func(sq.Update
361
554
if err == nil {
362
555
query , err = filterFn (query )
363
556
}
557
+ query = query .Set (ColumnUpdated , fftypes .Now ())
364
558
if err != nil {
365
559
return err
366
560
}
0 commit comments