Skip to content

Commit 54ef1e0

Browse files
committed
feat: use resource reference and field behavior annotations
This was a regression from the previous implementation - brings things up to par.
1 parent 136dcc1 commit 54ef1e0

File tree

5 files changed

+215
-33
lines changed

5 files changed

+215
-33
lines changed

aipcli/command.go

Lines changed: 154 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ import (
1212
"github.com/spf13/cobra"
1313
"github.com/spf13/pflag"
1414
"github.com/stoewer/go-strcase"
15+
"go.einride.tech/aip/reflect/aipreflect"
16+
"google.golang.org/genproto/googleapis/api/annotations"
1517
"google.golang.org/protobuf/encoding/protojson"
1618
"google.golang.org/protobuf/proto"
1719
"google.golang.org/protobuf/reflect/protoreflect"
20+
"google.golang.org/protobuf/reflect/protoregistry"
1821
)
1922

2023
// NewServiceCommand initializes a new *cobra.Command for the provided gRPC service.
@@ -44,7 +47,7 @@ func NewMethodCommand(
4447
}
4548
fromFile := cmd.Flags().StringP("from-file", "f", "", "path to a JSON file containing the request payload")
4649
_ = cmd.MarkFlagFilename("from-file", "json")
47-
setFlags(comments, cmd.Flags(), nil, in.ProtoReflect().Descriptor(), in.ProtoReflect)
50+
setFlags(comments, cmd, nil, in.ProtoReflect().Descriptor(), in.ProtoReflect)
4851
cmd.RunE = func(cmd *cobra.Command, args []string) error {
4952
if cmd.Flags().Changed("from-file") {
5053
data, err := os.ReadFile(*fromFile)
@@ -55,7 +58,7 @@ func NewMethodCommand(
5558
return err
5659
}
5760
}
58-
conn, err := Dial(cmd.Context())
61+
conn, err := dial(cmd.Context())
5962
if err != nil {
6063
return err
6164
}
@@ -91,7 +94,7 @@ func methodURI(method protoreflect.MethodDescriptor) string {
9194

9295
func setFlags(
9396
comments map[protoreflect.FullName]string,
94-
flags *pflag.FlagSet,
97+
cmd *cobra.Command,
9598
parentFields []protoreflect.FieldDescriptor,
9699
msg protoreflect.MessageDescriptor,
97100
mutable func() protoreflect.Message,
@@ -105,37 +108,48 @@ func setFlags(
105108
if field.IsList() {
106109
// TODO: Implement support for repeated durations.
107110
} else {
108-
flags.AddFlag(&pflag.Flag{
109-
Name: flagName(field, parentFields),
110-
Usage: flagUsage(comments[field.FullName()]),
111-
Value: durationValue{mutable: mutable, field: field},
111+
addFlag(cmd, field, parentFields, comments[field.FullName()], durationValue{
112+
mutable: mutable,
113+
field: field,
112114
})
113115
}
114116
case "google.protobuf.Timestamp":
115117
if field.IsList() {
116118
// TODO: Implement support for repeated timestamps.
117119
} else {
118-
flags.AddFlag(&pflag.Flag{
119-
Name: flagName(field, parentFields),
120-
Usage: flagUsage(comments[field.FullName()]),
121-
Value: timestampValue{mutable: mutable, field: field},
120+
addFlag(cmd, field, parentFields, comments[field.FullName()], timestampValue{
121+
mutable: mutable,
122+
field: field,
122123
})
123124
}
124125
case "google.protobuf.FieldMask":
125126
if field.IsList() {
126127
// Repeated field masks is intentionally not supported.
127128
} else {
128-
flags.AddFlag(&pflag.Flag{
129-
Name: flagName(field, parentFields),
130-
Usage: flagUsage(comments[field.FullName()]),
131-
Value: fieldMaskValue{mutable: mutable, field: field},
129+
addFlag(cmd, field, parentFields, comments[field.FullName()], fieldMaskValue{
130+
mutable: mutable,
131+
field: field,
132132
})
133133
}
134134
default:
135-
if field.Cardinality() != protoreflect.Repeated {
135+
switch {
136+
case field.IsMap():
137+
switch {
138+
case field.MapKey().Kind() == protoreflect.StringKind &&
139+
field.MapValue().Kind() == protoreflect.StringKind:
140+
addFlag(cmd, field, parentFields, comments[field.FullName()], mapStringStringValue{
141+
mutable: mutable,
142+
field: field,
143+
})
144+
default:
145+
// TODO: Implement support for more map types.
146+
}
147+
case field.IsList():
148+
// Repeated nested messages not supported.
149+
default:
136150
setFlags(
137151
comments,
138-
flags,
152+
cmd,
139153
append(parentFields, field),
140154
field.Message(),
141155
func() protoreflect.Message {
@@ -144,16 +158,25 @@ func setFlags(
144158
)
145159
}
146160
}
161+
case protoreflect.EnumKind:
162+
if field.IsList() {
163+
// TODO: Implement support for repeated enums.
164+
} else {
165+
addFlag(cmd, field, parentFields, comments[field.FullName()], enumValue{
166+
mutable: mutable,
167+
field: field,
168+
})
169+
}
147170
case protoreflect.StringKind, protoreflect.BoolKind, protoreflect.BytesKind, protoreflect.DoubleKind,
148171
protoreflect.FloatKind, protoreflect.Int64Kind, protoreflect.Int32Kind:
149-
setPrimitiveFlag(comments, flags, parentFields, mutable, field)
172+
setPrimitiveFlag(comments, cmd, parentFields, mutable, field)
150173
}
151174
}
152175
}
153176

154177
func setPrimitiveFlag(
155178
comments map[protoreflect.FullName]string,
156-
flags *pflag.FlagSet,
179+
cmd *cobra.Command,
157180
parentFields []protoreflect.FieldDescriptor,
158181
mutable func() protoreflect.Message,
159182
field protoreflect.FieldDescriptor,
@@ -224,11 +247,120 @@ func setPrimitiveFlag(
224247
default:
225248
panic(fmt.Errorf("unhandled primitive kind: %v", field.Kind())) // shouldn't happen
226249
}
227-
flags.AddFlag(&pflag.Flag{
250+
addFlag(cmd, field, parentFields, comments[field.FullName()], value)
251+
}
252+
253+
func addFlag(
254+
cmd *cobra.Command,
255+
field protoreflect.FieldDescriptor,
256+
parentFields []protoreflect.FieldDescriptor,
257+
comment string,
258+
value pflag.Value,
259+
) {
260+
flag := &pflag.Flag{
228261
Name: flagName(field, parentFields),
229-
Usage: flagUsage(comments[field.FullName()]),
262+
Usage: trimComment(comment),
230263
Value: value,
231-
})
264+
}
265+
cmd.Flags().AddFlag(flag)
266+
maybeMarkHidden(cmd, flag, field)
267+
maybeMarkRequired(cmd, flag, field)
268+
maybeRegisterResourceReferenceCompletionFunction(cmd, flag, field)
269+
maybeRegisterResourceNameCompletionFunction(cmd, flag, field)
270+
}
271+
272+
func maybeMarkHidden(
273+
cmd *cobra.Command,
274+
flag *pflag.Flag,
275+
field protoreflect.FieldDescriptor,
276+
) {
277+
if fieldBehaviors, ok := proto.GetExtension(
278+
field.Options(),
279+
annotations.E_FieldBehavior,
280+
).([]annotations.FieldBehavior); ok {
281+
for _, fieldBehavior := range fieldBehaviors {
282+
if fieldBehavior == annotations.FieldBehavior_OUTPUT_ONLY {
283+
_ = cmd.Flags().MarkHidden(flag.Name)
284+
}
285+
}
286+
}
287+
}
288+
289+
func maybeMarkRequired(
290+
cmd *cobra.Command,
291+
flag *pflag.Flag,
292+
field protoreflect.FieldDescriptor,
293+
) {
294+
if fieldBehaviors, ok := proto.GetExtension(
295+
field.Options(),
296+
annotations.E_FieldBehavior,
297+
).([]annotations.FieldBehavior); ok {
298+
for _, fieldBehavior := range fieldBehaviors {
299+
if fieldBehavior == annotations.FieldBehavior_REQUIRED {
300+
_ = cmd.MarkFlagRequired(flag.Name)
301+
}
302+
}
303+
}
304+
}
305+
306+
func maybeRegisterResourceReferenceCompletionFunction(
307+
cmd *cobra.Command,
308+
flag *pflag.Flag,
309+
field protoreflect.FieldDescriptor,
310+
) {
311+
if field.Kind() == protoreflect.StringKind {
312+
if resourceReference, ok := proto.GetExtension(
313+
field.Options(),
314+
annotations.E_ResourceReference,
315+
).(*annotations.ResourceReference); ok && resourceReference.GetType() != "" {
316+
completionFunc := resourceNameCompletionFunc
317+
if field.IsList() {
318+
completionFunc = resourceNameListCompletionFunc
319+
}
320+
aipreflect.RangeResourceDescriptorsInPackage(
321+
protoregistry.GlobalFiles,
322+
field.ParentFile().Package(),
323+
func(resource *annotations.ResourceDescriptor) bool {
324+
if resource.GetType() == resourceReference.GetType() && len(resource.GetPattern()) > 0 {
325+
_ = cmd.RegisterFlagCompletionFunc(
326+
flag.Name,
327+
completionFunc(resource.GetPattern()...),
328+
)
329+
return false
330+
}
331+
return true
332+
},
333+
)
334+
}
335+
}
336+
}
337+
338+
func maybeRegisterResourceNameCompletionFunction(
339+
cmd *cobra.Command,
340+
flag *pflag.Flag,
341+
field protoreflect.FieldDescriptor,
342+
) {
343+
if !field.IsList() && field.Name() == "name" {
344+
if resourceDescriptor, ok := proto.GetExtension(
345+
field.Parent().Options(),
346+
annotations.E_Resource,
347+
).(*annotations.ResourceDescriptor); ok && resourceDescriptor.GetType() != "" {
348+
aipreflect.RangeResourceDescriptorsInPackage(
349+
protoregistry.GlobalFiles,
350+
field.ParentFile().Package(),
351+
func(resource *annotations.ResourceDescriptor) bool {
352+
if resource.GetType() == resourceDescriptor.GetType() && len(resource.GetPattern()) > 0 {
353+
_ = cmd.RegisterFlagCompletionFunc(
354+
flag.Name,
355+
resourceNameCompletionFunc(resource.GetPattern()...),
356+
)
357+
return false
358+
}
359+
return true
360+
},
361+
)
362+
}
363+
}
232364
}
233365

234366
func trimComment(comment string) string {

aipcli/completion.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,27 @@ import (
1010

1111
type CompletionFunc func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective)
1212

13-
func ResourceNameCompletionFunc(patterns ...string) CompletionFunc {
13+
func resourceNameCompletionFunc(patterns ...string) CompletionFunc {
1414
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
1515
result := make([]string, 0, len(patterns))
1616
for _, pattern := range patterns {
17-
if completion, ok := CompleteResourceName(toComplete, pattern); ok {
17+
result = cobra.AppendActiveHelp(result, fmt.Sprintf("pattern: %s", pattern))
18+
if completion, ok := completeResourceName(toComplete, pattern); ok {
1819
result = append(result, fmt.Sprintf("%s\t%s", completion, pattern))
1920
}
2021
}
2122
return result, cobra.ShellCompDirectiveNoSpace
2223
}
2324
}
2425

25-
func ResourceNameListCompletionFunc(patterns ...string) CompletionFunc {
26+
func resourceNameListCompletionFunc(patterns ...string) CompletionFunc {
2627
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
2728
toCompleteElements := strings.Split(toComplete, ",")
2829
lastToCompleteElement := toCompleteElements[len(toCompleteElements)-1]
2930
result := make([]string, 0, len(patterns))
3031
for _, pattern := range patterns {
31-
if elementCompletion, ok := CompleteResourceName(lastToCompleteElement, pattern); ok {
32+
result = cobra.AppendActiveHelp(result, fmt.Sprintf("pattern: %s", pattern))
33+
if elementCompletion, ok := completeResourceName(lastToCompleteElement, pattern); ok {
3234
var completion string
3335
if len(toCompleteElements) > 1 {
3436
completion = strings.Join(
@@ -45,7 +47,7 @@ func ResourceNameListCompletionFunc(patterns ...string) CompletionFunc {
4547
}
4648
}
4749

48-
func CompleteResourceName(toComplete, pattern string) (string, bool) {
50+
func completeResourceName(toComplete, pattern string) (string, bool) {
4951
toCompleteSegments := strings.Split(toComplete, "/")
5052
patternSegments := strings.Split(pattern, "/")
5153
if len(toCompleteSegments) > len(patternSegments) {

aipcli/completion_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func TestCompleteResourceName(t *testing.T) {
127127
},
128128
} {
129129
t.Run(tt.name, func(t *testing.T) {
130-
actual, ok := CompleteResourceName(tt.toComplete, tt.pattern)
130+
actual, ok := completeResourceName(tt.toComplete, tt.pattern)
131131
assert.Equal(t, tt.ok, ok)
132132
assert.Equal(t, tt.completion, actual)
133133
})

aipcli/dial.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"google.golang.org/grpc/credentials/oauth"
1515
)
1616

17-
func Dial(ctx context.Context) (*grpc.ClientConn, error) {
17+
func dial(ctx context.Context) (*grpc.ClientConn, error) {
1818
config := ConfigFromContext(ctx)
1919
if config.Runtime.Insecure {
2020
return dialInsecure(ctx, config)

aipcli/flags.go

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ func flagName(field protoreflect.FieldDescriptor, parentFields []protoreflect.Fi
2525
return strings.ReplaceAll(result.String(), "_", "-")
2626
}
2727

28-
func flagUsage(comment string) string {
29-
return trimComment(comment)
30-
}
31-
3228
func newPrimitiveValue[T any](
3329
mutable func() protoreflect.Message,
3430
field protoreflect.FieldDescriptor,
@@ -228,3 +224,55 @@ func (v fieldMaskValue) Set(s string) error {
228224
v.mutable().Set(v.field, protoreflect.ValueOf(fieldMask.ProtoReflect()))
229225
return nil
230226
}
227+
228+
type enumValue struct {
229+
mutable func() protoreflect.Message
230+
field protoreflect.FieldDescriptor
231+
}
232+
233+
func (v enumValue) String() string {
234+
return ""
235+
}
236+
237+
func (v enumValue) Type() string {
238+
return "enum[" + string(v.field.Enum().Name()) + "]"
239+
}
240+
241+
func (v enumValue) Set(s string) error {
242+
value := v.field.Enum().Values().ByName(protoreflect.Name(s))
243+
if value == nil {
244+
return fmt.Errorf("no such value for %v: %v", v.field.Enum().Name(), s)
245+
}
246+
v.mutable().Set(v.field, protoreflect.ValueOfEnum(value.Number()))
247+
return nil
248+
}
249+
250+
type mapStringStringValue struct {
251+
mutable func() protoreflect.Message
252+
field protoreflect.FieldDescriptor
253+
}
254+
255+
func (v mapStringStringValue) String() string {
256+
return ""
257+
}
258+
259+
func (v mapStringStringValue) Type() string {
260+
return "map<string, string>"
261+
}
262+
263+
func (v mapStringStringValue) Set(s string) error {
264+
pairs := strings.Split(s, ",")
265+
if len(pairs) == 0 {
266+
return nil
267+
}
268+
value := v.mutable().NewField(v.field)
269+
for _, pair := range pairs {
270+
keyValue := strings.SplitN(pair, "=", 2)
271+
if len(keyValue) != 2 {
272+
return fmt.Errorf("invalid map pair: %s", pair)
273+
}
274+
value.Map().Set(protoreflect.ValueOfString(keyValue[0]).MapKey(), protoreflect.ValueOfString(keyValue[1]))
275+
}
276+
v.mutable().Set(v.field, value)
277+
return nil
278+
}

0 commit comments

Comments
 (0)