Skip to content

Commit 3da4b03

Browse files
feat(sensitivity): add sensitivity tagging with pattern-based auto-classification (#85)
- New pkg/sensitivity with Classifier, Level enum (None/PII/InternalIP/Credentials) - Built-in patterns: email, phone, credit card, SSN, AWS keys, OpenAI keys, GitHub tokens, Slack tokens, generic secrets - Configurable internal domain detection (.internal, .corp, .local) - StoreEntry accepts Sensitivity (explicit) and AutoClassify (pattern-based) - RecallResult includes MaxSensitivity and SensitiveChunks metadata - Sensitivity stored in SQLite, does not affect dedup or ranking - 17 classifier tests + 4 benchmarks (all <1ms) - 7 memory integration tests for sensitivity propagation Closes #82 Co-authored-by: Ona <no-reply@ona.com>
1 parent 8260f4f commit 3da4b03

5 files changed

Lines changed: 681 additions & 27 deletions

File tree

pkg/memory/sensitivity_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package memory
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/Siddhant-K-code/distill/pkg/sensitivity"
8+
)
9+
10+
func TestStore_ExplicitSensitivity(t *testing.T) {
11+
s := newTestStore(t)
12+
ctx := context.Background()
13+
14+
_, err := s.Store(ctx, StoreRequest{
15+
Entries: []StoreEntry{
16+
{Text: "Q3 pricing: customer A at $120k", Sensitivity: sensitivity.InternalIP, Embedding: makeEmbedding(0, 8)},
17+
},
18+
})
19+
if err != nil {
20+
t.Fatalf("Store: %v", err)
21+
}
22+
23+
recall, err := s.Recall(ctx, RecallRequest{
24+
Query: "pricing", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10,
25+
})
26+
if err != nil {
27+
t.Fatalf("Recall: %v", err)
28+
}
29+
if recall.MaxSensitivity != sensitivity.InternalIP {
30+
t.Errorf("expected MaxSensitivity=InternalIP, got %s", recall.MaxSensitivity)
31+
}
32+
if len(recall.SensitiveChunks) != 1 {
33+
t.Fatalf("expected 1 sensitive chunk, got %d", len(recall.SensitiveChunks))
34+
}
35+
if recall.SensitiveChunks[0].Sensitivity != sensitivity.InternalIP {
36+
t.Errorf("expected chunk sensitivity InternalIP, got %s", recall.SensitiveChunks[0].Sensitivity)
37+
}
38+
}
39+
40+
func TestStore_AutoClassify_Credentials(t *testing.T) {
41+
s := newTestStore(t)
42+
ctx := context.Background()
43+
44+
_, err := s.Store(ctx, StoreRequest{
45+
Entries: []StoreEntry{
46+
{Text: "API key: sk-proj-abc123def456ghi789jkl012", AutoClassify: true, Embedding: makeEmbedding(0, 8)},
47+
},
48+
})
49+
if err != nil {
50+
t.Fatalf("Store: %v", err)
51+
}
52+
53+
recall, _ := s.Recall(ctx, RecallRequest{
54+
Query: "key", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10,
55+
})
56+
if recall.MaxSensitivity != sensitivity.Credentials {
57+
t.Errorf("expected MaxSensitivity=Credentials, got %s", recall.MaxSensitivity)
58+
}
59+
}
60+
61+
func TestStore_AutoClassify_PII(t *testing.T) {
62+
s := newTestStore(t)
63+
ctx := context.Background()
64+
65+
_, err := s.Store(ctx, StoreRequest{
66+
Entries: []StoreEntry{
67+
{Text: "Contact alice@example.com for the report", AutoClassify: true, Embedding: makeEmbedding(0, 8)},
68+
},
69+
})
70+
if err != nil {
71+
t.Fatalf("Store: %v", err)
72+
}
73+
74+
recall, _ := s.Recall(ctx, RecallRequest{
75+
Query: "contact", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10,
76+
})
77+
if recall.MaxSensitivity != sensitivity.PII {
78+
t.Errorf("expected MaxSensitivity=PII, got %s", recall.MaxSensitivity)
79+
}
80+
}
81+
82+
func TestStore_AutoClassify_NoMatch(t *testing.T) {
83+
s := newTestStore(t)
84+
ctx := context.Background()
85+
86+
_, err := s.Store(ctx, StoreRequest{
87+
Entries: []StoreEntry{
88+
{Text: "The service uses REST APIs", AutoClassify: true, Embedding: makeEmbedding(0, 8)},
89+
},
90+
})
91+
if err != nil {
92+
t.Fatalf("Store: %v", err)
93+
}
94+
95+
recall, _ := s.Recall(ctx, RecallRequest{
96+
Query: "service", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10,
97+
})
98+
if recall.MaxSensitivity != sensitivity.None {
99+
t.Errorf("expected MaxSensitivity=None, got %s", recall.MaxSensitivity)
100+
}
101+
if len(recall.SensitiveChunks) != 0 {
102+
t.Errorf("expected 0 sensitive chunks, got %d", len(recall.SensitiveChunks))
103+
}
104+
}
105+
106+
func TestStore_AutoClassify_ExplicitOverride(t *testing.T) {
107+
s := newTestStore(t)
108+
ctx := context.Background()
109+
110+
// Explicit sensitivity is higher than what auto-classify would find
111+
_, err := s.Store(ctx, StoreRequest{
112+
Entries: []StoreEntry{
113+
{
114+
Text: "Normal text with no patterns",
115+
Sensitivity: sensitivity.Credentials,
116+
AutoClassify: true,
117+
Embedding: makeEmbedding(0, 8),
118+
},
119+
},
120+
})
121+
if err != nil {
122+
t.Fatalf("Store: %v", err)
123+
}
124+
125+
recall, _ := s.Recall(ctx, RecallRequest{
126+
Query: "text", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10,
127+
})
128+
// Explicit Credentials should be preserved even though auto-classify finds None
129+
if recall.MaxSensitivity != sensitivity.Credentials {
130+
t.Errorf("expected MaxSensitivity=Credentials (explicit), got %s", recall.MaxSensitivity)
131+
}
132+
}
133+
134+
func TestRecall_MaxSensitivity_MultipleEntries(t *testing.T) {
135+
s := newTestStore(t)
136+
ctx := context.Background()
137+
138+
_, err := s.Store(ctx, StoreRequest{
139+
Entries: []StoreEntry{
140+
{Text: "Normal architecture notes", Embedding: makeEmbedding(0, 8)},
141+
{Text: "Contact bob@company.com", Sensitivity: sensitivity.PII, Embedding: makeEmbedding(1.5, 8)},
142+
{Text: "AWS key AKIAIOSFODNN7EXAMPLE", Sensitivity: sensitivity.Credentials, Embedding: makeEmbedding(3.0, 8)},
143+
},
144+
})
145+
if err != nil {
146+
t.Fatalf("Store: %v", err)
147+
}
148+
149+
recall, _ := s.Recall(ctx, RecallRequest{
150+
Query: "all", QueryEmbedding: makeEmbedding(0.5, 8), MaxResults: 10,
151+
})
152+
if recall.MaxSensitivity != sensitivity.Credentials {
153+
t.Errorf("expected MaxSensitivity=Credentials, got %s", recall.MaxSensitivity)
154+
}
155+
if len(recall.SensitiveChunks) != 2 {
156+
t.Errorf("expected 2 sensitive chunks, got %d", len(recall.SensitiveChunks))
157+
}
158+
}
159+
160+
func TestRecall_SensitivityDoesNotAffectRanking(t *testing.T) {
161+
s := newTestStore(t)
162+
ctx := context.Background()
163+
164+
_, err := s.Store(ctx, StoreRequest{
165+
Entries: []StoreEntry{
166+
{Text: "Highly relevant auth info", Embedding: makeEmbedding(0, 8), Sensitivity: sensitivity.Credentials},
167+
{Text: "Less relevant payment info", Embedding: makeEmbedding(1.5, 8)},
168+
},
169+
})
170+
if err != nil {
171+
t.Fatalf("Store: %v", err)
172+
}
173+
174+
recall, _ := s.Recall(ctx, RecallRequest{
175+
Query: "auth", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10,
176+
})
177+
// The sensitive entry should still be ranked first (closest embedding)
178+
if len(recall.Memories) < 1 {
179+
t.Fatal("expected at least 1 memory")
180+
}
181+
if recall.Memories[0].Sensitivity != sensitivity.Credentials {
182+
t.Errorf("expected first result to have Credentials sensitivity, got %s", recall.Memories[0].Sensitivity)
183+
}
184+
}

pkg/memory/sqlite.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ import (
99
"time"
1010

1111
distillmath "github.com/Siddhant-K-code/distill/pkg/math"
12+
"github.com/Siddhant-K-code/distill/pkg/sensitivity"
1213
_ "modernc.org/sqlite"
1314
)
1415

1516
// SQLiteStore implements Store using SQLite for local persistent storage.
1617
// Uses a single connection (SetMaxOpenConns(1)) so SQLite's internal
1718
// serialization handles concurrency. No application-level mutex needed.
1819
type SQLiteStore struct {
19-
db *sql.DB
20-
cfg Config
21-
handlers []MemoryEventHandler
20+
db *sql.DB
21+
cfg Config
22+
handlers []MemoryEventHandler
23+
classifier *sensitivity.Classifier
2224
}
2325

2426
// NewSQLiteStore creates a new SQLite-backed memory store.
@@ -49,7 +51,11 @@ func NewSQLiteStore(dsn string, cfg Config) (*SQLiteStore, error) {
4951
return nil, fmt.Errorf("enable foreign keys: %w", err)
5052
}
5153

52-
s := &SQLiteStore{db: db, cfg: cfg}
54+
s := &SQLiteStore{
55+
db: db,
56+
cfg: cfg,
57+
classifier: sensitivity.New(sensitivity.DefaultConfig()),
58+
}
5359
if err := s.migrate(); err != nil {
5460
_ = db.Close()
5561
return nil, fmt.Errorf("migrate: %w", err)
@@ -68,6 +74,7 @@ func (s *SQLiteStore) migrate() error {
6874
session_id TEXT DEFAULT '',
6975
metadata TEXT DEFAULT '{}',
7076
decay_level INTEGER DEFAULT 0,
77+
sensitivity INTEGER DEFAULT 0,
7178
created_at TEXT NOT NULL,
7279
last_referenced TEXT NOT NULL,
7380
access_count INTEGER DEFAULT 0,
@@ -98,6 +105,7 @@ func (s *SQLiteStore) migrate() error {
98105
{"expired_at", "TEXT DEFAULT ''"},
99106
{"superseded_by", "TEXT DEFAULT ''"},
100107
{"expires_at", "TEXT DEFAULT ''"},
108+
{"sensitivity", "INTEGER DEFAULT 0"},
101109
} {
102110
_, _ = s.db.Exec("ALTER TABLE memories ADD COLUMN " + col.name + " " + col.def)
103111
}
@@ -149,10 +157,19 @@ func (s *SQLiteStore) Store(ctx context.Context, req StoreRequest) (*StoreResult
149157
expiresAt = entry.ExpiresAt.UTC().Format(time.RFC3339Nano)
150158
}
151159

160+
// Determine sensitivity level
161+
sens := entry.Sensitivity
162+
if entry.AutoClassify {
163+
classified := s.classifier.Classify(entry.Text)
164+
if classified.Level > sens {
165+
sens = classified.Level
166+
}
167+
}
168+
152169
_, err := s.db.ExecContext(ctx,
153-
`INSERT INTO memories (id, text, embedding, source, session_id, metadata, decay_level, created_at, last_referenced, access_count, expires_at)
154-
VALUES (?, ?, ?, ?, ?, ?, 0, ?, ?, 0, ?)`,
155-
id, entry.Text, embBlob, entry.Source, sessionID, string(metaJSON), now, now, expiresAt,
170+
`INSERT INTO memories (id, text, embedding, source, session_id, metadata, decay_level, sensitivity, created_at, last_referenced, access_count, expires_at)
171+
VALUES (?, ?, ?, ?, ?, ?, 0, ?, ?, ?, 0, ?)`,
172+
id, entry.Text, embBlob, entry.Source, sessionID, string(metaJSON), int(sens), now, now, expiresAt,
156173
)
157174
if err != nil {
158175
return nil, fmt.Errorf("insert memory: %w", err)
@@ -237,7 +254,7 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes
237254
}
238255

239256
// Build query with optional tag filter and expiry exclusion
240-
query := "SELECT m.id, m.text, m.embedding, m.source, m.decay_level, m.last_referenced FROM memories m"
257+
query := "SELECT m.id, m.text, m.embedding, m.source, m.decay_level, m.sensitivity, m.last_referenced FROM memories m"
241258
var args []interface{}
242259
var conditions []string
243260

@@ -273,11 +290,12 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes
273290
id, text, source, refStr string
274291
embBlob []byte
275292
decayLevel int
293+
sensitivity int
276294
}
277295
var rawRows []rawRow
278296
for rows.Next() {
279297
var r rawRow
280-
if err := rows.Scan(&r.id, &r.text, &r.embBlob, &r.source, &r.decayLevel, &r.refStr); err != nil {
298+
if err := rows.Scan(&r.id, &r.text, &r.embBlob, &r.source, &r.decayLevel, &r.sensitivity, &r.refStr); err != nil {
281299
_ = rows.Close()
282300
return nil, err
283301
}
@@ -323,6 +341,7 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes
323341
Tags: tags,
324342
Relevance: relevance,
325343
DecayLevel: DecayLevel(r.decayLevel),
344+
Sensitivity: sensitivity.Level(r.sensitivity),
326345
LastReferenced: lastRef,
327346
},
328347
relevance: relevance,
@@ -360,6 +379,9 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes
360379
// Entries with relevance >= 0.7 are considered stable candidates.
361380
hint := buildCacheBoundaryHint(results)
362381

382+
// Build sensitivity metadata from returned memories.
383+
maxSens, sensitiveChunks := buildSensitivityMetadata(results)
384+
363385
return &RecallResult{
364386
Memories: results,
365387
Stats: RecallStats{
@@ -368,7 +390,9 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes
368390
Returned: len(results),
369391
TokenCount: tokenCount,
370392
},
371-
CacheHint: hint,
393+
CacheHint: hint,
394+
MaxSensitivity: maxSens,
395+
SensitiveChunks: sensitiveChunks,
372396
}, nil
373397
}
374398

@@ -395,6 +419,25 @@ func buildCacheBoundaryHint(memories []RecalledMemory) *CacheBoundaryHint {
395419
}
396420
}
397421

422+
// buildSensitivityMetadata derives MaxSensitivity and SensitiveChunks from
423+
// the recalled memories. Only entries with non-zero sensitivity are included.
424+
func buildSensitivityMetadata(memories []RecalledMemory) (sensitivity.Level, []SensitiveChunk) {
425+
var maxSens sensitivity.Level
426+
var chunks []SensitiveChunk
427+
for _, m := range memories {
428+
if m.Sensitivity > maxSens {
429+
maxSens = m.Sensitivity
430+
}
431+
if m.Sensitivity > sensitivity.None {
432+
chunks = append(chunks, SensitiveChunk{
433+
ChunkID: m.ID,
434+
Sensitivity: m.Sensitivity,
435+
})
436+
}
437+
}
438+
return maxSens, chunks
439+
}
440+
398441
// Forget removes memories matching the given criteria.
399442
func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) {
400443

0 commit comments

Comments
 (0)