-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathsafety_policies.go
More file actions
90 lines (75 loc) · 2.73 KB
/
safety_policies.go
File metadata and controls
90 lines (75 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package agent
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/Protocol-Lattice/go-agent/src/models"
)
// RegexBlocklistPolicy enforces that a configurable list of regular
// expressions are not matched within the LLM output.
type RegexBlocklistPolicy struct {
patterns []*regexp.Regexp
}
// NewRegexBlocklistPolicy creates a new policy with the given string regex patterns.
// It returns an error if any of the patterns fail to compile.
func NewRegexBlocklistPolicy(patterns []string) (*RegexBlocklistPolicy, error) {
var compiled []*regexp.Regexp
for _, p := range patterns {
r, err := regexp.Compile(p)
if err != nil {
return nil, fmt.Errorf("failed to compile regex %q: %w", p, err)
}
compiled = append(compiled, r)
}
return &RegexBlocklistPolicy{patterns: compiled}, nil
}
// Validate checks the response against all configured regex patterns.
func (p *RegexBlocklistPolicy) Validate(ctx context.Context, response string) error {
for _, r := range p.patterns {
if r.MatchString(response) {
return fmt.Errorf("safety policy violation: output matches blocked pattern %q", r.String())
}
}
return nil
}
// LLMEvaluatorPolicy uses a secondary language model to evaluate the safety
// of the proposed response.
type LLMEvaluatorPolicy struct {
model models.Agent
prompt string
}
const defaultEvaluatorPrompt = `Evaluate the following text for safety.
If the text contains hate speech, dangerous instructions, PII, or violates general safety guidelines, respond with exactly "UNSAFE".
Otherwise, respond with exactly "SAFE".
TEXT TO EVALUATE:
<text>
%s
</text>`
// NewLLMEvaluatorPolicy creates a new safety policy that uses an LLM to evaluate responses.
// If promptTemplate is empty, a default evaluation prompt is used.
func NewLLMEvaluatorPolicy(model models.Agent, promptTemplate string) *LLMEvaluatorPolicy {
if promptTemplate == "" {
promptTemplate = defaultEvaluatorPrompt
}
return &LLMEvaluatorPolicy{
model: model,
prompt: promptTemplate,
}
}
// Validate sends the response to the evaluating LLM and checks its verdict.
func (p *LLMEvaluatorPolicy) Validate(ctx context.Context, response string) error {
// Sanitize output to prevent prompt injection breaking out of the <text> block
safeResponse := strings.ReplaceAll(response, "<text>", "(text)")
safeResponse = strings.ReplaceAll(safeResponse, "</text>", "(/text)")
evalPrompt := fmt.Sprintf(p.prompt, safeResponse)
result, err := p.model.Generate(ctx, evalPrompt)
if err != nil {
return fmt.Errorf("safety evaluation failed: %w", err)
}
verdict := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", result)))
if strings.Contains(verdict, "UNSAFE") {
return fmt.Errorf("safety policy violation: output flagged as unsafe by LLM evaluator")
}
return nil
}