Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions pkg/claims/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
package claims

import (
"errors"
"fmt"
"iter"
"sort"
"strings"

"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/printer"

"github.com/ucan-wg/go-ucan/pkg/policy/literal"
"github.com/ucan-wg/go-ucan/pkg/secretbox"
)

var ErrNotFound = errors.New("key not found in claims")

var ErrNotEncryptable = errors.New("value of this type cannot be encrypted")

// Claims is a container for claims key-value pairs in an attestation token.
// This also serves as a way to construct the underlying IPLD data with minimum allocations
// and transformations, while hiding the IPLD complexity from the caller.
type Claims struct {
// This type must be compatible with the IPLD type represented by the IPLD
// schema { String : Any }.

Keys []string
Values map[string]ipld.Node
}

// NewClaims constructs a new Claims.
func NewClaims() *Claims {
return &Claims{Values: map[string]ipld.Node{}}
}

// GetBool retrieves a value as a bool.
// Returns ErrNotFound if the given key is missing.
// Returns datamodel.ErrWrongKind if the value has the wrong type.
func (m *Claims) GetBool(key string) (bool, error) {
v, ok := m.Values[key]
if !ok {
return false, ErrNotFound
}
return v.AsBool()
}

// GetString retrieves a value as a string.
// Returns ErrNotFound if the given key is missing.
// Returns datamodel.ErrWrongKind if the value has the wrong type.
func (m *Claims) GetString(key string) (string, error) {
v, ok := m.Values[key]
if !ok {
return "", ErrNotFound
}
return v.AsString()
}

// GetEncryptedString decorates GetString and decrypt its output with the given symmetric encryption key.
func (m *Claims) GetEncryptedString(key string, encryptionKey []byte) (string, error) {
v, err := m.GetBytes(key)
if err != nil {
return "", err
}

decrypted, err := secretbox.DecryptStringWithKey(v, encryptionKey)
if err != nil {
return "", err
}

return string(decrypted), nil
}

// GetInt64 retrieves a value as an int64.
// Returns ErrNotFound if the given key is missing.
// Returns datamodel.ErrWrongKind if the value has the wrong type.
func (m *Claims) GetInt64(key string) (int64, error) {
v, ok := m.Values[key]
if !ok {
return 0, ErrNotFound
}
return v.AsInt()
}

// GetFloat64 retrieves a value as a float64.
// Returns ErrNotFound if the given key is missing.
// Returns datamodel.ErrWrongKind if the value has the wrong type.
func (m *Claims) GetFloat64(key string) (float64, error) {
v, ok := m.Values[key]
if !ok {
return 0, ErrNotFound
}
return v.AsFloat()
}

// GetBytes retrieves a value as a []byte.
// Returns ErrNotFound if the given key is missing.
// Returns datamodel.ErrWrongKind if the value has the wrong type.
func (m *Claims) GetBytes(key string) ([]byte, error) {
v, ok := m.Values[key]
if !ok {
return nil, ErrNotFound
}
return v.AsBytes()
}

// GetEncryptedBytes decorates GetBytes and decrypt its output with the given symmetric encryption key.
func (m *Claims) GetEncryptedBytes(key string, encryptionKey []byte) ([]byte, error) {
v, err := m.GetBytes(key)
if err != nil {
return nil, err
}

decrypted, err := secretbox.DecryptStringWithKey(v, encryptionKey)
if err != nil {
return nil, err
}

return decrypted, nil
}

// GetNode retrieves a value as a raw IPLD node.
// Returns ErrNotFound if the given key is missing.
func (m *Claims) GetNode(key string) (ipld.Node, error) {
v, ok := m.Values[key]
if !ok {
return nil, ErrNotFound
}
return v, nil
}

// Add adds a key/value pair in the claims set.
// Accepted types for val are any CBOR compatible type, or directly IPLD values.
func (m *Claims) Add(key string, val any) error {
if _, ok := m.Values[key]; ok {
return fmt.Errorf("duplicate key %q", key)
}

node, err := literal.Any(val)
if err != nil {
return err
}

m.Keys = append(m.Keys, key)
m.Values[key] = node

return nil
}

// AddEncrypted adds a key/value pair in the claims set.
// The value is encrypted with the given encryptionKey.
// Accepted types for the value are: string, []byte.
// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead.
func (m *Claims) AddEncrypted(key string, val any, encryptionKey []byte) error {
var encrypted []byte
var err error

switch val := val.(type) {
case string:
encrypted, err = secretbox.EncryptWithKey([]byte(val), encryptionKey)
if err != nil {
return err
}
case []byte:
encrypted, err = secretbox.EncryptWithKey(val, encryptionKey)
if err != nil {
return err
}
default:
return ErrNotEncryptable
}

return m.Add(key, encrypted)
}

type Iterator interface {
Iter() iter.Seq2[string, ipld.Node]
}

// Include merges the provided claims into the existing one.
//
// If duplicate keys are encountered, the new value is silently dropped
// without causing an error.
func (m *Claims) Include(other Iterator) {
for key, value := range other.Iter() {
if _, ok := m.Values[key]; ok {
// don't overwrite
continue
}
m.Values[key] = value
m.Keys = append(m.Keys, key)
}
}

// Len returns the number of key/values.
func (m *Claims) Len() int {
return len(m.Values)
}

// Iter iterates over the claims key/values
func (m *Claims) Iter() iter.Seq2[string, ipld.Node] {
return func(yield func(string, ipld.Node) bool) {
for _, key := range m.Keys {
if !yield(key, m.Values[key]) {
return
}
}
}
}

// Equals tells if two Claims hold the same key/values.
func (m *Claims) Equals(other *Claims) bool {
if len(m.Keys) != len(other.Keys) {
return false
}
if len(m.Values) != len(other.Values) {
return false
}
for _, key := range m.Keys {
if !ipld.DeepEqual(m.Values[key], other.Values[key]) {
return false
}
}
return true
}

func (m *Claims) String() string {
sort.Strings(m.Keys)

buf := strings.Builder{}
buf.WriteString("{")

for key, node := range m.Values {
buf.WriteString("\n\t")
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(strings.ReplaceAll(printer.Sprint(node), "\n", "\n\t"))
buf.WriteString(",")
}

if len(m.Values) > 0 {
buf.WriteString("\n")
}
buf.WriteString("}")

return buf.String()
}

// ReadOnly returns a read-only version of Claims.
func (m *Claims) ReadOnly() ReadOnly {
return ReadOnly{claims: m}
}

// Clone makes a deep copy.
func (m *Claims) Clone() *Claims {
res := &Claims{
Keys: make([]string, len(m.Keys)),
Values: make(map[string]ipld.Node, len(m.Values)),
}
copy(res.Keys, m.Keys)
for k, v := range m.Values {
res.Values[k] = v
}
return res
}
130 changes: 130 additions & 0 deletions pkg/claims/claims_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package claims_test

import (
"crypto/rand"
"maps"
"testing"

"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/stretchr/testify/require"

"github.com/ucan-wg/go-ucan/pkg/claims"
)

func TestClaims_Add(t *testing.T) {
t.Parallel()

type Unsupported struct{}

t.Run("error if not primitive or Node", func(t *testing.T) {
t.Parallel()

err := (&claims.Claims{}).Add("invalid", &Unsupported{})
require.Error(t, err)
})

t.Run("encrypted claims", func(t *testing.T) {
t.Parallel()

key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)

m := claims.NewClaims()

// string encryption
err = m.AddEncrypted("secret", "hello world", key)
require.NoError(t, err)

_, err = m.GetString("secret")
require.Error(t, err) // the ciphertext is saved as []byte instead of string

decrypted, err := m.GetEncryptedString("secret", key)
require.NoError(t, err)
require.Equal(t, "hello world", decrypted)

// bytes encryption
originalBytes := make([]byte, 128)
_, err = rand.Read(originalBytes)
require.NoError(t, err)
err = m.AddEncrypted("secret-bytes", originalBytes, key)
require.NoError(t, err)

encryptedBytes, err := m.GetBytes("secret-bytes")
require.NoError(t, err)
require.NotEqual(t, originalBytes, encryptedBytes)

decryptedBytes, err := m.GetEncryptedBytes("secret-bytes", key)
require.NoError(t, err)
require.Equal(t, originalBytes, decryptedBytes)

// error cases
t.Run("error on unsupported type", func(t *testing.T) {
err := m.AddEncrypted("invalid", 123, key)
require.ErrorIs(t, err, claims.ErrNotEncryptable)
})

t.Run("error on invalid key size", func(t *testing.T) {
err := m.AddEncrypted("invalid", "test", []byte("short-key"))
require.Error(t, err)
require.Contains(t, err.Error(), "invalid key size")
})

t.Run("error on nil key", func(t *testing.T) {
err := m.AddEncrypted("invalid", "test", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key is required")
})
})
}

func TestIterCloneEquals(t *testing.T) {
m := claims.NewClaims()

require.NoError(t, m.Add("foo", "bar"))
require.NoError(t, m.Add("baz", 1234))

expected := map[string]ipld.Node{
"foo": basicnode.NewString("bar"),
"baz": basicnode.NewInt(1234),
}

// claims -> iter
require.Equal(t, expected, maps.Collect(m.Iter()))

// readonly -> iter
ro := m.ReadOnly()
require.Equal(t, expected, maps.Collect(ro.Iter()))

// claims -> clone -> iter
clone := m.Clone()
require.Equal(t, expected, maps.Collect(clone.Iter()))

// readonly -> WriteableClone -> iter
wclone := ro.WriteableClone()
require.Equal(t, expected, maps.Collect(wclone.Iter()))

require.True(t, m.Equals(wclone))
require.True(t, ro.Equals(wclone.ReadOnly()))
}

func TestInclude(t *testing.T) {
m1 := claims.NewClaims()

require.NoError(t, m1.Add("samekey", "bar"))
require.NoError(t, m1.Add("baz", 1234))

m2 := claims.NewClaims()

require.NoError(t, m2.Add("samekey", "othervalue")) // check no overwrite
require.NoError(t, m2.Add("otherkey", 1234))

m1.Include(m2)

require.Equal(t, map[string]ipld.Node{
"samekey": basicnode.NewString("bar"),
"baz": basicnode.NewInt(1234),
"otherkey": basicnode.NewInt(1234),
}, maps.Collect(m1.Iter()))
}
Loading
Loading