Skip to content

Commit 26376d4

Browse files
committed
Add mismatched package naming handling to protoc template generator
1 parent 950aef0 commit 26376d4

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

pkg/capabilities/v2/protoc/pkg/template_generator.go

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ type templateGenerator struct {
2424
}
2525

2626
func (t *templateGenerator) GenerateFile(file *protogen.File, plugin *protogen.Plugin, args any) error {
27-
fileName, content, err := t.Generate(path.Base(file.GeneratedFilenamePrefix), args)
27+
importToPkg := make(map[protogen.GoImportPath]protogen.GoPackageName)
28+
for _, f := range plugin.Files {
29+
importToPkg[f.GoImportPath] = f.GoPackageName
30+
}
31+
32+
fileName, content, err := t.Generate(path.Base(file.GeneratedFilenamePrefix), args, importToPkg)
2833
if err != nil {
2934
return err
3035
}
@@ -34,13 +39,13 @@ func (t *templateGenerator) GenerateFile(file *protogen.File, plugin *protogen.P
3439
return nil
3540
}
3641

37-
func (t *templateGenerator) Generate(baseFile, args any) (string, string, error) {
38-
fileName, err := runTemplate(t.Name+"_fileName", t.FileNameTemplate, baseFile, t.Partials)
42+
func (t *templateGenerator) Generate(baseFile, args any, importToPkg map[protogen.GoImportPath]protogen.GoPackageName) (string, string, error) {
43+
fileName, err := runTemplate(t.Name+"_fileName", t.FileNameTemplate, baseFile, t.Partials, importToPkg)
3944
if err != nil {
4045
return "", "", err
4146
}
4247

43-
file, err := runTemplate(t.Name, t.Template, args, t.Partials)
48+
file, err := runTemplate(t.Name, t.Template, args, t.Partials, importToPkg)
4449
if err != nil {
4550
return fileName, "", err
4651
}
@@ -57,7 +62,7 @@ func (t *templateGenerator) Generate(baseFile, args any) (string, string, error)
5762
return fileName, prettyFile, err
5863
}
5964

60-
func runTemplate(name, tmplText string, args any, partials map[string]string) (string, error) {
65+
func runTemplate(name, tmplText string, args any, partials map[string]string, importToPkg map[protogen.GoImportPath]protogen.GoPackageName) (string, error) {
6166
buf := &bytes.Buffer{}
6267
imports := map[string]bool{}
6368
templ := template.New(name).Funcs(template.FuncMap{
@@ -85,11 +90,18 @@ func runTemplate(name, tmplText string, args any, partials map[string]string) (s
8590
return m, nil
8691
},
8792
"isTrigger": func(m *protogen.Method) bool { return m.Desc.IsStreamingServer() },
88-
"addImport": func(name protogen.GoImportPath, ignore string) string {
89-
if ignore != name.String() {
90-
imports[name.String()] = true
93+
"addImport": func(importPath protogen.GoImportPath, ignore string) string {
94+
importName := importPath.String()
95+
if ignore == importName {
96+
return ""
9197
}
9298

99+
// add package name alias if path is mismatched with the package name
100+
if !isDirNamePackageName(importPath, importToPkg) {
101+
importName = fmt.Sprintf("%s %s", importToPkg[importPath], importName)
102+
}
103+
104+
imports[importName] = true
93105
return ""
94106
},
95107
"allimports": func() []string {
@@ -105,10 +117,14 @@ func runTemplate(name, tmplText string, args any, partials map[string]string) (s
105117
return ident.GoName
106118
}
107119

108-
// remove quotes
109-
importPath = importPath[1 : len(importPath)-1]
110-
parts := strings.Split(importPath, "/")
111-
return fmt.Sprintf("%s.%s", parts[len(parts)-1], ident.GoName)
120+
packageName := path.Base(strings.Trim(importPath, `"`))
121+
122+
// use package name alias if package is mismatched with the package name
123+
if !isDirNamePackageName(ident.GoImportPath, importToPkg) {
124+
packageName = string(importToPkg[ident.GoImportPath])
125+
}
126+
127+
return fmt.Sprintf("%s.%s", packageName, ident.GoName)
112128
},
113129
"CapabilityId": func(s *protogen.Service) (string, error) {
114130
// TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-797 ID should be allowed to require a parameter.
@@ -164,6 +180,12 @@ func runTemplate(name, tmplText string, args any, partials map[string]string) (s
164180
return buf.String(), err
165181
}
166182

183+
func isDirNamePackageName(importPath protogen.GoImportPath, importToPkg map[protogen.GoImportPath]protogen.GoPackageName) bool {
184+
packageName := importToPkg[importPath]
185+
dirName := path.Base(strings.Trim(importPath.String(), `"`))
186+
return dirName == string(packageName)
187+
}
188+
167189
func getCapabilityMetadata(service *protogen.Service) (*pb.CapabilityMetadata, error) {
168190
opts := service.Desc.Options().(*descriptorpb.ServiceOptions)
169191
if proto.HasExtension(opts, pb.E_Capability) {

0 commit comments

Comments
 (0)