Skip to content

Commit e5be3a2

Browse files
authored
Merge pull request #1050 from alphagov/473-spike-using-structured-response-for-guardrails
[CHAT-473][DO NOT MERGE] Use structured output for non-sonnet 4.0 guardrail requests
2 parents 2b35fa4 + 899f6fd commit e5be3a2

7 files changed

Lines changed: 293 additions & 126 deletions

File tree

lib/answer_composition/multiple_guardrail/checker.rb

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,36 @@ def initialize(input, llm_prompt_name)
2727
end
2828

2929
def call
30-
response = anthropic_bedrock_client.messages.create(
30+
shared_config = {
3131
system: [{ type: "text", text: prompt.system_prompt, cache_control: { type: "ephemeral" } }],
3232
model: BedrockModels.model_id(self.class.bedrock_model),
3333
messages: [{ role: "user", content: prompt.user_prompt(input) }],
3434
max_tokens: MAX_TOKENS,
35-
)
36-
37-
parse_response(response)
35+
}
36+
37+
if self.class.bedrock_model == :claude_sonnet_4_0
38+
response = anthropic_bedrock_client.messages.create(**shared_config)
39+
parse_response(response)
40+
else
41+
response = anthropic_bedrock_client.messages.create(**shared_config.merge(
42+
output_config: {
43+
format: json_schema,
44+
},
45+
))
46+
47+
llm_guardrail_result = JSON.parse(response.content.first.text)
48+
49+
Result.new(
50+
llm_response: response.to_h,
51+
llm_guardrail_result: llm_guardrail_result.to_s,
52+
triggered: llm_guardrail_result.present?,
53+
guardrails: to_guardrail_hash(llm_guardrail_result),
54+
llm_prompt_tokens: response[:usage][:input_tokens],
55+
llm_completion_tokens: response[:usage][:output_tokens],
56+
llm_cached_tokens: response[:usage][:cache_read_input_tokens],
57+
model: response.model,
58+
)
59+
end
3860
end
3961

4062
private
@@ -67,7 +89,8 @@ def parse_response(response)
6789

6890
parts = llm_guardrail_result.split(" | ")
6991
triggered = parts.first.chomp == "True"
70-
guardrails = to_guardrail_hash(parts.second)
92+
triggered_guardrail_numbers = parts.second.scan(/\d+/).map(&:to_i)
93+
guardrails = to_guardrail_hash(triggered_guardrail_numbers)
7194

7295
Result.new(
7396
llm_response:,
@@ -96,12 +119,22 @@ def response_pattern
96119
end
97120
end
98121

99-
def to_guardrail_hash(parts)
100-
triggered_guardrail_numbers = parts.scan(/\d+/).map(&:to_i)
101-
122+
def to_guardrail_hash(triggered_guardrail_numbers)
102123
prompt.guardrails.each_with_object({}) do |guardrail, guardrails_hash|
103124
guardrails_hash[guardrail.name.to_sym] = triggered_guardrail_numbers.include?(guardrail.key)
104125
end
105126
end
127+
128+
def json_schema
129+
guardrail_keys = prompt.guardrails.map(&:key)
130+
{
131+
"type" => "json_schema",
132+
"schema" => {
133+
"description" => "Array of triggered guardrail numbers. Returns [] if none triggered.",
134+
"type" => "array",
135+
"items" => { "type" => "integer", "enum" => guardrail_keys },
136+
},
137+
}
138+
end
106139
end
107140
end

lib/answer_composition/multiple_guardrail/prompt.rb

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ def system_prompt
1616
guardrails_content = guardrails.map { |g| "#{g.key}. #{g.content}" }
1717
.join("\n")
1818

19-
prompts.fetch(:system_prompt)
19+
system_prompt_key = if Checker.bedrock_model == :claude_sonnet_4_0
20+
:system_prompt
21+
else
22+
:system_prompt_structured
23+
end
24+
25+
prompts.fetch(system_prompt_key)
2026
.sub("{guardrails}", guardrails_content)
2127
.sub("{date}", Date.current.strftime("%A %d %B %Y"))
2228
end

0 commit comments

Comments
 (0)