Skip to content

Commit 0fd41d9

Browse files
committed
fix(entproto): allow specifying import path
Basically a merge of: ent#616
1 parent a1d02d8 commit 0fd41d9

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

entproto/adapter.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ var (
5757
)
5858

5959
// LoadAdapter takes a *gen.Graph and parses it into protobuf file descriptors
60-
func LoadAdapter(graph *gen.Graph) (*Adapter, error) {
60+
func LoadAdapter(graph *gen.Graph, goPkg string) (*Adapter, error) {
6161
a := &Adapter{
6262
graph: graph,
6363
descriptors: make(map[string]*desc.FileDescriptor),
6464
schemaProtoFiles: make(map[string]string),
6565
errors: make(map[string]error),
66+
goPkg: goPkg,
6667
}
6768
if err := a.parse(); err != nil {
6869
return nil, err
@@ -76,6 +77,7 @@ type Adapter struct {
7677
descriptors map[string]*desc.FileDescriptor
7778
schemaProtoFiles map[string]string
7879
errors map[string]error
80+
goPkg string
7981
}
8082

8183
// AllFileDescriptors returns a file descriptor per proto package for each package that contains
@@ -204,10 +206,15 @@ func (a *Adapter) parse() error {
204206
}
205207

206208
func (a *Adapter) goPackageName(protoPkgName string) string {
207-
// TODO(rotemtam): make this configurable from an annotation
208-
entBase := a.graph.Config.Package
209-
slashed := strings.ReplaceAll(protoPkgName, ".", "/")
210-
return path.Join(entBase, "proto", slashed)
209+
// TODO(kdevo): maybe better to make this configurable from an annotation
210+
slashedProtoPkg := strings.ReplaceAll(protoPkgName, ".", "/")
211+
if a.goPkg == "" {
212+
entBase := a.graph.Config.Package
213+
return path.Join(entBase, "proto", slashedProtoPkg)
214+
} else {
215+
slashed := strings.ReplaceAll(protoPkgName, ".", "/")
216+
return path.Join(a.goPkg, slashed)
217+
}
211218
}
212219

213220
// GetFileDescriptor returns the proto file descriptor containing the transformed proto message descriptor for

entproto/extension.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ type Extension struct {
5858
entc.DefaultExtension
5959
protoDir string
6060
skipGenFile bool
61+
goPkg string
6162
}
6263

6364
// WithProtoDir sets the directory where the generated .proto files will be written.
@@ -67,6 +68,13 @@ func WithProtoDir(dir string) ExtensionOption {
6768
}
6869
}
6970

71+
// WithProtoDir sets the directory where the generated .proto files will be written.
72+
func WithGoPkg(pkg string) ExtensionOption {
73+
return func(e *Extension) {
74+
e.goPkg = pkg
75+
}
76+
}
77+
7078
// SkipGenFile skips the generation of a generate.go file next to each .proto file.
7179
func SkipGenFile() ExtensionOption {
7280
return func(e *Extension) {
@@ -124,10 +132,11 @@ func (e *Extension) generate(g *gen.Graph) error {
124132
if e.protoDir != "" {
125133
entProtoDir = e.protoDir
126134
}
127-
adapter, err := LoadAdapter(g)
135+
adapter, err := LoadAdapter(g, e.goPkg)
128136
if err != nil {
129137
return fmt.Errorf("entproto: failed parsing ent graph: %w", err)
130138
}
139+
131140
var errs error
132141
for _, schema := range g.Schemas {
133142
name := schema.Name
@@ -197,7 +206,7 @@ func (e *Extension) generate(g *gen.Graph) error {
197206
if err != nil {
198207
return err
199208
}
200-
contents := protocGenerateGo(fd, toSchema, toEnt, g.Config.Package)
209+
contents := e.protocGenerateGo(fd, toSchema, toEnt, g.Config.Package)
201210
if err := os.WriteFile(genGoPath, []byte(contents), 0600); err != nil {
202211
return fmt.Errorf("entproto: failed generating generate.go file for %q: %w", protoFilePath, err)
203212
}
@@ -216,7 +225,7 @@ func fileExists(fpath string) bool {
216225
return true
217226
}
218227

219-
func protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir, entPath, entPackage string) string {
228+
func (e *Extension) protocGenerateGo(fd *desc.FileDescriptor, toSchemaDir, entPath, entPackage string) string {
220229
levelsUp := len(strings.Split(fd.GetPackage(), "."))
221230
toProtoBase := ""
222231
for i := 0; i < levelsUp; i++ {

0 commit comments

Comments
 (0)