Skip to content

Commit 408f099

Browse files
authored
Improve query cancellation handling in tracebyidsharding (#3326)
* Improve query cancelation handling in tracebyidsharding * Move subCancel to only on error condition * Collapse defer * Drop status handling
1 parent 86ca897 commit 408f099

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

modules/frontend/tracebyidsharding.go

+24-25
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ func (s shardQuery) RoundTrip(r *http.Request) (*http.Response, error) {
6565

6666
// context propagation
6767
r = r.WithContext(ctx)
68-
reqs, err := s.buildShardedRequests(r)
68+
subCtx, subCancel := context.WithCancel(ctx)
69+
defer subCancel()
70+
71+
reqs, err := s.buildShardedRequests(subCtx, r)
6972
if err != nil {
7073
return nil, err
7174
}
@@ -75,14 +78,18 @@ func (s shardQuery) RoundTrip(r *http.Request) (*http.Response, error) {
7578
if s.cfg.ConcurrentShards > 0 {
7679
concurrentShards = uint(s.cfg.ConcurrentShards)
7780
}
78-
wg := boundedwaitgroup.New(concurrentShards)
79-
mtx := sync.Mutex{}
8081

81-
var overallError error
82+
var (
83+
overallError error
84+
85+
mtx = sync.Mutex{}
86+
statusCode = http.StatusNotFound
87+
statusMsg = "trace not found"
88+
wg = boundedwaitgroup.New(concurrentShards)
89+
)
90+
8291
combiner := trace.NewCombiner(s.o.MaxBytesPerTrace(userID))
8392
_, _ = combiner.Consume(&tempopb.Trace{}) // The query path returns a non-nil result even if no inputs (which is different than other paths which return nil for no inputs)
84-
statusCode := http.StatusNotFound
85-
statusMsg := "trace not found"
8693

8794
for _, req := range reqs {
8895
wg.Add(1)
@@ -97,20 +104,16 @@ func (s shardQuery) RoundTrip(r *http.Request) (*http.Response, error) {
97104
overallError = rtErr
98105
}
99106

100-
if shouldQuit(r.Context(), statusCode, overallError) {
101-
return
102-
}
103-
104-
// check http error
105-
if rtErr != nil {
106-
_ = level.Error(s.logger).Log("msg", "error querying proxy target", "url", innerR.RequestURI, "err", rtErr)
107-
overallError = rtErr
107+
// Check the context of the worker request
108+
if shouldQuit(innerR.Context(), statusCode, overallError) {
108109
return
109110
}
110111

111-
// if the status code is anything but happy, save the error and pass it down the line
112+
// if the status code is anything but happy, save the error and pass it
113+
// down the line
112114
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound {
113-
// todo: if we cancel the parent context here will it shortcircuit the other queries and fail fast?
115+
defer subCancel()
116+
114117
statusCode = resp.StatusCode
115118
bytesMsg, readErr := io.ReadAll(resp.Body)
116119
if readErr != nil {
@@ -129,7 +132,7 @@ func (s shardQuery) RoundTrip(r *http.Request) (*http.Response, error) {
129132
}
130133

131134
// marshal into a trace to combine.
132-
// todo: better define responsibilities between middleware. the parent middleware in frontend.go actually sets the header
135+
// TODO: better define responsibilities between middleware. the parent middleware in frontend.go actually sets the header
133136
// which forces the body here to be a proto encoded tempopb.Trace{}
134137
traceResp := &tempopb.TraceByIDResponse{}
135138
rtErr = proto.Unmarshal(buff, traceResp)
@@ -202,9 +205,8 @@ func (s shardQuery) RoundTrip(r *http.Request) (*http.Response, error) {
202205

203206
// buildShardedRequests returns a slice of requests sharded on the precalculated
204207
// block boundaries
205-
func (s *shardQuery) buildShardedRequests(parent *http.Request) ([]*http.Request, error) {
206-
ctx := parent.Context()
207-
userID, err := user.ExtractOrgID(ctx)
208+
func (s *shardQuery) buildShardedRequests(ctx context.Context, parent *http.Request) ([]*http.Request, error) {
209+
userID, err := user.ExtractOrgID(parent.Context())
208210
if err != nil {
209211
return nil, err
210212
}
@@ -237,6 +239,7 @@ func shouldQuit(ctx context.Context, statusCode int, err error) bool {
237239
if err != nil {
238240
return true
239241
}
242+
240243
if ctx.Err() != nil {
241244
return true
242245
}
@@ -245,9 +248,5 @@ func shouldQuit(ctx context.Context, statusCode int, err error) bool {
245248
return true
246249
}
247250

248-
if statusCode/100 == 5 { // bail on any 5xx's
249-
return true
250-
}
251-
252-
return false
251+
return statusCode/100 == 5 // bail on any 5xx's
253252
}

modules/frontend/tracebyidsharding_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func TestBuildShardedRequests(t *testing.T) {
3939
ctx := user.InjectOrgID(context.Background(), "blerg")
4040
req := httptest.NewRequest("GET", "/", nil).WithContext(ctx)
4141

42-
shardedReqs, err := sharder.buildShardedRequests(req)
42+
shardedReqs, err := sharder.buildShardedRequests(ctx, req)
4343
require.NoError(t, err)
4444
require.Len(t, shardedReqs, queryShards)
4545

0 commit comments

Comments
 (0)