forked from projectdiscovery/nuclei
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paths3.go
127 lines (113 loc) · 3.95 KB
/
s3.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package customtemplates
import (
"context"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/projectdiscovery/gologger"
nucleiConfig "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
errorutil "github.com/projectdiscovery/utils/errors"
stringsutil "github.com/projectdiscovery/utils/strings"
)
var _ Provider = &customTemplateS3Bucket{}
type customTemplateS3Bucket struct {
s3Client *s3.Client
bucketName string
prefix string
Location string
}
// Download retrieves all custom templates from s3 bucket
func (bk *customTemplateS3Bucket) Download(ctx context.Context) {
downloadPath := filepath.Join(nucleiConfig.DefaultConfig.CustomS3TemplatesDirectory, bk.bucketName)
s3Manager := manager.NewDownloader(bk.s3Client)
paginator := s3.NewListObjectsV2Paginator(bk.s3Client, &s3.ListObjectsV2Input{
Bucket: &bk.bucketName,
Prefix: &bk.prefix,
})
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
gologger.Error().Msgf("error downloading s3 bucket %s %s", bk.bucketName, err)
return
}
for _, obj := range page.Contents {
if err := downloadToFile(s3Manager, downloadPath, bk.bucketName, aws.ToString(obj.Key)); err != nil {
gologger.Error().Msgf("error downloading s3 bucket %s %s", bk.bucketName, err)
return
}
}
}
gologger.Info().Msgf("AWS bucket %s was cloned successfully at %s", bk.bucketName, downloadPath)
}
// Update downloads custom templates from s3 bucket
func (bk *customTemplateS3Bucket) Update(ctx context.Context) {
bk.Download(ctx)
}
// NewS3Providers returns a new instances of a s3 providers for downloading custom templates
func NewS3Providers(options *types.Options) ([]*customTemplateS3Bucket, error) {
providers := []*customTemplateS3Bucket{}
if options.AwsBucketName != "" && !options.AwsTemplateDisableDownload {
s3c, err := getS3Client(context.TODO(), options.AwsAccessKey, options.AwsSecretKey, options.AwsRegion, options.AwsProfile)
if err != nil {
return nil, errorutil.NewWithErr(err).Msgf("error downloading s3 bucket %s", options.AwsBucketName)
}
ctBucket := &customTemplateS3Bucket{
bucketName: options.AwsBucketName,
s3Client: s3c,
}
if strings.Contains(options.AwsBucketName, "/") {
bPath := strings.SplitN(options.AwsBucketName, "/", 2)
ctBucket.bucketName = bPath[0]
ctBucket.prefix = bPath[1]
}
providers = append(providers, ctBucket)
}
return providers, nil
}
func downloadToFile(downloader *manager.Downloader, targetDirectory, bucket, key string) error {
// Create the directories in the path
file := filepath.Join(targetDirectory, key)
// If empty dir in s3
if stringsutil.HasSuffixI(key, "/") {
return os.MkdirAll(file, 0775)
}
if err := os.MkdirAll(filepath.Dir(file), 0775); err != nil {
return err
}
// Set up the local file
fd, err := os.Create(file)
if err != nil {
return err
}
defer fd.Close()
// Download the file using the AWS SDK for Go
_, err = downloader.Download(context.TODO(), fd, &s3.GetObjectInput{Bucket: &bucket, Key: &key})
return err
}
func getS3Client(ctx context.Context, accessKey string, secretKey string, region string, profile string) (*s3.Client, error) {
var cfg aws.Config
var err error
if profile != "" {
cfg, err = config.LoadDefaultConfig(ctx, config.WithSharedConfigProfile(profile))
if err != nil {
return nil, err
}
} else if accessKey != "" && secretKey != "" {
cfg, err = config.LoadDefaultConfig(ctx, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")), config.WithRegion(region))
if err != nil {
return nil, err
}
} else {
cfg, err = config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
}
return s3.NewFromConfig(cfg), nil
}