Skip to content

Commit 45e9455

Browse files
authored
refact pkg/acquisition: split s3.go (#4035)
1 parent 71cb6b1 commit 45e9455

File tree

7 files changed

+386
-340
lines changed

7 files changed

+386
-340
lines changed

.golangci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ linters:
485485

486486
- linters:
487487
- containedctx
488-
path: pkg/acquisition/modules/s3/s3.go
488+
path: pkg/acquisition/modules/s3/source.go
489489
text: found a struct that contains a context.Context field
490490

491491
# migrate over time
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
package s3acquisition
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"errors"
7+
"fmt"
8+
"net/url"
9+
"strconv"
10+
"strings"
11+
12+
"github.com/aws/aws-sdk-go-v2/aws"
13+
"github.com/aws/aws-sdk-go-v2/config"
14+
"github.com/aws/aws-sdk-go-v2/service/s3"
15+
"github.com/aws/aws-sdk-go-v2/service/sqs"
16+
yaml "github.com/goccy/go-yaml"
17+
log "github.com/sirupsen/logrus"
18+
19+
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
20+
"github.com/crowdsecurity/crowdsec/pkg/metrics"
21+
)
22+
23+
type Configuration struct {
24+
configuration.DataSourceCommonCfg `yaml:",inline"`
25+
AwsProfile *string `yaml:"aws_profile"`
26+
AwsRegion string `yaml:"aws_region"`
27+
AwsEndpoint string `yaml:"aws_endpoint"`
28+
BucketName string `yaml:"bucket_name"`
29+
Prefix string `yaml:"prefix"`
30+
Key string `yaml:"-"` // Only for DSN acquisition
31+
PollingMethod string `yaml:"polling_method"`
32+
PollingInterval int `yaml:"polling_interval"`
33+
SQSName string `yaml:"sqs_name"`
34+
SQSFormat string `yaml:"sqs_format"`
35+
MaxBufferSize int `yaml:"max_buffer_size"`
36+
}
37+
38+
39+
const (
40+
PollMethodList = "list"
41+
PollMethodSQS = "sqs"
42+
)
43+
44+
func (s *Source) newS3Client(ctx context.Context) (*s3.Client, error) {
45+
var loadOpts []func(*config.LoadOptions) error
46+
if s.Config.AwsProfile != nil && *s.Config.AwsProfile != "" {
47+
loadOpts = append(loadOpts, config.WithSharedConfigProfile(*s.Config.AwsProfile))
48+
}
49+
50+
region := s.Config.AwsRegion
51+
if region == "" {
52+
region = "us-east-1"
53+
}
54+
55+
loadOpts = append(loadOpts, config.WithRegion(region))
56+
57+
if c := defaultCreds(); c != nil {
58+
loadOpts = append(loadOpts, config.WithCredentialsProvider(c))
59+
}
60+
61+
cfg, err := config.LoadDefaultConfig(ctx, loadOpts...)
62+
if err != nil {
63+
return nil, fmt.Errorf("failed to load aws config: %w", err)
64+
}
65+
66+
var clientOpts []func(*s3.Options)
67+
if s.Config.AwsEndpoint != "" {
68+
clientOpts = append(clientOpts, func(o *s3.Options) {
69+
o.BaseEndpoint = aws.String(s.Config.AwsEndpoint)
70+
})
71+
}
72+
73+
return s3.NewFromConfig(cfg, clientOpts...), nil
74+
}
75+
76+
func (s *Source) newSQSClient(ctx context.Context) (*sqs.Client, error) {
77+
var loadOpts []func(*config.LoadOptions) error
78+
if s.Config.AwsProfile != nil && *s.Config.AwsProfile != "" {
79+
loadOpts = append(loadOpts, config.WithSharedConfigProfile(*s.Config.AwsProfile))
80+
}
81+
82+
region := s.Config.AwsRegion
83+
if region == "" {
84+
region = "us-east-1"
85+
}
86+
87+
loadOpts = append(loadOpts, config.WithRegion(region))
88+
89+
if c := defaultCreds(); c != nil {
90+
loadOpts = append(loadOpts, config.WithCredentialsProvider(c))
91+
}
92+
93+
cfg, err := config.LoadDefaultConfig(ctx, loadOpts...)
94+
if err != nil {
95+
return nil, fmt.Errorf("failed to load aws config: %w", err)
96+
}
97+
98+
var clientOpts []func(*sqs.Options)
99+
if s.Config.AwsEndpoint != "" {
100+
clientOpts = append(clientOpts, func(o *sqs.Options) { o.BaseEndpoint = aws.String(s.Config.AwsEndpoint) })
101+
}
102+
103+
return sqs.NewFromConfig(cfg, clientOpts...), nil
104+
}
105+
106+
func (s *Source) UnmarshalConfig(yamlConfig []byte) error {
107+
s.Config = Configuration{}
108+
109+
err := yaml.UnmarshalWithOptions(yamlConfig, &s.Config, yaml.Strict())
110+
if err != nil {
111+
return fmt.Errorf("cannot parse S3Acquisition configuration: %s", yaml.FormatError(err, false, false))
112+
}
113+
114+
if s.Config.Mode == "" {
115+
s.Config.Mode = configuration.TAIL_MODE
116+
}
117+
118+
if s.Config.PollingMethod == "" {
119+
s.Config.PollingMethod = PollMethodList
120+
}
121+
122+
if s.Config.PollingInterval == 0 {
123+
s.Config.PollingInterval = 60
124+
}
125+
126+
if s.Config.MaxBufferSize == 0 {
127+
s.Config.MaxBufferSize = bufio.MaxScanTokenSize
128+
}
129+
130+
if s.Config.PollingMethod != PollMethodList && s.Config.PollingMethod != PollMethodSQS {
131+
return fmt.Errorf("invalid polling method %s", s.Config.PollingMethod)
132+
}
133+
134+
if s.Config.BucketName != "" && s.Config.SQSName != "" {
135+
return errors.New("bucket_name and sqs_name are mutually exclusive")
136+
}
137+
138+
if s.Config.PollingMethod == PollMethodSQS && s.Config.SQSName == "" {
139+
return errors.New("sqs_name is required when using sqs polling method")
140+
}
141+
142+
if s.Config.BucketName == "" && s.Config.PollingMethod == PollMethodList {
143+
return errors.New("bucket_name is required")
144+
}
145+
146+
if s.Config.SQSFormat != "" && s.Config.SQSFormat != SQSFormatEventBridge && s.Config.SQSFormat != SQSFormatS3Notification && s.Config.SQSFormat != SQSFormatSNS {
147+
return fmt.Errorf("invalid sqs_format %s, must be empty, %s, %s or %s", s.Config.SQSFormat, SQSFormatEventBridge, SQSFormatS3Notification, SQSFormatSNS)
148+
}
149+
150+
return nil
151+
}
152+
153+
func (s *Source) Configure(ctx context.Context, yamlConfig []byte, logger *log.Entry, _ metrics.AcquisitionMetricsLevel) error {
154+
err := s.UnmarshalConfig(yamlConfig)
155+
if err != nil {
156+
return err
157+
}
158+
159+
if s.Config.SQSName != "" {
160+
s.logger = logger.WithFields(log.Fields{
161+
"queue": s.Config.SQSName,
162+
})
163+
} else {
164+
s.logger = logger.WithFields(log.Fields{
165+
"bucket": s.Config.BucketName,
166+
"prefix": s.Config.Prefix,
167+
})
168+
}
169+
170+
if !s.Config.UseTimeMachine {
171+
s.logger.Warning("use_time_machine is not set to true in the datasource configuration. This will likely lead to false positives as S3 logs are not processed in real time.")
172+
}
173+
174+
if s.Config.PollingMethod == PollMethodList {
175+
s.logger.Warning("Polling method is set to list. This is not recommended as it will not scale well. Consider using SQS instead.")
176+
}
177+
178+
client, err := s.newS3Client(ctx)
179+
if err != nil {
180+
return err
181+
}
182+
183+
s.s3Client = client
184+
185+
if s.Config.PollingMethod == PollMethodSQS {
186+
sqsClient, err := s.newSQSClient(ctx)
187+
if err != nil {
188+
return err
189+
}
190+
191+
s.sqsClient = sqsClient
192+
}
193+
194+
return nil
195+
}
196+
197+
func (s *Source) ConfigureByDSN(ctx context.Context, dsn string, labels map[string]string, logger *log.Entry, uuid string) error {
198+
if !strings.HasPrefix(dsn, "s3://") {
199+
return fmt.Errorf("invalid DSN %s for S3 source, must start with s3://", dsn)
200+
}
201+
202+
s.Config = Configuration{}
203+
s.logger = logger.WithFields(log.Fields{
204+
"bucket": s.Config.BucketName,
205+
"prefix": s.Config.Prefix,
206+
})
207+
208+
dsn = strings.TrimPrefix(dsn, "s3://")
209+
args := strings.Split(dsn, "?")
210+
211+
if args[0] == "" {
212+
return errors.New("empty s3:// DSN")
213+
}
214+
215+
if len(args) == 2 && args[1] != "" {
216+
params, err := url.ParseQuery(args[1])
217+
if err != nil {
218+
return fmt.Errorf("could not parse s3 args: %w", err)
219+
}
220+
221+
for key, value := range params {
222+
switch key {
223+
case "log_level":
224+
if len(value) != 1 {
225+
return errors.New("expected zero or one value for 'log_level'")
226+
}
227+
228+
lvl, err := log.ParseLevel(value[0])
229+
if err != nil {
230+
return fmt.Errorf("unknown level %s: %w", value[0], err)
231+
}
232+
233+
s.logger.Logger.SetLevel(lvl)
234+
case "max_buffer_size":
235+
if len(value) != 1 {
236+
return errors.New("expected zero or one value for 'max_buffer_size'")
237+
}
238+
239+
maxBufferSize, err := strconv.Atoi(value[0])
240+
if err != nil {
241+
return fmt.Errorf("invalid value for 'max_buffer_size': %w", err)
242+
}
243+
244+
s.logger.Debugf("Setting max buffer size to %d", maxBufferSize)
245+
s.Config.MaxBufferSize = maxBufferSize
246+
default:
247+
return fmt.Errorf("unknown parameter %s", key)
248+
}
249+
}
250+
}
251+
252+
s.Config.Labels = labels
253+
s.Config.Mode = configuration.CAT_MODE
254+
s.Config.UniqueId = uuid
255+
256+
pathParts := strings.Split(args[0], "/")
257+
s.logger.Debugf("pathParts: %v", pathParts)
258+
259+
// FIXME: handle s3://bucket/
260+
if len(pathParts) == 1 {
261+
s.Config.BucketName = pathParts[0]
262+
s.Config.Prefix = ""
263+
} else if len(pathParts) > 1 {
264+
s.Config.BucketName = pathParts[0]
265+
if args[0][len(args[0])-1] == '/' {
266+
s.Config.Prefix = strings.Join(pathParts[1:], "/")
267+
} else {
268+
s.Config.Key = strings.Join(pathParts[1:], "/")
269+
}
270+
} else {
271+
return fmt.Errorf("invalid DSN %s for S3 source", dsn)
272+
}
273+
274+
client, err := s.newS3Client(ctx)
275+
if err != nil {
276+
return err
277+
}
278+
279+
s.s3Client = client
280+
281+
return nil
282+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package s3acquisition
2+
3+
import (
4+
"github.com/prometheus/client_golang/prometheus"
5+
6+
"github.com/crowdsecurity/crowdsec/pkg/metrics"
7+
)
8+
9+
func (*Source) GetMetrics() []prometheus.Collector {
10+
return []prometheus.Collector{
11+
metrics.S3DataSourceLinesRead,
12+
metrics.S3DataSourceObjectsRead,
13+
metrics.S3DataSourceSQSMessagesReceived,
14+
}
15+
}
16+
17+
func (*Source) GetAggregMetrics() []prometheus.Collector {
18+
return []prometheus.Collector{
19+
metrics.S3DataSourceLinesRead,
20+
metrics.S3DataSourceObjectsRead,
21+
metrics.S3DataSourceSQSMessagesReceived,
22+
}
23+
}

0 commit comments

Comments
 (0)