Skip to content

Commit 6651aa2

Browse files
authored
Fix: missing external enums (#14)
* build external enums * bump version * upgrade github action setup-protoc * remove some invalid tests * refactor
1 parent 1130134 commit 6651aa2

File tree

11 files changed

+220
-1456
lines changed

11 files changed

+220
-1456
lines changed

.github/workflows/go.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
go-version: '^1.17'
1818
- uses: arduino/setup-protoc@v1
1919
with:
20-
version: '3.17.3'
20+
version: '3.19.1'
2121
- name: Run go test
2222
run: go test -v ./...
2323
- name: Install protoc-gen-pubsub-schema

content_builder.go

Lines changed: 12 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (b *contentBuilder) build(protoFile *descriptorpb.FileDescriptorProto) (str
2929
compVersion := b.request.GetCompilerVersion()
3030
fmt.Fprintf(b.output, "// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.\n")
3131
fmt.Fprintf(b.output, "// versions:\n")
32-
fmt.Fprintf(b.output, "// protoc-gen-pubsub-schema v1.4.3\n")
32+
fmt.Fprintf(b.output, "// protoc-gen-pubsub-schema v1.4.4\n")
3333
fmt.Fprintf(b.output, "// protoc v%d.%d.%d%s\n", compVersion.GetMajor(), compVersion.GetMinor(), compVersion.GetPatch(), compVersion.GetSuffix())
3434
fmt.Fprintf(b.output, "// source: %s\n\n", protoFile.GetName())
3535
fmt.Fprintf(b.output, "syntax = \"%s\";\n", b.schemaSyntax)
@@ -39,110 +39,33 @@ func (b *contentBuilder) build(protoFile *descriptorpb.FileDescriptorProto) (str
3939
}
4040

4141
func (b *contentBuilder) buildMessages(messages []*descriptorpb.DescriptorProto, level int) {
42+
built := make(map[*descriptorpb.DescriptorProto]bool)
4243
for _, message := range messages {
43-
fmt.Fprintln(b.output)
44-
b.buildMessage(message, level)
45-
}
46-
}
47-
48-
func (b *contentBuilder) buildMessage(message *descriptorpb.DescriptorProto, level int) {
49-
fmt.Fprintf(b.output, "%smessage %s {\n", buildIndent(level), message.GetName())
50-
b.buildFields(message.GetField(), level+1)
51-
b.buildMessages(message.GetNestedType(), level+1)
52-
b.buildEnums(message.GetEnumType(), level+1)
53-
b.buildOtherTypes(message, level+1)
54-
fmt.Fprintf(b.output, "%s}\n", buildIndent(level))
55-
}
56-
57-
func (b *contentBuilder) buildFields(fields []*descriptorpb.FieldDescriptorProto, level int) {
58-
for _, field := range fields {
59-
fmt.Fprint(b.output, buildIndent(level))
60-
label := field.GetLabel()
61-
if b.schemaSyntax == "proto2" || label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
62-
fmt.Fprintf(b.output, "%s ", strings.ToLower(strings.TrimPrefix(label.String(), "LABEL_")))
63-
}
64-
fmt.Fprintf(b.output, "%s %s = %d;\n", b.getFieldType(field), field.GetName(), field.GetNumber())
65-
}
66-
}
67-
68-
func (b *contentBuilder) getFieldType(field *descriptorpb.FieldDescriptorProto) string {
69-
typeName := field.GetTypeName()
70-
switch field.GetType() {
71-
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
72-
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
73-
return wktMapping[typeName]
74-
}
75-
if b.isNestedType(typeName) {
76-
return shortName(typeName)
44+
if built[message] {
45+
continue
7746
}
78-
return pascalCase(typeName)
79-
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
80-
return shortName(typeName)
81-
default:
82-
return strings.ToLower(strings.TrimPrefix(field.GetType().String(), "TYPE_"))
47+
fmt.Fprintln(b.output)
48+
newMessageBuilder(b, message, level).build()
49+
built[message] = true
8350
}
8451
}
8552

8653
func (b *contentBuilder) buildEnums(enums []*descriptorpb.EnumDescriptorProto, level int) {
54+
built := make(map[*descriptorpb.EnumDescriptorProto]bool)
8755
for _, enum := range enums {
56+
if built[enum] {
57+
continue
58+
}
8859
fmt.Fprintln(b.output)
8960
fmt.Fprintf(b.output, "%senum %s {\n", buildIndent(level), enum.GetName())
9061
for _, value := range enum.GetValue() {
9162
fmt.Fprintf(b.output, "%s%s = %d;\n", buildIndent(level+1), value.GetName(), value.GetNumber())
9263
}
9364
fmt.Fprintf(b.output, "%s}\n", buildIndent(level))
65+
built[enum] = true
9466
}
9567
}
9668

97-
func (b *contentBuilder) buildOtherTypes(message *descriptorpb.DescriptorProto, level int) {
98-
built := make(map[string]bool)
99-
for _, field := range message.GetField() {
100-
typeName := field.GetTypeName()
101-
if field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
102-
continue
103-
}
104-
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
105-
continue
106-
}
107-
if b.isNestedType(typeName) {
108-
continue
109-
}
110-
if built[typeName] {
111-
continue
112-
}
113-
b.buildOtherType(typeName, level)
114-
built[typeName] = true
115-
}
116-
}
117-
118-
func (b *contentBuilder) buildOtherType(typeName string, level int) {
119-
message := b.messageTypes[typeName]
120-
defer func(name *string) { message.Name = name }(message.Name)
121-
*message.Name = pascalCase(typeName)
122-
fmt.Fprintln(b.output)
123-
b.buildMessage(message, level)
124-
}
125-
126-
func (b *contentBuilder) isNestedType(name string) bool {
127-
return b.messageTypes[name[:strings.LastIndexByte(name, '.')]] != nil
128-
}
129-
13069
func buildIndent(level int) string {
13170
return strings.Repeat(" ", level)
13271
}
133-
134-
func shortName(name string) string {
135-
return name[strings.LastIndexByte(name, '.')+1:]
136-
}
137-
138-
func pascalCase(name string) string {
139-
sb := new(strings.Builder)
140-
for i, c := range name {
141-
if i > 0 && name[i-1] == '.' {
142-
sb.WriteString(strings.ToUpper(string(c)))
143-
} else if c != '.' {
144-
sb.WriteRune(c)
145-
}
146-
}
147-
return sb.String()
148-
}

example/common/role.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
package example.common;
4+
5+
enum Role {
6+
OWNER = 0;
7+
EDITOR = 1;
8+
VIEWER = 2;
9+
}

example/user_add_comment.pps

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Code generated by protoc-gen-pubsub-schema. DO NOT EDIT.
22
// versions:
3-
// protoc-gen-pubsub-schema v1.4.3
4-
// protoc v3.17.3
3+
// protoc-gen-pubsub-schema v1.4.4
4+
// protoc v3.19.1
55
// source: example/user_add_comment.proto
66

77
syntax = "proto2";
@@ -15,7 +15,7 @@ message UserAddComment {
1515
message User {
1616
required string first_name = 1;
1717
optional string last_name = 2;
18-
required Role role = 3;
18+
required ExampleCommonRole role = 3;
1919
optional bytes avatar = 4;
2020
optional Location location = 5;
2121
optional GoogleProtobufTimestamp created_at = 6;
@@ -26,16 +26,16 @@ message UserAddComment {
2626
required double latitude = 2;
2727
}
2828

29-
enum Role {
30-
OWNER = 1;
31-
EDITOR = 2;
32-
VIEWER = 3;
33-
}
34-
3529
message GoogleProtobufTimestamp {
3630
optional int64 seconds = 1;
3731
optional int32 nanos = 2;
3832
}
33+
34+
enum ExampleCommonRole {
35+
OWNER = 0;
36+
EDITOR = 1;
37+
VIEWER = 2;
38+
}
3939
}
4040

4141
message ExampleCommonLabel {

example/user_add_comment.proto

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ syntax = "proto2";
33
package example;
44

55
import "example/common/label.proto";
6+
import "example/common/role.proto";
67
import "google/protobuf/timestamp.proto";
78

89
message UserAddComment {
@@ -14,7 +15,7 @@ message UserAddComment {
1415
message User {
1516
required string first_name = 1;
1617
optional string last_name = 2;
17-
required Role role = 3;
18+
required example.common.Role role = 3;
1819
optional bytes avatar = 4;
1920
optional Location location = 5;
2021
optional google.protobuf.Timestamp created_at = 6;
@@ -24,11 +25,5 @@ message UserAddComment {
2425
required double longitude = 1;
2526
required double latitude = 2;
2627
}
27-
28-
enum Role {
29-
OWNER = 1;
30-
EDITOR = 2;
31-
VIEWER = 3;
32-
}
3328
}
3429
}

message_builder.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"google.golang.org/protobuf/types/descriptorpb"
8+
)
9+
10+
type messageBuilder struct {
11+
*contentBuilder
12+
message *descriptorpb.DescriptorProto
13+
level int
14+
externalMessages []*descriptorpb.DescriptorProto
15+
externalEnums []*descriptorpb.EnumDescriptorProto
16+
}
17+
18+
func newMessageBuilder(b *contentBuilder, message *descriptorpb.DescriptorProto, level int) *messageBuilder {
19+
return &messageBuilder{b, message, level, nil, nil}
20+
}
21+
22+
func (b *messageBuilder) build() {
23+
fmt.Fprintf(b.output, "%smessage %s {\n", buildIndent(b.level), b.message.GetName())
24+
b.buildFields()
25+
b.buildMessages(b.message.GetNestedType(), b.level+1)
26+
b.buildEnums(b.message.GetEnumType(), b.level+1)
27+
fmt.Fprintf(b.output, "%s}\n", buildIndent(b.level))
28+
}
29+
30+
func (b *messageBuilder) buildFields() {
31+
for _, field := range b.message.GetField() {
32+
fmt.Fprint(b.output, buildIndent(b.level+1))
33+
label := field.GetLabel()
34+
if b.schemaSyntax == "proto2" || label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
35+
fmt.Fprintf(b.output, "%s ", strings.ToLower(strings.TrimPrefix(label.String(), "LABEL_")))
36+
}
37+
fmt.Fprintf(b.output, "%s %s = %d;\n", b.buildFieldType(field), field.GetName(), field.GetNumber())
38+
}
39+
}
40+
41+
func (b *messageBuilder) buildFieldType(field *descriptorpb.FieldDescriptorProto) string {
42+
typeName := field.GetTypeName()
43+
if b.isNestedType(field) {
44+
return getChildName(typeName)
45+
}
46+
switch field.GetType() {
47+
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
48+
if b.messageEncoding == "json" && wktMapping[typeName] != "" {
49+
return wktMapping[typeName]
50+
}
51+
internalName := pascalCase(typeName)
52+
internalMessage := b.messageTypes[field.GetTypeName()]
53+
internalMessage.Name = &internalName
54+
b.message.NestedType = append(b.message.NestedType, internalMessage)
55+
return internalName
56+
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
57+
internalName := pascalCase(typeName)
58+
internalEnum := b.enums[field.GetTypeName()]
59+
internalEnum.Name = &internalName
60+
b.message.EnumType = append(b.message.EnumType, internalEnum)
61+
return internalName
62+
default:
63+
return strings.ToLower(strings.TrimPrefix(field.GetType().String(), "TYPE_"))
64+
}
65+
}
66+
67+
func (b *messageBuilder) isNestedType(field *descriptorpb.FieldDescriptorProto) bool {
68+
return b.messageTypes[getParentName(field.GetTypeName())] == b.message
69+
}
70+
71+
func getParentName(name string) string {
72+
lastDotIndex := strings.LastIndexByte(name, '.')
73+
if lastDotIndex == -1 {
74+
return name
75+
}
76+
return name[:lastDotIndex]
77+
}
78+
79+
func getChildName(name string) string {
80+
return name[strings.LastIndexByte(name, '.')+1:]
81+
}
82+
83+
func pascalCase(name string) string {
84+
sb := new(strings.Builder)
85+
for i, c := range name {
86+
if i > 0 && name[i-1] == '.' {
87+
sb.WriteString(strings.ToUpper(string(c)))
88+
} else if c != '.' {
89+
sb.WriteRune(c)
90+
}
91+
}
92+
return sb.String()
93+
}

message_builder_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func Test_getParentName(t *testing.T) {
8+
type args struct {
9+
name string
10+
}
11+
tests := []struct {
12+
name string
13+
args args
14+
want string
15+
}{
16+
{
17+
name: "empty name",
18+
args: args{""},
19+
want: "",
20+
},
21+
{
22+
name: "normal name",
23+
args: args{".example.UserAddComment.User"},
24+
want: ".example.UserAddComment",
25+
},
26+
}
27+
for _, tt := range tests {
28+
t.Run(tt.name, func(t *testing.T) {
29+
if got := getParentName(tt.args.name); got != tt.want {
30+
t.Errorf("getParentName() = %v, want %v", got, tt.want)
31+
}
32+
})
33+
}
34+
}
35+
36+
func Test_getChildName(t *testing.T) {
37+
type args struct {
38+
name string
39+
}
40+
tests := []struct {
41+
name string
42+
args args
43+
want string
44+
}{
45+
{
46+
name: "empty name",
47+
args: args{""},
48+
want: "",
49+
},
50+
{
51+
name: "normal name",
52+
args: args{".example.UserAddComment.User"},
53+
want: "User",
54+
},
55+
}
56+
for _, tt := range tests {
57+
t.Run(tt.name, func(t *testing.T) {
58+
if got := getChildName(tt.args.name); got != tt.want {
59+
t.Errorf("getChildName() = %v, want %v", got, tt.want)
60+
}
61+
})
62+
}
63+
}

0 commit comments

Comments
 (0)