Skip to content

Commit 32c47bb

Browse files
authored
Code generation for protobuf visitors (#119)
1 parent 5c55f00 commit 32c47bb

14 files changed

Lines changed: 3645 additions & 7 deletions

File tree

CONTRIBUTING.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Start s2s-proxy
6565
./bins/s2s-proxy start --config develop/config/cluster-a-tcp-inbound-proxy.yaml
6666
```
6767

68-
Run
68+
Run
6969
```
7070
temporal operator --address 127.0.0.1:6233 cluster describe
7171
```
@@ -86,6 +86,27 @@ Start proxies
8686
./bins/s2s-proxy start --config develop/config/cluster-b-mux-server-proxy.yaml
8787
```
8888

89+
## Code Generation
90+
91+
### gRPC client generation
92+
93+
Run `make generate-rpcwrappers` to re-generate the clients
94+
95+
This uses the `cmd/tools/genrpcwrappers` tool to generates frontend and admin clients.
96+
97+
### Invalid UTF-8 Repair Functions
98+
99+
The proxy supports automatically repairing invalid UTF-8 strings in Temporal protobuf messages.
100+
101+
Background: Invalid UTF-8 strings could appear in protobuf messages in older versions of Temporal (<=1.22) which used [gogo/protobuf](https://github.com/gogo/protobuf). Newer versions of Temporal use `google.golang.org/protobuf` which validates UTF-8 strings during serialization. This means any messages with invalid UTF-8 strings cannot be processed by s2s-proxy or by newer Temporal server versions. To fix this, s2s-proxy automatically repairs invalid UTF-8 strings by rewriting messages as the pass through the proxy. It does this by including a copy of the old gogo-based protos and when an invalid UTF-8 error is seen during protobuf deserialization, it can use the gogo-based protos to deserialize the message and repair the invalid UTF-8 string.
102+
103+
Code generation is used to generate functions to handle all possible cases:
104+
105+
* `make generate-rpcwrappers` to re-generates type conversion functions
106+
* `make genvisitor` re-generates the invalid UTF-8 repair functions
107+
108+
Both of these code generation tools are based on protobuf reflection.
109+
89110
## License
90111

91112
MIT License, please see [LICENSE](LICENSE) for details.

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ lint:
4747
bench:
4848
@go test -run '^$$' -benchmem -bench=. ./... $(BENCH_ARG)
4949

50+
.PHONY: genvisitor
51+
GENVISITOR_FLAGS ?= # -debug -dump-tree
52+
genvisitor:
53+
go run ./cmd/tools/genvisitor/ $(GENVISITOR_FLAGS) > proto/compat/repair_utf8_gen.go
54+
go fmt proto/compat/repair_utf8_gen.go
55+
make fmt
56+
5057
# Mocks
5158
clean-mocks:
5259
@find . -name '*_mock.go' -delete

cmd/tools/genrpcwrappers/extra.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ import (
108108
"github.com/temporalio/s2s-proxy/common"
109109
)
110110
111-
// {{.ServiceNameTitle}}ConvertTo122 accepts a protobuf type and returns
111+
// {{.ServiceName}}ConvertTo122 accepts a protobuf type and returns
112112
// the corresponding gogo-based protobuf type from Temporal v1.22.
113-
func {{.ServiceNameTitle}}ConvertTo122(vAny any) (common.Marshaler, bool) {
113+
func {{.ServiceName}}ConvertTo122(vAny any) (common.Marshaler, bool) {
114114
switch vAny.(type) {
115115
`)
116116

cmd/tools/genvisitor/emitter.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"io"
6+
7+
"go.temporal.io/server/common/log"
8+
"go.temporal.io/server/common/log/tag"
9+
"google.golang.org/protobuf/reflect/protoreflect"
10+
)
11+
12+
const (
13+
// CurrentVersion means to emit code for the current version of protos.
14+
CurrentVersion Mode = iota
15+
16+
// Gogo122Version means to emit code for the older gogo-based protos
17+
// from Temporal v1.22. The way this works is to walk type hierarchies
18+
// for current version of protos, but convert the current protos back
19+
// to the corresponding gogo-based types/packages.
20+
Gogo122Version
21+
)
22+
23+
const (
24+
defaultPackageName = "main_test"
25+
defaultFuncSignature = "func VisitMessage(vAny any)"
26+
)
27+
28+
type (
29+
Emitter struct {
30+
logger log.Logger
31+
mode Mode
32+
packageName string
33+
funcSignature string
34+
funcTrailer string
35+
handlers []*Handler
36+
imports map[string]struct{}
37+
extraImports map[string]struct{}
38+
root *Tree
39+
inScopeVars map[string]struct{}
40+
}
41+
42+
// Handler matches a field in the type hierarchy to a function that generates code.
43+
Handler struct {
44+
// Include returns whether to include this path during code generation.
45+
Include func(VisitType, VisitPath) bool
46+
// Invocation returns a snippet of generated code.
47+
// It is passed a variable name that can be used for code generation.
48+
Invocation func(string) string
49+
}
50+
51+
Mode int
52+
)
53+
54+
func NewEmitter(logger log.Logger, mode Mode) *Emitter {
55+
return &Emitter{
56+
logger: logger,
57+
mode: mode,
58+
packageName: defaultPackageName,
59+
funcSignature: defaultFuncSignature,
60+
imports: make(map[string]struct{}),
61+
extraImports: make(map[string]struct{}),
62+
root: NewTree(),
63+
inScopeVars: map[string]struct{}{},
64+
}
65+
}
66+
67+
func (e *Emitter) SetPackageName(name string) { e.packageName = name }
68+
func (e *Emitter) SetFunctionSignature(sig string) { e.funcSignature = sig }
69+
func (e *Emitter) SetFunctionTrailer(trailer string) { e.funcTrailer = trailer }
70+
71+
func (e *Emitter) AddHandler(include func(vt VisitType, path VisitPath) bool, invocation func(string) string) {
72+
e.handlers = append(e.handlers, &Handler{
73+
Include: include,
74+
Invocation: invocation,
75+
})
76+
}
77+
78+
func (e *Emitter) AddImport(s string) {
79+
e.extraImports[s] = struct{}{}
80+
}
81+
82+
func (e *Emitter) Visit(mt protoreflect.MessageType) {
83+
Visit(mt.Descriptor(), e.visit)
84+
}
85+
86+
func (e *Emitter) visit(obj VisitType, path VisitPath) bool {
87+
if e.mode == Gogo122Version && shouldIgnoreTypeIfDoesntExistIn122(obj.Descriptor) {
88+
return false
89+
}
90+
91+
e.logger.Debug("Emitter.visit",
92+
tag.NewStringTag("obj", string(obj.FullName())),
93+
tag.NewStringTag("path", path.String()),
94+
)
95+
for _, handler := range e.handlers {
96+
if handler.Include(obj, path) {
97+
pathCopy := make(VisitPath, len(path))
98+
copy(pathCopy, path) // path is reused during the visitor / changes as it goes.
99+
e.root.Insert(pathCopy, handler)
100+
e.discoverImports(pathCopy)
101+
}
102+
}
103+
return true
104+
}
105+
106+
func (e *Emitter) discoverImports(path VisitPath) {
107+
for _, obj := range path {
108+
e.imports[obj.GoImportPath()] = struct{}{}
109+
}
110+
}
111+
112+
func (e *Emitter) Generate(out io.Writer) {
113+
e.genPreamble(out)
114+
115+
writef(out, "%s {\n", e.funcSignature)
116+
writeln(out, "switch root := vAny.(type) {")
117+
for _, typ := range e.root.SortedTypes() {
118+
writef(out, "case *%s:\n", typ.GoQualifiedName())
119+
if child := e.root.Children[typ.GoName()]; child != nil {
120+
e.emit(out, "root", child)
121+
}
122+
}
123+
writeln(out, "}")
124+
writeln(out, e.funcTrailer)
125+
writeln(out, "}")
126+
}
127+
128+
func (e *Emitter) genPreamble(out io.Writer) {
129+
writeln(out, `// Code generated by cmd/tools/genvisitor. DO NOT EDIT.`)
130+
writef(out, "package %s\n", e.packageName)
131+
writeln(out, "import (")
132+
133+
for imp := range e.imports {
134+
alias := getImportAlias(imp)
135+
if e.mode == Gogo122Version {
136+
imp = replaceWith122Import(imp)
137+
}
138+
writef(out, "%s \"%s\"\n", alias, imp)
139+
}
140+
141+
for imp := range e.extraImports {
142+
writef(out, "\"%s\"\n", imp)
143+
144+
}
145+
writeln(out, `)`)
146+
}
147+
148+
func (e *Emitter) emit(out io.Writer, parentVar string, node *Tree) {
149+
if node == nil {
150+
return
151+
}
152+
153+
for _, vt := range node.SortedTypes() {
154+
switch desc := vt.Descriptor.(type) {
155+
case protoreflect.FieldDescriptor:
156+
if desc.IsMap() {
157+
varName, freeVar := e.makeVar("val")
158+
defer freeVar()
159+
writef(out, "for _, %s := range %s.%s {\n", varName, parentVar, vt.GoGetter())
160+
e.emit(out, varName, node.Children[vt.GoName()])
161+
writeln(out, "}")
162+
} else if desc.IsList() {
163+
varName, freeVar := e.makeVar("item")
164+
defer freeVar()
165+
writef(out, "for _, %s := range %s.%s {\n", varName, parentVar, vt.GoGetter())
166+
e.emit(out, varName, node.Children[vt.GoName()])
167+
writeln(out, "}")
168+
} else {
169+
varName, freeVar := e.makeVar("y")
170+
defer freeVar()
171+
writef(out, "%s := %s.%s\n", varName, parentVar, vt.GoGetter())
172+
e.emit(out, varName, node.Children[vt.GoName()])
173+
}
174+
case protoreflect.OneofDescriptor:
175+
writef(out, "switch oneof := %s.%s.(type) {\n", parentVar, vt.GoGetter())
176+
e.emitOneOfCases(out, "oneof", vt, node.Children[vt.GoName()])
177+
writeln(out, "}")
178+
default:
179+
e.emit(out, parentVar, node.Children[vt.GoName()])
180+
}
181+
}
182+
183+
for _, handler := range node.Handlers {
184+
writeln(out, handler.Invocation(parentVar))
185+
}
186+
}
187+
188+
func (e *Emitter) emitOneOfCases(out io.Writer, parentVar string, oneof VisitType, node *Tree) {
189+
for _, vt := range node.SortedTypes() {
190+
writef(out, "case *%s.%s:\n", oneof.GoPackageName(), getOneofWrapperType(oneof, vt))
191+
varName, freeVar := e.makeVar("x")
192+
name := vt.GoName()
193+
writef(out, "%s := %s.%s\n", varName, parentVar, name)
194+
e.emit(out, varName, node.Children[vt.GoName()])
195+
freeVar()
196+
}
197+
}
198+
199+
func (e *Emitter) makeVar(name string) (string, func()) {
200+
i := 0
201+
for {
202+
i++
203+
name := fmt.Sprintf("%s%d", name, i)
204+
if _, ok := e.inScopeVars[name]; !ok {
205+
e.inScopeVars[name] = struct{}{}
206+
return name, func() { e.freeVar(name) }
207+
}
208+
}
209+
}
210+
211+
func (e *Emitter) freeVar(name string) {
212+
delete(e.inScopeVars, name)
213+
}
214+
215+
// Return the "wrapper" Golang interface for `oneof` fields.
216+
//
217+
// Protobuf `oneof` fields are generated as an interface:
218+
//
219+
// type ReplicationTask struct {
220+
// Attributes isReplicationTask_Attributes `protobuf_oneof:"attributes"`
221+
// ...
222+
// }
223+
//
224+
// The interface is implemented by "wrapper" types which seemingly do not appear
225+
// in the protobuf reflection registry, so we do not enounter these "wrapper"
226+
// type names while visiting the protobuf type hierachy.
227+
//
228+
// type ReplicationTask_SyncVersionedTransitionTaskAttributes struct {
229+
// SyncVersionedTransitionTaskAttributes *SyncVersionedTransitionTaskAttributes
230+
// }
231+
//
232+
// This returns the implementing type, e.g. "ReplicationTask_SyncVersionedTransitionTaskAttributes",
233+
// given the interface field (e.g. `Attributes`) and the wrapped field (e.g. `SyncVersionedTransitionTaskAttributes`)
234+
func getOneofWrapperType(oneof, typ VisitType) string {
235+
return string(oneof.Parent().Name()) + "_" + snakeToPascalCase(typ.Name())
236+
}

cmd/tools/genvisitor/main.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"os"
7+
"strings"
8+
9+
// Import to populate the protoregistry
10+
_ "go.temporal.io/api/workflowservice/v1"
11+
_ "go.temporal.io/server/api/adminservice/v1"
12+
"go.temporal.io/server/common/log"
13+
"go.temporal.io/server/common/log/tag"
14+
"google.golang.org/protobuf/reflect/protoreflect"
15+
"google.golang.org/protobuf/reflect/protoregistry"
16+
)
17+
18+
func main() {
19+
debugFlag := flag.Bool("debug", false, "enable debug logs to stderr")
20+
dumpTree := flag.Bool("dump-tree", false, "print the tree of matched paths in the type hierarchy to stderr")
21+
flag.Parse()
22+
23+
var logger log.Logger
24+
if *debugFlag {
25+
logger = log.NewCLILogger()
26+
} else {
27+
logger = log.NewNoopLogger()
28+
}
29+
30+
emitter := NewEmitter(logger, Gogo122Version)
31+
emitter.SetPackageName("compat")
32+
emitter.SetFunctionSignature(
33+
`func repairInvalidUTF8(vAny any) (ret bool, retErr error)`,
34+
)
35+
emitter.SetFunctionTrailer("return")
36+
emitter.AddHandler(
37+
// Match any type called "Failure"
38+
func(vt VisitType, path VisitPath) bool {
39+
// Match Failure types
40+
if vt.GoTypeName() != "Failure" {
41+
logger.Debug("ignore non Failure field", tag.NewAnyTag("path", path.String()))
42+
return false
43+
}
44+
// Skip nested "Cause" field in Failure types. The repairInvalidUTF8InFailure handler function
45+
// will descend into these.
46+
if strings.Contains(path.String(), "/Cause") {
47+
logger.Debug("ignore failure Cause", tag.NewAnyTag("path", path.String()))
48+
return false
49+
}
50+
return true
51+
},
52+
// Generate code to handle the Failure field
53+
func(varName string) string {
54+
return fmt.Sprintf(`if changed, err := repairInvalidUTF8InFailure(%s); err != nil || changed {
55+
ret = ret || changed
56+
if err != nil {
57+
retErr = err
58+
}
59+
}`, varName)
60+
},
61+
)
62+
63+
// We traverse the current version of protobuf types (not the gogo-based protos)
64+
// because protoreflect only works with the current version of protobuf types.
65+
// The emitter can translate back to gogo-based types, if it is configured with
66+
// Mode=Gogo122Version.
67+
protoregistry.GlobalTypes.RangeMessages(func(mt protoreflect.MessageType) bool {
68+
emitter.Visit(mt)
69+
return true
70+
})
71+
if *dumpTree {
72+
emitter.root.Dump(os.Stderr)
73+
}
74+
emitter.Generate(os.Stdout)
75+
}

0 commit comments

Comments
 (0)