Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions sidecar/internal/crypto/hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/binary"
"errors"
"os"
)
Expand Down Expand Up @@ -55,9 +56,54 @@ func (s *Signer) Sign(msg []byte) []byte {
return mac.Sum(nil)
}

// SignHex computes HMAC-SHA256 over msg and returns the hex-encoded MAC.
func (s *Signer) SignHex(msg []byte) string {
return hex.EncodeToString(s.Sign(msg))
}

// Verify returns true iff mac matches HMAC-SHA256 over msg.
// Uses hmac.Equal for constant-time comparison to prevent timing attacks.
func (s *Signer) Verify(msg, mac []byte) bool {
expected := s.Sign(msg)
return hmac.Equal(expected, mac)
}

// VerifyHex returns true iff macHex decodes and matches the expected HMAC.
func (s *Signer) VerifyHex(msg []byte, macHex string) bool {
decoded, err := hex.DecodeString(macHex)
if err != nil {
return false
}
return s.Verify(msg, decoded)
}

// ProvenanceMessage returns the canonical byte sequence covered by the
// provenance HMAC. Each field is length-prefixed to avoid ambiguity.
func ProvenanceMessage(hookType, provenance, sessionID, executionID, nonce string, expiresAtUnix int64, payload []byte) []byte {
fields := [][]byte{
[]byte(hookType),
[]byte(provenance),
[]byte(sessionID),
[]byte(executionID),
[]byte(nonce),
}

total := 8 + len(payload)
for _, field := range fields {
total += 4 + len(field)
}

buf := make([]byte, 0, total)
for _, field := range fields {
var lenBuf [4]byte
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(field)))
buf = append(buf, lenBuf[:]...)
buf = append(buf, field...)
}

var expBuf [8]byte
binary.BigEndian.PutUint64(expBuf[:], uint64(expiresAtUnix))
buf = append(buf, expBuf[:]...)
buf = append(buf, payload...)
return buf
}
17 changes: 17 additions & 0 deletions sidecar/internal/crypto/hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,20 @@ func TestNewSignerFromEnv_InvalidHex(t *testing.T) {
t.Error("expected error for invalid hex value")
}
}

func TestVerifyHex_ValidMAC(t *testing.T) {
s := testSigner(t)
msg := []byte("verify me")
macHex := s.SignHex(msg)
if !s.VerifyHex(msg, macHex) {
t.Error("VerifyHex returned false for a valid MAC")
}
}

func TestProvenanceMessage_Deterministic(t *testing.T) {
msg1 := ProvenanceMessage("on_prompt", "user", "s1", "e1", "n1", 1700000000, []byte(`"hello"`))
msg2 := ProvenanceMessage("on_prompt", "user", "s1", "e1", "n1", 1700000000, []byte(`"hello"`))
if string(msg1) != string(msg2) {
t.Error("ProvenanceMessage should be deterministic")
}
}
11 changes: 8 additions & 3 deletions sidecar/internal/crypto/nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,21 @@ func NewNonceStore(ttl time.Duration) *NonceStore {
// (or has already expired). Returns true if the nonce is still active —
// indicating a replay attempt. The check and record are atomic.
func (ns *NonceStore) Seen(nonce []byte) bool {
key := string(nonce)
return ns.SeenString(string(nonce))
}

// SeenString is the string variant of Seen and is used by the validate stage
// for provenance nonces carried inside the JSON payload.
func (ns *NonceStore) SeenString(nonce string) bool {
now := time.Now()

ns.mu.Lock()
defer ns.mu.Unlock()

if exp, ok := ns.m[key]; ok && now.Before(exp) {
if exp, ok := ns.m[nonce]; ok && now.Before(exp) {
return true // replay
}
ns.m[key] = now.Add(ns.ttl)
ns.m[nonce] = now.Add(ns.ttl)
return false
}

Expand Down
12 changes: 12 additions & 0 deletions sidecar/internal/crypto/nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ func TestNonceStore_Concurrent(t *testing.T) {
}
wg.Wait()
}

func TestNonceStore_SeenString(t *testing.T) {
ns := NewNonceStore(5 * time.Minute)
defer ns.Stop()

if ns.SeenString("string-nonce") {
t.Fatal("SeenString returned true on first use")
}
if !ns.SeenString("string-nonce") {
t.Fatal("SeenString returned false on replay")
}
}
114 changes: 114 additions & 0 deletions sidecar/internal/pipeline/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,117 @@
// - JSON schema validation of the RiskContext payload
// Invalid frames are rejected before any payload parsing.
package pipeline

import (
"encoding/json"
"errors"
"strings"
"time"

"github.com/acf-sdk/sidecar/internal/crypto"
"github.com/acf-sdk/sidecar/pkg/riskcontext"
)

const (
ProvenanceFlagReplay = "replay"
ProvenanceFlagExpired = "expired"
ProvenanceFlagMismatch = "mismatch"
)

var errSchemaInvalid = errors.New("pipeline: invalid risk context schema")

// ValidationConfig contains the external dependencies for the validate stage.
type ValidationConfig struct {
Signer *crypto.Signer
NonceStore *crypto.NonceStore
ExpectedExecutionID string
Now func() time.Time
}

// ValidatePayload parses and validates an inbound RiskContext JSON payload.
// It writes a normalized provenance signal into the returned context and
// returns BLOCK when provenance validation fails.
func ValidatePayload(payload []byte, cfg ValidationConfig) (*riskcontext.RiskContext, byte, error) {
ctx := &riskcontext.RiskContext{}
if err := json.Unmarshal(payload, ctx); err != nil {
return nil, 0, err
}
if err := validateSchema(ctx); err != nil {
ctx.ProvenanceTrust = 0
ctx.ProvenanceFlags = []string{ProvenanceFlagMismatch}
return ctx, 0x02, err
}
if err := validateProvenance(ctx, cfg); err != nil {
return ctx, 0x02, err
}
ctx.ProvenanceTrust = 1.0
ctx.ProvenanceFlags = nil
return ctx, 0x00, nil
}

func validateSchema(ctx *riskcontext.RiskContext) error {
if strings.TrimSpace(ctx.HookType) == "" || strings.TrimSpace(ctx.Provenance) == "" {
return errSchemaInvalid
}
if ctx.Payload == nil {
return errSchemaInvalid
}
return nil
}

func validateProvenance(ctx *riskcontext.RiskContext, cfg ValidationConfig) error {
flags := make([]string, 0, 3)
markFailure := func(flag string, err error) error {
ctx.ProvenanceTrust = 0
ctx.ProvenanceFlags = appendUnique(flags, flag)
return err
}

if cfg.Signer == nil || cfg.NonceStore == nil {
return markFailure(ProvenanceFlagMismatch, errSchemaInvalid)
}
if strings.TrimSpace(ctx.ExecutionID) == "" || strings.TrimSpace(ctx.ProvenanceNonce) == "" || strings.TrimSpace(ctx.ProvenanceHMAC) == "" || ctx.ExpiresAtUnix == 0 {
return markFailure(ProvenanceFlagMismatch, errSchemaInvalid)
}
if cfg.ExpectedExecutionID != "" && ctx.ExecutionID != cfg.ExpectedExecutionID {
return markFailure(ProvenanceFlagMismatch, errors.New("pipeline: execution_id mismatch"))
}

now := time.Now
if cfg.Now != nil {
now = cfg.Now
}
if now().After(time.Unix(ctx.ExpiresAtUnix, 0)) {
return markFailure(ProvenanceFlagExpired, errors.New("pipeline: provenance expired"))
}

payloadBytes, err := json.Marshal(ctx.Payload)
if err != nil {
return markFailure(ProvenanceFlagMismatch, err)
}
msg := crypto.ProvenanceMessage(
ctx.HookType,
ctx.Provenance,
ctx.SessionID,
ctx.ExecutionID,
ctx.ProvenanceNonce,
ctx.ExpiresAtUnix,
payloadBytes,
)
if !cfg.Signer.VerifyHex(msg, ctx.ProvenanceHMAC) {
return markFailure(ProvenanceFlagMismatch, errors.New("pipeline: provenance HMAC mismatch"))
}
if cfg.NonceStore.SeenString(ctx.ProvenanceNonce) {
return markFailure(ProvenanceFlagReplay, errors.New("pipeline: provenance nonce replay detected"))
}
return nil
}

func appendUnique(flags []string, flag string) []string {
for _, existing := range flags {
if existing == flag {
return flags
}
}
return append(flags, flag)
}
Loading