Skip to content

Commit e780450

Browse files
authored
TEE Server Error code Translation (google#587)
* Attest now will pass in a default audience instead of passing an error * TestCustomToken now lets empty audience through * Moved audience check from tee server to agent * removed unused constant * tokenOpts is now passed as request * added in constant for audience * stashing testing comment * adding in changes from jessieqliu comments * changes for tee-server error * changed test TestRequestFailurePassedToCaller to return correct status message * removed comments and code from another PR * refactor based on review comments * added in fix for using empty map and additional comment and nit changes --------- Co-authored-by: Sibghat Shah <sibghat@google.com>
1 parent 9ebbadb commit e780450

File tree

2 files changed

+140
-12
lines changed

2 files changed

+140
-12
lines changed

launcher/teeserver/tee_server.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,21 @@ import (
1515
"github.com/google/go-tpm-tools/verifier"
1616
"github.com/google/go-tpm-tools/verifier/models"
1717
"github.com/google/go-tpm-tools/verifier/util"
18+
"google.golang.org/grpc/codes"
19+
"google.golang.org/grpc/status"
1820
)
1921

22+
var clientErrorCodes = map[codes.Code]struct{}{
23+
codes.InvalidArgument: {},
24+
codes.FailedPrecondition: {},
25+
codes.PermissionDenied: {},
26+
codes.Unauthenticated: {},
27+
codes.NotFound: {},
28+
codes.Aborted: {},
29+
codes.OutOfRange: {},
30+
codes.Canceled: {},
31+
}
32+
2033
// AttestClients contains clients for supported verifier services that can be used to
2134
// get attestation tokens.
2235
type AttestClients struct {
@@ -120,15 +133,13 @@ func (a *attestHandler) attest(w http.ResponseWriter, r *http.Request, client ve
120133
switch r.Method {
121134
case http.MethodGet:
122135
if err := a.attestAgent.Refresh(a.ctx); err != nil {
123-
errStr := fmt.Sprintf("failed to refresh attestation agent: %v", err)
124-
a.logAndWriteError(errStr, http.StatusInternalServerError, w)
136+
a.logAndWriteHTTPError(w, http.StatusInternalServerError, fmt.Errorf("failed to refresh attestation agent: %w", err))
125137
return
126138
}
127139

128140
token, err := a.attestAgent.AttestWithClient(a.ctx, agent.AttestAgentOpts{}, client)
129141
if err != nil {
130-
errStr := fmt.Sprintf("failed to retrieve attestation service token: %v", err)
131-
a.logAndWriteError(errStr, http.StatusInternalServerError, w)
142+
a.handleAttestError(w, err, "failed to retrieve attestation service token")
132143
return
133144
}
134145

@@ -165,7 +176,8 @@ func (a *attestHandler) attest(w http.ResponseWriter, r *http.Request, client ve
165176
TokenOptions: &tokenOptions,
166177
}, client)
167178
if err != nil {
168-
a.logAndWriteHTTPError(w, http.StatusBadRequest, err)
179+
180+
a.handleAttestError(w, err, "failed to retrieve custom attestation service token")
169181
return
170182
}
171183

@@ -203,3 +215,20 @@ func (s *TeeServer) Shutdown(ctx context.Context) error {
203215
}
204216
return nil
205217
}
218+
219+
func (a *attestHandler) handleAttestError(w http.ResponseWriter, err error, message string) {
220+
st, ok := status.FromError(err)
221+
if ok {
222+
if _, exists := clientErrorCodes[st.Code()]; exists {
223+
// User errors, like invalid arguments. Map user errors to 400 Bad Request.
224+
a.logAndWriteHTTPError(w, http.StatusBadRequest, fmt.Errorf("%s: %w", message, err))
225+
return
226+
}
227+
// Server-side or transient errors. Map user errors 500 Internal Server Error.
228+
a.logAndWriteHTTPError(w, http.StatusInternalServerError, fmt.Errorf("%s: %w", message, err))
229+
return
230+
}
231+
// If it's not a gRPC error, it's likely an internal error within the launcher.
232+
// Map user errors 500 Internal Server Error
233+
a.logAndWriteHTTPError(w, http.StatusInternalServerError, fmt.Errorf("%s: %w", message, err))
234+
}

launcher/teeserver/tee_server_test.go

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
"github.com/google/go-tpm-tools/launcher/internal/logging"
1717
"github.com/google/go-tpm-tools/verifier"
1818
"github.com/google/go-tpm-tools/verifier/models"
19+
"google.golang.org/grpc/codes"
20+
"google.golang.org/grpc/status"
1921
)
2022

2123
// Implements verifier.Client interface so it can be used to initialize test attestHandlers
@@ -58,8 +60,6 @@ func (f fakeAttestationAgent) Close() error {
5860
func TestGetDefaultToken(t *testing.T) {
5961
testTokenContent := "test token"
6062

61-
// An empty attestHandler is fine for now as it is not being used
62-
// in the handler.
6363
ah := attestHandler{
6464
logger: logging.SimpleLogger(),
6565
clients: &AttestClients{
@@ -83,8 +83,40 @@ func TestGetDefaultToken(t *testing.T) {
8383
if w.Code != http.StatusOK {
8484
t.Errorf("got return code: %d, want: %d", w.Code, http.StatusOK)
8585
}
86-
if string(data) != testTokenContent {
87-
t.Errorf("got content: %v, want: %s", testTokenContent, string(data))
86+
if diff := cmp.Diff(testTokenContent, string(data)); diff != "" {
87+
t.Errorf("getToken() response body mismatch (-want +got):\n%s", diff)
88+
}
89+
}
90+
91+
func TestGetDefaultTokenServerError(t *testing.T) {
92+
// An empty attestHandler is fine for now as it is not being used
93+
// in the handler.
94+
ah := attestHandler{
95+
logger: logging.SimpleLogger(),
96+
clients: &AttestClients{
97+
GCA: &fakeVerifierClient{},
98+
},
99+
attestAgent: fakeAttestationAgent{
100+
attestWithClientFunc: func(context.Context, agent.AttestAgentOpts, verifier.Client) ([]byte, error) {
101+
return nil, errors.New("internal server error from agent")
102+
},
103+
}}
104+
105+
req := httptest.NewRequest(http.MethodGet, "/v1/token", nil)
106+
w := httptest.NewRecorder()
107+
108+
ah.getToken(w, req)
109+
data, err := io.ReadAll(w.Result().Body)
110+
if err != nil {
111+
t.Error(err)
112+
}
113+
114+
if w.Code != http.StatusInternalServerError {
115+
t.Errorf("got return code: %d, want: %d", w.Code, http.StatusInternalServerError)
116+
}
117+
expectedError := "failed to retrieve attestation service token: internal server error from agent"
118+
if diff := cmp.Diff(expectedError, string(data)); diff != "" {
119+
t.Errorf("getToken() response body mismatch (-want +got):\n%s", diff)
88120
}
89121
}
90122

@@ -118,7 +150,7 @@ func TestCustomToken(t *testing.T) {
118150
attestWithClientFunc: func(context.Context, agent.AttestAgentOpts, verifier.Client) ([]byte, error) {
119151
return nil, errors.New("Error")
120152
},
121-
want: http.StatusBadRequest,
153+
want: http.StatusInternalServerError,
122154
},
123155
{
124156
testName: "TestTokenTypeRequired",
@@ -167,8 +199,6 @@ func TestCustomToken(t *testing.T) {
167199
}
168200

169201
for i, test := range tests {
170-
// An empty attestHandler is fine for now as it is not being used
171-
// in the handler.
172202
ah := attestHandler{
173203
logger: logging.SimpleLogger(),
174204
clients: &AttestClients{
@@ -330,3 +360,72 @@ func TestCustomTokenDataParsedSuccessfully(t *testing.T) {
330360
}
331361
}
332362
}
363+
364+
func TestCustomHandleAttestError(t *testing.T) {
365+
body := `{
366+
"audience": "audience",
367+
"nonces": ["thisIsAcustomNonce"],
368+
"token_type": "OIDC"
369+
}`
370+
371+
testcases := []struct {
372+
name string
373+
err error
374+
wantStatusCode int
375+
}{
376+
{
377+
name: "FailedPrecondition error",
378+
err: status.New(codes.FailedPrecondition, "bad state").Err(),
379+
wantStatusCode: http.StatusBadRequest,
380+
},
381+
{
382+
name: "PermissionDenied error",
383+
err: status.New(codes.PermissionDenied, "denied").Err(),
384+
wantStatusCode: http.StatusBadRequest,
385+
},
386+
{
387+
name: "Internal error",
388+
err: status.New(codes.Internal, "internal server error").Err(),
389+
wantStatusCode: http.StatusInternalServerError,
390+
},
391+
{
392+
name: "Unavailable error",
393+
err: status.New(codes.Unavailable, "service unavailable").Err(),
394+
wantStatusCode: http.StatusInternalServerError,
395+
},
396+
{
397+
name: "non-gRPC error",
398+
err: errors.New("a generic error"),
399+
wantStatusCode: http.StatusInternalServerError,
400+
},
401+
}
402+
for _, tc := range testcases {
403+
t.Run(tc.name, func(t *testing.T) {
404+
ah := attestHandler{
405+
logger: logging.SimpleLogger(),
406+
clients: &AttestClients{
407+
GCA: &fakeVerifierClient{},
408+
},
409+
attestAgent: fakeAttestationAgent{
410+
attestWithClientFunc: func(context.Context, agent.AttestAgentOpts, verifier.Client) ([]byte, error) {
411+
return nil, tc.err
412+
},
413+
},
414+
}
415+
416+
req := httptest.NewRequest(http.MethodPost, "/v1/token", strings.NewReader(body))
417+
w := httptest.NewRecorder()
418+
419+
ah.getToken(w, req)
420+
421+
if w.Code != tc.wantStatusCode {
422+
t.Errorf("got status code %d, want %d", w.Code, tc.wantStatusCode)
423+
}
424+
425+
_, err := io.ReadAll(w.Result().Body)
426+
if err != nil {
427+
t.Errorf("failed to read response body: %v", err)
428+
}
429+
})
430+
}
431+
}

0 commit comments

Comments
 (0)