Skip to content

Commit 187743b

Browse files
authored
fix(create_table): avoid creating unintended foreign keys (uptrace#1130)
* fix(create_table): avoid creating unintended foreign keys Foreign keys should only be created for has-one and belongs-to relations iff: - None of the referencing columns is a primary keys - The table is an m2m 'junction' table an all referencing columns are primary keys The m2m edge case is covered by TestDatabaseInspector_Inspect
1 parent c915415 commit 187743b

File tree

4 files changed

+110
-9
lines changed

4 files changed

+110
-9
lines changed

internal/dbtest/db_test.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/uptrace/bun/driver/pgdriver"
2525
"github.com/uptrace/bun/driver/sqliteshim"
2626
"github.com/uptrace/bun/extra/bundebug"
27+
"github.com/uptrace/bun/migrate/sqlschema"
2728
"github.com/uptrace/bun/extra/bunexp"
2829
"github.com/uptrace/bun/schema"
2930

@@ -300,6 +301,7 @@ func TestDB(t *testing.T) {
300301
{testRunInTxAndSavepoint},
301302
{testDriverValuerReturnsItself},
302303
{testNoPanicWhenReturningNullColumns},
304+
{testNoForeignKeyForPrimaryKey},
303305
}
304306

305307
testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
@@ -1831,6 +1833,59 @@ func testNoPanicWhenReturningNullColumns(t *testing.T, db *bun.DB) {
18311833
})
18321834
}
18331835

1836+
func testNoForeignKeyForPrimaryKey(t *testing.T, db *bun.DB) {
1837+
inspect := inspectDbOrSkip(t, db)
1838+
1839+
for _, tt := range []struct {
1840+
name string
1841+
model interface{}
1842+
dontWant sqlschema.ForeignKey
1843+
}{
1844+
{name: "has-one relation", model: (*struct {
1845+
bun.BaseModel `bun:"table:users"`
1846+
ID string `bun:",pk"`
1847+
1848+
Profile *struct {
1849+
bun.BaseModel `bun:"table:profiles"`
1850+
ID string `bun:",pk"`
1851+
UserID string
1852+
} `bun:"rel:has-one,join:id=user_id"`
1853+
})(nil), dontWant: sqlschema.ForeignKey{
1854+
From: sqlschema.NewColumnReference("users", "id"),
1855+
To: sqlschema.NewColumnReference("profiles", "user_id"),
1856+
}},
1857+
1858+
{name: "belongs-to relation", model: (*struct {
1859+
bun.BaseModel `bun:"table:profiles"`
1860+
ID string `bun:",pk"`
1861+
1862+
User *struct {
1863+
bun.BaseModel `bun:"table:users"`
1864+
ID string `bun:",pk"`
1865+
ProfileID string
1866+
} `bun:"rel:belongs-to,join:id=profile_id"`
1867+
})(nil), dontWant: sqlschema.ForeignKey{
1868+
From: sqlschema.NewColumnReference("profiles", "id"),
1869+
To: sqlschema.NewColumnReference("users", "profile_id"),
1870+
}},
1871+
} {
1872+
t.Run(tt.name, func(t *testing.T) {
1873+
ctx := context.Background()
1874+
mustDropTableOnCleanup(t, ctx, db, tt.model)
1875+
1876+
_, err := db.NewCreateTable().Model(tt.model).WithForeignKeys().Exec(ctx)
1877+
require.NoError(t, err, "create table")
1878+
1879+
state := inspect(ctx)
1880+
require.NotContainsf(t, state.ForeignKeys, tt.dontWant,
1881+
"%s.%s -> %s.%s is not inteded",
1882+
tt.dontWant.From.TableName, tt.dontWant.From.Column,
1883+
tt.dontWant.To.TableName, tt.dontWant.To.Column,
1884+
)
1885+
})
1886+
}
1887+
}
1888+
18341889
func mustResetModel(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) {
18351890
err := db.ResetModel(ctx, models...)
18361891
require.NoError(tb, err, "must reset model")
@@ -1864,7 +1919,6 @@ func TestConnResolver(t *testing.T) {
18641919
})
18651920

18661921
resolver := bunexp.NewReadWriteConnResolver(
1867-
//bunexp.WithDBReplica(rwdb),
18681922
bunexp.WithDBReplica(rodb, bunexp.DBReplicaReadOnly),
18691923
)
18701924

schema/field.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
)
1111

1212
type Field struct {
13+
Table *Table // Contains this field
1314
StructField reflect.StructField
1415
IsPtr bool
1516

schema/relation.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ const (
1313
)
1414

1515
type Relation struct {
16+
Type int
17+
Field *Field // Has the bun tag defining this relation.
18+
1619
// Base and Join can be explained with this query:
1720
//
1821
// SELECT * FROM base_table JOIN join_table
19-
20-
Type int
21-
Field *Field
2222
JoinTable *Table
2323
BasePKs []*Field
2424
JoinPKs []*Field
@@ -34,10 +34,49 @@ type Relation struct {
3434
M2MJoinPKs []*Field
3535
}
3636

37-
// References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation.
38-
// For other relations, the constraint is created in either the referencing table (1:N, 'has-many' relations) or a mapping table (N:N, 'm2m' relations).
37+
// References returns true if the table which defines this Relation
38+
// needs to declare a foreign key constraint, as is the case
39+
// for 'has-one' and 'belongs-to' relations. For other relations,
40+
// the constraint is created either in the referencing table (1:N, 'has-many' relations)
41+
// or the junction table (N:N, 'm2m' relations).
42+
//
43+
// Usage of `rel:` tag does not always imply creation of foreign keys (when WithForeignKeys() is not set)
44+
// and can be used exclusively for joining tables at query time. For example:
45+
//
46+
// type User struct {
47+
// ID int64 `bun:",pk"`
48+
// Profile *Profile `bun:",rel:has-one,join:id=user_id"`
49+
// }
50+
//
51+
// Creating a FK users.id -> profiles.user_id would be confusing and incorrect,
52+
// so for such cases References() returns false. One notable exception to this rule
53+
// is when a Relation is defined in a junction table, in which case it is perfectly
54+
// fine for its primary keys to reference other tables. Consider:
55+
//
56+
// // UsersToGroups maps users to groups they follow.
57+
// type UsersToGroups struct {
58+
// UserID string `bun:"user_id,pk"` // Needs FK to users.id
59+
// GroupID string `bun:"group_id,pk"` // Needs FK to groups.id
60+
//
61+
// User *User `bun:"rel:belongs-to,join:user_id=id"`
62+
// Group *Group `bun:"rel:belongs-to,join:group_id=id"`
63+
// }
64+
//
65+
// Here BooksToReaders has a composite primary key, composed of other primary keys.
3966
func (r *Relation) References() bool {
40-
return r.Type == HasOneRelation || r.Type == BelongsToRelation
67+
allPK := true
68+
nonePK := true
69+
for _, f := range r.BasePKs {
70+
allPK = allPK && f.IsPK
71+
nonePK = nonePK && !f.IsPK
72+
}
73+
74+
// Erring on the side of caution, only create foreign keys
75+
// if the referencing columns are part of a composite PK
76+
// in the junction table of the m2m relationship.
77+
effectsM2M := r.Field.Table.IsM2MTable && allPK
78+
79+
return (r.Type == HasOneRelation || r.Type == BelongsToRelation) && (effectsM2M || nonePK)
4180
}
4281

4382
func (r *Relation) String() string {

schema/table.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ type Table struct {
6262
FieldMap map[string]*Field
6363
StructMap map[string]*structField
6464

65-
Relations map[string]*Relation
66-
Unique map[string][]*Field
65+
IsM2MTable bool // If true, this table is the "junction table" of an m2m relation.
66+
Relations map[string]*Relation
67+
Unique map[string][]*Field
6768

6869
SoftDeleteField *Field
6970
UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
@@ -516,6 +517,7 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
516517
}
517518

518519
field := &Field{
520+
Table: t,
519521
StructField: sf,
520522
IsPtr: sf.Type.Kind() == reflect.Ptr,
521523

@@ -895,6 +897,7 @@ func (t *Table) m2mRelation(field *Field) *Relation {
895897
JoinTable: joinTable,
896898
M2MTable: m2mTable,
897899
}
900+
m2mTable.markM2M()
898901

899902
if field.Tag.HasOption("join_on") {
900903
rel.Condition = field.Tag.Options["join_on"]
@@ -940,6 +943,10 @@ func (t *Table) m2mRelation(field *Field) *Relation {
940943
return rel
941944
}
942945

946+
func (t *Table) markM2M() {
947+
t.IsM2MTable = true
948+
}
949+
943950
//------------------------------------------------------------------------------
944951

945952
func (t *Table) Dialect() Dialect { return t.dialect }

0 commit comments

Comments
 (0)