Skip to content

Commit 865bc5b

Browse files
authored
[THOG-248] fix broken gitlab tests (#437)
* Fix broken gitlab test. * Close chunks chan from blocking indefinitely. * Range over chunksCh in the event chunksCh is nil we don't run into an invalid memory address error. Update warnings and move clone output information back. * Remove commented out code. * Remove .Run() because .CombinedOutput() should call .Run() * Update test to include count check. * Address PR comments. * Fix merge issue.
1 parent 3e0e1da commit 865bc5b

File tree

3 files changed

+72
-101
lines changed

3 files changed

+72
-101
lines changed

pkg/sources/git/git.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
100100

101101
s.conn = &conn
102102

103-
s.git = NewGit(s.Type(), s.jobId, s.sourceId, s.name, s.verify, runtime.NumCPU(),
103+
if concurrency == 0 {
104+
concurrency = runtime.NumCPU()
105+
}
106+
107+
s.git = NewGit(s.Type(), s.jobId, s.sourceId, s.name, s.verify, concurrency,
104108
func(file, email, commit, repository, timestamp string, line int64) *source_metadatapb.MetaData {
105109
return &source_metadatapb.MetaData{
106110
Data: &source_metadatapb.MetaData_Git{
@@ -216,11 +220,11 @@ func CloneRepo(userInfo *url.Userinfo, gitUrl string) (clonePath string, repo *g
216220
cloneURL.User = userInfo
217221
cloneCmd := exec.Command("git", "clone", cloneURL.String(), clonePath)
218222

219-
//cloneCmd := exec.Command("date")
220223
output, err := cloneCmd.CombinedOutput()
221224
if err != nil {
222225
err = errors.WrapPrefix(err, "error running 'git clone'", 0)
223226
}
227+
224228
if cloneCmd.ProcessState == nil {
225229
return "", nil, errors.New("clone command exited with no output")
226230
}
@@ -399,7 +403,7 @@ func (s *Git) ScanUnstaged(repo *git.Repository, scanOptions *ScanOptions, chunk
399403
return nil
400404
}
401405

402-
func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error {
406+
func (s *Git) ScanRepo(_ context.Context, repo *git.Repository, repoPath string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error {
403407
start := time.Now().UnixNano()
404408
if err := s.ScanCommits(repo, repoPath, scanOptions, chunksChan); err != nil {
405409
return err

pkg/sources/gitlab/gitlab.go

+53-93
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sync"
1111

1212
"github.com/go-errors/errors"
13+
gogit "github.com/go-git/go-git/v5"
1314
log "github.com/sirupsen/logrus"
1415
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
1516
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
@@ -41,7 +42,7 @@ type Source struct {
4142
jobSem *semaphore.Weighted
4243
}
4344

44-
// Ensure the Source satisfies the interface at compile time
45+
// Ensure the Source satisfies the interface at compile time.
4546
var _ sources.Source = (*Source)(nil)
4647

4748
// Type returns the type of source.
@@ -98,7 +99,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
9899
}
99100

100101
if len(s.url) == 0 {
101-
//assuming not custom gitlab url
102+
// Assuming not custom gitlab url.
102103
s.url = "https://gitlab.com/"
103104
}
104105

@@ -123,7 +124,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
123124
}
124125

125126
func (s *Source) newClient() (*gitlab.Client, error) {
126-
// initialize a new api instance
127+
// Initialize a new api instance.
127128
switch s.authMethod {
128129
case "OAUTH":
129130
apiClient, err := gitlab.NewOAuthClient(s.token, gitlab.WithBaseURL(s.url))
@@ -158,7 +159,7 @@ func (s *Source) newClient() (*gitlab.Client, error) {
158159
}
159160

160161
func (s *Source) getAllProjects(apiClient *gitlab.Client) ([]*gitlab.Project, error) {
161-
// projects without repo will get user projects, groups projects, and subgroup projects.
162+
// Projects without repo will get user projects, groups projects, and subgroup projects.
162163
user, _, err := apiClient.Users.CurrentUser()
163164

164165
if err != nil {
@@ -187,7 +188,7 @@ func (s *Source) getAllProjects(apiClient *gitlab.Client) ([]*gitlab.Project, er
187188
var groups []*gitlab.Group
188189

189190
listGroupsOptions := gitlab.ListGroupsOptions{
190-
AllAvailable: gitlab.Bool(false), // This actually grabs public groups on public GitLab if set to true
191+
AllAvailable: gitlab.Bool(false), // This actually grabs public groups on public GitLab if set to true.
191192
TopLevelOnly: gitlab.Bool(false),
192193
Owned: gitlab.Bool(false),
193194
}
@@ -262,125 +263,84 @@ func (s *Source) getRepos() ([]*url.URL, []error) {
262263

263264
func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk, repos []*url.URL) []error {
264265
wg := sync.WaitGroup{}
265-
errs := []error{}
266+
var errs []error
266267
var errsMut sync.Mutex
267-
if s.authMethod == "UNAUTHENTICATED" {
268-
for i, u := range repos {
269-
if common.IsDone(ctx) {
270-
// We are returning nil instead of the scanErrors slice here because
271-
// we don't want to mark this scan as errored if we cancelled it.
272-
return nil
268+
269+
for i, u := range repos {
270+
if common.IsDone(ctx) {
271+
// We are returning nil instead of the scanErrors slice here because
272+
// we don't want to mark this scan as errored if we cancelled it.
273+
return nil
274+
}
275+
if err := s.jobSem.Acquire(ctx, 1); err != nil {
276+
log.WithError(err).Debug("could not acquire semaphore")
277+
continue
278+
}
279+
wg.Add(1)
280+
go func(ctx context.Context, repoURL *url.URL, i int) {
281+
defer s.jobSem.Release(1)
282+
defer wg.Done()
283+
if len(repoURL.String()) == 0 {
284+
return
273285
}
274-
if err := s.jobSem.Acquire(ctx, 1); err != nil {
275-
log.WithError(err).Debug("could not acquire semaphore")
276-
continue
286+
s.SetProgressComplete(i, len(repos), fmt.Sprintf("Repo: %s", repoURL), "")
287+
288+
var path string
289+
var repo *gogit.Repository
290+
var err error
291+
if s.authMethod == "UNAUTHENTICATED" {
292+
path, repo, err = git.CloneRepoUsingUnauthenticated(repoURL.String())
293+
} else {
294+
path, repo, err = git.CloneRepoUsingToken(s.token, repoURL.String(), s.user)
277295
}
278-
wg.Add(1)
279-
go func(ctx context.Context, repoURL *url.URL, i int) {
280-
defer s.jobSem.Release(1)
281-
defer wg.Done()
282-
if len(repoURL.String()) == 0 {
283-
return
284-
}
285-
s.SetProgressComplete(i, len(repos), fmt.Sprintf("Repo: %s", repoURL), "")
286-
287-
path, repo, err := git.CloneRepoUsingUnauthenticated(repoURL.String())
288-
defer os.RemoveAll(path)
289-
if err != nil {
290-
errsMut.Lock()
291-
errs = append(errs, err)
292-
errsMut.Unlock()
293-
return
294-
}
295-
log.Debugf("Starting to scan repo %d/%d: %s", i+1, len(repos), repoURL.String())
296-
err = s.git.ScanRepo(ctx, repo, path, git.NewScanOptions(), chunksChan)
297-
if err != nil {
298-
errsMut.Lock()
299-
errs = append(errs, err)
300-
errsMut.Unlock()
301-
return
302-
}
303-
log.Debugf("Completed scanning repo %d/%d: %s", i+1, len(repos), repoURL.String())
304-
}(ctx, u, i)
305-
}
306-
307-
} else {
308-
for i, u := range repos {
309-
if common.IsDone(ctx) {
310-
// We are returning nil instead of the scanErrors slice here because
311-
// we don't want to mark this scan as errored if we cancelled it.
312-
return nil
296+
defer os.RemoveAll(path)
297+
if err != nil {
298+
errsMut.Lock()
299+
errs = append(errs, err)
300+
errsMut.Unlock()
301+
return
313302
}
314-
if err := s.jobSem.Acquire(ctx, 1); err != nil {
315-
log.WithError(err).Debug("could not acquire semaphore")
316-
continue
303+
log.Debugf("Starting to scan repo %d/%d: %s", i+1, len(repos), repoURL.String())
304+
err = s.git.ScanRepo(ctx, repo, path, git.NewScanOptions(), chunksChan)
305+
if err != nil {
306+
errsMut.Lock()
307+
errs = append(errs, err)
308+
errsMut.Unlock()
309+
return
317310
}
318-
wg.Add(1)
319-
go func(ctx context.Context, repoURL *url.URL, i int) {
320-
defer s.jobSem.Release(1)
321-
defer wg.Done()
322-
if len(repoURL.String()) == 0 {
323-
return
324-
}
325-
s.SetProgressComplete(i, len(repos), fmt.Sprintf("Repo: %s", repoURL), "")
326-
327-
// If a username is not provided we need to use a default one in order to clone a private repo.
328-
// Not setting "placeholder" as s.user on purpose in case any downstream services rely on a "" value for s.user.
329-
user := s.user
330-
if user == "" {
331-
user = "placeholder"
332-
}
333-
path, repo, err := git.CloneRepoUsingToken(s.token, repoURL.String(), user)
334-
defer os.RemoveAll(path)
335-
if err != nil {
336-
errsMut.Lock()
337-
errs = append(errs, err)
338-
errsMut.Unlock()
339-
return
340-
}
341-
log.Debugf("Starting to scan repo %d/%d: %s", i+1, len(repos), repoURL.String())
342-
err = s.git.ScanRepo(ctx, repo, path, git.NewScanOptions(), chunksChan)
343-
if err != nil {
344-
errsMut.Lock()
345-
errs = append(errs, err)
346-
errsMut.Unlock()
347-
return
348-
}
349-
log.Debugf("Completed scanning repo %d/%d: %s", i+1, len(repos), repoURL.String())
350-
}(ctx, u, i)
351-
}
311+
log.Debugf("Completed scanning repo %d/%d: %s", i+1, len(repos), repoURL.String())
312+
}(ctx, u, i)
352313
}
353-
354314
wg.Wait()
355315

356316
return errs
357317
}
358318

359319
// Chunks emits chunks of bytes over a channel.
360320
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
361-
// start client
321+
// Start client.
362322
apiClient, err := s.newClient()
363323
if err != nil {
364324
return errors.New(err)
365325
}
366-
// get repo within target
326+
// Get repo within target.
367327
repos, errs := s.getRepos()
368328
for _, repoErr := range errs {
369329
log.WithError(repoErr).Warn("error getting repo")
370330
}
371331

372-
// End early if we had errors getting specified repos but none were validated
332+
// End early if we had errors getting specified repos but none were validated.
373333
if len(errs) > 0 && len(repos) == 0 {
374334
return errors.New("All specified repos had validation issues, ending scan")
375335
}
376336

377-
// get all repos if not specified
337+
// Get all repos if not specified.
378338
if repos == nil {
379339
projects, err := s.getAllProjects(apiClient)
380340
if err != nil {
381341
return errors.New(err)
382342
}
383-
// turn projects into URLs for Git cloner
343+
// Turn projects into URLs for Git cloner.
384344
for _, prj := range projects {
385345
u, err := url.Parse(prj.HTTPURLToRepo)
386346
if err != nil {

pkg/sources/gitlab/gitlab_test.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,25 @@ func TestSource_Scan(t *testing.T) {
9292
}
9393
chunksCh := make(chan *sources.Chunk, 1)
9494
go func() {
95+
defer close(chunksCh)
9596
err = s.Chunks(ctx, chunksCh)
9697
if (err != nil) != tt.wantErr {
9798
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
9899
return
99100
}
100101
}()
101-
gotChunk := <-chunksCh
102+
var chunkCnt int
102103
// Commits don't come in a deterministic order, so remove metadata comparison
103-
gotChunk.Data = nil
104-
gotChunk.SourceMetadata = nil
105-
if diff := pretty.Compare(gotChunk, tt.wantChunk); diff != "" {
106-
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
104+
for gotChunk := range chunksCh {
105+
chunkCnt++
106+
gotChunk.Data = nil
107+
gotChunk.SourceMetadata = nil
108+
if diff := pretty.Compare(gotChunk, tt.wantChunk); diff != "" {
109+
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
110+
}
111+
}
112+
if chunkCnt < 1 {
113+
t.Errorf("0 chunks scanned.")
107114
}
108115
})
109116
}

0 commit comments

Comments
 (0)