Skip to content

Commit 782c2e3

Browse files
authored
Merge pull request #37 from gruntwork-io/feature/add-helpers
add helpers
2 parents ed788d2 + 01cf254 commit 782c2e3

File tree

8 files changed

+444
-8
lines changed

8 files changed

+444
-8
lines changed

aws/iam.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package aws
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/aws"
5+
"github.com/aws/aws-sdk-go/aws/defaults"
6+
"github.com/gruntwork-io/gruntwork-cli/errors"
7+
"github.com/aws/aws-sdk-go/service/iam"
8+
"github.com/aws/aws-sdk-go/aws/session"
9+
)
10+
11+
type PolicyDocument struct {
12+
Version string
13+
Statement []StatementEntry
14+
}
15+
16+
type StatementEntry struct {
17+
Effect string
18+
Action []string
19+
Resource string
20+
}
21+
22+
func CreateAwsConfig(awsRegion string) (*aws.Config, error) {
23+
config := defaults.Get().Config.WithRegion(awsRegion)
24+
25+
_, err := config.Credentials.Get()
26+
if err != nil {
27+
return nil, errors.WithStackTraceAndPrefix(err, "Error finding AWS credentials (did you set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables?)")
28+
}
29+
30+
return config, nil
31+
}
32+
33+
func GetIamUserName(awsRegion string) (string, error) {
34+
35+
iamClient, err := CreateIamClient(awsRegion)
36+
if err != nil {
37+
return "", err
38+
}
39+
40+
resp, err := iamClient.GetUser(&iam.GetUserInput{})
41+
if err != nil {
42+
return "", err
43+
}
44+
45+
return *resp.User.UserName, nil
46+
}
47+
48+
func CreateIamClient(awsRegion string) (*iam.IAM, error) {
49+
awsConfig, err := CreateAwsConfig(awsRegion)
50+
if err != nil {
51+
return nil, err
52+
}
53+
54+
return iam.New(session.New(), awsConfig), nil
55+
}

aws/s3.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,84 @@ import (
1111
"github.com/aws/aws-sdk-go/aws/session"
1212
"errors"
1313
"fmt"
14+
"github.com/gruntwork-io/gruntwork-cli/logging"
15+
"bytes"
1416
)
1517

18+
func CreateS3Client(awsRegion string) (*s3.S3, error) {
19+
awsConfig, err := CreateAwsConfig(awsRegion)
20+
if err != nil {
21+
return nil, err
22+
}
23+
24+
return s3.New(session.New(), awsConfig), nil
25+
}
26+
27+
func FindS3BucketWithTag(awsRegion string, key string, value string) (string, error) {
28+
logger := logging.GetLogger("FindS3BucketWithTag")
29+
30+
s3Client, err := CreateS3Client(awsRegion)
31+
if err != nil {
32+
return "", err
33+
}
34+
35+
resp, err := s3Client.ListBuckets(&s3.ListBucketsInput{})
36+
if err != nil {
37+
return "", err
38+
}
39+
40+
for _, bucket := range resp.Buckets {
41+
tagResponse, err := s3Client.GetBucketTagging(&s3.GetBucketTaggingInput{Bucket: bucket.Name})
42+
if err != nil {
43+
44+
if !strings.Contains(err.Error(), "AuthorizationHeaderMalformed") &&
45+
!strings.Contains(err.Error(), "BucketRegionError") &&
46+
!strings.Contains(err.Error(), "NoSuchTagSet") {
47+
return "", err
48+
}
49+
50+
}
51+
52+
for _, tag := range tagResponse.TagSet {
53+
if *tag.Key == key && *tag.Value == value {
54+
logger.Debugf("Found S3 bucket %s with %s=%s", *bucket.Name, key, value)
55+
return *bucket.Name, nil
56+
}
57+
}
58+
}
59+
60+
return "", nil
61+
}
62+
63+
func GetS3ObjectContents(awsRegion string, bucket string, key string) (string, error) {
64+
logger := logging.GetLogger("GetS3ObjectContents")
65+
66+
s3Client, err := CreateS3Client(awsRegion)
67+
if err != nil {
68+
return "", err
69+
}
70+
71+
res, err := s3Client.GetObject(&s3.GetObjectInput{
72+
Bucket: &bucket,
73+
Key: &key,
74+
})
75+
76+
if err != nil {
77+
return "", err
78+
}
79+
80+
buf := new(bytes.Buffer)
81+
_, err = buf.ReadFrom(res.Body)
82+
if err != nil {
83+
return "", err
84+
}
85+
86+
contents := buf.String()
87+
logger.Debugf("Read contents from s3://%s/%s", bucket, key)
88+
89+
return contents, nil
90+
}
91+
1692
// Create an S3 bucket.
1793
func CreateS3Bucket(region string, name string) {
1894
log := log.NewLogger("CreateS3Bucket")

aws/sqs.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package aws
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/aws/session"
5+
"github.com/gruntwork-io/gruntwork-cli/logging"
6+
"github.com/google/uuid"
7+
"github.com/aws/aws-sdk-go/service/sqs"
8+
"github.com/aws/aws-sdk-go/aws"
9+
"fmt"
10+
"strconv"
11+
"strings"
12+
)
13+
14+
func CreateRandomQueue(awsRegion string, prefix string) (string, error) {
15+
logger := logging.GetLogger("CreateRandomQueue")
16+
logger.Debugf("Creating randomly named SQS queue with prefix %s", prefix)
17+
18+
sqsClient, err := CreateSqsClient(awsRegion)
19+
if err != nil {
20+
return "", err
21+
}
22+
23+
channel, err := uuid.NewUUID()
24+
if err != nil {
25+
return "", err
26+
}
27+
28+
channelName := fmt.Sprintf("%s-%s", prefix, channel.String())
29+
30+
queue, err := sqsClient.CreateQueue(&sqs.CreateQueueInput{
31+
QueueName: aws.String(channelName),
32+
})
33+
34+
if err != nil {
35+
return "", err
36+
}
37+
38+
return *queue.QueueUrl, nil;
39+
}
40+
41+
func DeleteQueue(awsRegion string, queueUrl string) (error) {
42+
logger := logging.GetLogger("DeleteQueue")
43+
logger.Debugf("Deleting SQS Queue %s", queueUrl)
44+
45+
sqsClient, err := CreateSqsClient(awsRegion)
46+
if err != nil {
47+
return err
48+
}
49+
50+
_, err = sqsClient.DeleteQueue(&sqs.DeleteQueueInput{
51+
QueueUrl:aws.String(queueUrl),
52+
})
53+
54+
if err != nil {
55+
return err
56+
}
57+
return nil
58+
}
59+
60+
func DeleteMessageFromQueue(awsRegion string, queueUrl string, receipt string) (error) {
61+
logger := logging.GetLogger("DeleteMessageFromQueue")
62+
logger.Debugf("Deleting message from queue %s (%s)", queueUrl, receipt)
63+
64+
sqsClient, err := CreateSqsClient(awsRegion)
65+
if err != nil {
66+
return err
67+
}
68+
69+
_, err = sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
70+
ReceiptHandle: &receipt,
71+
QueueUrl: &queueUrl,
72+
})
73+
if err != nil {
74+
return err
75+
}
76+
77+
return nil
78+
}
79+
80+
func SendMessageToQueue(awsRegion string, queueUrl string, message string) (error) {
81+
logger := logging.GetLogger("SendMessageToQueue")
82+
83+
sqsClient, err := CreateSqsClient(awsRegion)
84+
if err != nil {
85+
return err
86+
}
87+
88+
logger.Debugf("Sending message %s to queue %s", message, queueUrl)
89+
res, err := sqsClient.SendMessage(&sqs.SendMessageInput{
90+
MessageBody: &message,
91+
QueueUrl: &queueUrl,
92+
})
93+
94+
if err != nil {
95+
if strings.Contains(err.Error(), "AWS.SimpleQueueService.NonExistentQueue") {
96+
logger.Warn(fmt.Sprintf("Client has stopped listening on queue %s", queueUrl))
97+
return nil
98+
}
99+
return err
100+
}
101+
logger.Debugf("Message id %s sent to queue %s", res.MessageId, queueUrl)
102+
103+
return nil
104+
}
105+
106+
func CreateSqsClient(awsRegion string) (*sqs.SQS, error) {
107+
awsConfig, err := CreateAwsConfig(awsRegion)
108+
if err != nil {
109+
return nil, err
110+
}
111+
112+
return sqs.New(session.New(), awsConfig), nil
113+
}
114+
115+
type QueueMessageResponse struct {
116+
ReceiptHandle string
117+
MessageBody string
118+
Error error
119+
}
120+
121+
// Waits to receive a message from on the queueUrl. Since the API only allows us to wait a max 20 seconds for a new
122+
// message to arrive, we must loop TIMEOUT/20 number of times to be able to wait for a total of TIMEOUT seconds
123+
func WaitForQueueMessage(awsRegion string, queueUrl string, timeout int) (QueueMessageResponse) {
124+
logger := logging.GetLogger("WaitForQueueMessage")
125+
126+
sqsClient, err := CreateSqsClient(awsRegion)
127+
if err != nil {
128+
return QueueMessageResponse{Error:err}
129+
}
130+
131+
cycles := timeout;
132+
cycleLength := 1;
133+
134+
if timeout >= 20 {
135+
cycleLength = 20
136+
cycles = timeout / cycleLength
137+
}
138+
139+
for i := 0; i < cycles; i++ {
140+
logger.Debugf("Waiting for message on %s (%ss)", queueUrl, strconv.Itoa(i * cycleLength))
141+
result, err := sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{
142+
QueueUrl: aws.String(queueUrl),
143+
AttributeNames: aws.StringSlice([]string{
144+
"SentTimestamp",
145+
}),
146+
MaxNumberOfMessages: aws.Int64(1),
147+
MessageAttributeNames: aws.StringSlice([]string{
148+
"All",
149+
}),
150+
WaitTimeSeconds: aws.Int64(int64(cycleLength)),
151+
})
152+
153+
if err != nil {
154+
return QueueMessageResponse{Error:err}
155+
}
156+
157+
if len(result.Messages) > 0 {
158+
logger.Debugf("Message %s received on %s", *result.Messages[0].MessageId, queueUrl)
159+
return QueueMessageResponse{ReceiptHandle:*result.Messages[0].ReceiptHandle, MessageBody:*result.Messages[0].Body}
160+
}
161+
}
162+
163+
return QueueMessageResponse{Error:fmt.Errorf("Failed to receive messages on %s within %s seconds", queueUrl, strconv.Itoa(timeout))}
164+
}

glide.lock

Lines changed: 24 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

resources/base_resources.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package resources
2+
3+
import (
4+
"testing"
5+
"github.com/gruntwork-io/terratest"
6+
)
7+
8+
func CreateBaseRandomResourceCollection(t *testing.T, excludedRegions ...string) *terratest.RandomResourceCollection {
9+
resourceCollectionOptions := terratest.NewRandomResourceCollectionOptions()
10+
11+
if (excludedRegions!=nil) {
12+
resourceCollectionOptions.ForbiddenRegions = excludedRegions
13+
}
14+
15+
randomResourceCollection, err := terratest.CreateRandomResourceCollection(resourceCollectionOptions)
16+
if err != nil {
17+
t.Fatalf("Failed to create random resource collection: %s\n", err.Error())
18+
}
19+
20+
return randomResourceCollection
21+
}

0 commit comments

Comments
 (0)