@@ -24,7 +24,12 @@ type templateGenerator struct {
24
24
}
25
25
26
26
func (t * templateGenerator ) GenerateFile (file * protogen.File , plugin * protogen.Plugin , args any ) error {
27
- fileName , content , err := t .Generate (path .Base (file .GeneratedFilenamePrefix ), args )
27
+ importToPkg := make (map [protogen.GoImportPath ]protogen.GoPackageName )
28
+ for _ , f := range plugin .Files {
29
+ importToPkg [f .GoImportPath ] = f .GoPackageName
30
+ }
31
+
32
+ fileName , content , err := t .Generate (path .Base (file .GeneratedFilenamePrefix ), args , importToPkg )
28
33
if err != nil {
29
34
return err
30
35
}
@@ -34,13 +39,13 @@ func (t *templateGenerator) GenerateFile(file *protogen.File, plugin *protogen.P
34
39
return nil
35
40
}
36
41
37
- func (t * templateGenerator ) Generate (baseFile , args any ) (string , string , error ) {
38
- fileName , err := runTemplate (t .Name + "_fileName" , t .FileNameTemplate , baseFile , t .Partials )
42
+ func (t * templateGenerator ) Generate (baseFile , args any , importToPkg map [protogen. GoImportPath ]protogen. GoPackageName ) (string , string , error ) {
43
+ fileName , err := runTemplate (t .Name + "_fileName" , t .FileNameTemplate , baseFile , t .Partials , importToPkg )
39
44
if err != nil {
40
45
return "" , "" , err
41
46
}
42
47
43
- file , err := runTemplate (t .Name , t .Template , args , t .Partials )
48
+ file , err := runTemplate (t .Name , t .Template , args , t .Partials , importToPkg )
44
49
if err != nil {
45
50
return fileName , "" , err
46
51
}
@@ -57,7 +62,7 @@ func (t *templateGenerator) Generate(baseFile, args any) (string, string, error)
57
62
return fileName , prettyFile , err
58
63
}
59
64
60
- func runTemplate (name , tmplText string , args any , partials map [string ]string ) (string , error ) {
65
+ func runTemplate (name , tmplText string , args any , partials map [string ]string , importToPkg map [protogen. GoImportPath ]protogen. GoPackageName ) (string , error ) {
61
66
buf := & bytes.Buffer {}
62
67
imports := map [string ]bool {}
63
68
templ := template .New (name ).Funcs (template.FuncMap {
@@ -85,11 +90,18 @@ func runTemplate(name, tmplText string, args any, partials map[string]string) (s
85
90
return m , nil
86
91
},
87
92
"isTrigger" : func (m * protogen.Method ) bool { return m .Desc .IsStreamingServer () },
88
- "addImport" : func (name protogen.GoImportPath , ignore string ) string {
89
- if ignore != name .String () {
90
- imports [name .String ()] = true
93
+ "addImport" : func (importPath protogen.GoImportPath , ignore string ) string {
94
+ importName := importPath .String ()
95
+ if ignore == importName {
96
+ return ""
91
97
}
92
98
99
+ // add package name alias if path is mismatched with the package name
100
+ if ! isDirNamePackageName (importPath , importToPkg ) {
101
+ importName = fmt .Sprintf ("%s %s" , importToPkg [importPath ], importName )
102
+ }
103
+
104
+ imports [importName ] = true
93
105
return ""
94
106
},
95
107
"allimports" : func () []string {
@@ -105,10 +117,14 @@ func runTemplate(name, tmplText string, args any, partials map[string]string) (s
105
117
return ident .GoName
106
118
}
107
119
108
- // remove quotes
109
- importPath = importPath [1 : len (importPath )- 1 ]
110
- parts := strings .Split (importPath , "/" )
111
- return fmt .Sprintf ("%s.%s" , parts [len (parts )- 1 ], ident .GoName )
120
+ packageName := path .Base (strings .Trim (importPath , `"` ))
121
+
122
+ // use package name alias if package is mismatched with the package name
123
+ if ! isDirNamePackageName (ident .GoImportPath , importToPkg ) {
124
+ packageName = string (importToPkg [ident .GoImportPath ])
125
+ }
126
+
127
+ return fmt .Sprintf ("%s.%s" , packageName , ident .GoName )
112
128
},
113
129
"CapabilityId" : func (s * protogen.Service ) (string , error ) {
114
130
// TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-797 ID should be allowed to require a parameter.
@@ -164,6 +180,12 @@ func runTemplate(name, tmplText string, args any, partials map[string]string) (s
164
180
return buf .String (), err
165
181
}
166
182
183
+ func isDirNamePackageName (importPath protogen.GoImportPath , importToPkg map [protogen.GoImportPath ]protogen.GoPackageName ) bool {
184
+ packageName := importToPkg [importPath ]
185
+ dirName := path .Base (strings .Trim (importPath .String (), `"` ))
186
+ return dirName == string (packageName )
187
+ }
188
+
167
189
func getCapabilityMetadata (service * protogen.Service ) (* pb.CapabilityMetadata , error ) {
168
190
opts := service .Desc .Options ().(* descriptorpb.ServiceOptions )
169
191
if proto .HasExtension (opts , pb .E_Capability ) {
0 commit comments