Skip to content

Commit a3bef02

Browse files
authored
Merge pull request #74 from BloodHoundAD/BED-4273--azurehound-backoff-retry
BED-4273 azurehound backoff retry
2 parents 6889fa8 + 05619cf commit a3bef02

File tree

4 files changed

+171
-35
lines changed

4 files changed

+171
-35
lines changed

client/rest/client.go

+7-19
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"encoding/json"
2626
"fmt"
2727
"io"
28-
"math"
2928
"net/http"
3029
"net/url"
3130
"strconv"
@@ -219,23 +218,9 @@ func (s *restClient) Send(req *http.Request) (*http.Response, error) {
219218
return s.send(req)
220219
}
221220

222-
func copyBody(req *http.Request) ([]byte, error) {
223-
var (
224-
body []byte
225-
err error
226-
)
227-
if req.Body != nil {
228-
body, err = io.ReadAll(req.Body)
229-
if body != nil {
230-
req.Body = io.NopCloser(bytes.NewBuffer(body))
231-
}
232-
}
233-
return body, err
234-
}
235-
236221
func (s *restClient) send(req *http.Request) (*http.Response, error) {
237222
// copy the bytes in case we need to retry the request
238-
if body, err := copyBody(req); err != nil {
223+
if body, err := CopyBody(req); err != nil {
239224
return nil, err
240225
} else {
241226
var (
@@ -254,7 +239,11 @@ func (s *restClient) send(req *http.Request) (*http.Response, error) {
254239

255240
// Try the request
256241
if res, err = s.http.Do(req); err != nil {
257-
// client error
242+
if IsClosedConnectionErr(err) {
243+
fmt.Printf("remote host force closed connection while requesting %s; attempt %d/%d; trying again\n", req.URL, retry+1, maxRetries)
244+
ExponentialBackoff(retry)
245+
continue
246+
}
258247
return nil, err
259248
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
260249
// Error response code handling
@@ -270,8 +259,7 @@ func (s *restClient) send(req *http.Request) (*http.Response, error) {
270259
}
271260
} else if res.StatusCode >= http.StatusInternalServerError {
272261
// Wait the time calculated by the 5 second exponential backoff
273-
backoff := math.Pow(5, float64(retry+1))
274-
time.Sleep(time.Second * time.Duration(backoff))
262+
ExponentialBackoff(retry)
275263
continue
276264
} else {
277265
// Not a status code that warrants a retry

client/rest/client_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (C) 2024 Specter Ops, Inc.
2+
//
3+
// This file is part of AzureHound.
4+
//
5+
// AzureHound is free software: you can redistribute it and/or modify
6+
// it under the terms of the GNU General Public License as published by
7+
// the Free Software Foundation, either version 3 of the License, or
8+
// (at your option) any later version.
9+
//
10+
// AzureHound is distributed in the hope that it will be useful,
11+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
// GNU General Public License for more details.
14+
//
15+
// You should have received a copy of the GNU General Public License
16+
// along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
18+
package rest
19+
20+
import (
21+
"net/http"
22+
"net/http/httptest"
23+
24+
"testing"
25+
26+
"github.com/bloodhoundad/azurehound/v2/client/config"
27+
)
28+
29+
func TestClosedConnection(t *testing.T) {
30+
var testServer *httptest.Server
31+
attempt := 0
32+
var mockHandler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
33+
attempt++
34+
testServer.CloseClientConnections()
35+
}
36+
37+
testServer = httptest.NewServer(mockHandler)
38+
defer testServer.Close()
39+
40+
defaultConfig := config.Config{
41+
Username: "azurehound",
42+
Password: "we_collect",
43+
Authority: testServer.URL,
44+
}
45+
46+
if client, err := NewRestClient(testServer.URL, defaultConfig); err != nil {
47+
t.Fatalf("error initializing rest client %v", err)
48+
} else {
49+
requestCompleted := false
50+
51+
// make request in separate goroutine so its not blocking after we validated the retry
52+
go func() {
53+
client.Authenticate() // Authenticate() because it uses the internal client.send method.
54+
// CloseClientConnections should block the request from completing, however if it completes then the test fails.
55+
requestCompleted = true
56+
}()
57+
58+
// block until attempt is > 2 or request succeeds
59+
for attempt <= 2 {
60+
if attempt > 1 || requestCompleted {
61+
break
62+
}
63+
}
64+
65+
if requestCompleted {
66+
t.Fatalf("expected an attempted retry but the request completed")
67+
}
68+
}
69+
}

client/rest/utils.go

+30
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
package rest
1919

2020
import (
21+
"bytes"
2122
"crypto/sha1"
2223
"crypto/x509"
2324
"encoding/base64"
2425
"encoding/json"
2526
"encoding/pem"
2627
"fmt"
2728
"io"
29+
"math"
30+
"net/http"
2831
"strings"
2932
"time"
3033

@@ -120,3 +123,30 @@ func x5t(certificate string) (string, error) {
120123
return base64.StdEncoding.EncodeToString(checksum[:]), nil
121124
}
122125
}
126+
127+
func IsClosedConnectionErr(err error) bool {
128+
var closedConnectionMsg = "An existing connection was forcibly closed by the remote host."
129+
closedFromClient := strings.Contains(err.Error(), closedConnectionMsg)
130+
// Mocking http.Do would require a larger refactor, so closedFromTestCase is used to cover testing only.
131+
closedFromTestCase := strings.HasSuffix(err.Error(), ": EOF")
132+
return closedFromClient || closedFromTestCase
133+
}
134+
135+
func ExponentialBackoff(retry int) {
136+
backoff := math.Pow(5, float64(retry+1))
137+
time.Sleep(time.Second * time.Duration(backoff))
138+
}
139+
140+
func CopyBody(req *http.Request) ([]byte, error) {
141+
var (
142+
body []byte
143+
err error
144+
)
145+
if req.Body != nil {
146+
body, err = io.ReadAll(req.Body)
147+
if body != nil {
148+
req.Body = io.NopCloser(bytes.NewBuffer(body))
149+
}
150+
}
151+
return body, err
152+
}

cmd/start.go

+65-16
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"errors"
2626
"fmt"
2727
"io"
28-
"math"
2928
"net/http"
3029
"net/url"
3130
"os"
@@ -215,16 +214,30 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch
215214
} else {
216215
req.Header.Set("User-Agent", constants.UserAgent())
217216
req.Header.Set("Accept", "application/json")
218-
req.Header.Set("Prefer", "wait=60")
219217
req.Header.Set("Content-Encoding", "gzip")
220218
for retry := 0; retry < maxRetries; retry++ {
221219
// No retries on regular err cases, only on HTTP 504 Gateway Timeout and HTTP 503 Service Unavailable
222220
if response, err := bheClient.Do(req); err != nil {
221+
if rest.IsClosedConnectionErr(err) {
222+
// try again on force closed connection
223+
log.Error(err, fmt.Sprintf("remote host force closed connection while requesting %s; attempt %d/%d; trying again", req.URL, retry+1, maxRetries))
224+
rest.ExponentialBackoff(retry)
225+
226+
if retry == maxRetries-1 {
227+
log.Error(ErrExceededRetryLimit, "")
228+
hasErrors = true
229+
}
230+
231+
continue
232+
}
223233
log.Error(err, unrecoverableErrMsg)
224234
return true
225-
} else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable {
226-
backoff := math.Pow(5, float64(retry+1))
227-
time.Sleep(time.Second * time.Duration(backoff))
235+
} else if response.StatusCode == http.StatusGatewayTimeout || response.StatusCode == http.StatusServiceUnavailable || response.StatusCode == http.StatusBadGateway {
236+
serverError := fmt.Errorf("received server error %d while requesting %v; attempt %d/%d; trying again", response.StatusCode, endpoint, retry+1, maxRetries)
237+
log.Error(serverError, "")
238+
239+
rest.ExponentialBackoff(retry)
240+
228241
if retry == maxRetries-1 {
229242
log.Error(ErrExceededRetryLimit, "")
230243
hasErrors = true
@@ -256,19 +269,55 @@ func ingest(ctx context.Context, bheUrl url.URL, bheClient *http.Client, in <-ch
256269

257270
// TODO: create/use a proper bloodhound client
258271
func do(bheClient *http.Client, req *http.Request) (*http.Response, error) {
259-
if res, err := bheClient.Do(req); err != nil {
260-
return nil, fmt.Errorf("failed to request %v: %w", req.URL, err)
261-
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
262-
var body json.RawMessage
263-
defer res.Body.Close()
264-
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
265-
return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status)
266-
} else {
267-
return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body)
268-
}
272+
var (
273+
res *http.Response
274+
maxRetries = 3
275+
)
276+
277+
// copy the bytes in case we need to retry the request
278+
if body, err := rest.CopyBody(req); err != nil {
279+
return nil, err
269280
} else {
270-
return res, nil
281+
for retry := 0; retry < maxRetries; retry++ {
282+
// Reusing http.Request requires rewinding the request body
283+
// back to a working state
284+
if body != nil && retry > 0 {
285+
req.Body = io.NopCloser(bytes.NewBuffer(body))
286+
}
287+
288+
if res, err = bheClient.Do(req); err != nil {
289+
if rest.IsClosedConnectionErr(err) {
290+
// try again on force closed connections
291+
log.Error(err, fmt.Sprintf("remote host force closed connection while requesting %s; attempt %d/%d; trying again", req.URL, retry+1, maxRetries))
292+
rest.ExponentialBackoff(retry)
293+
continue
294+
}
295+
// normal client error, dont attempt again
296+
return nil, err
297+
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
298+
if res.StatusCode >= http.StatusInternalServerError {
299+
// Internal server error, backoff and try again.
300+
serverError := fmt.Errorf("received server error %d while requesting %v", res.StatusCode, req.URL)
301+
log.Error(serverError, fmt.Sprintf("attempt %d/%d; trying again", retry+1, maxRetries))
302+
303+
rest.ExponentialBackoff(retry)
304+
continue
305+
}
306+
// bad request we do not need to retry
307+
var body json.RawMessage
308+
defer res.Body.Close()
309+
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
310+
return nil, fmt.Errorf("received unexpected response code from %v: %s; failure reading response body", req.URL, res.Status)
311+
} else {
312+
return nil, fmt.Errorf("received unexpected response code from %v: %s %s", req.URL, res.Status, body)
313+
}
314+
} else {
315+
return res, nil
316+
}
317+
}
271318
}
319+
320+
return nil, fmt.Errorf("unable to complete request to url=%s; attempts=%d;", req.URL, maxRetries)
272321
}
273322

274323
type basicResponse[T any] struct {

0 commit comments

Comments
 (0)