Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tools/goctl/model/sql/gen/findonebyfield.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e
for _, key := range table.UniqueCacheKey {
in, paramJoinString, originalFieldString := convertJoin(key, postgreSql)

// Append partial index predicate (WHERE clause) for PostgreSQL partial unique indexes.
// The predicate originates from pg_get_expr(C.INDPRED, C.INDRELID) on the PostgreSQL
// system catalog — it is schema metadata, not user input, so there is no SQL injection surface.
// pg_get_expr wraps the predicate in outer parentheses, so the result is correctly grouped.
// e.g. "sku = $1 and ((status = 1) AND (deleted_at IS NULL))"
if len(key.Predicate) > 0 {
originalFieldString = originalFieldString + " and " + key.Predicate
}

output, err := t.Execute(map[string]any{
"upperStartCamelObject": camelTableName,
"upperField": key.FieldNameJoin.Camel().With("").Source(),
Expand Down
80 changes: 77 additions & 3 deletions tools/goctl/model/sql/gen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import (
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/builder"
"github.com/zeromicro/go-zero/core/stringx"
corstringx "github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
)

//go:embed testdata/user.sql
Expand Down Expand Up @@ -158,8 +159,8 @@ func TestFields(t *testing.T) {
var (
studentFieldNames = builder.RawFieldNames(&Student{})
studentRows = strings.Join(studentFieldNames, ",")
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
studentRowsExpectAutoSet = strings.Join(corstringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
studentRowsWithPlaceHolder = strings.Join(corstringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
)

assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
Expand Down Expand Up @@ -198,3 +199,76 @@ func Test_genPublicModel(t *testing.T) {
assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n"))
assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {"))
}

func TestGenFindOneByFieldWithPartialIndex(t *testing.T) {
primaryField := &parser.Field{
Name: stringx.From("id"),
DataType: "int64",
Comment: "主键",
}
emailField := &parser.Field{
Name: stringx.From("email"),
DataType: "string",
Comment: "邮箱",
}
parsedTable := parser.Table{
Name: stringx.From("user"),
Db: stringx.From("go_zero"),
PrimaryKey: parser.Primary{
Field: *primaryField,
},
UniqueIndex: map[string][]*parser.Field{
"idx_active_email": {emailField},
},
UniqueIndexPredicate: map[string]string{
"idx_active_email": "(status = 1) AND (deleted_at IS NULL)",
},
Fields: []*parser.Field{
primaryField,
emailField,
},
}

primaryKey, uniqueKeys := genCacheKeys("cache", parsedTable)

table := Table{
Table: parsedTable,
PrimaryCacheKey: primaryKey,
UniqueCacheKey: uniqueKeys,
ContainsUniqueCacheKey: len(uniqueKeys) > 0,
}

result, err := genFindOneByField(table, false, true)
assert.NoError(t, err)

assert.Contains(t, result.findOneMethod, "where email = $1 and (status = 1) AND (deleted_at IS NULL)")
assert.Contains(t, result.findOneMethod, "FindOneByEmail")

// Verify non-partial indexes do NOT append predicate
parsedTable2 := parser.Table{
Name: stringx.From("user"),
Db: stringx.From("go_zero"),
PrimaryKey: parser.Primary{
Field: *primaryField,
},
UniqueIndex: map[string][]*parser.Field{
"email_unique": {emailField},
},
UniqueIndexPredicate: map[string]string{},
Fields: []*parser.Field{
primaryField,
emailField,
},
}
_, uniqueKeys2 := genCacheKeys("cache", parsedTable2)
table2 := Table{
Table: parsedTable2,
PrimaryCacheKey: primaryKey,
UniqueCacheKey: uniqueKeys2,
ContainsUniqueCacheKey: len(uniqueKeys2) > 0,
}
result2, err := genFindOneByField(table2, false, true)
assert.NoError(t, err)
assert.Contains(t, result2.findOneMethod, "where email = $1")
assert.NotContains(t, result2.findOneMethod, " and (")
}
33 changes: 28 additions & 5 deletions tools/goctl/model/sql/gen/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type Key struct {
FieldNameJoin Join
// Fields describes the fields of table
Fields []*parser.Field
// Predicate describes the partial index WHERE clause, e.g. "status = 1 AND deleted_at IS NULL"
Predicate string
}

// Join describes an alias of string slice
Expand All @@ -41,8 +43,18 @@ func genCacheKeys(prefix string, table parser.Table) (Key, []Key) {
var primaryKey Key
primaryKey = genCacheKey(prefix, table.Db, table.Name, []*parser.Field{&table.PrimaryKey.Field})
uniqueKey := make([]Key, 0, len(table.UniqueIndex))
for _, each := range table.UniqueIndex {
uniqueKey = append(uniqueKey, genCacheKey(prefix, table.Db, table.Name, each))
for indexName, each := range table.UniqueIndex {
var suffix string
if table.UniqueIndexPredicate != nil {
if pred := table.UniqueIndexPredicate[indexName]; pred != "" {
suffix = stringx.From(indexName).ToCamel()
}
}
key := genCacheKey(prefix, table.Db, table.Name, each, suffix)
if table.UniqueIndexPredicate != nil {
key.Predicate = table.UniqueIndexPredicate[indexName]
}
uniqueKey = append(uniqueKey, key)
}
sort.Slice(uniqueKey, func(i, j int) bool {
return uniqueKey[i].VarLeft < uniqueKey[j].VarLeft
Expand All @@ -51,7 +63,7 @@ func genCacheKeys(prefix string, table parser.Table) (Key, []Key) {
return primaryKey, uniqueKey
}

func genCacheKey(prefix string, db, table stringx.String, in []*parser.Field) Key {
func genCacheKey(prefix string, db, table stringx.String, in []*parser.Field, suffix ...string) Key {
var (
varLeftJoin, varRightJoin, fieldNameJoin Join
varLeft, varRight, varExpression string
Expand Down Expand Up @@ -84,10 +96,21 @@ func genCacheKey(prefix string, db, table stringx.String, in []*parser.Field) Ke
keyLeftJoin = append(keyLeftJoin, "key")

varLeft = util.SafeString(varLeftJoin.Camel().With("").Untitle())
varRight = fmt.Sprintf(`"%s"`, varRightJoin.Camel().Untitle().With(":").Source()+":")
keyLeft = util.SafeString(keyLeftJoin.Camel().With("").Untitle())

// Append uniqueness suffix to disambiguate cache keys for partial indexes
// sharing the same column set but with different WHERE predicates.
// Applied to both Go variable names (varLeft/keyLeft) and Redis key prefix value (varRight).
suffixPart := ""
if len(suffix) > 0 && suffix[0] != "" {
suffixPart = suffix[0] + ":"
varLeft += suffix[0]
keyLeft += suffix[0]
}

varRight = fmt.Sprintf(`"%s"`, varRightJoin.Camel().Untitle().With(":").Source()+":"+suffixPart)
varExpression = fmt.Sprintf(`%s = %s`, varLeft, varRight)

keyLeft = util.SafeString(keyLeftJoin.Camel().With("").Untitle())
keyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, keyRightJoin.With(", ").Source())
dataKeyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, dataRightJoin.With(", ").Source())
keyExpression = fmt.Sprintf("%s := %s", keyLeft, keyRight)
Expand Down
78 changes: 77 additions & 1 deletion tools/goctl/model/sql/gen/keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,81 @@ func TestGenCacheKeys(t *testing.T) {
})
}())
})
t.Run("partial unique index with different predicates", func(t *testing.T) {
emailField := &parser.Field{
Name: stringx.From("email"),
DataType: "string",
Comment: "邮箱",
SeqInIndex: 1,
}
_, uniqueKeys := genCacheKeys("cache", parser.Table{
Name: stringx.From("user"),
Db: stringx.From("go_zero"),
PrimaryKey: parser.Primary{
Field: *primaryField,
AutoIncrement: true,
},
UniqueIndex: map[string][]*parser.Field{
"idx_active_email": {emailField},
"idx_deleted_email": {emailField},
},
UniqueIndexPredicate: map[string]string{
"idx_active_email": "status = 1",
"idx_deleted_email": "deleted_at IS NULL",
},
Fields: []*parser.Field{
primaryField,
emailField,
},
})

// Same column, different predicates → different cache key variable names,
// disambiguated by the CamelCase index name suffix.
// Both Go variable names and Redis key prefix values must differ.
assert.Equal(t, 2, len(uniqueKeys))
assert.NotEqual(t, uniqueKeys[0].VarLeft, uniqueKeys[1].VarLeft,
"partial indexes with different predicates must have different cache keys")
assert.NotEqual(t, uniqueKeys[0].KeyLeft, uniqueKeys[1].KeyLeft)
assert.NotEqual(t, uniqueKeys[0].VarRight, uniqueKeys[1].VarRight,
"partial indexes with different predicates must have different Redis key prefixes")
assert.Contains(t, uniqueKeys[0].VarLeft, "IdxActiveEmail")
assert.Contains(t, uniqueKeys[0].KeyLeft, "IdxActiveEmail")
assert.Contains(t, uniqueKeys[0].VarRight, "IdxActiveEmail")
assert.Contains(t, uniqueKeys[1].VarLeft, "IdxDeletedEmail")
assert.Contains(t, uniqueKeys[1].KeyLeft, "IdxDeletedEmail")
assert.Contains(t, uniqueKeys[1].VarRight, "IdxDeletedEmail")
assert.NotEmpty(t, uniqueKeys[0].Predicate)
assert.NotEmpty(t, uniqueKeys[1].Predicate)
})
t.Run("partial unique index backward compatible", func(t *testing.T) {
emailField := &parser.Field{
Name: stringx.From("email"),
DataType: "string",
Comment: "邮箱",
SeqInIndex: 1,
}
_, uniqueKeys := genCacheKeys("cache", parser.Table{
Name: stringx.From("user"),
Db: stringx.From("go_zero"),
PrimaryKey: parser.Primary{
Field: *primaryField,
AutoIncrement: true,
},
UniqueIndex: map[string][]*parser.Field{
"email_unique": {emailField},
},
UniqueIndexPredicate: map[string]string{},
Fields: []*parser.Field{
primaryField,
emailField,
},
})

// No predicate → no suffix → same as original cache key name
assert.Equal(t, 1, len(uniqueKeys))
assert.Equal(t, "cacheGoZeroUserEmailPrefix", uniqueKeys[0].VarLeft)
assert.Equal(t, "goZeroUserEmailKey", uniqueKeys[0].KeyLeft)
})
}

func cacheKeyEqual(k1, k2 Key) bool {
Expand All @@ -202,5 +277,6 @@ func cacheKeyEqual(k1, k2 Key) bool {
k1.KeyRight == k2.KeyRight &&
k1.DataKeyRight == k2.DataKeyRight &&
k1.DataKeyExpression == k2.DataKeyExpression &&
k1.KeyExpression == k2.KeyExpression
k1.KeyExpression == k2.KeyExpression &&
k1.Predicate == k2.Predicate
}
12 changes: 9 additions & 3 deletions tools/goctl/model/sql/model/infoschemamodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type (
IndexName string `db:"INDEX_NAME"`
NonUnique int `db:"NON_UNIQUE"`
SeqInIndex int `db:"SEQ_IN_INDEX"`
Predicate string // partial index WHERE clause, e.g. "status = 1 AND deleted_at IS NULL"
}

// ColumnData describes the columns of table
Expand All @@ -54,9 +55,10 @@ type (
Table string
Columns []*Column
// Primary key not included
UniqueIndex map[string][]*Column
PrimaryKey *Column
NormalIndex map[string][]*Column
UniqueIndex map[string][]*Column
PrimaryKey *Column
NormalIndex map[string][]*Column
UniqueIndexPredicate map[string]string
}

// IndexType describes an alias of string
Expand Down Expand Up @@ -150,6 +152,7 @@ func (c *ColumnData) Convert() (*Table, error) {
table.Columns = c.Columns
table.UniqueIndex = map[string][]*Column{}
table.NormalIndex = map[string][]*Column{}
table.UniqueIndexPredicate = map[string]string{}

m := make(map[string][]*Column)
for _, each := range c.Columns {
Expand Down Expand Up @@ -178,6 +181,9 @@ func (c *ColumnData) Convert() (*Table, error) {
if one.Index != nil {
if one.Index.NonUnique == 0 {
table.UniqueIndex[indexName] = columns
if one.Index.Predicate != "" {
table.UniqueIndexPredicate[indexName] = one.Index.Predicate
}
} else {
table.NormalIndex[indexName] = columns
}
Expand Down
5 changes: 4 additions & 1 deletion tools/goctl/model/sql/model/postgresqlmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type PostgreIndex struct {
IsPrimary sql.NullBool `db:"is_primary"`
ColumnName sql.NullString `db:"column_name"`
IndexSort sql.NullInt32 `db:"index_sort"`
Predicate sql.NullString `db:"predicate"`
}

// NewPostgreSqlModel creates an instance and return
Expand Down Expand Up @@ -202,6 +203,7 @@ func (m *PostgreSqlModel) getIndex(schema, table string) (map[string][]*DbIndex,
IndexName: e.IndexName.String,
NonUnique: nonUnique,
SeqInIndex: int(e.IndexSort.Int32),
Predicate: e.Predicate.String,
})
}

Expand All @@ -215,7 +217,8 @@ func (m *PostgreSqlModel) FindIndex(schema, table string) ([]*PostgreIndex, erro
C.INDISUNIQUE AS is_unique,
C.INDISPRIMARY AS is_primary,
G.ATTNAME AS column_name,
G.attnum AS index_sort
G.attnum AS index_sort,
pg_get_expr(C.INDPRED, C.INDRELID) AS predicate
from PG_AM B
left join PG_CLASS F on
B.OID = F.RELAM
Expand Down
17 changes: 11 additions & 6 deletions tools/goctl/model/sql/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ const timeImport = "time.Time"
type (
// Table describes a mysql table
Table struct {
Name stringx.String
Db stringx.String
PrimaryKey Primary
UniqueIndex map[string][]*Field
Fields []*Field
ContainsPQ bool
Name stringx.String
Db stringx.String
PrimaryKey Primary
UniqueIndex map[string][]*Field
UniqueIndexPredicate map[string]string
Fields []*Field
ContainsPQ bool
}

// Primary describes a primary key
Expand Down Expand Up @@ -277,6 +278,7 @@ func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
var reply Table
reply.ContainsPQ = containsPQ
reply.UniqueIndex = map[string][]*Field{}
reply.UniqueIndexPredicate = map[string]string{}
reply.Name = stringx.From(table.Table)
reply.Db = stringx.From(table.Db)
seqInIndex := 0
Expand Down Expand Up @@ -344,6 +346,9 @@ func ConvertDataType(table *model.Table, strict bool) (*Table, error) {

uniqueIndexSet.Add(uniqueKey)
reply.UniqueIndex[indexName] = list
if predicate, ok := table.UniqueIndexPredicate[indexName]; ok && predicate != "" {
reply.UniqueIndexPredicate[indexName] = predicate
}
}

return &reply, nil
Expand Down
Loading