Skip to content

Commit 136dcc1

Browse files
committed
feat: simplify implementation by using runtime reflection
Drastically less generated code required when using the runtime descriptors to bind request fields to flags dynamically. No longer requires the gRPC generated code. As an added bonus, timestamp flags now support CEL expressions when prefixed by `=`, making it easier to express e.g. `=now()-duration("2h")`.
1 parent a8af073 commit 136dcc1

34 files changed

+712
-3100
lines changed

.sage/proto.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"go.einride.tech/sage/sg"
77
"go.einride.tech/sage/sgtool"
88
"go.einride.tech/sage/tools/sgbuf"
9-
"go.einride.tech/sage/tools/sgprotocgengogrpc"
109
)
1110

1211
type Proto sg.Namespace
@@ -57,7 +56,7 @@ func (Proto) ProtocGenGoAIPCLI(ctx context.Context) error {
5756
}
5857

5958
func (Proto) BufGenerateExample(ctx context.Context) error {
60-
sg.Deps(ctx, Proto.ProtocGenGo, sgprotocgengogrpc.PrepareCommand, Proto.ProtocGenGoAIPCLI)
59+
sg.Deps(ctx, Proto.ProtocGenGo, Proto.ProtocGenGoAIPCLI)
6160
sg.Logger(ctx).Println("generating example proto stubs...")
6261
cmd := sgbuf.Command(
6362
ctx,

README.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
AIP CLI
2-
=======
1+
AIP CLI Go
2+
==========
33

4-
Generate command line interfaces to your AIP gRPC services.
4+
Generate command line interfaces to your [AIP](https://aip.dev) gRPC services.
55

66
How to
77
------
@@ -34,11 +34,6 @@ plugins:
3434
out: cmd/examplectl
3535
opt: module=go.einride.tech/aip-cli/cmd/examplectl
3636

37-
# The CLI generator also requires the stubs generated by protoc-gen-go-grpc.
38-
- name: go-grpc
39-
out: cmd/examplectl
40-
opt: module=go.einride.tech/aip-cli/cmd/examplectl
41-
4237
# The CLI generator optionally generates a root command and a main file
4338
# to the root of the output module.
4439
- name: go-aip-cli

aipcli/command.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
package aipcli
2+
3+
import (
4+
"encoding/base64"
5+
"fmt"
6+
"os"
7+
"strconv"
8+
"strings"
9+
"unicode"
10+
"unicode/utf8"
11+
12+
"github.com/spf13/cobra"
13+
"github.com/spf13/pflag"
14+
"github.com/stoewer/go-strcase"
15+
"google.golang.org/protobuf/encoding/protojson"
16+
"google.golang.org/protobuf/proto"
17+
"google.golang.org/protobuf/reflect/protoreflect"
18+
)
19+
20+
// NewServiceCommand initializes a new *cobra.Command for the provided gRPC service.
21+
func NewServiceCommand(
22+
service protoreflect.ServiceDescriptor,
23+
comments map[protoreflect.FullName]string,
24+
) *cobra.Command {
25+
cmd := &cobra.Command{
26+
Use: serviceUse(service),
27+
Short: initialUpperCase(trimComment(comments[service.FullName()])),
28+
Long: comments[service.FullName()],
29+
}
30+
return cmd
31+
}
32+
33+
// NewMethodCommand initializes a new *cobra.Command for the provided gRPC method.
34+
func NewMethodCommand(
35+
method protoreflect.MethodDescriptor,
36+
in proto.Message,
37+
out proto.Message,
38+
comments map[protoreflect.FullName]string,
39+
) *cobra.Command {
40+
cmd := &cobra.Command{
41+
Use: methodUse(method),
42+
Short: initialUpperCase(trimComment(comments[method.FullName()])),
43+
Long: comments[method.FullName()],
44+
}
45+
fromFile := cmd.Flags().StringP("from-file", "f", "", "path to a JSON file containing the request payload")
46+
_ = cmd.MarkFlagFilename("from-file", "json")
47+
setFlags(comments, cmd.Flags(), nil, in.ProtoReflect().Descriptor(), in.ProtoReflect)
48+
cmd.RunE = func(cmd *cobra.Command, args []string) error {
49+
if cmd.Flags().Changed("from-file") {
50+
data, err := os.ReadFile(*fromFile)
51+
if err != nil {
52+
return err
53+
}
54+
if err := protojson.Unmarshal(data, in); err != nil {
55+
return err
56+
}
57+
}
58+
conn, err := Dial(cmd.Context())
59+
if err != nil {
60+
return err
61+
}
62+
LogRequest(cmd.Context(), in)
63+
if err := conn.Invoke(cmd.Context(), methodURI(method), in, out); err != nil {
64+
LogError(cmd.Context(), err)
65+
os.Exit(1)
66+
}
67+
LogResponse(cmd.Context(), out)
68+
return nil
69+
}
70+
return cmd
71+
}
72+
73+
func serviceUse(service protoreflect.ServiceDescriptor) string {
74+
result := string(service.Name())
75+
result = strings.TrimSuffix(result, "Service")
76+
result = strcase.KebabCase(result)
77+
return result
78+
}
79+
80+
func methodUse(method protoreflect.MethodDescriptor) string {
81+
result := string(method.Name())
82+
result = strcase.KebabCase(result)
83+
return result
84+
}
85+
86+
func methodURI(method protoreflect.MethodDescriptor) string {
87+
return "/" +
88+
string(method.Parent().(protoreflect.ServiceDescriptor).FullName()) +
89+
"/" + string(method.Name())
90+
}
91+
92+
func setFlags(
93+
comments map[protoreflect.FullName]string,
94+
flags *pflag.FlagSet,
95+
parentFields []protoreflect.FieldDescriptor,
96+
msg protoreflect.MessageDescriptor,
97+
mutable func() protoreflect.Message,
98+
) {
99+
for i := 0; i < msg.Fields().Len(); i++ {
100+
field := msg.Fields().Get(i)
101+
switch field.Kind() {
102+
case protoreflect.MessageKind:
103+
switch field.Message().FullName() {
104+
case "google.protobuf.Duration":
105+
if field.IsList() {
106+
// TODO: Implement support for repeated durations.
107+
} else {
108+
flags.AddFlag(&pflag.Flag{
109+
Name: flagName(field, parentFields),
110+
Usage: flagUsage(comments[field.FullName()]),
111+
Value: durationValue{mutable: mutable, field: field},
112+
})
113+
}
114+
case "google.protobuf.Timestamp":
115+
if field.IsList() {
116+
// TODO: Implement support for repeated timestamps.
117+
} else {
118+
flags.AddFlag(&pflag.Flag{
119+
Name: flagName(field, parentFields),
120+
Usage: flagUsage(comments[field.FullName()]),
121+
Value: timestampValue{mutable: mutable, field: field},
122+
})
123+
}
124+
case "google.protobuf.FieldMask":
125+
if field.IsList() {
126+
// Repeated field masks is intentionally not supported.
127+
} else {
128+
flags.AddFlag(&pflag.Flag{
129+
Name: flagName(field, parentFields),
130+
Usage: flagUsage(comments[field.FullName()]),
131+
Value: fieldMaskValue{mutable: mutable, field: field},
132+
})
133+
}
134+
default:
135+
if field.Cardinality() != protoreflect.Repeated {
136+
setFlags(
137+
comments,
138+
flags,
139+
append(parentFields, field),
140+
field.Message(),
141+
func() protoreflect.Message {
142+
return mutable().Mutable(field).Message()
143+
},
144+
)
145+
}
146+
}
147+
case protoreflect.StringKind, protoreflect.BoolKind, protoreflect.BytesKind, protoreflect.DoubleKind,
148+
protoreflect.FloatKind, protoreflect.Int64Kind, protoreflect.Int32Kind:
149+
setPrimitiveFlag(comments, flags, parentFields, mutable, field)
150+
}
151+
}
152+
}
153+
154+
func setPrimitiveFlag(
155+
comments map[protoreflect.FullName]string,
156+
flags *pflag.FlagSet,
157+
parentFields []protoreflect.FieldDescriptor,
158+
mutable func() protoreflect.Message,
159+
field protoreflect.FieldDescriptor,
160+
) {
161+
var value pflag.Value
162+
switch field.Kind() {
163+
case protoreflect.BoolKind:
164+
if field.IsList() {
165+
value = newPrimitiveListValue[bool](mutable, field, protoreflect.ValueOfBool, strconv.ParseBool)
166+
} else {
167+
value = newPrimitiveValue[bool](mutable, field, protoreflect.ValueOfBool, strconv.ParseBool)
168+
}
169+
case protoreflect.StringKind:
170+
parser := func(s string) (string, error) {
171+
return s, nil
172+
}
173+
if field.IsList() {
174+
value = newPrimitiveListValue[string](mutable, field, protoreflect.ValueOfString, parser)
175+
} else {
176+
value = newPrimitiveValue[string](mutable, field, protoreflect.ValueOfString, parser)
177+
}
178+
case protoreflect.BytesKind:
179+
value = newPrimitiveValue[[]byte](mutable, field, protoreflect.ValueOfBytes, base64.URLEncoding.DecodeString)
180+
case protoreflect.DoubleKind:
181+
parser := func(s string) (float64, error) {
182+
return strconv.ParseFloat(s, 64)
183+
}
184+
if field.IsList() {
185+
value = newPrimitiveListValue[float64](mutable, field, protoreflect.ValueOfFloat64, parser)
186+
} else {
187+
value = newPrimitiveValue[float64](mutable, field, protoreflect.ValueOfFloat64, parser)
188+
}
189+
case protoreflect.FloatKind:
190+
parser := func(s string) (float32, error) {
191+
d, err := strconv.ParseFloat(s, 32)
192+
if err != nil {
193+
return 0, err
194+
}
195+
return float32(d), nil
196+
}
197+
if field.IsList() {
198+
value = newPrimitiveListValue[float32](mutable, field, protoreflect.ValueOfFloat32, parser)
199+
} else {
200+
value = newPrimitiveValue[float32](mutable, field, protoreflect.ValueOfFloat32, parser)
201+
}
202+
case protoreflect.Int64Kind:
203+
parser := func(s string) (int64, error) {
204+
return strconv.ParseInt(s, 10, 64)
205+
}
206+
if field.IsList() {
207+
value = newPrimitiveListValue[int64](mutable, field, protoreflect.ValueOfInt64, parser)
208+
} else {
209+
value = newPrimitiveValue[int64](mutable, field, protoreflect.ValueOfInt64, parser)
210+
}
211+
case protoreflect.Int32Kind:
212+
parser := func(s string) (int32, error) {
213+
i64, err := strconv.ParseInt(s, 10, 32)
214+
if err != nil {
215+
return 0, err
216+
}
217+
return int32(i64), nil
218+
}
219+
if field.IsList() {
220+
value = newPrimitiveListValue[int32](mutable, field, protoreflect.ValueOfInt32, parser)
221+
} else {
222+
value = newPrimitiveValue[int32](mutable, field, protoreflect.ValueOfInt32, parser)
223+
}
224+
default:
225+
panic(fmt.Errorf("unhandled primitive kind: %v", field.Kind())) // shouldn't happen
226+
}
227+
flags.AddFlag(&pflag.Flag{
228+
Name: flagName(field, parentFields),
229+
Usage: flagUsage(comments[field.FullName()]),
230+
Value: value,
231+
})
232+
}
233+
234+
func trimComment(comment string) string {
235+
result := comment
236+
// Clean up comment line breaks and whitespace.
237+
result = strings.ReplaceAll(result, "//", "")
238+
result = strings.ReplaceAll(result, "\n", " ")
239+
result = strings.TrimSpace(result)
240+
result = strings.ReplaceAll(result, " ", " ")
241+
result = strings.ReplaceAll(result, " ", " ")
242+
// Cut out first sentence.
243+
if i := strings.IndexByte(result, '.'); i != -1 {
244+
result = result[:i]
245+
}
246+
// Trim manually documented field behavior.
247+
result = strings.TrimPrefix(result, "REQUIRED: ")
248+
result = strings.TrimPrefix(result, "Required: ")
249+
result = strings.ToLower(result)
250+
return result
251+
}
252+
253+
func initialUpperCase(s string) string {
254+
r, size := utf8.DecodeRuneInString(s)
255+
if size == utf8.RuneError {
256+
return s
257+
}
258+
return string(unicode.ToUpper(r)) + s[size:]
259+
}

0 commit comments

Comments
 (0)