Skip to content

Commit 736ef7c

Browse files
fix: msk collector fanout by region (#900)
1 parent 477e63b commit 736ef7c

File tree

2 files changed

+93
-17
lines changed

2 files changed

+93
-17
lines changed

pkg/aws/msk/msk.go

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"log/slog"
77
"strings"
8+
"sync"
89
"time"
910

1011
"github.com/aws/aws-sdk-go-v2/aws"
@@ -147,13 +148,8 @@ func New(ctx context.Context, config *Config) (*Collector, error) {
147148
func (c *Collector) Collect(ctx context.Context, ch chan<- prometheus.Metric) error {
148149
snapshot := c.pricingStore.Snapshot()
149150

151+
var wg sync.WaitGroup
150152
for _, region := range c.regions {
151-
select {
152-
case <-ctx.Done():
153-
return ctx.Err()
154-
default:
155-
}
156-
157153
if region.RegionName == nil || *region.RegionName == "" {
158154
c.logger.Warn("skipping region with empty name")
159155
continue
@@ -166,18 +162,21 @@ func (c *Collector) Collect(ctx context.Context, ch chan<- prometheus.Metric) er
166162
continue
167163
}
168164

169-
clusters, err := regionClient.ListMSKClusters(ctx)
170-
if err != nil {
171-
c.logger.Error("error listing MSK clusters", "region", regionName, "error", err)
172-
continue
173-
}
174-
175-
for _, cluster := range clusters {
176-
c.collectCluster(ch, snapshot, regionName, cluster)
177-
}
165+
wg.Add(1)
166+
go func() {
167+
defer wg.Done()
168+
clusters, err := regionClient.ListMSKClusters(ctx)
169+
if err != nil {
170+
c.logger.Error("error listing MSK clusters", "region", regionName, "error", err)
171+
return
172+
}
173+
for _, cluster := range clusters {
174+
c.collectCluster(ch, snapshot, regionName, cluster)
175+
}
176+
}()
178177
}
179-
180-
return nil
178+
wg.Wait()
179+
return ctx.Err()
181180
}
182181

183182
func (c *Collector) collectCluster(ch chan<- prometheus.Metric, snapshot pricingstore.Snapshot, region string, cluster msktypes.Cluster) {

pkg/aws/msk/msk_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package msk
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"log/slog"
@@ -285,6 +286,82 @@ func TestCollectorCollectContinuesWhenRegionListingFails(t *testing.T) {
285286
assert.Len(t, results, 2)
286287
}
287288

289+
func TestCollectorCollectReturnsContextErrWhenContextCancelled(t *testing.T) {
290+
ctrl := gomock.NewController(t)
291+
defer ctrl.Finish()
292+
293+
regionClient := mockclient.NewMockClient(ctrl)
294+
pricingClient := mockclient.NewMockClient(ctrl)
295+
296+
regionClient.EXPECT().
297+
ListMSKClusters(gomock.Any()).
298+
Return(nil, context.Canceled).
299+
AnyTimes()
300+
expectPricingLoad(pricingClient, "us-east-1", "USE1", "0.2100000000", "0.1000000000")
301+
302+
collector, err := New(t.Context(), &Config{
303+
Regions: []ec2types.Region{{RegionName: aws.String("us-east-1")}},
304+
RegionMap: map[string]client.Client{"us-east-1": regionClient},
305+
Client: pricingClient,
306+
Logger: testLogger(),
307+
AccountID: "123456789012",
308+
})
309+
require.NoError(t, err)
310+
311+
ctx, cancel := context.WithCancel(context.Background())
312+
cancel()
313+
314+
ch := make(chan prometheus.Metric, 10)
315+
err = collector.Collect(ctx, ch)
316+
close(ch)
317+
318+
assert.ErrorIs(t, err, context.Canceled)
319+
}
320+
321+
func TestCollectorCollectContinuesOnContextDeadlineExceeded(t *testing.T) {
322+
ctrl := gomock.NewController(t)
323+
defer ctrl.Finish()
324+
325+
failingClient := mockclient.NewMockClient(ctrl)
326+
healthyClient := mockclient.NewMockClient(ctrl)
327+
pricingClient := mockclient.NewMockClient(ctrl)
328+
cluster := newProvisionedCluster(
329+
"test-cluster",
330+
"arn:aws:kafka:us-west-2:123456789012:cluster/test-cluster",
331+
"kafka.m5.large", 3, 100,
332+
)
333+
334+
failingClient.EXPECT().
335+
ListMSKClusters(gomock.Any()).
336+
Return(nil, context.DeadlineExceeded).
337+
Times(1)
338+
healthyClient.EXPECT().
339+
ListMSKClusters(gomock.Any()).
340+
Return([]msktypes.Cluster{cluster}, nil).
341+
Times(1)
342+
expectPricingLoad(pricingClient, "us-east-1", "USE1", "0.2100000000", "0.1000000000")
343+
expectPricingLoad(pricingClient, "us-west-2", "USW2", "0.2100000000", "0.1000000000")
344+
345+
collector, err := New(t.Context(), &Config{
346+
Regions: []ec2types.Region{
347+
{RegionName: aws.String("us-east-1")},
348+
{RegionName: aws.String("us-west-2")},
349+
},
350+
RegionMap: map[string]client.Client{
351+
"us-east-1": failingClient,
352+
"us-west-2": healthyClient,
353+
},
354+
Client: pricingClient,
355+
Logger: testLogger(),
356+
AccountID: "123456789012",
357+
})
358+
require.NoError(t, err)
359+
360+
results, err := collectMetricResults(t, collector)
361+
require.NoError(t, err)
362+
assert.Len(t, results, 2) // metrics from us-west-2 still collected
363+
}
364+
288365
func newProvisionedCluster(name, arn, instanceType string, brokerCount, volumeSizeGiB int32) msktypes.Cluster {
289366
return msktypes.Cluster{
290367
ClusterArn: aws.String(arn),

0 commit comments

Comments
 (0)