Skip to content

Commit f298cfd

Browse files
authored
Merge pull request #509 from smallstep/herman/attestation-client-request-id
Add `X-Request-Id` and `User-Agent` headers to attestation requests
2 parents 2e4dcbb + 6bebabe commit f298cfd

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

tpm/attestation/client.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/x509"
88
"encoding/json"
99
"fmt"
10+
"io"
1011
"net/http"
1112
"net/url"
1213
"os"
@@ -214,7 +215,7 @@ func (ac *Client) attest(ctx context.Context, info *tpm.Info, ek *tpm.EK, attest
214215
}
215216

216217
attestURL := ac.baseURL.JoinPath("attest").String()
217-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, attestURL, bytes.NewReader(body))
218+
req, err := newRequest(ctx, http.MethodPost, attestURL, bytes.NewReader(body))
218219
if err != nil {
219220
return nil, fmt.Errorf("failed creating POST http request for %q: %w", attestURL, err)
220221
}
@@ -258,7 +259,7 @@ func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, e
258259
}
259260

260261
secretURL := ac.baseURL.JoinPath("secret").String()
261-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, secretURL, bytes.NewReader(body))
262+
req, err := newRequest(ctx, http.MethodPost, secretURL, bytes.NewReader(body))
262263
if err != nil {
263264
return nil, fmt.Errorf("failed creating POST http request for %q: %w", secretURL, err)
264265
}
@@ -280,3 +281,14 @@ func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, e
280281

281282
return &secretResp, nil
282283
}
284+
285+
func newRequest(ctx context.Context, method, requestURL string, body io.Reader) (*http.Request, error) {
286+
req, err := http.NewRequestWithContext(ctx, method, requestURL, body)
287+
if err != nil {
288+
return nil, err
289+
}
290+
enforceRequestID(req)
291+
setUserAgent(req)
292+
293+
return req, nil
294+
}

tpm/attestation/client_simulator_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ func mustParseURL(t *testing.T, urlString string) *url.URL {
7878

7979
func TestClient_Attest(t *testing.T) {
8080
ctx := context.Background()
81+
ctx = NewRequestIDContext(ctx, "requestID")
8182
instance := newSimulatedTPM(t)
8283
ak, err := instance.CreateAK(ctx, "ak1")
8384
require.NoError(t, err)
@@ -140,6 +141,9 @@ func TestClient_Attest(t *testing.T) {
140141
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141142
switch r.URL.Path {
142143
case "/attest":
144+
assert.Equal(t, "step-attestation-http-client/1.0", r.Header.Get("User-Agent"))
145+
assert.Equal(t, "requestID", r.Header.Get("X-Request-Id"))
146+
143147
var ar attestationRequest
144148
err := json.NewDecoder(r.Body).Decode(&ar)
145149
require.NoError(t, err)
@@ -165,6 +169,9 @@ func TestClient_Attest(t *testing.T) {
165169
Secret: encryptedCredentials.Secret,
166170
})
167171
case "/secret":
172+
assert.Equal(t, "step-attestation-http-client/1.0", r.Header.Get("User-Agent"))
173+
assert.Equal(t, "requestID", r.Header.Get("X-Request-Id"))
174+
168175
var sr secretRequest
169176
err := json.NewDecoder(r.Body).Decode(&sr)
170177
require.NoError(t, err)

tpm/attestation/requestid.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package attestation
2+
3+
import (
4+
"context"
5+
"net/http"
6+
7+
"go.step.sm/crypto/randutil"
8+
)
9+
10+
type requestIDContextKey struct{}
11+
12+
// NewRequestIDContext returns a new context with the given request ID added to the
13+
// context.
14+
func NewRequestIDContext(ctx context.Context, requestID string) context.Context {
15+
return context.WithValue(ctx, requestIDContextKey{}, requestID)
16+
}
17+
18+
// RequestIDFromContext returns the request ID from the context if it exists.
19+
// and is not empty.
20+
func RequestIDFromContext(ctx context.Context) (string, bool) {
21+
v, ok := ctx.Value(requestIDContextKey{}).(string)
22+
return v, ok && v != ""
23+
}
24+
25+
// requestIDHeader is the header name used for propagating request IDs from
26+
// the attestation client to the attestation CA and back again.
27+
const requestIDHeader = "X-Request-Id"
28+
29+
// newRequestID generates a new random UUIDv4 request ID. If it fails,
30+
// the request ID will be the empty string.
31+
func newRequestID() string {
32+
requestID, err := randutil.UUIDv4()
33+
if err != nil {
34+
return ""
35+
}
36+
37+
return requestID
38+
}
39+
40+
// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's
41+
// empty, the context is searched for a request ID. If that's also empty, a new
42+
// request ID is generated.
43+
func enforceRequestID(r *http.Request) {
44+
if requestID := r.Header.Get(requestIDHeader); requestID == "" {
45+
if reqID, ok := RequestIDFromContext(r.Context()); ok {
46+
requestID = reqID
47+
} else {
48+
requestID = newRequestID()
49+
}
50+
r.Header.Set(requestIDHeader, requestID)
51+
}
52+
}

tpm/attestation/useragent.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package attestation
2+
3+
import "net/http"
4+
5+
// UserAgent is the value of the User-Agent HTTP header that will
6+
// be set in HTTP requests to the attestation CA.
7+
var UserAgent = "step-attestation-http-client/1.0"
8+
9+
// setUserAgent sets the User-Agent header in HTTP requests.
10+
func setUserAgent(r *http.Request) {
11+
r.Header.Set("User-Agent", UserAgent)
12+
}

0 commit comments

Comments
 (0)