Skip to content

Commit b33964e

Browse files
authored
Merge pull request #1020 from alphagov/472-combine-multiple-checker
[CHAT-472] Merge multiple checker into single class
2 parents 920b8c5 + 3d17070 commit b33964e

18 files changed

Lines changed: 276 additions & 328 deletions

lib/answer_composition/composer.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def compose_answer
4040
Pipeline::JailbreakGuardrails,
4141
Pipeline::QuestionRephraser,
4242
Pipeline::QuestionRouter,
43-
Pipeline::QuestionRoutingGuardrails.new(llm_provider: :claude),
43+
Pipeline::QuestionRoutingGuardrails,
4444
Pipeline::SearchResultFetcher,
4545
Pipeline::StructuredAnswerComposer,
46-
Pipeline::AnswerGuardrails.new(llm_provider: :claude),
46+
Pipeline::AnswerGuardrails,
4747
])
4848
else
4949
raise "Answer strategy #{answer_strategy} not configured"

lib/answer_composition/pipeline/answer_guardrails.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module AnswerComposition
22
module Pipeline
33
class AnswerGuardrails < OutputGuardrails
4-
def call(context)
4+
def call
55
start_time = Clock.monotonic_time
66
response = generate_response(context)
77

lib/answer_composition/pipeline/output_guardrails.rb

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
module AnswerComposition
22
module Pipeline
33
class OutputGuardrails
4-
def initialize(llm_provider: :claude)
5-
@llm_provider = llm_provider
4+
attr_reader :context
5+
6+
def self.call(...) = new(...).call
7+
8+
def initialize(context)
9+
@context = context
610
end
711

812
protected
913

10-
attr_reader :llm_provider
11-
1214
def build_metrics(start_time, response_or_error)
1315
{
1416
duration: Clock.monotonic_time - start_time,
@@ -20,7 +22,7 @@ def build_metrics(start_time, response_or_error)
2022
end
2123

2224
def generate_response(context)
23-
result = ::Guardrails::MultipleChecker.call(context.answer.message, guardrail_name, llm_provider)
25+
result = ::Guardrails::MultipleChecker.call(context.answer.message, guardrail_name)
2426
context.answer.assign_llm_response(guardrail_name, result.llm_response)
2527
result
2628
end

lib/answer_composition/pipeline/question_routing_guardrails.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module AnswerComposition
22
module Pipeline
33
class QuestionRoutingGuardrails < OutputGuardrails
4-
def call(context)
4+
def call
55
return if context.answer.question_routing_label == "genuine_rag"
66

77
start_time = Clock.monotonic_time

lib/guardrails/claude/multiple_checker.rb

Lines changed: 0 additions & 49 deletions
This file was deleted.

lib/guardrails/multiple_checker.rb

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ def triggered_guardrails
99
end
1010
end
1111

12+
MAX_TOKENS = 100
13+
SUPPORTED_MODELS = %i[claude_sonnet_4_0 claude_haiku_4_5].freeze
14+
DEFAULT_MODEL = :claude_sonnet_4_0
15+
1216
class ResponseError < StandardError
1317
attr_reader :llm_response, :llm_guardrail_result, :llm_prompt_tokens,
1418
:llm_completion_tokens, :llm_cached_tokens, :model
@@ -35,12 +39,10 @@ class Prompt
3539

3640
Guardrail = Data.define(:key, :name, :content)
3741

38-
def initialize(prompt_name, llm_provider = :claude)
39-
prompts = if llm_provider == :claude
40-
AnswerComposition::Pipeline::Prompts.config(prompt_name, Claude::MultipleChecker.bedrock_model)
41-
else
42-
Rails.configuration.govuk_chat_private.llm_prompts[llm_provider][prompt_name]
43-
end
42+
def initialize(prompt_name)
43+
prompts = AnswerComposition::Pipeline::Prompts.config(
44+
prompt_name, Guardrails::MultipleChecker.bedrock_model
45+
)
4446

4547
raise "No LLM prompts found for #{prompt_name}" unless prompts
4648

@@ -72,8 +74,12 @@ def guardrails
7274

7375
def self.call(...) = new(...).call
7476

75-
def self.collated_prompts(llm_prompt_name, llm_provider)
76-
prompt = Prompt.new(llm_prompt_name, llm_provider)
77+
def self.bedrock_model
78+
BedrockModels.determine_model(ENV["BEDROCK_CLAUDE_GUARDRAILS_MODEL"], DEFAULT_MODEL, SUPPORTED_MODELS).last
79+
end
80+
81+
def self.collated_prompts(llm_prompt_name)
82+
prompt = Prompt.new(llm_prompt_name)
7783

7884
<<~PROMPT
7985
# System prompt
@@ -85,38 +91,46 @@ def self.collated_prompts(llm_prompt_name, llm_provider)
8591
PROMPT
8692
end
8793

88-
def initialize(input, llm_prompt_name, llm_provider)
94+
def initialize(input, llm_prompt_name)
8995
@input = input
9096
@llm_prompt_name = llm_prompt_name
91-
@llm_provider = llm_provider
9297
end
9398

9499
def call
95-
case llm_provider
96-
when :claude
97-
response = Claude::MultipleChecker.call(input, prompt)
98-
else
99-
raise "Unexpected provider #{llm_provider}"
100-
end
101-
parse_response(**response)
100+
response = anthropic_bedrock_client.messages.create(
101+
system: [{ type: "text", text: prompt.system_prompt, cache_control: { type: "ephemeral" } }],
102+
model: BedrockModels.model_id(self.class.bedrock_model),
103+
messages: [{ role: "user", content: prompt.user_prompt(input) }],
104+
max_tokens: MAX_TOKENS,
105+
)
106+
107+
parse_response(response)
102108
end
103109

104110
private
105111

106-
def parse_response(llm_response:,
107-
llm_guardrail_result:,
108-
llm_prompt_tokens:,
109-
llm_completion_tokens:,
110-
llm_cached_tokens:,
111-
model:)
112+
def anthropic_bedrock_client
113+
@anthropic_bedrock_client ||= Anthropic::BedrockClient.new(
114+
aws_region: ENV["CLAUDE_AWS_REGION"],
115+
)
116+
end
117+
118+
def parse_response(response)
119+
llm_response = response.to_h
120+
llm_guardrail_result = response[:content][0][:text]
121+
input_tokens = response[:usage][:input_tokens]
122+
output_tokens = response[:usage][:output_tokens]
123+
cache_read_input_tokens = response[:usage][:cache_read_input_tokens]
124+
model = response[:model]
125+
112126
unless response_pattern =~ llm_guardrail_result
113127
raise ResponseError.new(
114128
"Error parsing guardrail response",
115129
llm_response,
116130
llm_guardrail_result,
117-
llm_prompt_tokens,
118-
llm_completion_tokens,
119-
llm_cached_tokens,
131+
input_tokens,
132+
output_tokens,
133+
cache_read_input_tokens,
120134
model,
121135
)
122136
end
@@ -126,19 +140,19 @@ def parse_response(llm_response:,
126140
guardrails = to_guardrail_hash(parts.second)
127141

128142
Result.new(
129-
llm_response: llm_response,
130-
llm_guardrail_result: llm_guardrail_result,
131-
triggered: triggered,
132-
guardrails: guardrails,
133-
llm_prompt_tokens: llm_prompt_tokens,
134-
llm_completion_tokens: llm_completion_tokens,
135-
llm_cached_tokens: llm_cached_tokens,
143+
llm_response:,
144+
llm_guardrail_result:,
145+
triggered:,
146+
guardrails:,
147+
llm_prompt_tokens: input_tokens,
148+
llm_completion_tokens: output_tokens,
149+
llm_cached_tokens: cache_read_input_tokens,
136150
model:,
137151
)
138152
end
139153

140154
def prompt
141-
@prompt ||= Prompt.new(llm_prompt_name, llm_provider)
155+
@prompt ||= Prompt.new(llm_prompt_name)
142156
end
143157

144158
def guardrail_numbers

lib/tasks/evaluation.rake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace :evaluation do
3232
raise "Requires an INPUT env var" if ENV["INPUT"].blank?
3333
raise "Requires a guardrail type" if args[:guardrail_type].blank?
3434

35-
response = Guardrails::MultipleChecker.call(ENV["INPUT"], args[:guardrail_type].to_sym, :claude)
35+
response = Guardrails::MultipleChecker.call(ENV["INPUT"], args[:guardrail_type].to_sym)
3636

3737
puts(response.to_json)
3838
end

lib/tasks/guardrails.rake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace "guardrails" do
77
abort("Invalid guardrail type. Valid guardrail types are #{valid_guardrail_types.to_sentence}")
88
end
99

10-
prompt = Guardrails::MultipleChecker.collated_prompts(guardrail_type, :claude)
10+
prompt = Guardrails::MultipleChecker.collated_prompts(guardrail_type)
1111
puts prompt
1212
end
1313
end

spec/factories/output_guardrail_result_factory.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
llm_prompt_tokens { 13 }
66
llm_completion_tokens { 7 }
77
llm_cached_tokens { 10 }
8-
model { BedrockModels.model_id(Guardrails::Claude::MultipleChecker::DEFAULT_MODEL) }
8+
model { BedrockModels.model_id(Guardrails::MultipleChecker::DEFAULT_MODEL) }
99

1010
llm_response do
1111
content = Anthropic::Models::TextBlock.new(

spec/lib/answer_composition/composer_spec.rb

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,14 @@ def stub_pipeline_initialize(klass, *args, **kwargs)
2424
let(:question) { create :question, answer_strategy: :claude_structured_answer }
2525

2626
it "calls PipelineRunner with the correct pipeline" do
27-
stub_pipeline_initialize(AnswerComposition::Pipeline::QuestionRoutingGuardrails, llm_provider: :claude)
28-
stub_pipeline_initialize(AnswerComposition::Pipeline::AnswerGuardrails, llm_provider: :claude)
29-
3027
expected_pipeline = [
3128
AnswerComposition::Pipeline::JailbreakGuardrails,
3229
AnswerComposition::Pipeline::QuestionRephraser,
3330
AnswerComposition::Pipeline::QuestionRouter,
34-
AnswerComposition::Pipeline::QuestionRoutingGuardrails.new(llm_provider: :claude),
31+
AnswerComposition::Pipeline::QuestionRoutingGuardrails,
3532
AnswerComposition::Pipeline::SearchResultFetcher,
3633
AnswerComposition::Pipeline::StructuredAnswerComposer,
37-
AnswerComposition::Pipeline::AnswerGuardrails.new(llm_provider: :claude),
34+
AnswerComposition::Pipeline::AnswerGuardrails,
3835
]
3936
expected_pipeline.each do |pipeline|
4037
allow(pipeline).to receive(:call) { it }

0 commit comments

Comments
 (0)