Skip to content

Commit 006e0b8

Browse files
committed
Add find_cluster example to simplify test integration
1 parent 20670dd commit 006e0b8

3 files changed

Lines changed: 267 additions & 0 deletions

File tree

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// example_test.go
2+
package main
3+
4+
import (
5+
"context"
6+
"example/internal/util"
7+
"fmt"
8+
"os"
9+
"testing"
10+
"time"
11+
)
12+
13+
// Global variables for shared resources
14+
var (
15+
testCtx context.Context
16+
cancel context.CancelFunc
17+
)
18+
19+
func TestMain(m *testing.M) {
20+
// Setup before running tests
21+
setup()
22+
23+
// Run all tests
24+
code := m.Run()
25+
26+
// Cleanup after tests complete
27+
teardown()
28+
29+
// Exit with the test status code
30+
os.Exit(code)
31+
}
32+
33+
func setup() {
34+
// Initialize context with timeout for all tests
35+
testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute)
36+
37+
output, err := util.FindClusterByTag(testCtx, "us-east-1", "Name", "go single region cluster")
38+
39+
if err != nil {
40+
fmt.Errorf("Error finding cluster by tag")
41+
}
42+
43+
// Set up any environment variables needed for tests
44+
os.Setenv("REGION", "us-east-1")
45+
os.Setenv("CLUSTER_ID", *output.Identifier)
46+
47+
// Add any other initialization code here
48+
// For example: database connections, mock services, etc.
49+
}
50+
51+
func teardown() {
52+
// Cancel the context
53+
cancel()
54+
55+
// Clean up any resources, close connections, etc.
56+
}
57+
58+
// Test for GetCluster function
59+
func TestGetCluster(t *testing.T) {
60+
// Test cases
61+
tests := []struct {
62+
name string
63+
region string
64+
identifier string
65+
wantErr bool
66+
}{
67+
{
68+
name: "Valid cluster retrieval",
69+
region: os.Getenv("REGION"),
70+
identifier: "saabucfkaxwz5vcba4yzjpqdly",
71+
wantErr: false,
72+
},
73+
// Add more test cases as needed
74+
}
75+
76+
for _, tt := range tests {
77+
t.Run(tt.name, func(t *testing.T) {
78+
_, err := GetCluster(testCtx, tt.region, tt.identifier)
79+
if (err != nil) != tt.wantErr {
80+
t.Errorf("GetCluster() error = %v, wantErr %v", err, tt.wantErr)
81+
}
82+
})
83+
}
84+
}
85+
86+
// Add more test functions for other commands in the cmd folder
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package util
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/aws/aws-sdk-go-v2/config"
7+
"github.com/aws/aws-sdk-go-v2/service/dsql"
8+
"log"
9+
)
10+
11+
// FindClusterByTag finds an Aurora cluster by a specific tag name and value
12+
func FindClusterByTag(ctx context.Context, region, tagName, tagValue string) (*dsql.GetClusterOutput, error) {
13+
if tagName == "" || tagValue == "" {
14+
return nil, fmt.Errorf("tagName and tagValue cannot be empty")
15+
}
16+
17+
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
18+
if err != nil {
19+
log.Fatalf("Failed to load AWS configuration: %v", err)
20+
}
21+
22+
// Initialize the DSQL client
23+
client := dsql.NewFromConfig(cfg)
24+
25+
clustersOutput, err := client.ListClusters(ctx, &dsql.ListClustersInput{})
26+
27+
if err != nil {
28+
log.Fatalf("Failed to list clusters: %v", err)
29+
}
30+
31+
for _, val := range clustersOutput.Clusters {
32+
fmt.Println("found cluster:" + *val.Identifier + " with tag:" + tagName + "=" + tagValue)
33+
34+
clusterOutput, err := client.GetCluster(ctx, &dsql.GetClusterInput{Identifier: val.Identifier})
35+
if err != nil {
36+
log.Fatalf("Failed to get cluster: %v", err)
37+
}
38+
39+
if clusterOutput.Tags[tagName] == tagValue {
40+
return clusterOutput, nil
41+
}
42+
}
43+
44+
return nil, fmt.Errorf("no cluster found with tag %s=%s", tagName, tagValue)
45+
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package util
2+
3+
import (
4+
"context"
5+
"github.com/aws/aws-sdk-go-v2/aws"
6+
"github.com/aws/aws-sdk-go-v2/service/dsql"
7+
"github.com/aws/aws-sdk-go-v2/service/dsql/types"
8+
"testing"
9+
)
10+
11+
// MockDSQLClient implements the necessary DSQL client methods for testing
12+
type MockDSQLClient struct {
13+
listClusters *dsql.ListClustersOutput
14+
getCluster *dsql.GetClusterOutput
15+
listErr error
16+
getErr error
17+
}
18+
19+
func (m *MockDSQLClient) ListClusters(ctx context.Context, params *dsql.ListClustersInput) (*dsql.ListClustersOutput, error) {
20+
return m.listClusters, m.listErr
21+
}
22+
23+
func (m *MockDSQLClient) GetCluster(ctx context.Context, params *dsql.GetClusterInput) (*dsql.GetClusterOutput, error) {
24+
return m.getCluster, m.getErr
25+
}
26+
27+
func TestFindClusterByTag(t *testing.T) {
28+
tests := []struct {
29+
name string
30+
region string
31+
tagName string
32+
tagValue string
33+
mockData struct {
34+
listClusters *dsql.ListClustersOutput
35+
getCluster *dsql.GetClusterOutput
36+
listErr error
37+
getErr error
38+
}
39+
wantErr bool
40+
}{
41+
{
42+
name: "Successfully find cluster",
43+
region: "us-west-2",
44+
tagName: "Environment",
45+
tagValue: "Production",
46+
mockData: struct {
47+
listClusters *dsql.ListClustersOutput
48+
getCluster *dsql.GetClusterOutput
49+
listErr error
50+
getErr error
51+
}{
52+
listClusters: &dsql.ListClustersOutput{
53+
Clusters: []types.ClusterSummary{
54+
{
55+
Identifier: aws.String("cluster-1"),
56+
},
57+
},
58+
},
59+
getCluster: &dsql.GetClusterOutput{
60+
Identifier: aws.String("cluster-1"),
61+
Tags: map[string]string{
62+
"Environment": "Production",
63+
},
64+
},
65+
},
66+
wantErr: false,
67+
},
68+
{
69+
name: "Empty tag name",
70+
region: "us-west-2",
71+
tagName: "",
72+
tagValue: "Production",
73+
mockData: struct {
74+
listClusters *dsql.ListClustersOutput
75+
getCluster *dsql.GetClusterOutput
76+
listErr error
77+
getErr error
78+
}{},
79+
wantErr: true,
80+
},
81+
{
82+
name: "Cluster not found",
83+
region: "us-west-2",
84+
tagName: "Environment",
85+
tagValue: "Staging",
86+
mockData: struct {
87+
listClusters *dsql.ListClustersOutput
88+
getCluster *dsql.GetClusterOutput
89+
listErr error
90+
getErr error
91+
}{
92+
listClusters: &dsql.ListClustersOutput{
93+
Clusters: []types.ClusterSummary{
94+
{Identifier: aws.String("cluster-1")},
95+
},
96+
},
97+
getCluster: &dsql.GetClusterOutput{
98+
Identifier: aws.String("cluster-1"),
99+
Tags: map[string]string{
100+
"Environment": "Production",
101+
},
102+
},
103+
},
104+
wantErr: true,
105+
},
106+
}
107+
108+
for _, tt := range tests {
109+
t.Run(tt.name, func(t *testing.T) {
110+
// Create mock client with test data
111+
//mockClient := &MockDSQLClient{
112+
// listClusters: tt.mockData.listClusters,
113+
// getCluster: tt.mockData.getCluster,
114+
// listErr: tt.mockData.listErr,
115+
// getErr: tt.mockData.getErr,
116+
//}
117+
118+
// Call the function being tested
119+
result, err := FindClusterByTag(context.Background(), "us-east-1", tt.tagName, tt.tagValue)
120+
121+
// Check error cases
122+
if (err != nil) != tt.wantErr {
123+
t.Errorf("FindClusterByTag() error = %v, wantErr %v", err, tt.wantErr)
124+
return
125+
}
126+
127+
// For successful cases, verify the result
128+
if !tt.wantErr && result != nil {
129+
if result.Tags[tt.tagName] != tt.tagValue {
130+
t.Errorf("FindClusterByTag() got tag value = %v, want %v",
131+
result.Tags[tt.tagName], tt.tagValue)
132+
}
133+
}
134+
})
135+
}
136+
}

0 commit comments

Comments
 (0)