Skip to content

Commit

Permalink
feat: use resource reference and field behavior annotations
Browse files Browse the repository at this point in the history
This was a regression from the previous implementation - brings things
up to par.
  • Loading branch information
odsod committed Jul 27, 2022
1 parent 136dcc1 commit 54ef1e0
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 33 deletions.
176 changes: 154 additions & 22 deletions aipcli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stoewer/go-strcase"
"go.einride.tech/aip/reflect/aipreflect"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)

// NewServiceCommand initializes a new *cobra.Command for the provided gRPC service.
Expand Down Expand Up @@ -44,7 +47,7 @@ func NewMethodCommand(
}
fromFile := cmd.Flags().StringP("from-file", "f", "", "path to a JSON file containing the request payload")
_ = cmd.MarkFlagFilename("from-file", "json")
setFlags(comments, cmd.Flags(), nil, in.ProtoReflect().Descriptor(), in.ProtoReflect)
setFlags(comments, cmd, nil, in.ProtoReflect().Descriptor(), in.ProtoReflect)
cmd.RunE = func(cmd *cobra.Command, args []string) error {
if cmd.Flags().Changed("from-file") {
data, err := os.ReadFile(*fromFile)
Expand All @@ -55,7 +58,7 @@ func NewMethodCommand(
return err
}
}
conn, err := Dial(cmd.Context())
conn, err := dial(cmd.Context())
if err != nil {
return err
}
Expand Down Expand Up @@ -91,7 +94,7 @@ func methodURI(method protoreflect.MethodDescriptor) string {

func setFlags(
comments map[protoreflect.FullName]string,
flags *pflag.FlagSet,
cmd *cobra.Command,
parentFields []protoreflect.FieldDescriptor,
msg protoreflect.MessageDescriptor,
mutable func() protoreflect.Message,
Expand All @@ -105,37 +108,48 @@ func setFlags(
if field.IsList() {
// TODO: Implement support for repeated durations.
} else {
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: durationValue{mutable: mutable, field: field},
addFlag(cmd, field, parentFields, comments[field.FullName()], durationValue{
mutable: mutable,
field: field,
})
}
case "google.protobuf.Timestamp":
if field.IsList() {
// TODO: Implement support for repeated timestamps.
} else {
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: timestampValue{mutable: mutable, field: field},
addFlag(cmd, field, parentFields, comments[field.FullName()], timestampValue{
mutable: mutable,
field: field,
})
}
case "google.protobuf.FieldMask":
if field.IsList() {
// Repeated field masks is intentionally not supported.
} else {
flags.AddFlag(&pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Value: fieldMaskValue{mutable: mutable, field: field},
addFlag(cmd, field, parentFields, comments[field.FullName()], fieldMaskValue{
mutable: mutable,
field: field,
})
}
default:
if field.Cardinality() != protoreflect.Repeated {
switch {
case field.IsMap():
switch {
case field.MapKey().Kind() == protoreflect.StringKind &&
field.MapValue().Kind() == protoreflect.StringKind:
addFlag(cmd, field, parentFields, comments[field.FullName()], mapStringStringValue{
mutable: mutable,
field: field,
})
default:
// TODO: Implement support for more map types.
}
case field.IsList():
// Repeated nested messages not supported.
default:
setFlags(
comments,
flags,
cmd,
append(parentFields, field),
field.Message(),
func() protoreflect.Message {
Expand All @@ -144,16 +158,25 @@ func setFlags(
)
}
}
case protoreflect.EnumKind:
if field.IsList() {
// TODO: Implement support for repeated enums.
} else {
addFlag(cmd, field, parentFields, comments[field.FullName()], enumValue{
mutable: mutable,
field: field,
})
}
case protoreflect.StringKind, protoreflect.BoolKind, protoreflect.BytesKind, protoreflect.DoubleKind,
protoreflect.FloatKind, protoreflect.Int64Kind, protoreflect.Int32Kind:
setPrimitiveFlag(comments, flags, parentFields, mutable, field)
setPrimitiveFlag(comments, cmd, parentFields, mutable, field)
}
}
}

func setPrimitiveFlag(
comments map[protoreflect.FullName]string,
flags *pflag.FlagSet,
cmd *cobra.Command,
parentFields []protoreflect.FieldDescriptor,
mutable func() protoreflect.Message,
field protoreflect.FieldDescriptor,
Expand Down Expand Up @@ -224,11 +247,120 @@ func setPrimitiveFlag(
default:
panic(fmt.Errorf("unhandled primitive kind: %v", field.Kind())) // shouldn't happen
}
flags.AddFlag(&pflag.Flag{
addFlag(cmd, field, parentFields, comments[field.FullName()], value)
}

func addFlag(
cmd *cobra.Command,
field protoreflect.FieldDescriptor,
parentFields []protoreflect.FieldDescriptor,
comment string,
value pflag.Value,
) {
flag := &pflag.Flag{
Name: flagName(field, parentFields),
Usage: flagUsage(comments[field.FullName()]),
Usage: trimComment(comment),
Value: value,
})
}
cmd.Flags().AddFlag(flag)
maybeMarkHidden(cmd, flag, field)
maybeMarkRequired(cmd, flag, field)
maybeRegisterResourceReferenceCompletionFunction(cmd, flag, field)
maybeRegisterResourceNameCompletionFunction(cmd, flag, field)
}

func maybeMarkHidden(
cmd *cobra.Command,
flag *pflag.Flag,
field protoreflect.FieldDescriptor,
) {
if fieldBehaviors, ok := proto.GetExtension(
field.Options(),
annotations.E_FieldBehavior,
).([]annotations.FieldBehavior); ok {
for _, fieldBehavior := range fieldBehaviors {
if fieldBehavior == annotations.FieldBehavior_OUTPUT_ONLY {
_ = cmd.Flags().MarkHidden(flag.Name)
}
}
}
}

func maybeMarkRequired(
cmd *cobra.Command,
flag *pflag.Flag,
field protoreflect.FieldDescriptor,
) {
if fieldBehaviors, ok := proto.GetExtension(
field.Options(),
annotations.E_FieldBehavior,
).([]annotations.FieldBehavior); ok {
for _, fieldBehavior := range fieldBehaviors {
if fieldBehavior == annotations.FieldBehavior_REQUIRED {
_ = cmd.MarkFlagRequired(flag.Name)
}
}
}
}

func maybeRegisterResourceReferenceCompletionFunction(
cmd *cobra.Command,
flag *pflag.Flag,
field protoreflect.FieldDescriptor,
) {
if field.Kind() == protoreflect.StringKind {
if resourceReference, ok := proto.GetExtension(
field.Options(),
annotations.E_ResourceReference,
).(*annotations.ResourceReference); ok && resourceReference.GetType() != "" {
completionFunc := resourceNameCompletionFunc
if field.IsList() {
completionFunc = resourceNameListCompletionFunc
}
aipreflect.RangeResourceDescriptorsInPackage(
protoregistry.GlobalFiles,
field.ParentFile().Package(),
func(resource *annotations.ResourceDescriptor) bool {
if resource.GetType() == resourceReference.GetType() && len(resource.GetPattern()) > 0 {
_ = cmd.RegisterFlagCompletionFunc(
flag.Name,
completionFunc(resource.GetPattern()...),
)
return false
}
return true
},
)
}
}
}

func maybeRegisterResourceNameCompletionFunction(
cmd *cobra.Command,
flag *pflag.Flag,
field protoreflect.FieldDescriptor,
) {
if !field.IsList() && field.Name() == "name" {
if resourceDescriptor, ok := proto.GetExtension(
field.Parent().Options(),
annotations.E_Resource,
).(*annotations.ResourceDescriptor); ok && resourceDescriptor.GetType() != "" {
aipreflect.RangeResourceDescriptorsInPackage(
protoregistry.GlobalFiles,
field.ParentFile().Package(),
func(resource *annotations.ResourceDescriptor) bool {
if resource.GetType() == resourceDescriptor.GetType() && len(resource.GetPattern()) > 0 {
_ = cmd.RegisterFlagCompletionFunc(
flag.Name,
resourceNameCompletionFunc(resource.GetPattern()...),
)
return false
}
return true
},
)
}
}
}

func trimComment(comment string) string {
Expand Down
12 changes: 7 additions & 5 deletions aipcli/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,27 @@ import (

type CompletionFunc func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective)

func ResourceNameCompletionFunc(patterns ...string) CompletionFunc {
func resourceNameCompletionFunc(patterns ...string) CompletionFunc {
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
result := make([]string, 0, len(patterns))
for _, pattern := range patterns {
if completion, ok := CompleteResourceName(toComplete, pattern); ok {
result = cobra.AppendActiveHelp(result, fmt.Sprintf("pattern: %s", pattern))
if completion, ok := completeResourceName(toComplete, pattern); ok {
result = append(result, fmt.Sprintf("%s\t%s", completion, pattern))
}
}
return result, cobra.ShellCompDirectiveNoSpace
}
}

func ResourceNameListCompletionFunc(patterns ...string) CompletionFunc {
func resourceNameListCompletionFunc(patterns ...string) CompletionFunc {
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
toCompleteElements := strings.Split(toComplete, ",")
lastToCompleteElement := toCompleteElements[len(toCompleteElements)-1]
result := make([]string, 0, len(patterns))
for _, pattern := range patterns {
if elementCompletion, ok := CompleteResourceName(lastToCompleteElement, pattern); ok {
result = cobra.AppendActiveHelp(result, fmt.Sprintf("pattern: %s", pattern))
if elementCompletion, ok := completeResourceName(lastToCompleteElement, pattern); ok {
var completion string
if len(toCompleteElements) > 1 {
completion = strings.Join(
Expand All @@ -45,7 +47,7 @@ func ResourceNameListCompletionFunc(patterns ...string) CompletionFunc {
}
}

func CompleteResourceName(toComplete, pattern string) (string, bool) {
func completeResourceName(toComplete, pattern string) (string, bool) {
toCompleteSegments := strings.Split(toComplete, "/")
patternSegments := strings.Split(pattern, "/")
if len(toCompleteSegments) > len(patternSegments) {
Expand Down
2 changes: 1 addition & 1 deletion aipcli/completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestCompleteResourceName(t *testing.T) {
},
} {
t.Run(tt.name, func(t *testing.T) {
actual, ok := CompleteResourceName(tt.toComplete, tt.pattern)
actual, ok := completeResourceName(tt.toComplete, tt.pattern)
assert.Equal(t, tt.ok, ok)
assert.Equal(t, tt.completion, actual)
})
Expand Down
2 changes: 1 addition & 1 deletion aipcli/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"google.golang.org/grpc/credentials/oauth"
)

func Dial(ctx context.Context) (*grpc.ClientConn, error) {
func dial(ctx context.Context) (*grpc.ClientConn, error) {
config := ConfigFromContext(ctx)
if config.Runtime.Insecure {
return dialInsecure(ctx, config)
Expand Down
56 changes: 52 additions & 4 deletions aipcli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ func flagName(field protoreflect.FieldDescriptor, parentFields []protoreflect.Fi
return strings.ReplaceAll(result.String(), "_", "-")
}

func flagUsage(comment string) string {
return trimComment(comment)
}

func newPrimitiveValue[T any](
mutable func() protoreflect.Message,
field protoreflect.FieldDescriptor,
Expand Down Expand Up @@ -228,3 +224,55 @@ func (v fieldMaskValue) Set(s string) error {
v.mutable().Set(v.field, protoreflect.ValueOf(fieldMask.ProtoReflect()))
return nil
}

type enumValue struct {
mutable func() protoreflect.Message
field protoreflect.FieldDescriptor
}

func (v enumValue) String() string {
return ""
}

func (v enumValue) Type() string {
return "enum[" + string(v.field.Enum().Name()) + "]"
}

func (v enumValue) Set(s string) error {
value := v.field.Enum().Values().ByName(protoreflect.Name(s))
if value == nil {
return fmt.Errorf("no such value for %v: %v", v.field.Enum().Name(), s)
}
v.mutable().Set(v.field, protoreflect.ValueOfEnum(value.Number()))
return nil
}

type mapStringStringValue struct {
mutable func() protoreflect.Message
field protoreflect.FieldDescriptor
}

func (v mapStringStringValue) String() string {
return ""
}

func (v mapStringStringValue) Type() string {
return "map<string, string>"
}

func (v mapStringStringValue) Set(s string) error {
pairs := strings.Split(s, ",")
if len(pairs) == 0 {
return nil
}
value := v.mutable().NewField(v.field)
for _, pair := range pairs {
keyValue := strings.SplitN(pair, "=", 2)
if len(keyValue) != 2 {
return fmt.Errorf("invalid map pair: %s", pair)
}
value.Map().Set(protoreflect.ValueOfString(keyValue[0]).MapKey(), protoreflect.ValueOfString(keyValue[1]))
}
v.mutable().Set(v.field, value)
return nil
}

0 comments on commit 54ef1e0

Please sign in to comment.