Skip to content

Commit 35df6bd

Browse files
committed
Handle multiple action candidates
1 parent 4c2908d commit 35df6bd

1 file changed

Lines changed: 110 additions & 135 deletions

File tree

proxy.go

Lines changed: 110 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ type ServiceHttp struct {
188188
}
189189

190190
type ServiceStructure struct {
191+
Required []string `json:"required"`
191192
Shape string `json:"shape"`
192193
Type string `json:"type"`
193194
Member *ServiceStructure `json:"member"`
@@ -268,6 +269,14 @@ func flatten(top bool, flatMap map[string][]string, nested interface{}, prefix s
268269
return nil
269270
}
270271

272+
type ActionCandidate struct {
273+
Path string
274+
Action string
275+
URIParams map[string]string
276+
Params map[string][]string
277+
Operation ServiceOperation
278+
}
279+
271280
func handleAWSRequest(req *http.Request, body []byte, respCode int) {
272281
host := req.Host
273282
uri := req.RequestURI
@@ -297,107 +306,7 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
297306
params := make(map[string][]string)
298307
action := "*"
299308

300-
if serviceDef.Metadata.Protocol == "rest-json" {
301-
// URL param schema
302-
urlobj, err := url.ParseRequestURI(uri)
303-
if err != nil {
304-
return
305-
}
306-
vals := urlobj.Query()
307-
308-
// path part
309-
longestPath := ""
310-
311-
OperationLoop:
312-
for operationName, operation := range serviceDef.Operations {
313-
path := urlobj.Path
314-
if operation.Http.RequestURI == "" || operation.Http.RequestURI[0] != '/' {
315-
operation.Http.RequestURI = "/" + operation.Http.RequestURI
316-
}
317-
318-
if strings.Contains(operation.Http.RequestURI, "?") {
319-
path += "?"
320-
321-
operationurlobj, err := url.ParseRequestURI(operation.Http.RequestURI)
322-
if err != nil {
323-
continue
324-
}
325-
326-
operationquery := operationurlobj.Query()
327-
for operationquerykey, operationqueryvalue := range operationquery {
328-
if _, ok := vals[operationquerykey]; ok {
329-
if operationqueryvalue[0] == "" {
330-
path += operationquerykey + "&"
331-
} else if len(vals[operationquerykey]) > 0 {
332-
path += operationquerykey + "=" + vals[operationquerykey][0] + "&"
333-
} else {
334-
continue OperationLoop
335-
}
336-
} else {
337-
continue OperationLoop
338-
}
339-
}
340-
341-
if path[len(path)-1] == '&' {
342-
path = path[:len(path)-1]
343-
}
344-
}
345-
346-
templateMatches := regexp.MustCompile(`{([^}]+?)\+?}`).FindAllStringSubmatch(operation.Http.RequestURI, -1)
347-
regexStr := regexp.MustCompile(`\\{([^}]+?\\\+)\\}`).ReplaceAllString(regexp.QuoteMeta(operation.Http.RequestURI), `([^?]+)`) // {Key+}
348-
regexStr = fmt.Sprintf("^%s$", regexp.MustCompile(`\\{(.+?)\\}`).ReplaceAllString(regexStr, `([^/?]+?)`)) // {Bucket}
349-
pathMatchSuccess := regexp.MustCompile(regexStr).Match([]byte(path))
350-
351-
if operation.Http.Method == "" {
352-
operation.Http.Method = "POST"
353-
}
354-
355-
if operation.Http.Method == req.Method && pathMatchSuccess {
356-
if len(path) > len(longestPath) {
357-
longestPath = path
358-
} else {
359-
continue
360-
}
361-
362-
action = operationName
363-
pathMatches := regexp.MustCompile(regexStr).FindAllStringSubmatch(path, -1)
364-
365-
if len(pathMatches) > 0 && len(pathMatches) > 0 && len(templateMatches) == len(pathMatches[0])-1 {
366-
for i := 0; i < len(templateMatches); i++ {
367-
uriparams[templateMatches[i][1]] = pathMatches[0][1:][i]
368-
}
369-
}
370-
}
371-
}
372-
373-
// query part
374-
for k, v := range vals {
375-
normalizedK := regexp.MustCompile(`\.member\.[0-9]+`).ReplaceAllString(k, "[]")
376-
normalizedK = regexp.MustCompile(`\.[0-9]+`).ReplaceAllString(normalizedK, "[]")
377-
378-
resolvedPropertyName := resolvePropertyName(serviceDef.Operations[action].Input, normalizedK, "", "", serviceDef.Shapes)
379-
if resolvedPropertyName != "" {
380-
normalizedK = resolvedPropertyName
381-
}
382-
383-
if len(params[normalizedK]) > 0 {
384-
params[normalizedK] = append(params[normalizedK], v...)
385-
} else {
386-
params[normalizedK] = v
387-
}
388-
}
389-
390-
// body part
391-
if len(body) > 0 {
392-
var bodyJSON interface{}
393-
err := json.Unmarshal(body, &bodyJSON)
394-
if err != nil {
395-
return
396-
}
397-
398-
flatten(true, params, bodyJSON, "")
399-
}
400-
} else if serviceDef.Metadata.Protocol == "json" {
309+
if serviceDef.Metadata.Protocol == "json" {
401310
// JSON schema
402311
var bodyJSON interface{}
403312
err := json.Unmarshal(body, &bodyJSON)
@@ -444,18 +353,18 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
444353
}
445354
}
446355
}
447-
} else if serviceDef.Metadata.Protocol == "rest-xml" {
356+
} else if serviceDef.Metadata.Protocol == "rest-json" || serviceDef.Metadata.Protocol == "rest-xml" {
448357
// URL param schema
449358
urlobj, err := url.ParseRequestURI(uri)
450359
if err != nil {
451360
return
452361
}
453362
vals := urlobj.Query()
454363

455-
// path part
456-
longestPath := ""
364+
actionCandidates := []ActionCandidate{}
457365

458-
OperationLoop2:
366+
// path part
367+
OperationLoop:
459368
for operationName, operation := range serviceDef.Operations {
460369
path := urlobj.Path
461370
if serviceDef.Metadata.EndpointPrefix == "s3" && strings.HasPrefix(operation.Http.RequestURI, "/{Bucket}") && endpointUriPrefix != "" { // https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html#VirtualHostingSpecifyBucket
@@ -485,10 +394,10 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
485394
} else if len(vals[operationquerykey]) > 0 {
486395
path += operationquerykey + "=" + vals[operationquerykey][0] + "&"
487396
} else {
488-
continue OperationLoop2
397+
continue OperationLoop
489398
}
490399
} else {
491-
continue OperationLoop2
400+
continue OperationLoop
492401
}
493402
}
494403

@@ -507,50 +416,116 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
507416
}
508417

509418
if operation.Http.Method == req.Method && pathMatchSuccess {
510-
if len(path) > len(longestPath) {
511-
longestPath = path
512-
} else {
513-
continue
514-
}
515-
516419
action = operationName
420+
uriparams = map[string]string{}
421+
517422
pathMatches := regexp.MustCompile(regexStr).FindAllStringSubmatch(path, -1)
518423

519424
if len(pathMatches) > 0 && len(pathMatches) > 0 && len(templateMatches) == len(pathMatches[0])-1 {
520425
for i := 0; i < len(templateMatches); i++ {
521426
uriparams[templateMatches[i][1]] = pathMatches[0][1:][i]
522427
}
523428
}
524-
}
525-
}
526429

527-
// query part
528-
for k, v := range vals {
529-
normalizedK := regexp.MustCompile(`\.member\.[0-9]+`).ReplaceAllString(k, "[]")
530-
normalizedK = regexp.MustCompile(`\.[0-9]+`).ReplaceAllString(normalizedK, "[]")
430+
// query part
431+
for k, v := range vals {
432+
normalizedK := regexp.MustCompile(`\.member\.[0-9]+`).ReplaceAllString(k, "[]")
433+
normalizedK = regexp.MustCompile(`\.[0-9]+`).ReplaceAllString(normalizedK, "[]")
531434

532-
resolvedPropertyName := resolvePropertyName(serviceDef.Operations[action].Input, normalizedK, "", "", serviceDef.Shapes)
533-
if resolvedPropertyName != "" {
534-
normalizedK = resolvedPropertyName
535-
}
435+
resolvedPropertyName := resolvePropertyName(serviceDef.Operations[action].Input, normalizedK, "", "", serviceDef.Shapes)
436+
if resolvedPropertyName != "" {
437+
normalizedK = resolvedPropertyName
438+
} else {
439+
// continue // Skipping just in case
440+
}
536441

537-
if len(params[normalizedK]) > 0 {
538-
params[normalizedK] = append(params[normalizedK], v...)
539-
} else {
540-
params[normalizedK] = v
442+
if len(params[normalizedK]) > 0 {
443+
params[normalizedK] = append(params[normalizedK], v...)
444+
} else {
445+
params[normalizedK] = v
446+
}
447+
}
448+
449+
// header part
450+
for k, v := range req.Header {
451+
resolvedPropertyName := resolvePropertyName(serviceDef.Operations[action].Input, k, "", "", serviceDef.Shapes)
452+
if resolvedPropertyName != "" {
453+
k = resolvedPropertyName
454+
} else {
455+
continue
456+
}
457+
458+
if len(params[k]) > 0 {
459+
params[k] = append(params[k], v...)
460+
} else {
461+
params[k] = v
462+
}
463+
}
464+
465+
// body part
466+
if len(body) > 0 {
467+
if serviceDef.Metadata.Protocol == "rest-json" {
468+
var bodyJSON interface{}
469+
err := json.Unmarshal(body, &bodyJSON)
470+
if err != nil {
471+
return
472+
}
473+
474+
flatten(true, params, bodyJSON, "")
475+
} else {
476+
var bodyXML interface{}
477+
err := xml.Unmarshal(body, &bodyXML)
478+
if err != nil {
479+
return
480+
}
481+
482+
flatten(true, params, bodyXML, "")
483+
}
484+
}
485+
486+
actionCandidates = append(actionCandidates, ActionCandidate{
487+
Path: path,
488+
Action: action,
489+
Params: params,
490+
URIParams: uriparams,
491+
Operation: operation,
492+
})
541493
}
542494
}
543495

544-
// body part
545-
if len(body) > 0 {
546-
var bodyXML interface{}
547-
err := xml.Unmarshal(body, &bodyXML)
548-
if err != nil {
549-
return
496+
// select candidate
497+
var selectedActionCandidate ActionCandidate
498+
ActionCandidateLoop:
499+
for _, actionCandidate := range actionCandidates {
500+
for _, requiredParam := range actionCandidate.Operation.Input.Required { // check input requirements
501+
if _, ok := actionCandidate.Params[requiredParam]; ok {
502+
continue
503+
}
504+
if _, ok := actionCandidate.URIParams[requiredParam]; ok {
505+
continue
506+
}
507+
continue ActionCandidateLoop // requirements not met
508+
}
509+
if selectedActionCandidate.Action == "" { // first one
510+
selectedActionCandidate = actionCandidate
511+
continue
512+
}
513+
if len(actionCandidate.Path) > len(selectedActionCandidate.Path) { // longer path wins
514+
selectedActionCandidate = actionCandidate
515+
continue
516+
}
517+
if len(actionCandidate.Operation.Input.Required) > len(selectedActionCandidate.Operation.Input.Required) { // more requirements wins
518+
selectedActionCandidate = actionCandidate
519+
continue
550520
}
551-
552-
flatten(true, params, bodyXML, "")
553521
}
522+
action = selectedActionCandidate.Action
523+
params = selectedActionCandidate.Params
524+
uriparams = selectedActionCandidate.URIParams
525+
}
526+
527+
if action == "" {
528+
return
554529
}
555530

556531
region := "us-east-1"

0 commit comments

Comments
 (0)