diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..fb4f930 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,55 @@ +name: CI + +on: + push: + branches: [ main, nsm-modernize-with-tests ] + pull_request: + branches: [ main ] + +permissions: + contents: read + pull-requests: read + +jobs: + test: + name: Test + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + check-latest: true + + - name: Verify dependencies + run: go mod verify + + - name: Build + run: go build -v ./... + + - name: Vet + run: go vet ./... + + - name: Test with coverage + run: make test coverage + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + check-latest: true + + - name: Run lint via Makefile + run: make lint + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c56b772 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Code coverage profiles and other test artifacts +*.out +coverage.* +*.coverprofile +profile.cov + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + +# Editor/IDE +# .idea/ +# .vscode/ + +# AI +.claude diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..3bdf28f --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,44 @@ +run: + timeout: 5m + tests: true + modules-download-mode: readonly + +linters: + enable: + - bodyclose + - errcheck + - gofmt + - goimports + - gosec + - gosimple + - govet + - ineffassign + - misspell + - revive + - staticcheck + - unconvert + - unused + - gocyclo + - goconst + - unparam + +linters-settings: + gocyclo: + min-complexity: 15 + goconst: + min-len: 2 + min-occurrences: 2 + misspell: + locale: US + gosec: + excludes: + - G304 # File path provided by user input - acceptable for NSM device access + +issues: + exclude-rules: + - path: _test\.go + linters: + - gosec + - path: example/ + linters: + - gosec \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c59484c --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +.PHONY: all test lint coverage clean build fmt + +all: fmt lint test coverage + +# Build all packages +build: + go build -v ./... + +# Run tests with race detection and coverage +test: + go test -race -coverprofile=coverage.out -covermode=atomic ./... + +# Format code +fmt: + go fmt ./... + go run golang.org/x/tools/cmd/goimports@latest -w . + +# Run linting (depends on fmt) +lint: fmt + go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.61.0 run --timeout=5m + +# Generate and display coverage report +coverage: test + go tool cover -func=coverage.out + @echo "" + @COVERAGE=$$(go tool cover -func=coverage.out | grep total | awk '{print $$3}' | sed 's/%//'); \ + echo "Total coverage: $${COVERAGE}%"; \ + if [ "$$(echo "$${COVERAGE} < 60.0" | bc -l)" -eq 1 ]; then \ + echo "❌ Coverage $${COVERAGE}% is below minimum 60%"; \ + exit 1; \ + else \ + echo "✅ Coverage $${COVERAGE}% meets minimum threshold"; \ + fi + +# Generate HTML coverage report +coverage-html: test + go tool cover -html=coverage.out -o coverage.html + @echo "Coverage report generated: coverage.html" + +# Run go vet +vet: + go vet ./... + +# Verify dependencies +verify: + go mod verify + go mod tidy -diff + +# Clean build artifacts +clean: + go clean -testcache + rm -f coverage.out coverage.html + +# Run all checks (CI simulation) +ci: verify build vet fmt lint test coverage + @echo "✅ All CI checks passed" \ No newline at end of file diff --git a/README.md b/README.md index aeff182..4ce5939 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,10 @@ import ( func generateBigPrime() (*big.Int, error) { sess, err := nsm.OpenDefaultSession() - defer sess.Close() - - if nil != err { + if err != nil { return nil, err } + defer sess.Close() return rand.Prime(sess, 2048) } @@ -94,26 +93,25 @@ import ( func attest(nonce, userData, publicKey []byte) ([]byte, error) { sess, err := nsm.OpenDefaultSession() - defer sess.Close() - - if nil != err { + if err != nil { return nil, err } + defer sess.Close() res, err := sess.Send(&request.Attestation{ Nonce: nonce, UserData: userData, PublicKey: publicKey, }) - if nil != err { + if err != nil { return nil, err } - if "" != res.Error { + if res.Error != "" { return nil, errors.New(string(res.Error)) } - if nil == res.Attestation || nil == res.Attestation.Document { + if res.Attestation == nil || res.Attestation.Document == nil { return nil, errors.New("NSM device did not return an attestation") } diff --git a/example/attestation/main.go b/example/attestation/main.go index 8262eaf..aa77fe1 100644 --- a/example/attestation/main.go +++ b/example/attestation/main.go @@ -4,33 +4,33 @@ import ( "encoding/base64" "errors" "fmt" + "time" + "github.com/hf/nsm" "github.com/hf/nsm/request" - "time" ) func attest(nonce, userData, publicKey []byte) ([]byte, error) { sess, err := nsm.OpenDefaultSession() - defer sess.Close() - - if nil != err { + if err != nil { return nil, err } + defer sess.Close() res, err := sess.Send(&request.Attestation{ Nonce: nonce, UserData: userData, PublicKey: publicKey, }) - if nil != err { + if err != nil { return nil, err } - if "" != res.Error { + if res.Error != "" { return nil, errors.New(string(res.Error)) } - if nil == res.Attestation || nil == res.Attestation.Document { + if res.Attestation == nil || res.Attestation.Document == nil { return nil, errors.New("NSM device did not return an attestation") } diff --git a/go.mod b/go.mod index 34051ea..e4d7621 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,7 @@ module github.com/hf/nsm -go 1.15 +go 1.23 -require ( - github.com/fxamacker/cbor/v2 v2.2.0 - golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 // indirect - golang.org/x/tools v0.0.0-20210105210202-9ed45478a130 // indirect -) +require github.com/fxamacker/cbor/v2 v2.7.0 + +require github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index 80ca857..9a9c7a0 100644 --- a/go.sum +++ b/go.sum @@ -1,31 +1,4 @@ -github.com/fxamacker/cbor/v2 v2.2.0 h1:6eXqdDDe588rSYAi1HfZKbx6YYQO4mxQ9eC6xYpU/JQ= -github.com/fxamacker/cbor/v2 v2.2.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI= -golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7 h1:EBZoQjiKKPaLbPrbpssUfuHtwM6KV/vb4U85g/cigFY= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20210105210202-9ed45478a130 h1:8qSBr5nyKsEgkP918Pu5FFDZpTtLIjXSo6mrtdVOFfk= -golang.org/x/tools v0.0.0-20210105210202-9ed45478a130/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/ioc/ioc_test.go b/ioc/ioc_test.go new file mode 100644 index 0000000..fbf8d9b --- /dev/null +++ b/ioc/ioc_test.go @@ -0,0 +1,27 @@ +package ioc + +import ( + "testing" +) + +// TestCommand tests the IOCTL command calculation +func TestCommand(t *testing.T) { + // Test that Command returns a non-zero value for valid inputs + result := Command(READ|WRITE, 0x0A, 0, 16) + if result == 0 { + t.Error("Command should return non-zero value") + } + + // Test that different inputs produce different results + result1 := Command(READ, 0x0A, 0, 16) + result2 := Command(WRITE, 0x0A, 0, 16) + if result1 == result2 { + t.Error("Different directions should produce different commands") + } + + // Test with actual NSM values used in the main package + nsmCommand := Command(READ|WRITE, 0x0A, 0, 16) + if nsmCommand == 0 { + t.Error("NSM command should be non-zero") + } +} diff --git a/nsm.go b/nsm.go index 1f1cfdc..10af63a 100644 --- a/nsm.go +++ b/nsm.go @@ -20,6 +20,15 @@ const ( maxRequestSize = 0x1000 maxResponseSize = 0x3000 ioctlMagic = 0x0A + + // Error messages + errGetRandomNoBytes = "GetRandom response did not include random bytes" + errGetRandomFailedFmt = "GetRandom failed with error code %v" +) + +var ( + // ErrSessionClosed is returned when the session is in a closed state. + ErrSessionClosed = errors.New("Session is closed") ) // FileDescriptor is a generic file descriptor interface that can be closed. @@ -71,18 +80,13 @@ type ErrorGetRandomFailed struct { // Error returns the formatted string. func (err *ErrorGetRandomFailed) Error() string { - if "" != err.ErrorCode { - return fmt.Sprintf("GetRandom failed with error code %v", err.ErrorCode) + if err.ErrorCode != "" { + return fmt.Sprintf(errGetRandomFailedFmt, err.ErrorCode) } - return "GetRandom response did not include random bytes" + return errGetRandomNoBytes } -var ( - // ErrSessionClosed is returned when the session is in a closed state. - ErrSessionClosed error = errors.New("Session is closed") -) - // A Session is used to interact with the NSM. type Session struct { fd FileDescriptor @@ -97,6 +101,14 @@ type ioctlMessage struct { } func send(options Options, fd uintptr, req []byte, res []byte) ([]byte, error) { + // Validate slices to prevent panic on empty slices + if len(req) == 0 { + return nil, errors.New("request buffer is empty") + } + if len(res) == 0 { + return nil, errors.New("response buffer is empty") + } + iovecReq := syscall.Iovec{ Base: &req[0], } @@ -112,6 +124,9 @@ func send(options Options, fd uintptr, req []byte, res []byte) ([]byte, error) { Response: iovecRes, } + // IOCTL calls to /dev/nsm are synchronous and block until the NSM device + // responds. Each call performs a context switch to the Nitro hypervisor. + // Reference: https://github.com/aws/aws-nitro-enclaves-nsm-api _, _, err := options.Syscall( syscall.SYS_IOCTL, fd, @@ -119,34 +134,46 @@ func send(options Options, fd uintptr, req []byte, res []byte) ([]byte, error) { uintptr(unsafe.Pointer(&msg)), ) - if 0 != err { + if err != 0 { return nil, &ErrorIoctlFailed{ Errno: err, } } + // Validate response length to prevent buffer overrun + if msg.Response.Len > uint64(len(res)) { + return nil, fmt.Errorf("response length %d exceeds buffer size %d", msg.Response.Len, len(res)) + } + return res[:msg.Response.Len], nil } // OpenSession opens a new session with the provided options. func OpenSession(opts Options) (*Session, error) { - session := &Session{ - options: opts, + // Set defaults if not provided + if opts.Open == nil { + opts.Open = DefaultOptions.Open + } + if opts.Syscall == nil { + opts.Syscall = DefaultOptions.Syscall } fd, err := opts.Open() - if nil != err { - return session, err + if err != nil { + return nil, err } - session.fd = fd + session := &Session{ + options: opts, + fd: fd, + } session.reqpool = &sync.Pool{ - New: func() interface{} { + New: func() any { return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) }, } session.respool = &sync.Pool{ - New: func() interface{} { + New: func() any { return make([]byte, maxResponseSize) }, } @@ -162,35 +189,65 @@ func OpenDefaultSession() (*Session, error) { // Close this session. It is not thread safe to Close while other threads are // Read-ing or Send-ing. func (sess *Session) Close() error { - if nil == sess.fd { + if sess == nil || sess.fd == nil { return nil } - err := sess.fd.Close() - sess.fd = nil - sess.reqpool = nil - sess.respool = nil + var err error + // Always clear the session state to prevent reuse, even on panic + defer func() { + sess.fd = nil + sess.reqpool = nil + sess.respool = nil + }() + + // Close the file descriptor + err = sess.fd.Close() return err } -// Send an NSM request to the device and await its response. It safe to call -// this from multiple threads that are Read-ing or Send-ing, but not Close-ing. -// Each Send and Read call reserves at most 16KB of memory, so having multiple -// parallel sends or reads might lead to increased memory usage. +// Send an NSM request to the device and await its response. +// IOCTL operations are synchronous and expensive - each call blocks and +// context-switches to the Nitro hypervisor. Use sparingly. +// Safe to call from multiple goroutines, but not while Close-ing. +// Each call reserves up to 16KB of memory. func (sess *Session) Send(req request.Request) (response.Response, error) { - reqb := sess.reqpool.Get().(*bytes.Buffer) + if req == nil { + return response.Response{}, fmt.Errorf("request cannot be nil") + } + + if sess == nil || sess.fd == nil || sess.reqpool == nil || sess.respool == nil { + return response.Response{}, ErrSessionClosed + } + + reqbRaw := sess.reqpool.Get() + reqb, ok := reqbRaw.(*bytes.Buffer) + if !ok { + sess.reqpool.Put(reqbRaw) + return response.Response{}, fmt.Errorf("pool returned unexpected type %T", reqbRaw) + } defer sess.reqpool.Put(reqb) reqb.Reset() encoder := cbor.NewEncoder(reqb) err := encoder.Encode(req.Encoded()) - if nil != err { - return response.Response{}, err + if err != nil { + return response.Response{}, fmt.Errorf("failed to encode request: %w", err) } - resb := sess.respool.Get().([]byte) - defer sess.respool.Put(resb) + // Validate encoded request size + if reqb.Len() > maxRequestSize { + return response.Response{}, fmt.Errorf("encoded request size %d exceeds maximum %d", reqb.Len(), maxRequestSize) + } + + resbRaw := sess.respool.Get() + resb, ok := resbRaw.([]byte) + if !ok { + sess.respool.Put(resbRaw) + return response.Response{}, fmt.Errorf("pool returned unexpected type %T", resbRaw) + } + defer sess.respool.Put(resbRaw) return sess.sendMarshaled(reqb, resb) } @@ -198,31 +255,48 @@ func (sess *Session) Send(req request.Request) (response.Response, error) { func (sess *Session) sendMarshaled(reqb *bytes.Buffer, resb []byte) (response.Response, error) { res := response.Response{} - if nil == sess.fd { - return res, errors.New("Session is closed") + if sess == nil || sess.fd == nil { + return res, ErrSessionClosed } resb, err := send(sess.options, sess.fd.Fd(), reqb.Bytes(), resb) - if nil != err { + if err != nil { return res, err } + // Validate response data before unmarshaling + if len(resb) == 0 { + return res, fmt.Errorf("empty response from NSM device") + } + if len(resb) > maxResponseSize { + return res, fmt.Errorf("response size %d exceeds maximum %d", len(resb), maxResponseSize) + } + err = cbor.Unmarshal(resb, &res) - if nil != err { - return res, err + if err != nil { + return res, fmt.Errorf("failed to unmarshal CBOR response: %w", err) } return res, nil } -// Read entropy from the NSM device. It is safe to call this from multiple -// threads that are Read-ing or Send-ing, but not Close-ing. This method will -// always attempt to fill the whole slice with entropy thus blocking until that -// occurs. If reading fails, it is probably an irrecoverable error. Each Send -// and Read call reserves at most 16KB of memory, so having multiple parallel -// sends or reads might lead to increased memory usage. +// Read entropy from the NSM device. This method blocks until the entire slice +// is filled with cryptographically secure random bytes from the NSM. +// Each GetRandom request is a synchronous IOCTL that context-switches to the +// Nitro hypervisor, making it expensive. Consider using returned entropy to +// seed a DRBG rather than calling repeatedly. +// Safe to call from multiple goroutines, but not while Close-ing. func (sess *Session) Read(into []byte) (int, error) { - reqb := sess.reqpool.Get().(*bytes.Buffer) + if sess == nil || sess.fd == nil || sess.reqpool == nil || sess.respool == nil { + return 0, ErrSessionClosed + } + + reqbRaw := sess.reqpool.Get() + reqb, ok := reqbRaw.(*bytes.Buffer) + if !ok { + sess.reqpool.Put(reqbRaw) + return 0, fmt.Errorf("pool returned unexpected type %T", reqbRaw) + } defer sess.reqpool.Put(reqb) getRandom := request.GetRandom{} @@ -230,27 +304,38 @@ func (sess *Session) Read(into []byte) (int, error) { reqb.Reset() encoder := cbor.NewEncoder(reqb) err := encoder.Encode(getRandom.Encoded()) - if nil != err { + if err != nil { return 0, err } - resb := sess.respool.Get().([]byte) - defer sess.respool.Put(resb) + resbRaw := sess.respool.Get() + resb, ok := resbRaw.([]byte) + if !ok { + sess.respool.Put(resbRaw) + return 0, fmt.Errorf("pool returned unexpected type %T", resbRaw) + } + defer sess.respool.Put(resbRaw) - for i := 0; i < len(into); i += 0 { + for i := 0; i < len(into); { res, err := sess.sendMarshaled(reqb, resb) - if nil != err { + if err != nil { return i, err } - if "" != res.Error || nil == res.GetRandom || nil == res.GetRandom.Random || 0 == len(res.GetRandom.Random) { + if res.Error != "" || res.GetRandom == nil || res.GetRandom.Random == nil || len(res.GetRandom.Random) == 0 { return i, &ErrorGetRandomFailed{ ErrorCode: res.Error, } } - i += copy(into[i:], res.GetRandom.Random) + copied := copy(into[i:], res.GetRandom.Random) + if copied == 0 { + return i, &ErrorGetRandomFailed{ + ErrorCode: errGetRandomNoBytes, + } + } + i += copied } return len(into), nil diff --git a/nsm_test.go b/nsm_test.go new file mode 100644 index 0000000..4b90a3a --- /dev/null +++ b/nsm_test.go @@ -0,0 +1,467 @@ +package nsm + +import ( + "bytes" + "errors" + "os" + "sync" + "syscall" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/hf/nsm/request" +) + +// mockFileDescriptor implements FileDescriptor for testing +type mockFileDescriptor struct { + fd uintptr + closed bool + closeErr error +} + +func (m *mockFileDescriptor) Fd() uintptr { + return m.fd +} + +func (m *mockFileDescriptor) Close() error { + if m.closed { + return os.ErrClosed + } + m.closed = true + return m.closeErr +} + +// TestSendEmptyBufferValidation tests the fix for potential panics on empty slices +func TestSendEmptyBufferValidation(t *testing.T) { + opts := Options{ + Open: func() (FileDescriptor, error) { + return &mockFileDescriptor{fd: 1}, nil + }, + Syscall: func(_, _, _, _ uintptr) (r1, r2 uintptr, err syscall.Errno) { + return 0, 0, 0 + }, + } + + tests := []struct { + name string + req []byte + res []byte + wantErr string + }{ + { + name: "empty request buffer should error", + req: []byte{}, + res: make([]byte, 100), + wantErr: "request buffer is empty", + }, + { + name: "empty response buffer should error", + req: make([]byte, 100), + res: []byte{}, + wantErr: "response buffer is empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := send(opts, 1, tt.req, tt.res) + if err == nil { + t.Fatal("expected error but got nil") + } + if err.Error() != tt.wantErr { + t.Errorf("got error %q, want %q", err.Error(), tt.wantErr) + } + }) + } +} + +// TestOpenSessionErrorHandling tests that OpenSession returns nil on error +func TestOpenSessionErrorHandling(t *testing.T) { + t.Run("open failure returns nil session", func(t *testing.T) { + expectedErr := errors.New("failed to open device") + opts := Options{ + Open: func() (FileDescriptor, error) { + return nil, expectedErr + }, + } + + sess, err := OpenSession(opts) + if sess != nil { + t.Error("expected nil session on error, got non-nil") + } + if err != expectedErr { + t.Errorf("got error %v, want %v", err, expectedErr) + } + }) + + t.Run("successful open returns session", func(t *testing.T) { + opts := Options{ + Open: func() (FileDescriptor, error) { + return &mockFileDescriptor{fd: 1}, nil + }, + } + + sess, err := OpenSession(opts) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if sess == nil { + t.Error("expected non-nil session") + } else if sess.fd == nil { + t.Error("expected non-nil file descriptor") + } + }) + + t.Run("default session uses default options", func(t *testing.T) { + // Mock the default open to avoid accessing /dev/nsm + originalOpen := DefaultOptions.Open + defer func() { DefaultOptions.Open = originalOpen }() + + DefaultOptions.Open = func() (FileDescriptor, error) { + return &mockFileDescriptor{fd: 1}, nil + } + + sess, err := OpenDefaultSession() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if sess == nil { + t.Error("expected non-nil session") + } + }) +} + +// TestSessionCloseNilSafety tests Close with nil session and nil fd +func TestSessionCloseNilSafety(t *testing.T) { + t.Run("close on nil session", func(t *testing.T) { + var sess *Session + err := sess.Close() + if err != nil { + t.Errorf("expected nil error for nil session, got %v", err) + } + }) + + t.Run("close on session with nil fd", func(t *testing.T) { + sess := &Session{} + err := sess.Close() + if err != nil { + t.Errorf("expected nil error for nil fd, got %v", err) + } + }) + + t.Run("close sets fields to nil", func(t *testing.T) { + fd := &mockFileDescriptor{fd: 1} + sess := &Session{ + fd: fd, + reqpool: &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) + }, + }, + respool: &sync.Pool{ + New: func() any { + return make([]byte, maxResponseSize) + }, + }, + } + + err := sess.Close() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // Verify cleanup happened + if sess.fd != nil { + t.Error("fd should be nil after close") + } + if sess.reqpool != nil { + t.Error("reqpool should be nil after close") + } + if sess.respool != nil { + t.Error("respool should be nil after close") + } + if !fd.closed { + t.Error("underlying fd should be closed") + } + }) +} + +// TestSendNilValidation tests nil request validation +func TestSendNilValidation(t *testing.T) { + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + reqpool: &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) + }, + }, + } + + _, err := sess.Send(nil) + if err == nil { + t.Error("expected error for nil request") + } + expected := "request cannot be nil" + if err.Error() != expected { + t.Errorf("got error %q, want %q", err.Error(), expected) + } +} + +// TestSendClosedSession tests behavior on closed session +func TestSendClosedSession(t *testing.T) { + t.Run("send on session with nil fd", func(t *testing.T) { + sess := &Session{} // nil fd + _, err := sess.Send(&request.DescribeNSM{}) + if err != ErrSessionClosed { + t.Errorf("got error %v, want %v", err, ErrSessionClosed) + } + }) + + t.Run("send on closed session", func(t *testing.T) { + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + // pools are nil - simulates closed session + } + _, err := sess.Send(&request.DescribeNSM{}) + if err != ErrSessionClosed { + t.Errorf("got error %v, want %v", err, ErrSessionClosed) + } + }) +} + +// TestPoolTypeSafety tests safe type assertions for pools +func TestPoolTypeSafety(t *testing.T) { + // Create a pool that returns wrong type + badReqPool := &sync.Pool{ + New: func() any { + return "wrong type" // Should be *bytes.Buffer + }, + } + + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + reqpool: badReqPool, + respool: &sync.Pool{ + New: func() any { + return make([]byte, maxResponseSize) + }, + }, + } + + _, err := sess.Send(&request.DescribeNSM{}) + if err == nil { + t.Error("expected error for wrong pool type") + } + expected := "pool returned unexpected type string" + if err.Error() != expected { + t.Errorf("got error %q, want %q", err.Error(), expected) + } + + // Test wrong response pool type + badResPool := &sync.Pool{ + New: func() any { + return make([]int, 10) // Should be []byte + }, + } + + sess2 := &Session{ + fd: &mockFileDescriptor{fd: 1}, + reqpool: &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) + }, + }, + respool: badResPool, + } + + _, err = sess2.Send(&request.DescribeNSM{}) + if err == nil { + t.Error("expected error for wrong response pool type") + } + expected = "pool returned unexpected type []int" + if err.Error() != expected { + t.Errorf("got error %q, want %q", err.Error(), expected) + } +} + +// TestErrorTypes tests custom error type formatting +func TestErrorTypes(t *testing.T) { + t.Run("ErrorIoctlFailed", func(t *testing.T) { + err := &ErrorIoctlFailed{Errno: syscall.EINVAL} + expected := "ioctl failed on device with errno invalid argument" + if err.Error() != expected { + t.Errorf("got %q, want %q", err.Error(), expected) + } + }) + + t.Run("ErrorGetRandomFailed with error code", func(t *testing.T) { + err := &ErrorGetRandomFailed{ErrorCode: "InvalidRequest"} + expected := "GetRandom failed with error code InvalidRequest" + if err.Error() != expected { + t.Errorf("got %q, want %q", err.Error(), expected) + } + }) + + t.Run("ErrorGetRandomFailed without error code", func(t *testing.T) { + err := &ErrorGetRandomFailed{} + expected := "GetRandom response did not include random bytes" + if err.Error() != expected { + t.Errorf("got %q, want %q", err.Error(), expected) + } + }) +} + +// TestReadClosedSession tests Read method on closed session +func TestReadClosedSession(t *testing.T) { + sess := &Session{ + reqpool: &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) + }, + }, + respool: &sync.Pool{ + New: func() any { + return make([]byte, maxResponseSize) + }, + }, + } + + buf := make([]byte, 32) + n, err := sess.Read(buf) + if err != ErrSessionClosed { + t.Errorf("got error %v, want %v", err, ErrSessionClosed) + } + if n != 0 { + t.Errorf("expected 0 bytes read, got %d", n) + } +} + +// TestRequestSizeValidation tests request size limits +func TestRequestSizeValidation(t *testing.T) { + // Create a mock request that encodes to a large size + largeReq := &request.Attestation{ + UserData: make([]byte, maxRequestSize), // This will exceed limit when encoded + } + + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + reqpool: &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) + }, + }, + respool: &sync.Pool{ + New: func() any { + return make([]byte, maxResponseSize) + }, + }, + } + + _, err := sess.Send(largeReq) + if err == nil { + t.Error("expected error for oversized request") + } + if err.Error() == "" { + t.Error("error message should not be empty") + } + // The exact error depends on CBOR encoding, but should contain size information + t.Logf("Got expected error for large request: %v", err) +} + +// TestSendMarshaledValidation tests sendMarshaled function +func TestSendMarshaledValidation(t *testing.T) { + t.Run("empty response validation", func(t *testing.T) { + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + options: Options{ + Syscall: func(_, _, _, _ uintptr) (r1, r2 uintptr, err syscall.Errno) { + return 0, 0, 0 // Empty response + }, + }, + } + + reqb := bytes.NewBufferString("test") + resb := make([]byte, 0) // Empty response buffer from send() + + _, err := sess.sendMarshaled(reqb, resb) + if err == nil { + t.Error("expected error for empty response") + } + // The actual error depends on which validation catches it first + }) + + t.Run("CBOR unmarshal error", func(t *testing.T) { + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + options: Options{ + Syscall: func(_, _, _, _ uintptr) (r1, r2 uintptr, err syscall.Errno) { + return 0, 0, 0 + }, + }, + } + + reqb := bytes.NewBufferString("test") + resb := []byte{0xFF, 0xFF, 0xFF} // Invalid CBOR + + _, err := sess.sendMarshaled(reqb, resb) + if err == nil { + t.Error("expected error for invalid CBOR") + } + if err.Error() == "" { + t.Error("error message should not be empty") + } + t.Logf("Got expected CBOR error: %v", err) + }) +} + +// TestSendSuccessPath tests successful send operation +func TestSendSuccessPath(t *testing.T) { + // Create a valid CBOR response for DescribeNSM + validResponse := map[string]interface{}{ + "DescribeNSM": map[string]interface{}{ + "version_major": uint16(1), + "version_minor": uint16(0), + "module_id": "test-module", + }, + } + + validCBOR, err := cbor.Marshal(validResponse) + if err != nil { + t.Fatalf("failed to create test response: %v", err) + } + + sess := &Session{ + fd: &mockFileDescriptor{fd: 1}, + options: Options{ + Syscall: func(_, _, _, _ uintptr) (r1, r2 uintptr, err syscall.Errno) { + // Simulate successful syscall that copies response data + return 0, 0, 0 + }, + }, + reqpool: &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, maxRequestSize)) + }, + }, + respool: &sync.Pool{ + New: func() any { + // Return buffer with valid CBOR data + buf := make([]byte, maxResponseSize) + copy(buf, validCBOR) + return buf + }, + }, + } + + // Test sendMarshaled directly with valid CBOR + reqb := bytes.NewBufferString("test") + + response, err := sess.sendMarshaled(reqb, validCBOR) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if response.DescribeNSM == nil { + t.Error("expected DescribeNSM response") + } +} diff --git a/request/request.go b/request/request.go index d97f866..f2eed75 100644 --- a/request/request.go +++ b/request/request.go @@ -2,6 +2,12 @@ // payload. package request +const ( + // Request type names + nameDescribeNSM = "DescribeNSM" + nameGetRandom = "GetRandom" +) + // A Request interface. type Request interface { // Returns the Go-encoded form of the request, according to Rust's cbor @@ -69,7 +75,7 @@ type DescribeNSM struct { // Encoded returns the Go-encoded form of the request, according to Rust's cbor // serde. func (r *DescribeNSM) Encoded() interface{} { - return "DescribeNSM" + return nameDescribeNSM } // An Attestation request. @@ -94,5 +100,5 @@ type GetRandom struct { // Encoded returns the Go-encoded form of the request, according to Rust's cbor // serde. func (r *GetRandom) Encoded() interface{} { - return "GetRandom" + return nameGetRandom } diff --git a/request/request_test.go b/request/request_test.go new file mode 100644 index 0000000..2f15d57 --- /dev/null +++ b/request/request_test.go @@ -0,0 +1,79 @@ +package request + +import ( + "fmt" + "reflect" + "testing" +) + +// TestRequestEncoding tests all request types implement proper encoding +func TestRequestEncoding(t *testing.T) { + t.Run("DescribeNSM returns string", func(t *testing.T) { + req := &DescribeNSM{} + encoded := req.Encoded() + expected := "DescribeNSM" + if encoded != expected { + t.Errorf("got %v, want %v", encoded, expected) + } + }) + + t.Run("GetRandom returns string", func(t *testing.T) { + req := &GetRandom{} + encoded := req.Encoded() + expected := "GetRandom" + if encoded != expected { + t.Errorf("got %v, want %v", encoded, expected) + } + }) + + t.Run("DescribePCR returns map", func(t *testing.T) { + req := &DescribePCR{Index: 5} + encoded := req.Encoded() + expectedMap, ok := encoded.(map[string]*DescribePCR) + if !ok { + t.Errorf("expected map[string]*DescribePCR, got %T", encoded) + return + } + if expectedMap["DescribePCR"].Index != 5 { + t.Errorf("got Index %d, want 5", expectedMap["DescribePCR"].Index) + } + }) + + t.Run("Attestation returns map", func(t *testing.T) { + req := &Attestation{ + Nonce: []byte{1, 2}, + UserData: []byte{3, 4}, + PublicKey: []byte{5, 6}, + } + encoded := req.Encoded() + expectedMap, ok := encoded.(map[string]*Attestation) + if !ok { + t.Errorf("expected map[string]*Attestation, got %T", encoded) + return + } + att := expectedMap["Attestation"] + if !reflect.DeepEqual(att.Nonce, []byte{1, 2}) { + t.Errorf("got Nonce %v, want [1 2]", att.Nonce) + } + }) + + // Test that all request types can be encoded without panicking + requests := []Request{ + &DescribeNSM{}, + &DescribePCR{Index: 0}, + &ExtendPCR{Index: 0, Data: []byte{}}, + &LockPCR{Index: 0}, + &LockPCRs{Range: 0}, + &GetRandom{}, + &Attestation{}, + } + + for i, req := range requests { + t.Run(fmt.Sprintf("request_%d_encodes", i), func(t *testing.T) { + encoded := req.Encoded() + if encoded == nil { + t.Errorf("%T.Encoded() returned nil", req) + } + }) + } +} diff --git a/response/response.go b/response/response.go index 317c33b..056945f 100644 --- a/response/response.go +++ b/response/response.go @@ -3,6 +3,7 @@ package response import ( "fmt" + "github.com/fxamacker/cbor/v2" ) @@ -57,7 +58,7 @@ type DescribeNSM struct { VersionPatch uint16 `cbor:"version_patch" json:"version_patch,omitempty"` ModuleID string `cbor:"module_id" json:"module_id,omitempty"` MaxPCRs uint16 `cbor:"max_pcrs" json:"max_pcrs,omitempty"` - LockedPCRs []uint16 `cbor:"locked_pcrs" json:"digest,omitempty"` + LockedPCRs []uint16 `cbor:"locked_pcrs" json:"locked_pcrs,omitempty"` Digest Digest `cbor:"digest" json:"digest,omitempty"` } @@ -98,16 +99,16 @@ type mapResponse struct { // UnmarshalCBOR function to correctly unmarshal the CBOR encoding according to // Rust's cbor serde implementation. func (res *Response) UnmarshalCBOR(data []byte) error { - // One might try to question the sanity behind this decoding function. - // Please enjoy this: https://github.com/pyfisch/cbor/blob/2f2d0253e2d30e5ba7812cf0b149838b0c95530d/src/ser.rs#L83-L117 + // Handle CBOR encoding compatibility with Rust's serde implementation + // Reference: https://github.com/pyfisch/cbor/blob/2f2d0253e2d30e5ba7812cf0b149838b0c95530d/src/ser.rs#L83-L117 possiblyString := "" err := cbor.Unmarshal(data, &possiblyString) - if nil != err { + if err != nil { possiblyMap := mapResponse{} - err := cbor.Unmarshal(data, &possiblyMap) - if nil != err { - return err + mapErr := cbor.Unmarshal(data, &possiblyMap) + if mapErr != nil { + return fmt.Errorf("failed to unmarshal response as string (%v) or map (%v)", err, mapErr) } res.DescribePCR = possiblyMap.DescribePCR diff --git a/response/response_test.go b/response/response_test.go new file mode 100644 index 0000000..34d5f47 --- /dev/null +++ b/response/response_test.go @@ -0,0 +1,145 @@ +package response + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" +) + +// TestUnmarshalCBORStringResponses tests string-based responses +func TestUnmarshalCBORStringResponses(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + check func(*Response) bool + }{ + { + name: "LockPCR response", + input: "LockPCR", + wantErr: false, + check: func(r *Response) bool { + return r.LockPCR != nil && r.LockPCRs == nil + }, + }, + { + name: "LockPCRs response", + input: "LockPCRs", + wantErr: false, + check: func(r *Response) bool { + return r.LockPCRs != nil && r.LockPCR == nil + }, + }, + { + name: "unknown string response", + input: "UnknownResponse", + wantErr: true, + check: func(_ *Response) bool { return true }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := cbor.Marshal(tt.input) + if err != nil { + t.Fatalf("failed to marshal test input: %v", err) + } + + var res Response + err = res.UnmarshalCBOR(data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && !tt.check(&res) { + t.Errorf("UnmarshalCBOR() result check failed") + } + }) + } +} + +// TestUnmarshalCBORMapResponses tests map-based responses +func TestUnmarshalCBORMapResponses(t *testing.T) { + t.Run("DescribeNSM response", func(t *testing.T) { + input := mapResponse{ + DescribeNSM: &DescribeNSM{ + VersionMajor: 1, + VersionMinor: 0, + ModuleID: "test-module", + MaxPCRs: 16, + Digest: Digest("SHA256"), + }, + } + + data, err := cbor.Marshal(input) + if err != nil { + t.Fatalf("failed to marshal test input: %v", err) + } + + var res Response + err = res.UnmarshalCBOR(data) + if err != nil { + t.Errorf("UnmarshalCBOR() error = %v", err) + return + } + + if res.DescribeNSM == nil { + t.Error("expected DescribeNSM to be set") + return + } + + if res.DescribeNSM.ModuleID != "test-module" { + t.Errorf("got ModuleID %q, want %q", res.DescribeNSM.ModuleID, "test-module") + } + }) + + t.Run("Error response", func(t *testing.T) { + input := mapResponse{ + Error: "InvalidRequest", + } + + data, err := cbor.Marshal(input) + if err != nil { + t.Fatalf("failed to marshal test input: %v", err) + } + + var res Response + err = res.UnmarshalCBOR(data) + if err != nil { + t.Errorf("UnmarshalCBOR() error = %v", err) + return + } + + if res.Error != "InvalidRequest" { + t.Errorf("got Error %q, want %q", res.Error, "InvalidRequest") + } + }) +} + +// TestUnmarshalCBORInvalid tests handling of invalid CBOR data +func TestUnmarshalCBORInvalid(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + { + name: "invalid CBOR", + data: []byte{0xFF, 0xFF, 0xFF}, + }, + { + name: "empty data", + data: []byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var res Response + err := res.UnmarshalCBOR(tt.data) + if err == nil { + t.Error("expected error for invalid CBOR data") + } + }) + } +} diff --git a/tools.go b/tools.go new file mode 100644 index 0000000..41172d6 --- /dev/null +++ b/tools.go @@ -0,0 +1,7 @@ +//go:build tools + +package tools + +import ( + _ "github.com/golangci/golangci-lint/cmd/golangci-lint" +)