Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/models/enum/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package enum

type Enum string
3 changes: 3 additions & 0 deletions examples/models/enum/enum/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package enum

type Enum string
8 changes: 8 additions & 0 deletions examples/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"database/sql"
"time"

"gorm.io/cli/gorm/examples/models/enum"
enum2 "gorm.io/cli/gorm/examples/models/enum/enum"
"gorm.io/cli/gorm/genconfig"
"gorm.io/datatypes"
"gorm.io/gorm"
Expand Down Expand Up @@ -38,8 +40,14 @@ type User struct {
IsAdult bool `gorm:"column:is_adult"`
Profile string `gen:"json"`
AwardTypes datatypes.JSONSlice[int]
TagTypes datatypes.JSONSlice[UserTagType]
Tag UserTagType
Enum enum.Enum
Enum2 enum2.Enum
}

type UserTagType string

type Account struct {
gorm.Model
UserID sql.NullInt64
Expand Down
4 changes: 4 additions & 0 deletions examples/output/models/user.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

99 changes: 92 additions & 7 deletions internal/gen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,22 @@ func (f Field) Type() string {
}

// Check if type implements allowed interfaces
goType := strings.TrimPrefix(f.GoType, "*")

var (
goType = strings.TrimPrefix(f.GoType, "*")
pkgIdx = strings.LastIndex(goType, ".")
pkgIdx int
pkgName = f.file.Package
typName = goType
)

// Find the last '.' before any generic bracket '['
bracketIdx := strings.Index(goType, "[")
if bracketIdx != -1 {
pkgIdx = strings.LastIndex(goType[:bracketIdx], ".")
} else {
pkgIdx = strings.LastIndex(goType, ".")
}

if pkgIdx > 0 {
pkgName, typName = goType[:pkgIdx], goType[pkgIdx+1:]
}
Expand All @@ -453,21 +462,88 @@ func (f Field) Type() string {
return fmt.Sprintf("field.Number[%s]", goType)
}

// Process generic type parameters to convert full paths to short names
goType = f.processGenericType(goType)

if typ := loadNamedType(f.file.goModDir, f.file.getFullImportPath(pkgName), typName); typ != nil {
if ImplementsAllowedInterfaces(typ) { // For interface-implementing types, use generic Field
return fmt.Sprintf("field.Field[%s]", filepath.Base(goType))
if ImplementsAllowedInterfaces(typ) || IsUnderlyingComparable(typ) {
return fmt.Sprintf("field.Field[%s]", goType)
}
}

// Check if this is a relation field based on its type
if strings.HasPrefix(goType, "[]") {
elementType := filepath.Base(strings.TrimPrefix(goType, "[]"))
elementType := strings.TrimPrefix(goType, "[]")
return fmt.Sprintf("field.Slice[%s]", elementType)
} else if strings.Contains(goType, ".") {
return fmt.Sprintf("field.Struct[%s]", filepath.Base(goType))
return fmt.Sprintf("field.Struct[%s]", goType)
}

return fmt.Sprintf("field.Field[%s]", filepath.Base(goType))
return fmt.Sprintf("field.Field[%s]", goType)
}

// processGenericType converts full package paths in generic type parameters to short names
// and ensures required imports are added
// e.g., "datatypes.JSONSlice[gorm.io/cli/gorm/examples/models.UserTagType]"
//
// -> "datatypes.JSONSlice[models.UserTagType]"
func (f Field) processGenericType(goType string) string {
goType = strings.TrimSpace(goType)

// Handle pointer types by recursively processing without the pointer prefix
if strings.HasPrefix(goType, "*") {
return f.processGenericType(goType[1:])
}

// Handle slice types by recursively processing the element type
if strings.HasPrefix(goType, "[]") {
return "[]" + f.processGenericType(goType[2:])
}

// Split the type into the main identifier and the generic arguments (separated by '[')
mainPart, argsPart, hasArgs := strings.Cut(goType, "[")

// Resolve the package alias for the main type
shortMain := f.file.getImportAliasType(mainPart)

// If generic arguments exist, process them recursively
if hasArgs {
// Remove the trailing closing bracket ']' from the arguments part
cleanArgs := strings.TrimSuffix(argsPart, "]")

// Split the arguments string by comma, respecting nested brackets
// e.g., "TypeA, Map[K,V]" -> ["TypeA", "Map[K,V]"]
args := splitGenericArgs(cleanArgs)

var simplifiedArgs []string
for _, arg := range args {
simplifiedArgs = append(simplifiedArgs, f.processGenericType(arg))
}

return shortMain + "[" + strings.Join(simplifiedArgs, ", ") + "]"
}

return shortMain
}

// getImportAliasType returns the import alias type string for a raw type string
// e.g., "datatypes2 gorm.io/datatypes.JSONSlice" -> "datatypes2.JSONSlice"
func (p *File) getImportAliasType(raw string) string {
lastDot := strings.LastIndex(raw, ".")
if lastDot == -1 {
return raw
}

pathStr := raw[:lastDot]
typeName := raw[lastDot+1:]

pkgName := filepath.Base(pathStr)
imp := p.getImport(pathStr)
if imp != nil {
pkgName = imp.Name
}

return pkgName + "." + typeName
}

// Value returns the field value string with column name for template generation
Expand Down Expand Up @@ -777,6 +853,15 @@ func (p *File) getFullImportPath(shortName string) string {
return shortName
}

func (p *File) getImport(path string) *Import {
for _, i := range p.Imports {
if i.Path == path {
return &i
}
}
return nil
}

// handleAnonymousEmbedding processes anonymous embedded fields and returns true if handled
func (p *File) handleAnonymousEmbedding(field *ast.Field, pkgName string, s *Struct) bool {
// Helper function to add fields from embedded struct
Expand Down
4 changes: 4 additions & 0 deletions internal/gen/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ func TestProcessStructType(t *testing.T) {
{Name: "IsAdult", DBName: "is_adult", GoType: "bool"},
{Name: "Profile", DBName: "profile", GoType: "string", NamedGoType: "json"},
{Name: "AwardTypes", DBName: "award_types", GoType: "datatypes.JSONSlice[int]"},
{Name: "TagTypes", DBName: "tag_types", GoType: "datatypes.JSONSlice[UserTagType]"},
{Name: "Tag", DBName: "tag", GoType: "UserTagType"},
{Name: "Enum", DBName: "enum", GoType: "enum.Enum"}, // 添加
{Name: "Enum2", DBName: "enum2", GoType: "enum2.Enum"},
},
}

Expand Down
32 changes: 32 additions & 0 deletions internal/gen/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ func ImplementsAllowedInterfaces(typ types.Type) bool {
return false
}

func IsUnderlyingComparable(typ types.Type) bool {
underlying := typ.Underlying()
if _, ok := underlying.(*types.Struct); ok {
return false
}
return types.Comparable(underlying)
}

func findGoModDir(filename string) string {
cmd := exec.Command("go", "env", "GOMOD")
cmd.Dir = filepath.Dir(filename)
Expand Down Expand Up @@ -202,3 +210,27 @@ func stripGeneric(s string) string {
}
return s
}

// splitGenericArgs splits a generic type argument string into individual arguments.
func splitGenericArgs(s string) []string {
var args []string
depth := 0
start := 0
for i, char := range s {
switch char {
case '[':
depth++
case ']':
depth--
case ',':
if depth == 0 {
args = append(args, s[start:i])
start = i + 1
}
}
}
if start < len(s) {
args = append(args, s[start:])
}
return args
}