From 1003cd3a439fd1b6d67698c0936e878af29ffd47 Mon Sep 17 00:00:00 2001 From: hoshi Date: Wed, 29 Apr 2026 15:44:03 +0800 Subject: [PATCH 1/2] fix(goctl): preserve PostgreSQL partial unique index WHERE clause in model generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #3841 pg_get_expr(C.INDPRED, C.INDRELID) now retrieves the partial index predicate from pg_index, propagated through all layers (PostgreIndex → DbIndex → Column → model.Table → parser.Table → Key) and appended to generated FindOneByXxx SQL WHERE clauses. Cache keys are disambiguated using the index name as suffix for partial unique indexes sharing the same column set but different predicates. --- tools/goctl/model/sql/gen/findonebyfield.go | 7 ++ tools/goctl/model/sql/gen/keys.go | 33 ++++++-- tools/goctl/model/sql/gen/keys_test.go | 78 ++++++++++++++++++- .../goctl/model/sql/model/infoschemamodel.go | 12 ++- .../goctl/model/sql/model/postgresqlmodel.go | 5 +- tools/goctl/model/sql/parser/parser.go | 17 ++-- 6 files changed, 136 insertions(+), 16 deletions(-) diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index 1581509bfac7..c56d696fd2ae 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -28,6 +28,13 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e for _, key := range table.UniqueCacheKey { in, paramJoinString, originalFieldString := convertJoin(key, postgreSql) + // Append partial index predicate (WHERE clause) for PostgreSQL partial unique indexes. + // pg_get_expr wraps the predicate in outer parentheses, so the result is correctly grouped. + // e.g. "sku = $1 and ((status = 1) AND (deleted_at IS NULL))" + if len(key.Predicate) > 0 { + originalFieldString = originalFieldString + " and " + key.Predicate + } + output, err := t.Execute(map[string]any{ "upperStartCamelObject": camelTableName, "upperField": key.FieldNameJoin.Camel().With("").Source(), diff --git a/tools/goctl/model/sql/gen/keys.go b/tools/goctl/model/sql/gen/keys.go index 8e4920079c1c..ec359e0b214b 100644 --- a/tools/goctl/model/sql/gen/keys.go +++ b/tools/goctl/model/sql/gen/keys.go @@ -32,6 +32,8 @@ type Key struct { FieldNameJoin Join // Fields describes the fields of table Fields []*parser.Field + // Predicate describes the partial index WHERE clause, e.g. "status = 1 AND deleted_at IS NULL" + Predicate string } // Join describes an alias of string slice @@ -41,8 +43,18 @@ func genCacheKeys(prefix string, table parser.Table) (Key, []Key) { var primaryKey Key primaryKey = genCacheKey(prefix, table.Db, table.Name, []*parser.Field{&table.PrimaryKey.Field}) uniqueKey := make([]Key, 0, len(table.UniqueIndex)) - for _, each := range table.UniqueIndex { - uniqueKey = append(uniqueKey, genCacheKey(prefix, table.Db, table.Name, each)) + for indexName, each := range table.UniqueIndex { + var suffix string + if table.UniqueIndexPredicate != nil { + if pred := table.UniqueIndexPredicate[indexName]; pred != "" { + suffix = stringx.From(indexName).ToCamel() + } + } + key := genCacheKey(prefix, table.Db, table.Name, each, suffix) + if table.UniqueIndexPredicate != nil { + key.Predicate = table.UniqueIndexPredicate[indexName] + } + uniqueKey = append(uniqueKey, key) } sort.Slice(uniqueKey, func(i, j int) bool { return uniqueKey[i].VarLeft < uniqueKey[j].VarLeft @@ -51,7 +63,7 @@ func genCacheKeys(prefix string, table parser.Table) (Key, []Key) { return primaryKey, uniqueKey } -func genCacheKey(prefix string, db, table stringx.String, in []*parser.Field) Key { +func genCacheKey(prefix string, db, table stringx.String, in []*parser.Field, suffix ...string) Key { var ( varLeftJoin, varRightJoin, fieldNameJoin Join varLeft, varRight, varExpression string @@ -84,10 +96,21 @@ func genCacheKey(prefix string, db, table stringx.String, in []*parser.Field) Ke keyLeftJoin = append(keyLeftJoin, "key") varLeft = util.SafeString(varLeftJoin.Camel().With("").Untitle()) - varRight = fmt.Sprintf(`"%s"`, varRightJoin.Camel().Untitle().With(":").Source()+":") + keyLeft = util.SafeString(keyLeftJoin.Camel().With("").Untitle()) + + // Append uniqueness suffix to disambiguate cache keys for partial indexes + // sharing the same column set but with different WHERE predicates. + // Applied to both Go variable names (varLeft/keyLeft) and Redis key prefix value (varRight). + suffixPart := "" + if len(suffix) > 0 && suffix[0] != "" { + suffixPart = suffix[0] + ":" + varLeft += suffix[0] + keyLeft += suffix[0] + } + + varRight = fmt.Sprintf(`"%s"`, varRightJoin.Camel().Untitle().With(":").Source()+":"+suffixPart) varExpression = fmt.Sprintf(`%s = %s`, varLeft, varRight) - keyLeft = util.SafeString(keyLeftJoin.Camel().With("").Untitle()) keyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, keyRightJoin.With(", ").Source()) dataKeyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, dataRightJoin.With(", ").Source()) keyExpression = fmt.Sprintf("%s := %s", keyLeft, keyRight) diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go index fc86674d6aaa..09c671c495fa 100644 --- a/tools/goctl/model/sql/gen/keys_test.go +++ b/tools/goctl/model/sql/gen/keys_test.go @@ -177,6 +177,81 @@ func TestGenCacheKeys(t *testing.T) { }) }()) }) + t.Run("partial unique index with different predicates", func(t *testing.T) { + emailField := &parser.Field{ + Name: stringx.From("email"), + DataType: "string", + Comment: "邮箱", + SeqInIndex: 1, + } + _, uniqueKeys := genCacheKeys("cache", parser.Table{ + Name: stringx.From("user"), + Db: stringx.From("go_zero"), + PrimaryKey: parser.Primary{ + Field: *primaryField, + AutoIncrement: true, + }, + UniqueIndex: map[string][]*parser.Field{ + "idx_active_email": {emailField}, + "idx_deleted_email": {emailField}, + }, + UniqueIndexPredicate: map[string]string{ + "idx_active_email": "status = 1", + "idx_deleted_email": "deleted_at IS NULL", + }, + Fields: []*parser.Field{ + primaryField, + emailField, + }, + }) + + // Same column, different predicates → different cache key variable names, + // disambiguated by the CamelCase index name suffix. + // Both Go variable names and Redis key prefix values must differ. + assert.Equal(t, 2, len(uniqueKeys)) + assert.NotEqual(t, uniqueKeys[0].VarLeft, uniqueKeys[1].VarLeft, + "partial indexes with different predicates must have different cache keys") + assert.NotEqual(t, uniqueKeys[0].KeyLeft, uniqueKeys[1].KeyLeft) + assert.NotEqual(t, uniqueKeys[0].VarRight, uniqueKeys[1].VarRight, + "partial indexes with different predicates must have different Redis key prefixes") + assert.Contains(t, uniqueKeys[0].VarLeft, "IdxActiveEmail") + assert.Contains(t, uniqueKeys[0].KeyLeft, "IdxActiveEmail") + assert.Contains(t, uniqueKeys[0].VarRight, "IdxActiveEmail") + assert.Contains(t, uniqueKeys[1].VarLeft, "IdxDeletedEmail") + assert.Contains(t, uniqueKeys[1].KeyLeft, "IdxDeletedEmail") + assert.Contains(t, uniqueKeys[1].VarRight, "IdxDeletedEmail") + assert.NotEmpty(t, uniqueKeys[0].Predicate) + assert.NotEmpty(t, uniqueKeys[1].Predicate) + }) + t.Run("partial unique index backward compatible", func(t *testing.T) { + emailField := &parser.Field{ + Name: stringx.From("email"), + DataType: "string", + Comment: "邮箱", + SeqInIndex: 1, + } + _, uniqueKeys := genCacheKeys("cache", parser.Table{ + Name: stringx.From("user"), + Db: stringx.From("go_zero"), + PrimaryKey: parser.Primary{ + Field: *primaryField, + AutoIncrement: true, + }, + UniqueIndex: map[string][]*parser.Field{ + "email_unique": {emailField}, + }, + UniqueIndexPredicate: map[string]string{}, + Fields: []*parser.Field{ + primaryField, + emailField, + }, + }) + + // No predicate → no suffix → same as original cache key name + assert.Equal(t, 1, len(uniqueKeys)) + assert.Equal(t, "cacheGoZeroUserEmailPrefix", uniqueKeys[0].VarLeft) + assert.Equal(t, "goZeroUserEmailKey", uniqueKeys[0].KeyLeft) + }) } func cacheKeyEqual(k1, k2 Key) bool { @@ -202,5 +277,6 @@ func cacheKeyEqual(k1, k2 Key) bool { k1.KeyRight == k2.KeyRight && k1.DataKeyRight == k2.DataKeyRight && k1.DataKeyExpression == k2.DataKeyExpression && - k1.KeyExpression == k2.KeyExpression + k1.KeyExpression == k2.KeyExpression && + k1.Predicate == k2.Predicate } diff --git a/tools/goctl/model/sql/model/infoschemamodel.go b/tools/goctl/model/sql/model/infoschemamodel.go index 7271bdf8b7a3..0cc922d2a8e3 100644 --- a/tools/goctl/model/sql/model/infoschemamodel.go +++ b/tools/goctl/model/sql/model/infoschemamodel.go @@ -39,6 +39,7 @@ type ( IndexName string `db:"INDEX_NAME"` NonUnique int `db:"NON_UNIQUE"` SeqInIndex int `db:"SEQ_IN_INDEX"` + Predicate string // partial index WHERE clause, e.g. "status = 1 AND deleted_at IS NULL" } // ColumnData describes the columns of table @@ -54,9 +55,10 @@ type ( Table string Columns []*Column // Primary key not included - UniqueIndex map[string][]*Column - PrimaryKey *Column - NormalIndex map[string][]*Column + UniqueIndex map[string][]*Column + PrimaryKey *Column + NormalIndex map[string][]*Column + UniqueIndexPredicate map[string]string } // IndexType describes an alias of string @@ -150,6 +152,7 @@ func (c *ColumnData) Convert() (*Table, error) { table.Columns = c.Columns table.UniqueIndex = map[string][]*Column{} table.NormalIndex = map[string][]*Column{} + table.UniqueIndexPredicate = map[string]string{} m := make(map[string][]*Column) for _, each := range c.Columns { @@ -178,6 +181,9 @@ func (c *ColumnData) Convert() (*Table, error) { if one.Index != nil { if one.Index.NonUnique == 0 { table.UniqueIndex[indexName] = columns + if one.Index.Predicate != "" { + table.UniqueIndexPredicate[indexName] = one.Index.Predicate + } } else { table.NormalIndex[indexName] = columns } diff --git a/tools/goctl/model/sql/model/postgresqlmodel.go b/tools/goctl/model/sql/model/postgresqlmodel.go index b7bbcb5354c6..a3ea5cdf32e2 100644 --- a/tools/goctl/model/sql/model/postgresqlmodel.go +++ b/tools/goctl/model/sql/model/postgresqlmodel.go @@ -41,6 +41,7 @@ type PostgreIndex struct { IsPrimary sql.NullBool `db:"is_primary"` ColumnName sql.NullString `db:"column_name"` IndexSort sql.NullInt32 `db:"index_sort"` + Predicate sql.NullString `db:"predicate"` } // NewPostgreSqlModel creates an instance and return @@ -202,6 +203,7 @@ func (m *PostgreSqlModel) getIndex(schema, table string) (map[string][]*DbIndex, IndexName: e.IndexName.String, NonUnique: nonUnique, SeqInIndex: int(e.IndexSort.Int32), + Predicate: e.Predicate.String, }) } @@ -215,7 +217,8 @@ func (m *PostgreSqlModel) FindIndex(schema, table string) ([]*PostgreIndex, erro C.INDISUNIQUE AS is_unique, C.INDISPRIMARY AS is_primary, G.ATTNAME AS column_name, - G.attnum AS index_sort + G.attnum AS index_sort, + pg_get_expr(C.INDPRED, C.INDRELID) AS predicate from PG_AM B left join PG_CLASS F on B.OID = F.RELAM diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 81beae142b2a..468faa147d6b 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -20,12 +20,13 @@ const timeImport = "time.Time" type ( // Table describes a mysql table Table struct { - Name stringx.String - Db stringx.String - PrimaryKey Primary - UniqueIndex map[string][]*Field - Fields []*Field - ContainsPQ bool + Name stringx.String + Db stringx.String + PrimaryKey Primary + UniqueIndex map[string][]*Field + UniqueIndexPredicate map[string]string + Fields []*Field + ContainsPQ bool } // Primary describes a primary key @@ -277,6 +278,7 @@ func ConvertDataType(table *model.Table, strict bool) (*Table, error) { var reply Table reply.ContainsPQ = containsPQ reply.UniqueIndex = map[string][]*Field{} + reply.UniqueIndexPredicate = map[string]string{} reply.Name = stringx.From(table.Table) reply.Db = stringx.From(table.Db) seqInIndex := 0 @@ -344,6 +346,9 @@ func ConvertDataType(table *model.Table, strict bool) (*Table, error) { uniqueIndexSet.Add(uniqueKey) reply.UniqueIndex[indexName] = list + if predicate, ok := table.UniqueIndexPredicate[indexName]; ok && predicate != "" { + reply.UniqueIndexPredicate[indexName] = predicate + } } return &reply, nil From cfc8fef680d0718193e33f648388d6bab5b7bf63 Mon Sep 17 00:00:00 2001 From: hoshi Date: Thu, 30 Apr 2026 09:25:28 +0800 Subject: [PATCH 2/2] fix(goctl): add trust assumption comment and integration test for predicate rendering --- tools/goctl/model/sql/gen/findonebyfield.go | 2 + tools/goctl/model/sql/gen/gen_test.go | 80 ++++++++++++++++++++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index c56d696fd2ae..1bdfc40eee88 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -29,6 +29,8 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e in, paramJoinString, originalFieldString := convertJoin(key, postgreSql) // Append partial index predicate (WHERE clause) for PostgreSQL partial unique indexes. + // The predicate originates from pg_get_expr(C.INDPRED, C.INDRELID) on the PostgreSQL + // system catalog — it is schema metadata, not user input, so there is no SQL injection surface. // pg_get_expr wraps the predicate in outer parentheses, so the result is correctly grouped. // e.g. "sku = $1 and ((status = 1) AND (deleted_at IS NULL))" if len(key.Predicate) > 0 { diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index 5caaf5af3aa7..edc2041e8f97 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -14,10 +14,11 @@ import ( "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/builder" - "github.com/zeromicro/go-zero/core/stringx" + corstringx "github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser" "github.com/zeromicro/go-zero/tools/goctl/util/pathx" + "github.com/zeromicro/go-zero/tools/goctl/util/stringx" ) //go:embed testdata/user.sql @@ -158,8 +159,8 @@ func TestFields(t *testing.T) { var ( studentFieldNames = builder.RawFieldNames(&Student{}) studentRows = strings.Join(studentFieldNames, ",") - studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",") - studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" + studentRowsExpectAutoSet = strings.Join(corstringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",") + studentRowsWithPlaceHolder = strings.Join(corstringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" ) assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames) @@ -198,3 +199,76 @@ func Test_genPublicModel(t *testing.T) { assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n")) assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {")) } + +func TestGenFindOneByFieldWithPartialIndex(t *testing.T) { + primaryField := &parser.Field{ + Name: stringx.From("id"), + DataType: "int64", + Comment: "主键", + } + emailField := &parser.Field{ + Name: stringx.From("email"), + DataType: "string", + Comment: "邮箱", + } + parsedTable := parser.Table{ + Name: stringx.From("user"), + Db: stringx.From("go_zero"), + PrimaryKey: parser.Primary{ + Field: *primaryField, + }, + UniqueIndex: map[string][]*parser.Field{ + "idx_active_email": {emailField}, + }, + UniqueIndexPredicate: map[string]string{ + "idx_active_email": "(status = 1) AND (deleted_at IS NULL)", + }, + Fields: []*parser.Field{ + primaryField, + emailField, + }, + } + + primaryKey, uniqueKeys := genCacheKeys("cache", parsedTable) + + table := Table{ + Table: parsedTable, + PrimaryCacheKey: primaryKey, + UniqueCacheKey: uniqueKeys, + ContainsUniqueCacheKey: len(uniqueKeys) > 0, + } + + result, err := genFindOneByField(table, false, true) + assert.NoError(t, err) + + assert.Contains(t, result.findOneMethod, "where email = $1 and (status = 1) AND (deleted_at IS NULL)") + assert.Contains(t, result.findOneMethod, "FindOneByEmail") + + // Verify non-partial indexes do NOT append predicate + parsedTable2 := parser.Table{ + Name: stringx.From("user"), + Db: stringx.From("go_zero"), + PrimaryKey: parser.Primary{ + Field: *primaryField, + }, + UniqueIndex: map[string][]*parser.Field{ + "email_unique": {emailField}, + }, + UniqueIndexPredicate: map[string]string{}, + Fields: []*parser.Field{ + primaryField, + emailField, + }, + } + _, uniqueKeys2 := genCacheKeys("cache", parsedTable2) + table2 := Table{ + Table: parsedTable2, + PrimaryCacheKey: primaryKey, + UniqueCacheKey: uniqueKeys2, + ContainsUniqueCacheKey: len(uniqueKeys2) > 0, + } + result2, err := genFindOneByField(table2, false, true) + assert.NoError(t, err) + assert.Contains(t, result2.findOneMethod, "where email = $1") + assert.NotContains(t, result2.findOneMethod, " and (") +}