Skip to content

Commit f69ef9c

Browse files
committed
Replace teeserver binary with mock_wsd for mocking WorkloadAttestationService.
1 parent 91df863 commit f69ef9c

File tree

7 files changed

+296
-307
lines changed

7 files changed

+296
-307
lines changed

cmd/mock_wsd/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# mock_wsd
2+
3+
`mock_wsd` is a standalone binary that mocks the `WorkloadAttestationService` specifically for the `GetKeyEndorsement` endpoint.
4+
5+
The binary exposes UDS located at `/run/workload_attestation.sock`. It serves RESTful JSON API that answers requests formatted as `GetKeyEndorsementRequest` and returns `GetKeyEndorsementResponse`.
6+
7+
Internally, it spins up the `teeserver` logic and accesses the `AttestationAgent` directly to fetch a standalone VMAttestation, wraps it in the `GetKeyEndorsementResponse` struct, and sets the label to `WORKLOAD_ATTESTATION`.
8+
9+
> **Note on Attestation Evidence**
10+
> By default, `mock_wsd` is configured with `launchSpec.Experiments.EnableAttestationEvidence = false`, flip this value to `true` before compiling.
11+
12+
### 1. Build the binary
13+
14+
From the root of `go-tpm-tools`:
15+
16+
```bash
17+
go build -o mock_wsd_bin ./cmd/mock_wsd
18+
```
19+
20+
### 2. Start the `mock_wsd`
21+
22+
(Requires root privileges to listen on `/run/workload_attestation.sock` and access `/dev/tpmrm0`)
23+
24+
```bash
25+
sudo ./mock_wsd_bin
26+
```
27+
28+
### 3. Query `mock_wsd`
29+
30+
Once the server is running, you can query `mock_wsd` from another terminal. You can write the output to a JSON file (e.g., `evidence.json`) as follows:
31+
```bash
32+
curl --unix-socket /run/workload_attestation.sock \
33+
-H "Content-Type: application/json" \
34+
-d '{"challenge": "Y2hhbGxlbmdl", "key_handle": {"handle": "some_handle"}}' \
35+
http://localhost/v1/workload/attestation/key_endorsement | jq . > evidence.json
36+
```
37+
38+
This will save the JSON API response (containing the nested `KeyEndorsement` -> `VmProtectedKeyEndorsement` -> `KeyAttestation` -> `VMAttestation`) into `evidence.json`.

cmd/mock_wsd/main.go

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"log"
9+
"net"
10+
"net/http"
11+
"os"
12+
"os/signal"
13+
"strings"
14+
"syscall"
15+
16+
clogging "cloud.google.com/go/logging"
17+
"github.com/google/go-tpm-tools/client"
18+
"github.com/google/go-tpm-tools/launcher/agent"
19+
"github.com/google/go-tpm-tools/launcher/spec"
20+
"github.com/google/go-tpm-tools/launcher/teeserver"
21+
"github.com/google/go-tpm-tools/launcher/teeserver/models"
22+
"github.com/google/go-tpm-tools/verifier/util"
23+
)
24+
25+
const (
26+
teeserverSocketPath = "/run/container_launcher/teeserver.sock"
27+
mockWsdSocketPath = "/run/workload_attestation.sock"
28+
)
29+
30+
type cmdLogger struct {
31+
*log.Logger
32+
}
33+
34+
func (l *cmdLogger) Log(severity clogging.Severity, msg string, args ...any) {
35+
l.Printf("%v: %s %v\n", severity, msg, args)
36+
}
37+
38+
func (l *cmdLogger) Info(msg string, args ...any) {
39+
l.Printf("INFO: %s %v\n", msg, args)
40+
}
41+
42+
func (l *cmdLogger) Warn(msg string, args ...any) {
43+
l.Printf("WARN: %s %v\n", msg, args)
44+
}
45+
46+
func (l *cmdLogger) Error(msg string, args ...any) {
47+
l.Printf("ERROR: %s %v\n", msg, args)
48+
}
49+
50+
func (l *cmdLogger) SerialConsoleFile() *os.File {
51+
return nil
52+
}
53+
54+
func (l *cmdLogger) Close() {
55+
}
56+
57+
type KeyHandle struct {
58+
Handle string `json:"handle"`
59+
}
60+
61+
type GetKeyEndorsementRequest struct {
62+
Challenge []byte `json:"challenge"`
63+
KeyHandle KeyHandle `json:"key_handle"`
64+
}
65+
66+
type GetKeyEndorsementResponse struct {
67+
Endorsement KeyEndorsement `json:"endorsement"`
68+
}
69+
70+
type KeyEndorsement struct {
71+
VmProtectedKeyEndorsement VmProtectedKeyEndorsement `json:"vm_protected_key_endorsement"`
72+
}
73+
74+
type VmProtectedKeyEndorsement struct {
75+
BindingKeyAttestation *KeyAttestation `json:"binding_key_attestation,omitempty"`
76+
ProtectedKeyAttestation *KeyAttestation `json:"protected_key_attestation,omitempty"`
77+
}
78+
79+
type KeyAttestation struct {
80+
Attestation *models.VMAttestation `json:"attestation"`
81+
}
82+
83+
func main() {
84+
if os.Getuid() != 0 {
85+
log.Println("Warning: mock_wsd usually requires root privileges to create sockets in /run")
86+
}
87+
88+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
89+
defer stop()
90+
91+
logger := &cmdLogger{log.New(os.Stdout, "mock_wsd ", log.LstdFlags)}
92+
93+
// 1. Initialize Attestation Agent for teeserver
94+
launchSpec := spec.LaunchSpec{}
95+
launchSpec.Experiments.EnableAttestationEvidence = true
96+
97+
vTPM, err := os.OpenFile("/dev/tpmrm0", os.O_RDWR, 0)
98+
if err != nil {
99+
logger.Error(fmt.Sprintf("Failed to open vTPM: %v", err))
100+
} else {
101+
defer vTPM.Close()
102+
}
103+
var tpmCloser io.ReadWriteCloser
104+
if vTPM != nil {
105+
tpmCloser = vTPM
106+
}
107+
108+
var akFetcher util.TpmKeyFetcher
109+
if tpmCloser != nil {
110+
akFetcher = client.GceAttestationKeyECC
111+
} else {
112+
akFetcher = func(_ io.ReadWriter) (*client.Key, error) {
113+
return nil, fmt.Errorf("no vTPM available")
114+
}
115+
}
116+
117+
attestAgent, err := agent.CreateAttestationAgent(
118+
tpmCloser, akFetcher, nil, nil, nil, launchSpec, logger,
119+
)
120+
if err != nil {
121+
logger.Error(fmt.Sprintf("failed to create attestation agent: %v", err))
122+
os.Exit(1)
123+
}
124+
defer attestAgent.Close()
125+
126+
// 2. Start teeserver
127+
if err := os.MkdirAll("/run/container_launcher", 0755); err != nil { logger.Error(fmt.Sprintf("failed to create directory /run/container_launcher: %v", err)); os.Exit(1) }
128+
if err := os.RemoveAll(teeserverSocketPath); err != nil {
129+
logger.Error(fmt.Sprintf("Failed to remove existing socket %s: %v", teeserverSocketPath, err))
130+
}
131+
132+
clients := teeserver.AttestClients{GCA: nil, ITA: nil}
133+
tServer, err := teeserver.New(ctx, teeserverSocketPath, attestAgent, logger, launchSpec, clients)
134+
if err != nil {
135+
logger.Error(fmt.Sprintf("failed to create tee server: %v", err))
136+
os.Exit(1)
137+
}
138+
139+
if err := os.Chmod(teeserverSocketPath, 0777); err != nil {
140+
logger.Warn(fmt.Sprintf("failed to chmod socket %s: %v", teeserverSocketPath, err))
141+
}
142+
143+
logger.Info("Starting TEE Server", "socket", teeserverSocketPath)
144+
errChan := make(chan error, 2)
145+
go func() {
146+
errChan <- tServer.Serve()
147+
}()
148+
149+
// 3. Start mock_wsd server
150+
mux := http.NewServeMux()
151+
// Pass the attestAgent directly to avoid the loopback HTTP call
152+
mux.HandleFunc("/v1/workload/attestation/key_endorsement", func(w http.ResponseWriter, r *http.Request) {
153+
handleGetKeyEndorsement(w, r, attestAgent, logger)
154+
})
155+
156+
if err := os.RemoveAll(mockWsdSocketPath); err != nil {
157+
logger.Error(fmt.Sprintf("Failed to remove existing socket %s: %v", mockWsdSocketPath, err))
158+
}
159+
160+
listener, err := net.Listen("unix", mockWsdSocketPath)
161+
if err != nil {
162+
logger.Error(fmt.Sprintf("Failed to listen on %s: %v", mockWsdSocketPath, err))
163+
os.Exit(1)
164+
}
165+
defer listener.Close()
166+
167+
if err := os.Chmod(mockWsdSocketPath, 0777); err != nil {
168+
logger.Warn(fmt.Sprintf("failed to chmod socket %s: %v", mockWsdSocketPath, err))
169+
}
170+
171+
logger.Info("Starting mock_wsd server attached to TEE Server", "socket", mockWsdSocketPath)
172+
go func() {
173+
errChan <- http.Serve(listener, mux)
174+
}()
175+
176+
// 4. Wait for termination
177+
select {
178+
case err := <-errChan:
179+
if err != nil {
180+
logger.Error(fmt.Sprintf("server error: %v", err))
181+
os.Exit(1)
182+
}
183+
case <-ctx.Done():
184+
logger.Info("Shutting down servers")
185+
if err := tServer.Shutdown(context.Background()); err != nil {
186+
if !strings.Contains(err.Error(), "use of closed network connection") {
187+
logger.Error(fmt.Sprintf("failed to shutdown tee server: %v", err))
188+
}
189+
}
190+
}
191+
}
192+
193+
func handleGetKeyEndorsement(w http.ResponseWriter, r *http.Request, attestAgent agent.AttestationAgent, logger *cmdLogger) {
194+
if r.Method != http.MethodPost {
195+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
196+
return
197+
}
198+
199+
var req GetKeyEndorsementRequest
200+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
201+
http.Error(w, fmt.Sprintf("Failed to decode request: %v", err), http.StatusBadRequest)
202+
return
203+
}
204+
205+
if len(req.Challenge) == 0 {
206+
http.Error(w, "challenge is required", http.StatusBadRequest)
207+
return
208+
}
209+
210+
// Call the generic AttestationEvidence function directly from the agent
211+
attestation, err := attestAgent.AttestationEvidence(r.Context(), req.Challenge, nil)
212+
if err != nil {
213+
logger.Error(fmt.Sprintf("Error getting evidence from agent: %v", err))
214+
http.Error(w, fmt.Sprintf("Internal error: %v", err), http.StatusInternalServerError)
215+
return
216+
}
217+
218+
// Construct the response
219+
// For the mock, we only include the attestation in BindingKeyAttestation
220+
resp := GetKeyEndorsementResponse{
221+
Endorsement: KeyEndorsement{
222+
VmProtectedKeyEndorsement: VmProtectedKeyEndorsement{
223+
BindingKeyAttestation: &KeyAttestation{
224+
Attestation: attestation,
225+
},
226+
},
227+
},
228+
}
229+
230+
w.Header().Set("Content-Type", "application/json")
231+
if err := json.NewEncoder(w).Encode(resp); err != nil {
232+
logger.Error(fmt.Sprintf("Error encoding response: %v", err))
233+
}
234+
}

cmd/teeserver/README.md

Lines changed: 0 additions & 28 deletions
This file was deleted.

cmd/teeserver/evidence.go

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)