Skip to content

Commit b489dd0

Browse files
authored
gapic: avoid generating duplicate iterators (#155)
1 parent 8a10cc2 commit b489dd0

File tree

4 files changed

+50
-24
lines changed

4 files changed

+50
-24
lines changed

internal/gengapic/gengapic.go

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,19 @@ type generator struct {
181181
serviceConfig *serviceConfig
182182

183183
grpcConf *conf.ServiceConfig
184+
185+
// Auxiliary types to be generated in the package
186+
aux *auxTypes
184187
}
185188

186189
func (g *generator) init(files []*descriptor.FileDescriptorProto) {
187190
g.descInfo = pbinfo.Of(files)
188191

189192
g.comments = map[proto.Message]string{}
190193
g.imports = map[pbinfo.ImportSpec]bool{}
194+
g.aux = &auxTypes{
195+
iters: map[string]*iterType{},
196+
}
191197

192198
for _, f := range files {
193199
for _, loc := range f.GetSourceCodeInfo().GetLocation() {
@@ -299,27 +305,36 @@ func (g *generator) gen(serv *descriptor.ServiceDescriptorProto, pkgName string)
299305
return err
300306
}
301307

302-
aux := auxTypes{
303-
iters: map[string]iterType{},
304-
}
308+
// clear LRO types between services
309+
g.aux.lros = []*descriptor.MethodDescriptorProto{}
310+
305311
for _, m := range serv.Method {
306312
g.methodDoc(m)
307-
if err := g.genMethod(servName, serv, m, &aux); err != nil {
313+
if err := g.genMethod(servName, serv, m); err != nil {
308314
return errors.E(err, "method: %s", m.GetName())
309315
}
310316
}
311317

312-
sort.Slice(aux.lros, func(i, j int) bool {
313-
return aux.lros[i].GetName() < aux.lros[j].GetName()
318+
sort.Slice(g.aux.lros, func(i, j int) bool {
319+
return g.aux.lros[i].GetName() < g.aux.lros[j].GetName()
314320
})
315-
for _, m := range aux.lros {
321+
for _, m := range g.aux.lros {
316322
if err := g.lroType(servName, serv, m); err != nil {
317323
return err
318324
}
319325
}
320326

321-
var iters []iterType
322-
for _, iter := range aux.iters {
327+
var iters []*iterType
328+
for _, iter := range g.aux.iters {
329+
// skip iterators that have already been generated in this package
330+
//
331+
// TODO(ndietz): investigate generating auxiliary types in a
332+
// separate file in the same package to avoid keeping this state
333+
if iter.generated {
334+
continue
335+
}
336+
337+
iter.generated = true
323338
iters = append(iters, iter)
324339
}
325340
sort.Slice(iters, func(i, j int) bool {
@@ -340,14 +355,14 @@ type auxTypes struct {
340355
// "List" of iterator types. We use these to generate FooIterator returned by paging methods.
341356
// Since multiple methods can page over the same type, we dedupe by the name of the iterator,
342357
// which is in turn determined by the element type name.
343-
iters map[string]iterType
358+
iters map[string]*iterType
344359
}
345360

346361
// genMethod generates a single method from a client. m must be a method declared in serv.
347362
// If the generated method requires an auxillary type, it is added to aux.
348-
func (g *generator) genMethod(servName string, serv *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto, aux *auxTypes) error {
363+
func (g *generator) genMethod(servName string, serv *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) error {
349364
if m.GetOutputType() == lroType {
350-
aux.lros = append(aux.lros, m)
365+
g.aux.lros = append(g.aux.lros, m)
351366
return g.lroCall(servName, m)
352367
}
353368

@@ -362,7 +377,7 @@ func (g *generator) genMethod(servName string, serv *descriptor.ServiceDescripto
362377
if err != nil {
363378
return err
364379
}
365-
aux.iters[iter.iterTypeName] = iter
380+
366381
return g.pagingCall(servName, m, pf, iter)
367382
}
368383

internal/gengapic/gengapic_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,22 +253,22 @@ methods:
253253
proto.SetExtension(m.Options, longrunning.E_OperationInfo, lroType)
254254
}
255255

256-
aux := auxTypes{
257-
iters: map[string]iterType{},
256+
g.aux = &auxTypes{
257+
iters: map[string]*iterType{},
258258
}
259-
if err := g.genMethod("Foo", serv, m, &aux); err != nil {
259+
if err := g.genMethod("Foo", serv, m); err != nil {
260260
t.Error(err)
261261
continue
262262
}
263263

264-
for _, m := range aux.lros {
264+
for _, m := range g.aux.lros {
265265
if err := g.lroType("MyService", serv, m); err != nil {
266266
t.Error(err)
267267
continue methods
268268
}
269269
}
270270

271-
for _, iter := range aux.iters {
271+
for _, iter := range g.aux.iters {
272272
g.pagingIter(iter)
273273
}
274274

internal/gengapic/paging.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ type iterType struct {
3030
// If the elem type is a message, elemImports contains pbinfo.ImportSpec for the type.
3131
// Otherwise, len(elemImports)==0.
3232
elemImports []pbinfo.ImportSpec
33+
34+
generated bool
3335
}
3436

3537
// iterTypeOf deduces iterType from a field to be iterated over.
3638
// elemField should be the "resource" of a paginating RPC.
37-
func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (iterType, error) {
39+
func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (*iterType, error) {
3840
var pt iterType
3941

4042
switch t := *elemField.Type; {
@@ -43,7 +45,7 @@ func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (iter
4345

4446
imp, err := g.descInfo.ImportSpec(eType)
4547
if err != nil {
46-
return iterType{}, err
48+
return &iterType{}, err
4749
}
4850

4951
pt.elemTypeName = fmt.Sprintf("*%s.%s", imp.Name, eType.GetName())
@@ -66,7 +68,13 @@ func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (iter
6668
pt.elemTypeName = pType
6769
pt.iterTypeName = upperFirst(pt.elemTypeName) + "Iterator"
6870
}
69-
return pt, nil
71+
72+
if iter, ok := g.aux.iters[pt.iterTypeName]; ok {
73+
return iter, nil
74+
}
75+
g.aux.iters[pt.iterTypeName] = &pt
76+
77+
return &pt, nil
7078
}
7179

7280
// TODO(pongad): this will probably need to read from annotations later.
@@ -126,7 +134,7 @@ func (g *generator) pagingField(m *descriptor.MethodDescriptorProto) (*descripto
126134
return elemFields[0], nil
127135
}
128136

129-
func (g *generator) pagingCall(servName string, m *descriptor.MethodDescriptorProto, elemField *descriptor.FieldDescriptorProto, pt iterType) error {
137+
func (g *generator) pagingCall(servName string, m *descriptor.MethodDescriptorProto, elemField *descriptor.FieldDescriptorProto, pt *iterType) error {
130138
inType := g.descInfo.Type[*m.InputType]
131139
outType := g.descInfo.Type[*m.OutputType]
132140

@@ -202,7 +210,7 @@ func (g *generator) pagingCall(servName string, m *descriptor.MethodDescriptorPr
202210
return nil
203211
}
204212

205-
func (g *generator) pagingIter(pt iterType) {
213+
func (g *generator) pagingIter(pt *iterType) {
206214
p := g.printf
207215

208216
p("// %s manages a stream of %s.", pt.iterTypeName, pt.elemTypeName)

internal/gengapic/paging_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ func TestIterTypeOf(t *testing.T) {
151151
Name: proto.String("Foo"),
152152
}
153153
g := &generator{
154+
aux: &auxTypes{
155+
iters: map[string]*iterType{},
156+
},
154157
descInfo: pbinfo.Info{
155158
Type: map[string]pbinfo.ProtoType{
156159
msgType.GetName(): msgType,
@@ -204,7 +207,7 @@ func TestIterTypeOf(t *testing.T) {
204207
got, err := g.iterTypeOf(tst.field)
205208
if err != nil {
206209
t.Error(err)
207-
} else if diff := cmp.Diff(tst.want, got, cmp.AllowUnexported(got)); diff != "" {
210+
} else if diff := cmp.Diff(tst.want, *got, cmp.AllowUnexported(*got)); diff != "" {
208211
t.Errorf("%d: (got=-, want=+):\n%s", i, diff)
209212
}
210213
}

0 commit comments

Comments
 (0)