diff --git a/launcher/agent/agent.go b/launcher/agent/agent.go index 7ed33705d..4f44748b9 100644 --- a/launcher/agent/agent.go +++ b/launcher/agent/agent.go @@ -57,7 +57,7 @@ type AttestationAgent interface { MeasureEvent(gecel.Content) error Attest(context.Context, AttestAgentOpts) ([]byte, error) AttestWithClient(ctx context.Context, opts AttestAgentOpts, client verifier.Client) ([]byte, error) - AttestationEvidence(ctx context.Context, challenge []byte, extraData []byte) (*attestationpb.VmAttestation, error) + AttestationEvidence(ctx context.Context, challenge []byte, extraData []byte, opts AttestAgentOpts) (*attestationpb.VmAttestation, error) Refresh(context.Context) error Close() error } @@ -74,6 +74,8 @@ type attestRoot interface { ComputeNonce(challenge []byte, extraData []byte) []byte // AddDeviceROTs adds detected device RoTs(root of trust). AddDeviceROTs([]DeviceROT) + // AttestDeviceROTs fetches a list of runtime device attestation report. + AttestDeviceROTs(nonce []byte) ([]any, error) } // DeviceROT defines an interface for all attached devices to collect attestation. @@ -86,6 +88,12 @@ type DeviceROT interface { // VerifyAttestation API type AttestAgentOpts struct { TokenOptions *models.TokenOptions + *DeviceReportOpts +} + +// DeviceReportOpts contains options for runtime device attestations. +type DeviceReportOpts struct { + EnableRuntimeGPUAttestation bool } type agent struct { @@ -297,7 +305,7 @@ func (a *agent) AttestWithClient(ctx context.Context, opts AttestAgentOpts, clie } // AttestationEvidence returns the attestation evidence (TPM or TDX). -func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraData []byte) (*attestationpb.VmAttestation, error) { +func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraData []byte, opts AttestAgentOpts) (*attestationpb.VmAttestation, error) { if !a.launchSpec.Experiments.EnableAttestationEvidence { return nil, fmt.Errorf("attestation evidence is disabled") } @@ -313,12 +321,10 @@ func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraDa if err != nil { return nil, fmt.Errorf("failed to attest: %v", err) } - var cosCel bytes.Buffer if err := a.avRot.GetCEL().EncodeCEL(&cosCel); err != nil { return nil, err } - attestation := &attestationpb.VmAttestation{ Label: []byte(labels.WorkloadAttestation), Challenge: challenge, @@ -343,9 +349,41 @@ func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraDa default: return nil, fmt.Errorf("unknown attestation type: %T", v) } + + deviceReports, err := a.attestDeviceROTs(finalNonce, opts) + if err != nil { + return nil, err + } + attestation.DeviceReports = deviceReports + return attestation, nil } +func (a *agent) attestDeviceROTs(nonce []byte, opts AttestAgentOpts) ([]*attestationpb.DeviceAttestationReport, error) { + if opts.DeviceReportOpts == nil { + return nil, nil + } + deviceROTs, err := a.avRot.AttestDeviceROTs(nonce) + if err != nil { + return nil, err + } + + var deviceReports []*attestationpb.DeviceAttestationReport + for _, dr := range deviceROTs { + switch v := dr.(type) { + case *attestationpb.NvidiaAttestationReport: + if opts.DeviceReportOpts.EnableRuntimeGPUAttestation { + deviceReports = append(deviceReports, &attestationpb.DeviceAttestationReport{ + Report: &attestationpb.DeviceAttestationReport_NvidiaReport{ + NvidiaReport: v, + }, + }) + } + } + } + return deviceReports, nil +} + func (a *agent) verify(ctx context.Context, req verifier.VerifyAttestationRequest, client verifier.Client) (*verifier.VerifyAttestationResponse, error) { if a.launchSpec.Experiments.EnableVerifyCS { return client.VerifyConfidentialSpace(ctx, req) @@ -423,6 +461,13 @@ func (t *tpmAttestRoot) AddDeviceROTs(deviceROTs []DeviceROT) { t.deviceROTs = append(t.deviceROTs, deviceROTs...) } +func (t *tpmAttestRoot) AttestDeviceROTs(nonce []byte) ([]any, error) { + t.tpmMu.Lock() + defer t.tpmMu.Unlock() + + return doAttestDeviceROTs(t.deviceROTs, nonce) +} + type tdxAttestRoot struct { tdxMu sync.Mutex qp *tg.LinuxConfigFsQuoteProvider @@ -461,28 +506,20 @@ func (t *tdxAttestRoot) Attest(nonce []byte) (any, error) { return nil, err } - var nvAtt *attestationpb.NvidiaAttestationReport - for _, deviceRoT := range t.deviceROTs { - att, err := deviceRoT.Attest(nonce) - if err != nil { - return nil, err - } - switch v := att.(type) { - case *attestationpb.NvidiaAttestationReport: - nvAtt = v - default: - return nil, fmt.Errorf("unknown device attestation type: %T", v) - } - } - return &verifier.TDCCELAttestation{ - CcelAcpiTable: ccelTable, - CcelData: ccelData, - TdQuote: rawQuote, - NvidiaAttestation: nvAtt, + CcelAcpiTable: ccelTable, + CcelData: ccelData, + TdQuote: rawQuote, }, nil } +func (t *tdxAttestRoot) AttestDeviceROTs(nonce []byte) ([]any, error) { + t.tdxMu.Lock() + defer t.tdxMu.Unlock() + + return doAttestDeviceROTs(t.deviceROTs, nonce) +} + func (t *tdxAttestRoot) ComputeNonce(challenge []byte, extraData []byte) []byte { challengeData := challenge if extraData != nil { @@ -596,3 +633,15 @@ func convertToTPMQuote(v *pb.Attestation) *attestationpb.TpmQuote { }, } } + +func doAttestDeviceROTs(deviceROTs []DeviceROT, nonce []byte) ([]any, error) { + var deviceReports []any + for _, deviceROT := range deviceROTs { + deviceReport, err := deviceROT.Attest(nonce) + if err != nil { + return nil, err + } + deviceReports = append(deviceReports, deviceReport) + } + return deviceReports, nil +} diff --git a/launcher/agent/agent_test.go b/launcher/agent/agent_test.go index 0b94eca77..2b2b0fed6 100644 --- a/launcher/agent/agent_test.go +++ b/launcher/agent/agent_test.go @@ -41,6 +41,7 @@ import ( "github.com/google/go-tpm-tools/verifier/oci" "github.com/google/go-tpm-tools/verifier/oci/cosign" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) @@ -658,7 +659,7 @@ type fakeTdxAttestRoot struct { cel gecel.CEL receivedNonce []byte tdxQuote []byte - deviceRoTS []DeviceROT + deviceROTs []DeviceROT } func (f *fakeTdxAttestRoot) Extend(c gecel.Content) error { @@ -673,23 +674,8 @@ func (f *fakeTdxAttestRoot) GetCEL() gecel.CEL { func (f *fakeTdxAttestRoot) Attest(nonce []byte) (any, error) { f.receivedNonce = nonce - var nvAtt *attestationpb.NvidiaAttestationReport - for _, deviceRoT := range f.deviceRoTS { - att, err := deviceRoT.Attest(nonce) - if err != nil { - return nil, err - } - switch v := att.(type) { - case *attestationpb.NvidiaAttestationReport: - nvAtt = v - default: - return nil, fmt.Errorf("unknown device attestation type: %T", v) - } - } - return &verifier.TDCCELAttestation{ - TdQuote: f.tdxQuote, - NvidiaAttestation: nvAtt, + TdQuote: f.tdxQuote, }, nil } @@ -704,11 +690,28 @@ func (f *fakeTdxAttestRoot) ComputeNonce(challenge []byte, extraData []byte) []b return finalNonce[:] } +func (f *fakeTdxAttestRoot) AttestDeviceROTs(nonce []byte) ([]any, error) { + var deviceReports []any + for _, deviceRoT := range f.deviceROTs { + att, err := deviceRoT.Attest(nonce) + if err != nil { + return nil, err + } + switch v := att.(type) { + case *attestationpb.NvidiaAttestationReport: + deviceReports = append(deviceReports, v) + default: + return nil, fmt.Errorf("unknown device attestation type: %T", v) + } + } + return deviceReports, nil +} + //go:embed testdata/cel.b64 var celB64 string func (f *fakeTdxAttestRoot) AddDeviceROTs(deviceRoTS []DeviceROT) { - f.deviceRoTS = append(f.deviceRoTS, deviceRoTS...) + f.deviceROTs = append(f.deviceROTs, deviceRoTS...) } type fakeGPURoT struct{} @@ -725,7 +728,7 @@ func (f *fakeGPURoT) Attest(nonce []byte) (any, error) { }, }, nil } -func TestTdxAttestRoot(t *testing.T) { +func TestTDXAttestDeviceROTs(t *testing.T) { testCases := []struct { name string tdxAttestRoot *fakeTdxAttestRoot @@ -742,7 +745,7 @@ func TestTdxAttestRoot(t *testing.T) { { name: "success tdxAttestRoot w/ GPU device", tdxAttestRoot: &fakeTdxAttestRoot{ - deviceRoTS: []DeviceROT{&fakeGPURoT{}}, + deviceROTs: []DeviceROT{&fakeGPURoT{}}, }, nonce: []byte("test-nonce"), wantGPU: true, @@ -751,7 +754,7 @@ func TestTdxAttestRoot(t *testing.T) { { name: "failed tdxAttestRoot w/ GPU device", tdxAttestRoot: &fakeTdxAttestRoot{ - deviceRoTS: []DeviceROT{&fakeGPURoT{}}, + deviceROTs: []DeviceROT{&fakeGPURoT{}}, }, nonce: []byte(""), wantPass: false, @@ -760,13 +763,16 @@ func TestTdxAttestRoot(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - attestation, err := tc.tdxAttestRoot.Attest(tc.nonce) - if gotPass := (err == nil); gotPass != tc.wantPass { - t.Errorf("tdxAttestRoot.Attest() did not return expected attestation result, got %v, want %v", gotPass, tc.wantPass) + deviceReports, err := tc.tdxAttestRoot.AttestDeviceROTs(tc.nonce) + if gotPass := err == nil; gotPass != tc.wantPass { + t.Errorf("tdxAttestRoot.AttestDeviceROTs() did not return expected attestation result, got %v, want %v", gotPass, tc.wantPass) } - if tc.wantPass && tc.wantGPU { - if att := attestation.(*verifier.TDCCELAttestation); att.NvidiaAttestation == nil { - t.Error("tdxAttestRoot.Attest() did not return expected GPU attestation, want GPU attestation, but got nil") + if tc.wantGPU { + if len(deviceReports) == 0 { + t.Fatalf("tdxAttestRoot.AttestDeviceROTs() didn't return any device reports") + } + if att := deviceReports[0].(*attestationpb.NvidiaAttestationReport); att == nil { + t.Errorf("tdxAttestRoot.AttestDeviceROTs() didn't return expected device report type, want %v, but got nil", &attestationpb.NvidiaAttestationReport{}) } } }) @@ -811,6 +817,7 @@ func TestAttestationEvidence_TDX_Success(t *testing.T) { }, }, } + attestAgent.avRot.AddDeviceROTs([]DeviceROT{&fakeGPURoT{}}) if err := measureFakeEvents(attestAgent); err != nil { t.Fatalf("failed to measure events: %v", err) @@ -818,35 +825,76 @@ func TestAttestationEvidence_TDX_Success(t *testing.T) { challenge := []byte("test-challenge") extraData := []byte("test-extra-data") - att, err := attestAgent.AttestationEvidence(ctx, challenge, extraData) - if err != nil { - t.Fatalf("AttestationEvidence failed: %v", err) - } - // Verify the nonce passed to Attest was derived from challenge+extraData. - expectedNonce := fakeRoot.ComputeNonce(challenge, extraData) - if !bytes.Equal(fakeRoot.receivedNonce, expectedNonce) { - t.Errorf("got nonce %x, want %x", fakeRoot.receivedNonce, expectedNonce) + testCases := []struct { + name string + opts AttestAgentOpts + wantGPUReport *attestationpb.NvidiaAttestationReport + }{ + { + name: "TDX attestation", + opts: AttestAgentOpts{}, + }, + { + name: "TDX attestation + runtime GPU attestation", + opts: AttestAgentOpts{ + DeviceReportOpts: &DeviceReportOpts{ + EnableRuntimeGPUAttestation: true, + }, + }, + wantGPUReport: &attestationpb.NvidiaAttestationReport{ + CcFeature: &attestationpb.NvidiaAttestationReport_Spt{ + Spt: &attestationpb.NvidiaAttestationReport_SinglePassthroughAttestation{ + GpuQuote: &attestationpb.GpuInfo{Uuid: "fake-gpu-uuid"}, + }, + }, + }, + }, } - if att.GetQuote().GetTdxCcelQuote() == nil { - t.Fatal("expected TDCCELAttestation to be populated for TDX") - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + att, err := attestAgent.AttestationEvidence(ctx, challenge, extraData, tc.opts) + if err != nil { + t.Fatalf("AttestationEvidence failed: %v", err) + } - if !bytes.Equal(att.GetQuote().GetTdxCcelQuote().GetTdQuote(), testTDXQuote) { - t.Errorf("TDQuote mismatch: got %x, want %x", att.GetQuote().GetTdxCcelQuote().GetTdQuote(), testTDXQuote) - } - if att.GetQuote().GetTpmQuote() != nil { - t.Error("expected TPMQuote to be nil for TDX attestation") - } - if !bytes.Equal(att.GetQuote().GetTdxCcelQuote().GetCelLaunchEventLog(), testCEL) { - t.Errorf("CELLaunchEventLog mismatch: got %x, want %x", att.GetQuote().GetTdxCcelQuote().GetCelLaunchEventLog(), testCEL) - } - if !bytes.Equal(att.Challenge, challenge) { - t.Errorf("challenge mismatch: got %x, want %x", att.Challenge, challenge) - } - if !bytes.Equal(att.ExtraData, extraData) { - t.Errorf("extraData mismatch: got %x, want %x", att.ExtraData, extraData) + // Verify the nonce passed to Attest was derived from challenge+extraData. + expectedNonce := fakeRoot.ComputeNonce(challenge, extraData) + if !bytes.Equal(fakeRoot.receivedNonce, expectedNonce) { + t.Errorf("got nonce %x, want %x", fakeRoot.receivedNonce, expectedNonce) + } + + if att.GetQuote().GetTdxCcelQuote() == nil { + t.Fatal("expected TDCCELAttestation to be populated for TDX") + } + + if !bytes.Equal(att.GetQuote().GetTdxCcelQuote().GetTdQuote(), testTDXQuote) { + t.Errorf("TDQuote mismatch: got %x, want %x", att.GetQuote().GetTdxCcelQuote().GetTdQuote(), testTDXQuote) + } + if att.GetQuote().GetTpmQuote() != nil { + t.Error("expected TPMQuote to be nil for TDX attestation") + } + if !bytes.Equal(att.GetQuote().GetTdxCcelQuote().GetCelLaunchEventLog(), testCEL) { + t.Errorf("CELLaunchEventLog mismatch: got %x, want %x", att.GetQuote().GetTdxCcelQuote().GetCelLaunchEventLog(), testCEL) + } + if !bytes.Equal(att.Challenge, challenge) { + t.Errorf("challenge mismatch: got %x, want %x", att.Challenge, challenge) + } + if !bytes.Equal(att.ExtraData, extraData) { + t.Errorf("extraData mismatch: got %x, want %x", att.ExtraData, extraData) + } + if tc.wantGPUReport != nil { + if len(att.DeviceReports) == 0 { + t.Fatalf("Failed to get runtime GPU attestation") + } + + gotDeviceReport := att.DeviceReports[0] + if gotGPUReport, wantGPUReport := gotDeviceReport.GetNvidiaReport(), tc.wantGPUReport; !proto.Equal(gotGPUReport, wantGPUReport) { + t.Errorf("runtime GPU attestation mismatch: got %v, want %v", gotGPUReport, wantGPUReport) + } + } + }) } } @@ -883,7 +931,7 @@ func TestAttestationEvidence_TPM_Success(t *testing.T) { challenge := []byte("test-challenge") extraData := []byte("test-extra-data") - att, err := agent.AttestationEvidence(ctx, challenge, extraData) + att, err := agent.AttestationEvidence(ctx, challenge, extraData, AttestAgentOpts{}) if err != nil { t.Fatalf("AttestationEvidence failed on TPM: %v", err) } @@ -1067,7 +1115,7 @@ func TestAttestationEvidence_ExperimentDisabled(t *testing.T) { } defer agent.Close() - _, err = agent.AttestationEvidence(ctx, []byte("challenge"), nil) + _, err = agent.AttestationEvidence(ctx, []byte("challenge"), nil, AttestAgentOpts{}) if err == nil { t.Error("expected error when EnableAttestationEvidence is disabled, got nil") } @@ -1101,7 +1149,7 @@ func TestAttestationEvidence_TDX_NilExtraData(t *testing.T) { } challenge := []byte("test-challenge") - att, err := attestatAgent.AttestationEvidence(ctx, challenge, nil) + att, err := attestatAgent.AttestationEvidence(ctx, challenge, nil, AttestAgentOpts{}) if err != nil { t.Fatalf("AttestationEvidence failed with nil extraData: %v", err) } diff --git a/launcher/container_runner_test.go b/launcher/container_runner_test.go index 871cd504d..98c8563c9 100644 --- a/launcher/container_runner_test.go +++ b/launcher/container_runner_test.go @@ -74,7 +74,7 @@ func (f *fakeAttestationAgent) AttestWithClient(_ context.Context, _ agent.Attes return nil, fmt.Errorf("unimplemented") } -func (f *fakeAttestationAgent) AttestationEvidence(_ context.Context, _ []byte, _ []byte) (*attestationpb.VmAttestation, error) { +func (f *fakeAttestationAgent) AttestationEvidence(_ context.Context, _ []byte, _ []byte, _ agent.AttestAgentOpts) (*attestationpb.VmAttestation, error) { return nil, fmt.Errorf("unimplemented") } diff --git a/launcher/teeserver/tee_server.go b/launcher/teeserver/tee_server.go index cae8e69a6..3fb822b4d 100644 --- a/launcher/teeserver/tee_server.go +++ b/launcher/teeserver/tee_server.go @@ -8,9 +8,9 @@ import ( "fmt" "net" "net/http" + "strings" attestationpb "github.com/GoogleCloudPlatform/confidential-space/server/proto/gen/attestation" - "github.com/containerd/containerd/protobuf/proto" keymanager "github.com/google/go-tpm-tools/keymanager/km_common/proto" wsd "github.com/google/go-tpm-tools/keymanager/workload_service" "github.com/google/go-tpm-tools/launcher/agent" @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" ) const ( @@ -150,7 +151,12 @@ func (a *attestHandler) getITAToken(w http.ResponseWriter, r *http.Request) { } // getAttestationEvidence retrieves the attestation evidence. +// It returns partial response with query parameter support. +// It currently supports "label", "challenge", "quote", "extraData", and "deviceReports" params. +// The default response with no query parameter will return all fields except device reports. +// If the fields param is "*", it will return all fields including device reports. func (a *attestHandler) getAttestationEvidence(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { a.logAndWriteHTTPError(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed")) return @@ -169,13 +175,25 @@ func (a *attestHandler) getAttestationEvidence(w http.ResponseWriter, r *http.Re return } - evidence, err := a.attestAgent.AttestationEvidence(a.ctx, req.Challenge, nil) + fields := r.URL.Query().Get("fields") + attestOpts := agent.AttestAgentOpts{ + DeviceReportOpts: &agent.DeviceReportOpts{ + EnableRuntimeGPUAttestation: fields == "*" || strings.Contains(fields, "deviceReports"), + }, + } + evidence, err := a.attestAgent.AttestationEvidence(a.ctx, req.Challenge, nil, attestOpts) if err != nil { a.logAndWriteHTTPError(w, http.StatusInternalServerError, err) return } - evidenceBytes, err := protojson.Marshal(evidence) + partialEvidence, err := filterVMAttestationFields(evidence, fields) + if err != nil { + a.logAndWriteHTTPError(w, http.StatusBadRequest, fmt.Errorf("invalid fields parameter: %v", err)) + return + } + + evidenceBytes, err := protojson.Marshal(partialEvidence) if err != nil { a.logAndWriteHTTPError(w, http.StatusInternalServerError, fmt.Errorf("failed to marshal evidence: %v", err)) return @@ -185,6 +203,36 @@ func (a *attestHandler) getAttestationEvidence(w http.ResponseWriter, r *http.Re w.Write(evidenceBytes) } +// filterVMAttestationFields return a partial VM Attestation based on the query parameters. +func filterVMAttestationFields(att *attestationpb.VmAttestation, fields string) (*attestationpb.VmAttestation, error) { + if fields == "" || fields == "*" { + return att, nil + } + fieldSlice := strings.Split(fields, ",") + fieldMap := make(map[string]bool) + for _, f := range fieldSlice { + fieldMap[strings.TrimSpace(f)] = true + } + + out := &attestationpb.VmAttestation{} + if fieldMap["label"] { + out.Label = att.GetLabel() + } + if fieldMap["challenge"] { + out.Challenge = att.GetChallenge() + } + if fieldMap["extraData"] { + out.ExtraData = att.GetExtraData() + } + if fieldMap["quote"] { + out.Quote = att.GetQuote() + } + if fieldMap["deviceReports"] { + out.DeviceReports = att.GetDeviceReports() + } + return out, nil +} + // getKeyEndorsement retrieves the attestation evidence with KEM and binding key claims. func (a *attestHandler) getKeyEndorsement(w http.ResponseWriter, r *http.Request) { if !a.launchSpec.Experiments.EnableKeyManager { @@ -243,13 +291,13 @@ func (a *attestHandler) getKeyEndorsement(w http.ResponseWriter, r *http.Request return } - kemEvidence, err := a.attestAgent.AttestationEvidence(a.ctx, req.Challenge, kemBytes) + kemEvidence, err := a.attestAgent.AttestationEvidence(a.ctx, req.Challenge, kemBytes, agent.AttestAgentOpts{}) if err != nil { a.logAndWriteHTTPError(w, http.StatusInternalServerError, fmt.Errorf("failed to collect attestation evidence with kem key claims")) return } - bindingEvidence, err := a.attestAgent.AttestationEvidence(a.ctx, req.Challenge, bindingBytes) + bindingEvidence, err := a.attestAgent.AttestationEvidence(a.ctx, req.Challenge, bindingBytes, agent.AttestAgentOpts{}) if err != nil { a.logAndWriteHTTPError(w, http.StatusInternalServerError, fmt.Errorf("failed to collect attestation evidence with binding key claims")) return diff --git a/launcher/teeserver/tee_server_test.go b/launcher/teeserver/tee_server_test.go index 95b40707d..6f7712393 100644 --- a/launcher/teeserver/tee_server_test.go +++ b/launcher/teeserver/tee_server_test.go @@ -24,8 +24,10 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" attestationpb "github.com/GoogleCloudPlatform/confidential-space/server/proto/gen/attestation" + "google.golang.org/protobuf/testing/protocmp" ) // Implements verifier.Client interface so it can be used to initialize test attestHandlers @@ -58,8 +60,19 @@ func (f fakeAttestationAgent) AttestWithClient(c context.Context, a agent.Attest return f.attestWithClientFunc(c, a, v) } -func (f fakeAttestationAgent) AttestationEvidence(c context.Context, nonce []byte, extraData []byte) (*attestationpb.VmAttestation, error) { - return f.attestationEvidenceFunc(c, nonce, extraData) +func (f fakeAttestationAgent) AttestationEvidence(c context.Context, nonce []byte, extraData []byte, opts agent.AttestAgentOpts) (*attestationpb.VmAttestation, error) { + attestation, err := f.attestationEvidenceFunc(c, nonce, extraData) + if err != nil { + return nil, err + } + if opts.DeviceReportOpts != nil && opts.DeviceReportOpts.EnableRuntimeGPUAttestation { + attestation.DeviceReports = append(attestation.DeviceReports, &attestationpb.DeviceAttestationReport{ + Report: &attestationpb.DeviceAttestationReport_NvidiaReport{ + NvidiaReport: &attestationpb.NvidiaAttestationReport{}, + }, + }) + } + return attestation, nil } func (f fakeAttestationAgent) MeasureEvent(c gecel.Content) error { @@ -604,22 +617,264 @@ func TestCustomHandleAttestError(t *testing.T) { } func TestAttestationEvidence(t *testing.T) { - ah := attestHandler{ - logger: logging.SimpleLogger(), - attestAgent: fakeAttestationAgent{ + testAttestation := &attestationpb.VmAttestation{ + Label: []byte("test-label"), + Challenge: []byte("test-challenge"), + ExtraData: []byte("test-extra-data"), + Quote: &attestationpb.VmAttestationQuote{ + Quote: &attestationpb.VmAttestationQuote_TdxCcelQuote{ + TdxCcelQuote: &attestationpb.TdxCcelQuote{}, + }, + }, + } + + testCases := []struct { + name string + method string + url string + body string + attestationEvidenceFunc func(context.Context, []byte, []byte) (*attestationpb.VmAttestation, error) + wantStatusCode int + wantBodyContains string + }{ + { + name: "success no fields", + method: http.MethodPost, + url: "/v1/evidence", + body: `{"challenge": "dGVzdA=="}`, + wantStatusCode: http.StatusOK, + attestationEvidenceFunc: func(_ context.Context, _ []byte, _ []byte) (*attestationpb.VmAttestation, error) { + return testAttestation, nil + }, + wantBodyContains: `{"label":"dGVzdC1sYWJlbA==","challenge":"dGVzdC1jaGFsbGVuZ2U=","extraData":"dGVzdC1leHRyYS1kYXRh","quote":{"tdxCcelQuote":{}}}`, + }, + { + name: "success with * fields", + method: http.MethodPost, + url: "/v1/evidence?fields=*", + body: `{"challenge": "dGVzdA=="}`, + wantStatusCode: http.StatusOK, + attestationEvidenceFunc: func(_ context.Context, _ []byte, _ []byte) (*attestationpb.VmAttestation, error) { + return testAttestation, nil + }, + wantBodyContains: `{"label":"dGVzdC1sYWJlbA==","challenge":"dGVzdC1jaGFsbGVuZ2U=","extraData":"dGVzdC1leHRyYS1kYXRh","quote":{"tdxCcelQuote":{}},"deviceReports":[{"nvidiaReport":{}}]}`, + }, + { + name: "success with fields", + method: http.MethodPost, + url: "/v1/evidence?fields=label,quote", + body: `{"challenge": "dGVzdA=="}`, + wantStatusCode: http.StatusOK, + attestationEvidenceFunc: func(_ context.Context, _ []byte, _ []byte) (*attestationpb.VmAttestation, error) { + return testAttestation, nil + }, + wantBodyContains: `{"label":"dGVzdC1sYWJlbA==","quote":{"tdxCcelQuote":{}}}`, + }, + { + name: "wrong method", + method: http.MethodGet, + url: "/v1/evidence", + body: "", + wantStatusCode: http.StatusMethodNotAllowed, + wantBodyContains: "method not allowed", + }, + { + name: "malformed json", + method: http.MethodPost, + url: "/v1/evidence", + body: `{"challenge": "dGVzdA=="`, + wantStatusCode: http.StatusBadRequest, + wantBodyContains: "failed to decode request", + }, + { + name: "missing challenge", + method: http.MethodPost, + url: "/v1/evidence", + body: `{}`, + wantStatusCode: http.StatusBadRequest, + wantBodyContains: "challenge is required", + }, + { + name: "attestation agent error", + method: http.MethodPost, + url: "/v1/evidence", + body: `{"challenge": "dGVzdA=="}`, + wantStatusCode: http.StatusInternalServerError, attestationEvidenceFunc: func(_ context.Context, _ []byte, _ []byte) (*attestationpb.VmAttestation, error) { - return &attestationpb.VmAttestation{}, nil + return nil, errors.New("agent error") }, + wantBodyContains: "agent error", }, } - req := httptest.NewRequest(http.MethodPost, "/v1/evidence", strings.NewReader("{\"challenge\": \"dGVzdA==\"}")) - w := httptest.NewRecorder() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + attestationFunc := tc.attestationEvidenceFunc + if attestationFunc == nil { + attestationFunc = func(_ context.Context, _ []byte, _ []byte) (*attestationpb.VmAttestation, error) { + return &attestationpb.VmAttestation{}, nil + } + } + ah := attestHandler{ + logger: logging.SimpleLogger(), + attestAgent: fakeAttestationAgent{ + attestationEvidenceFunc: attestationFunc, + }, + } - ah.getAttestationEvidence(w, req) + req := httptest.NewRequest(tc.method, tc.url, strings.NewReader(tc.body)) + w := httptest.NewRecorder() - if w.Code != http.StatusOK { - t.Errorf("got return code: %d, want: %d", w.Code, http.StatusOK) + ah.getAttestationEvidence(w, req) + + if w.Code != tc.wantStatusCode { + t.Errorf("getAttestationEvidence() got status code %d, want %d", w.Code, tc.wantStatusCode) + } + + if tc.wantStatusCode == http.StatusOK { + var gotEvidence attestationpb.VmAttestation + if err := protojson.Unmarshal(w.Body.Bytes(), &gotEvidence); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + var wantEvidence attestationpb.VmAttestation + if err := protojson.Unmarshal([]byte(tc.wantBodyContains), &wantEvidence); err != nil { + t.Fatalf("failed to unmarshal wantBodyContains: %v", err) + } + if diff := cmp.Diff(&wantEvidence, &gotEvidence, protocmp.Transform()); diff != "" { + t.Errorf("getAttestationEvidence() response body mismatch (-want +got):\n%s", diff) + } + } else { + respBody, _ := io.ReadAll(w.Body) + if !strings.Contains(string(respBody), tc.wantBodyContains) { + t.Errorf("getAttestationEvidence() response body = %q, want to contain %q", string(respBody), tc.wantBodyContains) + } + } + }) + } +} + +func TestFilterVMAttestationFields(t *testing.T) { + fullAttestation := &attestationpb.VmAttestation{ + Label: []byte("test-label"), + Challenge: []byte("test-challenge"), + ExtraData: []byte("test-extra-data"), + Quote: &attestationpb.VmAttestationQuote{ + Quote: &attestationpb.VmAttestationQuote_TpmQuote{ + TpmQuote: &attestationpb.TpmQuote{}, + }, + }, + DeviceReports: []*attestationpb.DeviceAttestationReport{ + { + Report: &attestationpb.DeviceAttestationReport_NvidiaReport{ + NvidiaReport: &attestationpb.NvidiaAttestationReport{}, + }, + }, + }, + } + + testCases := []struct { + name string + fields string + mutate func(att *attestationpb.VmAttestation) + want *attestationpb.VmAttestation + }{ + { + name: "no fields", + fields: "", + mutate: func(att *attestationpb.VmAttestation) { + att.DeviceReports = nil + }, + want: &attestationpb.VmAttestation{ + Label: fullAttestation.Label, + Challenge: fullAttestation.Challenge, + ExtraData: fullAttestation.ExtraData, + Quote: fullAttestation.Quote, + }, + }, + { + name: "single field label", + fields: "label", + want: &attestationpb.VmAttestation{ + Label: fullAttestation.Label, + }, + }, + { + name: "single field challenge", + fields: "challenge", + want: &attestationpb.VmAttestation{ + Challenge: fullAttestation.Challenge, + }, + }, + { + name: "single field extraData", + fields: "extraData", + want: &attestationpb.VmAttestation{ + ExtraData: fullAttestation.ExtraData, + }, + }, + { + name: "single field quote", + fields: "quote", + want: &attestationpb.VmAttestation{ + Quote: fullAttestation.Quote, + }, + }, + { + name: "single field deviceReports", + fields: "deviceReports", + want: &attestationpb.VmAttestation{ + DeviceReports: fullAttestation.DeviceReports, + }, + }, + { + name: "multiple fields", + fields: "label,quote", + want: &attestationpb.VmAttestation{ + Label: fullAttestation.Label, + Quote: fullAttestation.Quote, + }, + }, + { + name: "all fields", + fields: "label,challenge,extraData,quote,deviceReports", + want: fullAttestation, + }, + { + name: "fields with whitespace", + fields: " label , deviceReports ", + want: &attestationpb.VmAttestation{ + Label: fullAttestation.Label, + DeviceReports: fullAttestation.DeviceReports, + }, + }, + { + name: "all fields with *", + fields: "*", + want: fullAttestation, + }, + { + name: "unknown fields are ignored", + fields: "label,foo,bar", + want: &attestationpb.VmAttestation{ + Label: fullAttestation.Label, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + attestation := proto.Clone(fullAttestation).(*attestationpb.VmAttestation) + if tc.mutate != nil { + tc.mutate(attestation) + } + got, err := filterVMAttestationFields(attestation, tc.fields) + if err != nil { + t.Fatalf("filterVMAttestationFields() returned an unexpected error: %v", err) + } + if diff := cmp.Diff(tc.want, got, protocmp.Transform()); diff != "" { + t.Errorf("filterVMAttestationFields() returned diff (-want +got):\n%s", diff) + } + }) } } diff --git a/server/eventlog.go b/server/eventlog.go index caf93fa5a..72a456600 100644 --- a/server/eventlog.go +++ b/server/eventlog.go @@ -7,7 +7,6 @@ import ( attestationpb "github.com/GoogleCloudPlatform/confidential-space/server/proto/gen/attestation" - "github.com/containerd/containerd/protobuf/proto" gecel "github.com/google/go-eventlog/cel" "github.com/google/go-eventlog/extract" gepb "github.com/google/go-eventlog/proto/state" @@ -16,6 +15,7 @@ import ( "github.com/google/go-tpm-tools/cel" pb "github.com/google/go-tpm-tools/proto/attest" tpmpb "github.com/google/go-tpm-tools/proto/tpm" + "google.golang.org/protobuf/proto" ) // parsePCClientEventLog parses a raw event log and replays the parsed event diff --git a/server/eventlog_test.go b/server/eventlog_test.go index 5d0631762..346310368 100644 --- a/server/eventlog_test.go +++ b/server/eventlog_test.go @@ -10,7 +10,6 @@ import ( attestationpb "github.com/GoogleCloudPlatform/confidential-space/server/proto/gen/attestation" - "github.com/containerd/containerd/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-configfs-tsm/configfs/fakertmr" configfstsmrtmr "github.com/google/go-configfs-tsm/rtmr" @@ -25,6 +24,7 @@ import ( pb "github.com/google/go-tpm-tools/proto/tpm" "github.com/google/go-tpm/legacy/tpm2" "github.com/google/go-tpm/tpmutil" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) diff --git a/verifier/client.go b/verifier/client.go index 3a5b716eb..da9a43b07 100644 --- a/verifier/client.go +++ b/verifier/client.go @@ -5,7 +5,6 @@ package verifier import ( "context" - csattestpb "github.com/GoogleCloudPlatform/confidential-space/server/proto/gen/attestation" attestpb "github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/verifier/models" "google.golang.org/genproto/googleapis/rpc/status" @@ -59,7 +58,6 @@ type TDCCELAttestation struct { // still needs following two for GCE info AkCert []byte IntermediateCerts [][]byte - NvidiaAttestation *csattestpb.NvidiaAttestationReport } // VerifyAttestationResponse is the response from a successful