Skip to content

Commit 56e9593

Browse files
authored
feat(shared): add tokens package for secure token generation and validation (#1331)
* feat(shared): add tokens package for secure token generation and validation Adds shared/pkg/tokens with GenerateToken, HashToken, ValidateTokenHash, TokenWithExpiry, IsExpired, and TimeUntilExpiry. Uses crypto/rand and SHA256 for secure single-use token workflows (invitations, password resets). * fix(shared/tokens): use constant-time comparison in ValidateTokenHash Replace string equality with crypto/subtle.ConstantTimeCompare to prevent timing side-channel attacks during token validation. --------- Co-authored-by: Ben Coombs <bjcoombs@users.noreply.github.com>
1 parent cc9960e commit 56e9593

4 files changed

Lines changed: 303 additions & 0 deletions

File tree

shared/pkg/tokens/expiry.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package tokens
2+
3+
import "time"
4+
5+
// TTL constants for different token types.
6+
const (
7+
InvitationTokenTTL = 72 * time.Hour
8+
PasswordResetTokenTTL = 1 * time.Hour
9+
)
10+
11+
// TokenWithExpiry pairs a stored token hash with its expiry time.
12+
type TokenWithExpiry struct {
13+
Hash string
14+
ExpiresAt time.Time
15+
}
16+
17+
// IsExpired returns true if the token's expiry time is at or before the current time.
18+
func IsExpired(t TokenWithExpiry) bool {
19+
return !time.Now().Before(t.ExpiresAt)
20+
}
21+
22+
// TimeUntilExpiry returns the duration remaining until the token expires.
23+
// A negative or zero duration indicates the token has already expired.
24+
func TimeUntilExpiry(t TokenWithExpiry) time.Duration {
25+
return time.Until(t.ExpiresAt)
26+
}

shared/pkg/tokens/expiry_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Package tokens_test provides tests for the tokens package.
2+
package tokens_test
3+
4+
import (
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
10+
"github.com/meridianhub/meridian/shared/pkg/tokens"
11+
)
12+
13+
func TestIsExpired(t *testing.T) {
14+
t.Run("returns false for token with future expiry", func(t *testing.T) {
15+
tok := tokens.TokenWithExpiry{
16+
Hash: "somehash",
17+
ExpiresAt: time.Now().Add(1 * time.Hour),
18+
}
19+
assert.False(t, tokens.IsExpired(tok))
20+
})
21+
22+
t.Run("returns true for token with past expiry", func(t *testing.T) {
23+
tok := tokens.TokenWithExpiry{
24+
Hash: "somehash",
25+
ExpiresAt: time.Now().Add(-1 * time.Second),
26+
}
27+
assert.True(t, tokens.IsExpired(tok))
28+
})
29+
30+
t.Run("returns true for token expiring exactly now (boundary)", func(t *testing.T) {
31+
tok := tokens.TokenWithExpiry{
32+
Hash: "somehash",
33+
ExpiresAt: time.Now(),
34+
}
35+
// At-or-past expiry is expired
36+
assert.True(t, tokens.IsExpired(tok))
37+
})
38+
}
39+
40+
func TestTimeUntilExpiry(t *testing.T) {
41+
t.Run("returns positive duration for future token", func(t *testing.T) {
42+
tok := tokens.TokenWithExpiry{
43+
Hash: "somehash",
44+
ExpiresAt: time.Now().Add(1 * time.Hour),
45+
}
46+
d := tokens.TimeUntilExpiry(tok)
47+
assert.Greater(t, d, time.Duration(0))
48+
})
49+
50+
t.Run("returns zero or negative for expired token", func(t *testing.T) {
51+
tok := tokens.TokenWithExpiry{
52+
Hash: "somehash",
53+
ExpiresAt: time.Now().Add(-1 * time.Second),
54+
}
55+
d := tokens.TimeUntilExpiry(tok)
56+
assert.LessOrEqual(t, d, time.Duration(0))
57+
})
58+
59+
t.Run("is approximately correct", func(t *testing.T) {
60+
future := time.Now().Add(72 * time.Hour)
61+
tok := tokens.TokenWithExpiry{
62+
Hash: "somehash",
63+
ExpiresAt: future,
64+
}
65+
d := tokens.TimeUntilExpiry(tok)
66+
// Should be within 1 second of 72 hours
67+
assert.InDelta(t, (72 * time.Hour).Seconds(), d.Seconds(), 1.0)
68+
})
69+
}
70+
71+
func TestConstants(t *testing.T) {
72+
t.Run("InvitationTokenLength is 32", func(t *testing.T) {
73+
assert.Equal(t, 32, tokens.InvitationTokenLength)
74+
})
75+
76+
t.Run("PasswordResetTokenLength is 32", func(t *testing.T) {
77+
assert.Equal(t, 32, tokens.PasswordResetTokenLength)
78+
})
79+
80+
t.Run("InvitationTokenTTL is 72 hours", func(t *testing.T) {
81+
assert.Equal(t, 72*time.Hour, tokens.InvitationTokenTTL)
82+
})
83+
84+
t.Run("PasswordResetTokenTTL is 1 hour", func(t *testing.T) {
85+
assert.Equal(t, 1*time.Hour, tokens.PasswordResetTokenTTL)
86+
})
87+
}

shared/pkg/tokens/token.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Package tokens provides secure token generation, hashing, and validation utilities
2+
// for single-use tokens such as invitations and password resets.
3+
package tokens
4+
5+
import (
6+
"crypto/rand"
7+
"crypto/sha256"
8+
"crypto/subtle"
9+
"encoding/base64"
10+
"encoding/hex"
11+
"errors"
12+
)
13+
14+
// ErrInvalidLength is returned when a non-positive token length is provided.
15+
var ErrInvalidLength = errors.New("token length must be positive")
16+
17+
// Token length constants for different use cases.
18+
const (
19+
InvitationTokenLength = 32
20+
PasswordResetTokenLength = 32
21+
)
22+
23+
// GenerateToken generates a cryptographically secure random token of the given byte length.
24+
// Returns the URL-safe base64 plaintext (for delivery to the user) and its SHA256 hex hash
25+
// (for storage). The plaintext must never be stored — only the hash is persisted.
26+
func GenerateToken(length int) (plaintext, hash string, err error) {
27+
if length <= 0 {
28+
return "", "", ErrInvalidLength
29+
}
30+
31+
b := make([]byte, length)
32+
if _, err = rand.Read(b); err != nil {
33+
return "", "", err
34+
}
35+
36+
plaintext = base64.RawURLEncoding.EncodeToString(b)
37+
hash = HashToken(plaintext)
38+
return plaintext, hash, nil
39+
}
40+
41+
// HashToken returns the SHA256 hex digest of the given plaintext token.
42+
// Use this to produce a storable hash; never store the plaintext.
43+
func HashToken(plaintext string) string {
44+
sum := sha256.Sum256([]byte(plaintext))
45+
return hex.EncodeToString(sum[:])
46+
}
47+
48+
// ValidateTokenHash returns true if the SHA256 hash of plaintext matches the stored hash.
49+
// Uses constant-time comparison to prevent timing attacks.
50+
func ValidateTokenHash(plaintext, hash string) bool {
51+
if plaintext == "" || hash == "" {
52+
return false
53+
}
54+
computed := HashToken(plaintext)
55+
return subtle.ConstantTimeCompare([]byte(computed), []byte(hash)) == 1
56+
}

shared/pkg/tokens/token_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Package tokens_test provides tests for the tokens package.
2+
package tokens_test
3+
4+
import (
5+
"strings"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/meridianhub/meridian/shared/pkg/tokens"
12+
)
13+
14+
func TestGenerateToken(t *testing.T) {
15+
t.Run("returns plaintext and hash", func(t *testing.T) {
16+
plaintext, hash, err := tokens.GenerateToken(tokens.InvitationTokenLength)
17+
require.NoError(t, err)
18+
assert.NotEmpty(t, plaintext)
19+
assert.NotEmpty(t, hash)
20+
})
21+
22+
t.Run("plaintext length matches requested bytes encoded as base64url", func(t *testing.T) {
23+
plaintext, _, err := tokens.GenerateToken(tokens.InvitationTokenLength)
24+
require.NoError(t, err)
25+
// base64url RawEncoding: ceil(n * 4/3), no padding
26+
// 32 bytes -> 43 chars
27+
assert.GreaterOrEqual(t, len(plaintext), 40)
28+
})
29+
30+
t.Run("plaintext contains only URL-safe characters", func(t *testing.T) {
31+
plaintext, _, err := tokens.GenerateToken(tokens.InvitationTokenLength)
32+
require.NoError(t, err)
33+
for _, c := range plaintext {
34+
assert.True(t, isURLSafeChar(c), "unexpected char: %c", c)
35+
}
36+
})
37+
38+
t.Run("hash matches HashToken of plaintext", func(t *testing.T) {
39+
plaintext, hash, err := tokens.GenerateToken(tokens.InvitationTokenLength)
40+
require.NoError(t, err)
41+
expected := tokens.HashToken(plaintext)
42+
assert.Equal(t, expected, hash)
43+
})
44+
45+
t.Run("returns error for zero length", func(t *testing.T) {
46+
_, _, err := tokens.GenerateToken(0)
47+
assert.Error(t, err)
48+
})
49+
50+
t.Run("returns error for negative length", func(t *testing.T) {
51+
_, _, err := tokens.GenerateToken(-1)
52+
assert.Error(t, err)
53+
})
54+
55+
t.Run("generates unique tokens", func(t *testing.T) {
56+
const count = 10000
57+
seen := make(map[string]struct{}, count)
58+
for i := 0; i < count; i++ {
59+
plaintext, _, err := tokens.GenerateToken(tokens.InvitationTokenLength)
60+
require.NoError(t, err)
61+
_, exists := seen[plaintext]
62+
assert.False(t, exists, "duplicate token at iteration %d", i)
63+
seen[plaintext] = struct{}{}
64+
}
65+
})
66+
}
67+
68+
func TestHashToken(t *testing.T) {
69+
t.Run("returns non-empty hex string", func(t *testing.T) {
70+
hash := tokens.HashToken("some-plaintext-token")
71+
assert.NotEmpty(t, hash)
72+
assert.True(t, isHex(hash), "hash should be hex: %s", hash)
73+
})
74+
75+
t.Run("SHA256 produces 64-char hex string", func(t *testing.T) {
76+
hash := tokens.HashToken("some-plaintext-token")
77+
assert.Len(t, hash, 64)
78+
})
79+
80+
t.Run("is deterministic", func(t *testing.T) {
81+
plaintext := "deterministic-token-value"
82+
hash1 := tokens.HashToken(plaintext)
83+
hash2 := tokens.HashToken(plaintext)
84+
assert.Equal(t, hash1, hash2)
85+
})
86+
87+
t.Run("different inputs produce different hashes", func(t *testing.T) {
88+
hash1 := tokens.HashToken("token-a")
89+
hash2 := tokens.HashToken("token-b")
90+
assert.NotEqual(t, hash1, hash2)
91+
})
92+
}
93+
94+
func TestValidateTokenHash(t *testing.T) {
95+
t.Run("returns true for matching plaintext and hash", func(t *testing.T) {
96+
plaintext, hash, err := tokens.GenerateToken(tokens.InvitationTokenLength)
97+
require.NoError(t, err)
98+
assert.True(t, tokens.ValidateTokenHash(plaintext, hash))
99+
})
100+
101+
t.Run("returns false for wrong plaintext", func(t *testing.T) {
102+
_, hash, err := tokens.GenerateToken(tokens.InvitationTokenLength)
103+
require.NoError(t, err)
104+
assert.False(t, tokens.ValidateTokenHash("wrong-plaintext", hash))
105+
})
106+
107+
t.Run("returns false for empty plaintext", func(t *testing.T) {
108+
_, hash, err := tokens.GenerateToken(tokens.InvitationTokenLength)
109+
require.NoError(t, err)
110+
assert.False(t, tokens.ValidateTokenHash("", hash))
111+
})
112+
113+
t.Run("returns false for empty hash", func(t *testing.T) {
114+
plaintext, _, err := tokens.GenerateToken(tokens.InvitationTokenLength)
115+
require.NoError(t, err)
116+
assert.False(t, tokens.ValidateTokenHash(plaintext, ""))
117+
})
118+
}
119+
120+
func isURLSafeChar(c rune) bool {
121+
return (c >= 'A' && c <= 'Z') ||
122+
(c >= 'a' && c <= 'z') ||
123+
(c >= '0' && c <= '9') ||
124+
c == '-' || c == '_'
125+
}
126+
127+
func isHex(s string) bool {
128+
for _, c := range strings.ToLower(s) {
129+
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
130+
return false
131+
}
132+
}
133+
return true
134+
}

0 commit comments

Comments
 (0)