Skip to content

Commit c0b16b2

Browse files
authored
fix: read full response body and close connections to avoid file desc… (#209)
* fix: read full response body and close connections to avoid file descriptor leakage * close the connection even if there's an error * cr feedback * lintfix * other order * same defered function * nilcheck response body * remove debug * revert go mod
1 parent 39ee88f commit c0b16b2

File tree

1 file changed

+62
-34
lines changed

1 file changed

+62
-34
lines changed

internal/extension/extension.go

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"encoding/base64"
1515
"encoding/json"
1616
"fmt"
17+
"io"
1718
"net/http"
1819
"os"
1920
"reflect"
@@ -97,7 +98,14 @@ func (em *ExtensionManager) checkAgentRunning() {
9798
// Tell the extension not to create an execution span if universal instrumentation is disabled
9899
if !em.isUniversalInstrumentation {
99100
req, _ := http.NewRequest(http.MethodGet, em.helloRoute, nil)
100-
if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 {
101+
response, err := em.httpClient.Do(req)
102+
if response != nil && response.Body != nil {
103+
defer func() {
104+
_, _ = io.Copy(io.Discard, response.Body)
105+
response.Body.Close()
106+
}()
107+
}
108+
if err == nil && response.StatusCode == 200 {
101109
logger.Debug("Hit the extension /hello route")
102110
} else {
103111
logger.Debug("Will use the API since the Serverless Agent was detected but the hello route was unreachable")
@@ -110,7 +118,14 @@ func (em *ExtensionManager) checkAgentRunning() {
110118
func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, eventPayload json.RawMessage) context.Context {
111119
body := bytes.NewBuffer(eventPayload)
112120
req, _ := http.NewRequest(http.MethodPost, em.startInvocationUrl, body)
113-
if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 {
121+
response, err := em.httpClient.Do(req)
122+
if response != nil && response.Body != nil {
123+
defer func() {
124+
_, _ = io.Copy(io.Discard, response.Body)
125+
response.Body.Close()
126+
}()
127+
}
128+
if err == nil && response.StatusCode == 200 {
114129
// Propagate dd-trace context from the extension response if found in the response headers
115130
traceId := response.Header.Get(string(DdTraceId))
116131
if traceId != "" {
@@ -179,7 +194,7 @@ func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functi
179194
logger.Error(fmt.Errorf("could not get sampling priority from spanContext.SamplingPriority()"))
180195
}
181196
} else {
182-
if priority, ok := getSamplingPriority(functionExecutionSpan) ; ok {
197+
if priority, ok := getSamplingPriority(functionExecutionSpan); ok {
183198
req.Header.Set(string(DdSamplingPriority), fmt.Sprint(priority))
184199
} else {
185200
logger.Error(fmt.Errorf("could not get sampling priority from getSamplingPriority()"))
@@ -188,6 +203,12 @@ func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functi
188203
}
189204

190205
resp, err := em.httpClient.Do(req)
206+
if resp != nil && resp.Body != nil {
207+
defer func() {
208+
_, _ = io.Copy(io.Discard, resp.Body)
209+
resp.Body.Close()
210+
}()
211+
}
191212
if err != nil || resp.StatusCode != 200 {
192213
logger.Error(fmt.Errorf("could not send end invocation payload to the extension: %v", err))
193214
}
@@ -236,7 +257,14 @@ func (em *ExtensionManager) IsExtensionRunning() bool {
236257

237258
func (em *ExtensionManager) Flush() error {
238259
req, _ := http.NewRequest(http.MethodGet, em.flushRoute, nil)
239-
if response, err := em.httpClient.Do(req); err != nil {
260+
response, err := em.httpClient.Do(req)
261+
if response != nil && response.Body != nil {
262+
defer func() {
263+
_, _ = io.Copy(io.Discard, response.Body)
264+
response.Body.Close()
265+
}()
266+
}
267+
if err != nil {
240268
err := fmt.Errorf("was not able to reach the Agent to flush: %s", err)
241269
logger.Error(err)
242270
return err
@@ -252,33 +280,33 @@ func (em *ExtensionManager) Flush() error {
252280
// But for dd-trace-go v1.74.x, reflection is needed to access the SamplingPriority method because
253281
// the method hidden in the v2 SpanContextV2Adapter struct.
254282
func getSamplingPriority(span ddtrace.Span) (int, bool) {
255-
// Get the span context
256-
ctx := span.Context()
257-
258-
// Use reflection to access the underlying v2 SpanContext
259-
ctxValue := reflect.ValueOf(ctx)
260-
if ctxValue.Type().String() != "internal.SpanContextV2Adapter" {
261-
return 0, false
262-
}
263-
264-
// Get the Ctx field (the underlying v2.SpanContext)
265-
ctxField := ctxValue.FieldByName("Ctx")
266-
if !ctxField.IsValid() {
267-
return 0, false
268-
}
269-
270-
// Call SamplingPriority() on the underlying v2 SpanContext
271-
method := ctxField.MethodByName("SamplingPriority")
272-
if !method.IsValid() {
273-
return 0, false
274-
}
275-
276-
results := method.Call([]reflect.Value{})
277-
if len(results) != 2 {
278-
return 0, false
279-
}
280-
281-
priority := int(results[0].Int())
282-
ok := results[1].Bool()
283-
return priority, ok
284-
}
283+
// Get the span context
284+
ctx := span.Context()
285+
286+
// Use reflection to access the underlying v2 SpanContext
287+
ctxValue := reflect.ValueOf(ctx)
288+
if ctxValue.Type().String() != "internal.SpanContextV2Adapter" {
289+
return 0, false
290+
}
291+
292+
// Get the Ctx field (the underlying v2.SpanContext)
293+
ctxField := ctxValue.FieldByName("Ctx")
294+
if !ctxField.IsValid() {
295+
return 0, false
296+
}
297+
298+
// Call SamplingPriority() on the underlying v2 SpanContext
299+
method := ctxField.MethodByName("SamplingPriority")
300+
if !method.IsValid() {
301+
return 0, false
302+
}
303+
304+
results := method.Call([]reflect.Value{})
305+
if len(results) != 2 {
306+
return 0, false
307+
}
308+
309+
priority := int(results[0].Int())
310+
ok := results[1].Bool()
311+
return priority, ok
312+
}

0 commit comments

Comments
 (0)