forked from ory/kratos
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate.go
More file actions
298 lines (250 loc) · 7.28 KB
/
create.go
File metadata and controls
298 lines (250 loc) · 7.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package batch
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"time"
"github.com/jmoiron/sqlx/reflectx"
"github.com/ory/x/dbal"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/pop/v6"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlxx"
)
type (
insertQueryArgs struct {
TableName string
ColumnsDecl string
Columns []string
Placeholders string
}
quoter interface {
Quote(key string) string
}
TracerConnection struct {
Tracer *otelx.Tracer
Connection *pop.Connection
}
)
func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs {
var (
v T
model = pop.NewModel(v, ctx)
columns []string
quotedColumns []string
placeholders []string
placeholderRow []string
)
for _, col := range model.Columns().Cols {
columns = append(columns, col.Name)
placeholderRow = append(placeholderRow, "?")
}
// We sort for the sole reason that the test snapshots are deterministic.
sort.Strings(columns)
for _, col := range columns {
quotedColumns = append(quotedColumns, quoter.Quote(col))
}
// We generate a list (for every row one) of VALUE statements here that
// will be substituted by their column values later:
//
// (?, ?, ?, ?),
// (?, ?, ?, ?),
// (?, ?, ?, ?)
for _, m := range models {
m := reflect.ValueOf(m)
pl := make([]string, len(placeholderRow))
copy(pl, placeholderRow)
// There is a special case - when using CockroachDB we want to generate
// UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of:
//
// (gen_random_uuid(), ?, ?, ?),
for k := range placeholderRow {
if columns[k] != "id" {
continue
}
field := mapper.FieldByName(m, columns[k])
val, ok := field.Interface().(uuid.UUID)
if !ok {
continue
}
if val == uuid.Nil && dialect == dbal.DriverCockroachDB {
pl[k] = "gen_random_uuid()"
break
}
}
placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", ")))
}
return insertQueryArgs{
TableName: quoter.Quote(model.TableName()),
ColumnsDecl: strings.Join(quotedColumns, ", "),
Columns: columns,
Placeholders: strings.Join(placeholders, ",\n"),
}
}
func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) {
for _, m := range models {
m := reflect.ValueOf(m)
now := nowFunc()
// Append model fields to args
for _, c := range columns {
field := mapper.FieldByName(m, c)
switch c {
case "created_at":
if pop.IsZeroOfUnderlyingType(field.Interface()) {
field.Set(reflect.ValueOf(now))
}
case "updated_at":
field.Set(reflect.ValueOf(now))
case "id":
if value, ok := field.Interface().(uuid.UUID); ok && value != uuid.Nil {
break // breaks switch, not for
} else if value, ok := field.Interface().(string); ok && len(value) > 0 {
break // breaks switch, not for
} else if dialect == dbal.DriverCockroachDB {
// This is a special case:
// 1. We're using cockroach
// 2. It's the primary key field ("ID")
// 3. A UUID was not yet set.
//
// If all these conditions meet, the VALUE statement will look as such:
//
// (gen_random_uuid(), ?, ?, ?, ...)
//
// For that reason, we do not add the ID value to the list of arguments,
// because one of the arguments is using a built-in and thus doesn't need a value.
continue // break switch, not for
}
id, err := uuid.NewV7()
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(id))
}
values = append(values, field.Interface())
// Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time,
// but we want a nil pointer instead.
if i, ok := field.Interface().(*sqlxx.NullTime); ok {
if time.Time(*i).IsZero() {
field.Set(reflect.Zero(field.Type()))
}
}
}
}
return values, nil
}
type createOptions struct {
onConflict string
}
type option func(*createOptions)
func OnConflictDoNothing() func(*createOptions) {
return func(o *createOptions) {
o.onConflict = "ON CONFLICT DO NOTHING"
}
}
// CreateFromSlice is a helper around Create that accepts a slice of models
// instead of a slice of model pointers.
func CreateFromSlice[T any](ctx context.Context, p *TracerConnection, models []T, opts ...option) (err error) {
var ptrs []*T
for k := range models {
ptrs = append(ptrs, &models[k])
}
return Create(ctx, p, ptrs, opts...)
}
// Create batch-inserts the given models into the database using a single INSERT statement.
// The models are either all created or none.
func Create[T any](ctx context.Context, p *TracerConnection, models []*T, opts ...option) (err error) {
ctx, span := p.Tracer.Tracer().Start(ctx, "persistence.sql.batch.Create")
defer otelx.End(span, &err)
if len(models) == 0 {
return nil
}
options := &createOptions{}
for _, opt := range opts {
opt(options)
}
var v T
model := pop.NewModel(v, ctx)
conn := p.Connection
quoter, ok := conn.Dialect.(quoter)
if !ok {
return errors.Errorf("store is not a quoter: %T", conn.Store)
}
queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models)
values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) })
if err != nil {
return err
}
var returningClause string
if conn.Dialect.Name() != dbal.DriverMySQL {
// PostgreSQL, CockroachDB, SQLite support RETURNING.
returningClause = fmt.Sprintf("RETURNING %s", model.IDField())
}
query := conn.Dialect.TranslateSQL(fmt.Sprintf(
"INSERT INTO %s (%s) VALUES\n%s\n%s\n%s",
queryArgs.TableName,
queryArgs.ColumnsDecl,
queryArgs.Placeholders,
options.onConflict,
returningClause,
))
rows, err := conn.TX.QueryContext(ctx, query, values...)
if err != nil {
return sqlcon.HandleError(err)
}
defer rows.Close()
// Hydrate the models from the RETURNING clause.
//
// Databases not supporting RETURNING will just return 0 rows.
count := 0
for rows.Next() {
if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil {
return err
}
count++
}
if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}
return sqlcon.HandleError(err)
}
// setModelID was copy & pasted from pop. It basically sets
// the primary key to the given value read from the SQL row.
func setModelID(row *sql.Rows, model *pop.Model) error {
el := reflect.ValueOf(model.Value).Elem()
fbn := el.FieldByName("ID")
if !fbn.IsValid() {
return errors.New("model does not have a field named id")
}
pkt, err := model.PrimaryKeyType()
if err != nil {
return errors.WithStack(err)
}
switch pkt {
case "UUID":
var id uuid.UUID
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
fbn.Set(reflect.ValueOf(id))
default:
var id interface{}
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
v := reflect.ValueOf(id)
switch fbn.Kind() {
case reflect.Int, reflect.Int64:
fbn.SetInt(v.Int())
default:
fbn.Set(reflect.ValueOf(id))
}
}
return nil
}