Skip to content

Commit

Permalink
feat: simplify implementation by using runtime reflection
Browse files Browse the repository at this point in the history
Drastically less generated code required when using the runtime
descriptors to bind request fields to flags dynamically.

No longer requires the gRPC generated code.

As an added bonus, timestamp flags now support CEL expressions when
prefixed by `=`, making it easier to express e.g.
`=now()-duration("2h")`.
  • Loading branch information
odsod committed Jul 27, 2022
1 parent a8af073 commit 136dcc1
Show file tree
Hide file tree
Showing 34 changed files with 712 additions and 3,100 deletions.
3 changes: 1 addition & 2 deletions .sage/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"go.einride.tech/sage/sg"
"go.einride.tech/sage/sgtool"
"go.einride.tech/sage/tools/sgbuf"
"go.einride.tech/sage/tools/sgprotocgengogrpc"
)

type Proto sg.Namespace
Expand Down Expand Up @@ -57,7 +56,7 @@ func (Proto) ProtocGenGoAIPCLI(ctx context.Context) error {
}

func (Proto) BufGenerateExample(ctx context.Context) error {
sg.Deps(ctx, Proto.ProtocGenGo, sgprotocgengogrpc.PrepareCommand, Proto.ProtocGenGoAIPCLI)
sg.Deps(ctx, Proto.ProtocGenGo, Proto.ProtocGenGoAIPCLI)
sg.Logger(ctx).Println("generating example proto stubs...")
cmd := sgbuf.Command(
ctx,
Expand Down
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
AIP CLI
=======
AIP CLI Go
==========

Generate command line interfaces to your AIP gRPC services.
Generate command line interfaces to your [AIP](https://aip.dev) gRPC services.

How to
------
Expand Down Expand Up @@ -34,11 +34,6 @@ plugins:
out: cmd/examplectl
opt: module=go.einride.tech/aip-cli/cmd/examplectl

# The CLI generator also requires the stubs generated by protoc-gen-go-grpc.
- name: go-grpc
out: cmd/examplectl
opt: module=go.einride.tech/aip-cli/cmd/examplectl

# The CLI generator optionally generates a root command and a main file
# to the root of the output module.
- name: go-aip-cli
Expand Down
259 changes: 259 additions & 0 deletions aipcli/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
package aipcli

import (
"encoding/base64"
"fmt"
"os"
"strconv"
"strings"
"unicode"
"unicode/utf8"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// NewServiceCommand initializes a new *cobra.Command for the provided gRPC service.
func NewServiceCommand(
service protoreflect.ServiceDescriptor,
comments map[protoreflect.FullName]string,
) *cobra.Command {
cmd := &cobra.Command{
Use: serviceUse(service),
Short: initialUpperCase(trimComment(comments[service.FullName()])),
Long: comments[service.FullName()],
}
return cmd
}

// NewMethodCommand initializes a new *cobra.Command for the provided gRPC method.
func NewMethodCommand(
method protoreflect.MethodDescriptor,
in proto.Message,
out proto.Message,
comments map[protoreflect.FullName]string,
) *cobra.Command {
cmd := &cobra.Command{
Use: methodUse(method),
Short: initialUpperCase(trimComment(comments[method.FullName()])),
Long: comments[method.FullName()],
}
fromFile := cmd.Flags().StringP("from-file", "f", "", "path to a JSON file containing the request payload")
_ = cmd.MarkFlagFilename("from-file", "json")
setFlags(comments, cmd.Flags(), nil, in.ProtoReflect().Descriptor(), in.ProtoReflect)
cmd.RunE = func(cmd *cobra.Command, args []string) error {
if cmd.Flags().Changed("from-file") {
data, err := os.ReadFile(*fromFile)
if err != nil {
return err
}
if err := protojson.Unmarshal(data, in); err != nil {
return err
}
}
conn, err := Dial(cmd.Context())
if err != nil {
return err
}
LogRequest(cmd.Context(), in)
if err := conn.Invoke(cmd.Context(), methodURI(method), in, out); err != nil {
LogError(cmd.Context(), err)
os.Exit(1)
}
LogResponse(cmd.Context(), out)
return nil
}
return cmd
}

func serviceUse(service protoreflect.ServiceDescriptor) string {
result := string(service.Name())
result = strings.TrimSuffix(result, "Service")
result = strcase.KebabCase(result)
return result
}

func methodUse(method protoreflect.MethodDescriptor) string {
result := string(method.Name())
result = strcase.KebabCase(result)
return result
}

func methodURI(method protoreflect.MethodDescriptor) string {
return "/" +
string(method.Parent().(protoreflect.ServiceDescriptor).FullName()) +
"/" + string(method.Name())
}

func setFlags(
comments map[protoreflect.FullName]string,
flags *pflag.FlagSet,
parentFields []protoreflect.FieldDescriptor,
msg protoreflect.MessageDescriptor,
mutable func() protoreflect.Message,
) {
for i := 0; i < msg.Fields().Len(); i++ {
field := msg.Fields().Get(i)
switch field.Kind() {
case protoreflect.MessageKind:
switch field.Message().FullName() {
case "google.protobuf.Duration":
if field.IsList() {
// TODO: Implement support for repeated durations.
} else {
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: durationValue{mutable: mutable, field: field},
})
}
case "google.protobuf.Timestamp":
if field.IsList() {
// TODO: Implement support for repeated timestamps.
} else {
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: timestampValue{mutable: mutable, field: field},
})
}
case "google.protobuf.FieldMask":
if field.IsList() {
// Repeated field masks is intentionally not supported.
} else {
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: fieldMaskValue{mutable: mutable, field: field},
})
}
default:
if field.Cardinality() != protoreflect.Repeated {
setFlags(
comments,
flags,
append(parentFields, field),
field.Message(),
func() protoreflect.Message {
return mutable().Mutable(field).Message()
},
)
}
}
case protoreflect.StringKind, protoreflect.BoolKind, protoreflect.BytesKind, protoreflect.DoubleKind,
protoreflect.FloatKind, protoreflect.Int64Kind, protoreflect.Int32Kind:
setPrimitiveFlag(comments, flags, parentFields, mutable, field)
}
}
}

func setPrimitiveFlag(
comments map[protoreflect.FullName]string,
flags *pflag.FlagSet,
parentFields []protoreflect.FieldDescriptor,
mutable func() protoreflect.Message,
field protoreflect.FieldDescriptor,
) {
var value pflag.Value
switch field.Kind() {
case protoreflect.BoolKind:
if field.IsList() {
value = newPrimitiveListValue[bool](mutable, field, protoreflect.ValueOfBool, strconv.ParseBool)
} else {
value = newPrimitiveValue[bool](mutable, field, protoreflect.ValueOfBool, strconv.ParseBool)
}
case protoreflect.StringKind:
parser := func(s string) (string, error) {
return s, nil
}
if field.IsList() {
value = newPrimitiveListValue[string](mutable, field, protoreflect.ValueOfString, parser)
} else {
value = newPrimitiveValue[string](mutable, field, protoreflect.ValueOfString, parser)
}
case protoreflect.BytesKind:
value = newPrimitiveValue[[]byte](mutable, field, protoreflect.ValueOfBytes, base64.URLEncoding.DecodeString)
case protoreflect.DoubleKind:
parser := func(s string) (float64, error) {
return strconv.ParseFloat(s, 64)
}
if field.IsList() {
value = newPrimitiveListValue[float64](mutable, field, protoreflect.ValueOfFloat64, parser)
} else {
value = newPrimitiveValue[float64](mutable, field, protoreflect.ValueOfFloat64, parser)
}
case protoreflect.FloatKind:
parser := func(s string) (float32, error) {
d, err := strconv.ParseFloat(s, 32)
if err != nil {
return 0, err
}
return float32(d), nil
}
if field.IsList() {
value = newPrimitiveListValue[float32](mutable, field, protoreflect.ValueOfFloat32, parser)
} else {
value = newPrimitiveValue[float32](mutable, field, protoreflect.ValueOfFloat32, parser)
}
case protoreflect.Int64Kind:
parser := func(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
if field.IsList() {
value = newPrimitiveListValue[int64](mutable, field, protoreflect.ValueOfInt64, parser)
} else {
value = newPrimitiveValue[int64](mutable, field, protoreflect.ValueOfInt64, parser)
}
case protoreflect.Int32Kind:
parser := func(s string) (int32, error) {
i64, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return 0, err
}
return int32(i64), nil
}
if field.IsList() {
value = newPrimitiveListValue[int32](mutable, field, protoreflect.ValueOfInt32, parser)
} else {
value = newPrimitiveValue[int32](mutable, field, protoreflect.ValueOfInt32, parser)
}
default:
panic(fmt.Errorf("unhandled primitive kind: %v", field.Kind())) // shouldn't happen
}
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: value,
})
}

func trimComment(comment string) string {
result := comment
// Clean up comment line breaks and whitespace.
result = strings.ReplaceAll(result, "//", "")
result = strings.ReplaceAll(result, "\n", " ")
result = strings.TrimSpace(result)
result = strings.ReplaceAll(result, " ", " ")
result = strings.ReplaceAll(result, " ", " ")
// Cut out first sentence.
if i := strings.IndexByte(result, '.'); i != -1 {
result = result[:i]
}
// Trim manually documented field behavior.
result = strings.TrimPrefix(result, "REQUIRED: ")
result = strings.TrimPrefix(result, "Required: ")
result = strings.ToLower(result)
return result
}

func initialUpperCase(s string) string {
r, size := utf8.DecodeRuneInString(s)
if size == utf8.RuneError {
return s
}
return string(unicode.ToUpper(r)) + s[size:]
}
Loading

0 comments on commit 136dcc1

Please sign in to comment.