Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions launcher/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type AttestationAgent interface {
MeasureEvent(gecel.Content) error
Attest(context.Context, AttestAgentOpts) ([]byte, error)
AttestWithClient(ctx context.Context, opts AttestAgentOpts, client verifier.Client) ([]byte, error)
AttestationEvidence(ctx context.Context, challenge []byte, extraData []byte) (*attestationpb.VmAttestation, error)
AttestationEvidence(ctx context.Context, challenge []byte, extraData []byte, opts AttestAgentOpts) (*attestationpb.VmAttestation, error)
Refresh(context.Context) error
Close() error
}
Expand All @@ -74,6 +74,8 @@ type attestRoot interface {
ComputeNonce(challenge []byte, extraData []byte) []byte
// AddDeviceROTs adds detected device RoTs(root of trust).
AddDeviceROTs([]DeviceROT)
// AttestDeviceROTs fetches a list of runtime device attestation report.
AttestDeviceROTs(nonce []byte) ([]any, error)
}

// DeviceROT defines an interface for all attached devices to collect attestation.
Expand All @@ -86,6 +88,12 @@ type DeviceROT interface {
// VerifyAttestation API
type AttestAgentOpts struct {
TokenOptions *models.TokenOptions
*DeviceReportOpts
}

// DeviceReportOpts contains options for runtime device attestations.
type DeviceReportOpts struct {
EnableRuntimeGPUAttestation bool
}

type agent struct {
Expand Down Expand Up @@ -297,7 +305,7 @@ func (a *agent) AttestWithClient(ctx context.Context, opts AttestAgentOpts, clie
}

// AttestationEvidence returns the attestation evidence (TPM or TDX).
func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraData []byte) (*attestationpb.VmAttestation, error) {
func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraData []byte, opts AttestAgentOpts) (*attestationpb.VmAttestation, error) {
if !a.launchSpec.Experiments.EnableAttestationEvidence {
return nil, fmt.Errorf("attestation evidence is disabled")
}
Expand All @@ -313,12 +321,10 @@ func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraDa
if err != nil {
return nil, fmt.Errorf("failed to attest: %v", err)
}

var cosCel bytes.Buffer
if err := a.avRot.GetCEL().EncodeCEL(&cosCel); err != nil {
return nil, err
}

attestation := &attestationpb.VmAttestation{
Label: []byte(labels.WorkloadAttestation),
Challenge: challenge,
Expand All @@ -343,9 +349,41 @@ func (a *agent) AttestationEvidence(_ context.Context, challenge []byte, extraDa
default:
return nil, fmt.Errorf("unknown attestation type: %T", v)
}

deviceReports, err := a.attestDeviceROTs(finalNonce, opts)
if err != nil {
return nil, err
}
attestation.DeviceReports = deviceReports

return attestation, nil
}

func (a *agent) attestDeviceROTs(nonce []byte, opts AttestAgentOpts) ([]*attestationpb.DeviceAttestationReport, error) {
if opts.DeviceReportOpts == nil {
return nil, nil
}
deviceROTs, err := a.avRot.AttestDeviceROTs(nonce)
if err != nil {
return nil, err
}

var deviceReports []*attestationpb.DeviceAttestationReport
for _, dr := range deviceROTs {
switch v := dr.(type) {
case *attestationpb.NvidiaAttestationReport:
if opts.DeviceReportOpts.EnableRuntimeGPUAttestation {
deviceReports = append(deviceReports, &attestationpb.DeviceAttestationReport{
Report: &attestationpb.DeviceAttestationReport_NvidiaReport{
NvidiaReport: v,
},
})
}
}
}
return deviceReports, nil
}

func (a *agent) verify(ctx context.Context, req verifier.VerifyAttestationRequest, client verifier.Client) (*verifier.VerifyAttestationResponse, error) {
if a.launchSpec.Experiments.EnableVerifyCS {
return client.VerifyConfidentialSpace(ctx, req)
Expand Down Expand Up @@ -423,6 +461,13 @@ func (t *tpmAttestRoot) AddDeviceROTs(deviceROTs []DeviceROT) {
t.deviceROTs = append(t.deviceROTs, deviceROTs...)
}

func (t *tpmAttestRoot) AttestDeviceROTs(nonce []byte) ([]any, error) {
t.tpmMu.Lock()
defer t.tpmMu.Unlock()

return doAttestDeviceROTs(t.deviceROTs, nonce)
}

type tdxAttestRoot struct {
tdxMu sync.Mutex
qp *tg.LinuxConfigFsQuoteProvider
Expand Down Expand Up @@ -461,28 +506,20 @@ func (t *tdxAttestRoot) Attest(nonce []byte) (any, error) {
return nil, err
}

var nvAtt *attestationpb.NvidiaAttestationReport
for _, deviceRoT := range t.deviceROTs {
att, err := deviceRoT.Attest(nonce)
if err != nil {
return nil, err
}
switch v := att.(type) {
case *attestationpb.NvidiaAttestationReport:
nvAtt = v
default:
return nil, fmt.Errorf("unknown device attestation type: %T", v)
}
}

return &verifier.TDCCELAttestation{
CcelAcpiTable: ccelTable,
CcelData: ccelData,
TdQuote: rawQuote,
NvidiaAttestation: nvAtt,
CcelAcpiTable: ccelTable,
CcelData: ccelData,
TdQuote: rawQuote,
}, nil
}

func (t *tdxAttestRoot) AttestDeviceROTs(nonce []byte) ([]any, error) {
t.tdxMu.Lock()
defer t.tdxMu.Unlock()

return doAttestDeviceROTs(t.deviceROTs, nonce)
}

func (t *tdxAttestRoot) ComputeNonce(challenge []byte, extraData []byte) []byte {
challengeData := challenge
if extraData != nil {
Expand Down Expand Up @@ -596,3 +633,15 @@ func convertToTPMQuote(v *pb.Attestation) *attestationpb.TpmQuote {
},
}
}

func doAttestDeviceROTs(deviceROTs []DeviceROT, nonce []byte) ([]any, error) {
var deviceReports []any
for _, deviceROT := range deviceROTs {
deviceReport, err := deviceROT.Attest(nonce)
if err != nil {
return nil, err
}
deviceReports = append(deviceReports, deviceReport)
}
return deviceReports, nil
}
Loading
Loading