Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #6 from DCSO/limit
Browse files Browse the repository at this point in the history
implement global limit for TIE fetches
  • Loading branch information
satta authored Sep 22, 2023
2 parents 1c8274d + 91fd0eb commit 2a7668f
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 15 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ collectors:
from: 1
to: 5
chunk-size: 100
# Maximum limit for returned IoCs, which will be returned sorted by
# data types, in the order specified in the "data-types" config field above
# Set to 0 to disable limiting.
limit:
total: 1000

# Threat Bus ZeroMQ connection settings
# -------------------------------------
Expand Down
61 changes: 50 additions & 11 deletions collector.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// tie-threatbus-bridge
// Copyright (c) 2020, DCSO GmbH
// Copyright (c) 2020, 2023, DCSO GmbH

package main

Expand Down Expand Up @@ -29,6 +29,9 @@ type TIECollectorConfig struct {
From int `yaml:"from"`
To int `yaml:"to"`
} `yaml:"severity"`
Limit struct {
Total uint64 `yaml:"total"`
} `yaml:"limit"`
}

type TIECollector struct {
Expand Down Expand Up @@ -95,7 +98,7 @@ func queryAllTIE(u *url.URL) url.Values {
return q
}

func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64, error) {
func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64, uint64, error) {
offset := 0
limit := Config.Collectors.TIE.ChunkSize
retryCount := 0
Expand All @@ -110,7 +113,7 @@ func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return iocCount, err
return iocCount, 0, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", Config.Collectors.TIE.Token))
q := query(req.URL)
Expand All @@ -119,12 +122,13 @@ func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64
q.Add("order_by", "seq")
req.URL.RawQuery = q.Encode()

buf := make(map[string][]IOC)
for {
log.Debugf("TIE: requesting %v", req.URL)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return iocCount, err
return iocCount, 0, err
}
defer resp.Body.Close()
var qryRes IOCQueryStruct
Expand All @@ -137,7 +141,6 @@ func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64
if err != nil {
log.Errorf("error decoding JSON: %s", err.Error())
}

for _, val := range qryRes.Iocs {
keep := true
for _, c := range val.Categories {
Expand All @@ -147,11 +150,13 @@ func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64
}
}
if keep {
if _, ok := buf[val.DataType]; !ok {
buf[val.DataType] = make([]IOC, 0)
}
iocCount++
outChan <- val
buf[val.DataType] = append(buf[val.DataType], val)
}
}

if !qryRes.HasMore {
log.Debug("no more data")
break
Expand Down Expand Up @@ -183,9 +188,43 @@ func (m *TIECollector) getIOCForQuery(query queryFunc, outChan chan IOC) (uint64
break
}
}
}

var i uint64
// process types defined in configuration first, in given order
for _, t := range Config.Collectors.TIE.DataTypes {
if vals, ok := buf[t]; ok {
for _, val := range vals {
if Config.Collectors.TIE.Limit.Total == 0 || i < Config.Collectors.TIE.Limit.Total {
outChan <- val
i++
} else {
break
}
}
}
}
// process all others
for k, v := range buf {
alreadyHandled := false
for _, t := range Config.Collectors.TIE.DataTypes {
if t == k {
alreadyHandled = true
break
}
}
if !alreadyHandled {
for _, val := range v {
if Config.Collectors.TIE.Limit.Total == 0 || i < Config.Collectors.TIE.Limit.Total {
outChan <- val
i++
} else {
break
}
}
}
}
return iocCount, nil
return iocCount, i, nil
}

func (m *TIECollector) Name() string {
Expand All @@ -196,8 +235,8 @@ func (m *TIECollector) Configure() error {
return nil
}

func (m *TIECollector) Fetch(out chan IOC) (uint64, error) {
func (m *TIECollector) Fetch(out chan IOC) (uint64, uint64, error) {
log.Debug(Config)
cnt, err := m.getIOCForQuery(queryAllTIE, out)
return cnt, err
cnt, sent, err := m.getIOCForQuery(queryAllTIE, out)
return cnt, sent, err
}
149 changes: 149 additions & 0 deletions collector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// tie-threatbus-bridge
// Copyright (c) 2023, DCSO GmbH

package main

import (
"context"
"fmt"
"math/rand"
"net/http"
"sync"
"testing"

"github.com/jarcoal/httpmock"
)

func TestTIECollector_getIOCForQuery(t *testing.T) {
type args struct {
query queryFunc
outChan chan IOC
}
tests := []struct {
name string
m *TIECollector
args args
want uint64
want1 uint64
wantErr bool
}{
{
name: "test",
m: &TIECollector{},
args: struct {
query queryFunc
outChan chan IOC
}{
query: queryAllTIE,
outChan: make(chan IOC),
},
want: 1000,
want1: 900,
},
}

httpmock.Activate()
defer httpmock.DeactivateAndReset()

var domainCount, urlCount uint64
httpmock.RegisterResponder("GET", "http://testtie",
func(req *http.Request) (*http.Response, error) {
iocs := make([]IOC, 1000)
for i := 0; i < 1000; i++ {
if rand.Intn(100)%2 == 0 {
iocs[i] = IOC{
DataType: "URLVerbatim",
Value: fmt.Sprintf("http://url%d.com", i),
}
urlCount++
} else {
iocs[i] = IOC{
DataType: "DomainName",
Value: fmt.Sprintf("domain%d.net", i),
}
domainCount++
}
}
t.Logf("domains %v urls %v", domainCount, urlCount)
res := IOCQueryStruct{
HasMore: false,
Iocs: iocs,
}
resp, err := httpmock.NewJsonResponse(200, res)
if err != nil {
return httpmock.NewStringResponse(500, ""), nil
}
return resp, nil
})

Config = GlobalConfig{
Collectors: struct {
TIE TIECollectorConfig `yaml:"tie"`
}{
TIE: TIECollectorConfig{
URL: "http://testtie",
Enable: true,
DataTypes: []string{
"DomainName",
"URLVerbatim",
},
Limit: struct {
Total uint64 "yaml:\"total\""
}{
Total: 900,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &TIECollector{}

var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.TODO())
coll := make([]IOC, 0)
go func(ctx context.Context, ch chan IOC, wg *sync.WaitGroup) {
for {
select {
case <-ctx.Done():
return
case v := <-ch:
coll = append(coll, v)
wg.Done()
}
}
}(ctx, tt.args.outChan, &wg)

wg.Add(int(tt.want1))
got, got1, err := m.getIOCForQuery(tt.args.query, tt.args.outChan)
wg.Wait()
cancel()

if (err != nil) != tt.wantErr {
t.Errorf("TIECollector.getIOCForQuery() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("TIECollector.getIOCForQuery() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("TIECollector.getIOCForQuery() got1 = %v, want %v", got1, tt.want1)
}
var gotDomainCount, gotUrlCount uint64
for _, v := range coll {
switch v.DataType {
case "DomainName":
gotDomainCount++
case "URLVerbatim":
gotUrlCount++
}
}
if gotDomainCount != domainCount {
t.Errorf("TIECollector.getIOCForQuery() gotDomainCount = %v, domainCount %v", gotDomainCount, domainCount)
}
if gotUrlCount != urlCount-(tt.want-tt.want1) {
t.Errorf("TIECollector.getIOCForQuery() gotUrlCount = %v, urlCount %v", gotUrlCount, urlCount)
}
})
}
}
5 changes: 4 additions & 1 deletion collectorconfig.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// tie-threatbus-bridge
// Copyright (c) 2020, 2023, DCSO GmbH

package main

type GlobalConfig struct {
Expand All @@ -11,7 +14,7 @@ type GlobalConfig struct {
var Config GlobalConfig

type Collector interface {
Fetch(chan IOC) (uint64, error)
Fetch(chan IOC) (uint64, uint64, error)
Configure() error
Name() string
}
Expand Down
7 changes: 6 additions & 1 deletion config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ collectors:
from: 1
to: 5
chunk-size: 100
# Maximum limit for returned IoCs, which will be returned sorted by
# data types, in the order specified in the "data-types" config field above
# Set to 0 to disable limiting.
limit:
total: 1000

# Threat Bus ZeroMQ connection settings
# -------------------------------------
Expand All @@ -37,4 +42,4 @@ threatbus:
# legacy or stix2
format: legacy

logfile: /var/log/tie-threatbus-bridge.log
logfile: /var/log/tie-threatbus-bridge.log
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.14
require (
github.com/TcM1911/stix2 v0.6.1-0.20201122154655-049b8a26ae97
github.com/google/uuid v1.2.0 // indirect
github.com/jarcoal/httpmock v1.3.1 // indirect
github.com/pebbe/zmq4 v1.2.5
github.com/sirupsen/logrus v1.8.1
github.com/tent/http-link-go v0.0.0-20130702225549-ac974c61c2f9
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs=
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww=
github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pebbe/zmq4 v1.2.5 h1:ygTu6F/sMp7TIo7JN/ObpotHudy7+Rnun1LLSybyCFs=
github.com/pebbe/zmq4 v1.2.5/go.mod h1:3+LG+02U+ToKtxF9avLo17NGTVDhWtRhsdU3spikK8o=
Expand Down
5 changes: 3 additions & 2 deletions tie-threatbus-bridge.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// tie-threatbus-bridge
// Copyright (c) 2020, 2021 DCSO GmbH
// Copyright (c) 2020, 2023 DCSO GmbH

package main

Expand Down Expand Up @@ -30,13 +30,14 @@ func update(iocChan chan IOC) {
log.WithFields(log.Fields{
"domain": "status",
}).Info("update started")
count, err := tc.Fetch(iocChan)
count, sent, err := tc.Fetch(iocChan)
if err != nil {
log.Error(err)
}
log.WithFields(log.Fields{
"domain": "metrics",
"iocs-processed": count,
"iocs-sent": sent,
}).Info("update done")
}

Expand Down

0 comments on commit 2a7668f

Please sign in to comment.