Skip to content

Commit 4c232b5

Browse files
refactor: improve testability of provider selection and add unit tests in cmd/exporter (#840)
Co-authored-by: Leandro López <inkel.ar@gmail.com>
1 parent b6711dc commit 4c232b5

2 files changed

Lines changed: 215 additions & 3 deletions

File tree

cmd/exporter/exporter.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,22 @@ func createPromRegistryHandler(csp provider.Provider, region string) (http.Handl
208208
}
209209

210210
func selectProvider(ctx context.Context, cfg *config.Config) (provider.Provider, error) {
211+
return selectProviderWith(ctx, cfg,
212+
func(ctx context.Context, cfg *aws.Config) (provider.Provider, error) { return aws.New(ctx, cfg) },
213+
func(ctx context.Context, cfg *azure.Config) (provider.Provider, error) { return azure.New(ctx, cfg) },
214+
func(ctx context.Context, cfg *google.Config) (provider.Provider, error) { return google.New(ctx, cfg) },
215+
)
216+
}
217+
218+
type newProviderFunc[T any] func(context.Context, T) (provider.Provider, error)
219+
220+
func selectProviderWith(
221+
ctx context.Context,
222+
cfg *config.Config,
223+
newAWS newProviderFunc[*aws.Config],
224+
newAzure newProviderFunc[*azure.Config],
225+
newGCP newProviderFunc[*google.Config],
226+
) (provider.Provider, error) {
211227
// Set collector timeout with 1 minute default
212228
collectorTimeout := cfg.Collector.Timeout
213229
if collectorTimeout == 0 {
@@ -216,14 +232,14 @@ func selectProvider(ctx context.Context, cfg *config.Config) (provider.Provider,
216232

217233
switch cfg.Provider {
218234
case "azure":
219-
return azure.New(ctx, &azure.Config{
235+
return newAzure(ctx, &azure.Config{
220236
Logger: cfg.Logger,
221237
SubscriptionId: cfg.Providers.Azure.SubscriptionId,
222238
Services: cfg.Providers.Azure.Services,
223239
CollectorTimeout: collectorTimeout,
224240
})
225241
case "aws":
226-
return aws.New(ctx, &aws.Config{
242+
return newAWS(ctx, &aws.Config{
227243
Logger: cfg.Logger,
228244
Region: cfg.Providers.AWS.Region,
229245
Profile: cfg.Providers.AWS.Profile,
@@ -235,7 +251,7 @@ func selectProvider(ctx context.Context, cfg *config.Config) (provider.Provider,
235251
})
236252

237253
case "gcp":
238-
return google.New(ctx, &google.Config{
254+
return newGCP(ctx, &google.Config{
239255
Logger: cfg.Logger,
240256
ProjectId: cfg.ProjectID,
241257
Region: cfg.Providers.GCP.Region,

cmd/exporter/exporter_test.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,145 @@
11
package main
22

33
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
"net/http/httptest"
48
"testing"
9+
"time"
510

611
"github.com/grafana/cloudcost-exporter/cmd/exporter/config"
12+
"github.com/grafana/cloudcost-exporter/pkg/aws"
13+
"github.com/grafana/cloudcost-exporter/pkg/azure"
14+
"github.com/grafana/cloudcost-exporter/pkg/google"
15+
"github.com/grafana/cloudcost-exporter/pkg/provider"
16+
mock_provider "github.com/grafana/cloudcost-exporter/pkg/provider/mocks"
17+
"go.uber.org/mock/gomock"
718
)
819

20+
func Test_selectProvider(t *testing.T) {
21+
tests := map[string]struct {
22+
providerName string
23+
collectorTimeout time.Duration
24+
awsCalled bool
25+
azureCalled bool
26+
gcpCalled bool
27+
constructorErr error
28+
wantErr bool
29+
wantCollectorTimeout time.Duration // if non-zero, asserts the timeout forwarded to the constructor
30+
}{
31+
"aws provider": {
32+
providerName: "aws",
33+
awsCalled: true,
34+
},
35+
"azure provider": {
36+
providerName: "azure",
37+
azureCalled: true,
38+
},
39+
"gcp provider": {
40+
providerName: "gcp",
41+
gcpCalled: true,
42+
},
43+
"aws constructor error is propagated": {
44+
providerName: "aws",
45+
constructorErr: errors.New("constructor failed"),
46+
wantErr: true,
47+
},
48+
"azure constructor error is propagated": {
49+
providerName: "azure",
50+
constructorErr: errors.New("constructor failed"),
51+
wantErr: true,
52+
},
53+
"gcp constructor error is propagated": {
54+
providerName: "gcp",
55+
constructorErr: errors.New("constructor failed"),
56+
wantErr: true,
57+
},
58+
"unknown provider returns error": {
59+
providerName: "unknown",
60+
wantErr: true,
61+
},
62+
"zero timeout defaults to one minute": {
63+
providerName: "aws",
64+
collectorTimeout: 0,
65+
awsCalled: true,
66+
wantCollectorTimeout: time.Minute,
67+
},
68+
"explicit timeout is passed through": {
69+
providerName: "aws",
70+
collectorTimeout: 5 * time.Minute,
71+
awsCalled: true,
72+
wantCollectorTimeout: 5 * time.Minute,
73+
},
74+
}
75+
76+
for name, tc := range tests {
77+
t.Run(name, func(t *testing.T) {
78+
ctrl := gomock.NewController(t)
79+
defer ctrl.Finish()
80+
81+
var awsCalled, azureCalled, gcpCalled bool
82+
var capturedTimeout time.Duration
83+
84+
mockProv := mock_provider.NewMockProvider(ctrl)
85+
86+
stubAWS := func(_ context.Context, cfg *aws.Config) (provider.Provider, error) {
87+
awsCalled = true
88+
capturedTimeout = cfg.CollectorTimeout
89+
if tc.constructorErr != nil {
90+
return nil, tc.constructorErr
91+
}
92+
return mockProv, nil
93+
}
94+
stubAzure := func(_ context.Context, cfg *azure.Config) (provider.Provider, error) {
95+
azureCalled = true
96+
capturedTimeout = cfg.CollectorTimeout
97+
if tc.constructorErr != nil {
98+
return nil, tc.constructorErr
99+
}
100+
return mockProv, nil
101+
}
102+
stubGCP := func(_ context.Context, cfg *google.Config) (provider.Provider, error) {
103+
gcpCalled = true
104+
capturedTimeout = cfg.CollectorTimeout
105+
if tc.constructorErr != nil {
106+
return nil, tc.constructorErr
107+
}
108+
return mockProv, nil
109+
}
110+
111+
cfg := &config.Config{Provider: tc.providerName}
112+
cfg.Collector.Timeout = tc.collectorTimeout
113+
got, err := selectProviderWith(context.Background(), cfg, stubAWS, stubAzure, stubGCP)
114+
115+
if tc.wantErr {
116+
if err == nil {
117+
t.Error("expected error, got nil")
118+
}
119+
return
120+
}
121+
if err != nil {
122+
t.Fatalf("unexpected error: %v", err)
123+
}
124+
if got == nil {
125+
t.Fatal("expected non-nil provider")
126+
}
127+
if awsCalled != tc.awsCalled {
128+
t.Errorf("awsCalled = %v, want %v", awsCalled, tc.awsCalled)
129+
}
130+
if azureCalled != tc.azureCalled {
131+
t.Errorf("azureCalled = %v, want %v", azureCalled, tc.azureCalled)
132+
}
133+
if gcpCalled != tc.gcpCalled {
134+
t.Errorf("gcpCalled = %v, want %v", gcpCalled, tc.gcpCalled)
135+
}
136+
if tc.wantCollectorTimeout != 0 && capturedTimeout != tc.wantCollectorTimeout {
137+
t.Errorf("collectorTimeout = %v, want %v", capturedTimeout, tc.wantCollectorTimeout)
138+
}
139+
})
140+
}
141+
}
142+
9143
func Test_regionFromConfig(t *testing.T) {
10144
tests := map[string]struct {
11145
provider string
@@ -51,3 +185,65 @@ func Test_regionFromConfig(t *testing.T) {
51185
})
52186
}
53187
}
188+
189+
func Test_createPromRegistryHandler(t *testing.T) {
190+
tests := map[string]struct {
191+
setupMock func(m *mock_provider.MockProvider)
192+
wantErr bool
193+
wantHTTPStatus int
194+
}{
195+
"returns error when RegisterCollectors fails": {
196+
setupMock: func(m *mock_provider.MockProvider) {
197+
m.EXPECT().Describe(gomock.Any()).AnyTimes()
198+
m.EXPECT().RegisterCollectors(gomock.Any()).Return(errors.New("collector registration failed"))
199+
},
200+
wantErr: true,
201+
},
202+
"returns working handler on success": {
203+
setupMock: func(m *mock_provider.MockProvider) {
204+
m.EXPECT().Describe(gomock.Any()).AnyTimes()
205+
m.EXPECT().RegisterCollectors(gomock.Any()).Return(nil)
206+
m.EXPECT().Collect(gomock.Any()).AnyTimes()
207+
},
208+
wantErr: false,
209+
wantHTTPStatus: http.StatusOK,
210+
},
211+
}
212+
213+
for name, tc := range tests {
214+
t.Run(name, func(t *testing.T) {
215+
ctrl := gomock.NewController(t)
216+
defer ctrl.Finish()
217+
218+
mockProv := mock_provider.NewMockProvider(ctrl)
219+
tc.setupMock(mockProv)
220+
221+
handler, err := createPromRegistryHandler(mockProv, "us-east-1")
222+
223+
if tc.wantErr {
224+
if err == nil {
225+
t.Error("expected error, got nil")
226+
}
227+
if handler != nil {
228+
t.Error("expected nil handler on error")
229+
}
230+
return
231+
}
232+
233+
if err != nil {
234+
t.Fatalf("unexpected error: %v", err)
235+
}
236+
if handler == nil {
237+
t.Fatal("expected non-nil handler")
238+
}
239+
240+
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
241+
rec := httptest.NewRecorder()
242+
handler.ServeHTTP(rec, req)
243+
244+
if rec.Code != tc.wantHTTPStatus {
245+
t.Errorf("expected HTTP status %d, got %d", tc.wantHTTPStatus, rec.Code)
246+
}
247+
})
248+
}
249+
}

0 commit comments

Comments
 (0)