Skip to content

Commit 4a878b4

Browse files
authored
Merge pull request #726 from alphagov/2996-add-coherence-metric
Add Coherence metric
2 parents 8d4833a + 7ab136b commit 4a878b4

13 files changed

Lines changed: 374 additions & 112 deletions

lib/auto_evaluation.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module AutoEvaluation
2+
ScoreResult = Data.define(
3+
:score,
4+
:reason,
5+
:success,
6+
:llm_responses,
7+
:metrics,
8+
)
9+
end

lib/auto_evaluation/answer_relevancy.rb

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
11
class AutoEvaluation::AnswerRelevancy
2-
Result = Data.define(
3-
:score,
4-
:reason,
5-
:success,
6-
:llm_responses,
7-
:metrics,
8-
)
9-
102
THRESHOLD = 0.5
113

124
def self.call(...) = new(...).call
@@ -55,7 +47,7 @@ def call
5547
question_message:, verdicts:, score:,
5648
)
5749

58-
Result.new(
50+
AutoEvaluation::ScoreResult.new(
5951
score:,
6052
reason:,
6153
success: score >= THRESHOLD,
@@ -78,7 +70,7 @@ def calculate_score(verdicts)
7870
end
7971

8072
def build_maximum_score_result(reason:, llm_responses:, metrics:)
81-
Result.new(
73+
AutoEvaluation::ScoreResult.new(
8274
score: 1.0,
8375
reason:,
8476
success: true,

lib/auto_evaluation/bedrock_openai_oss_invoke.rb

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AutoEvaluation
22
class BedrockOpenAIOssInvoke
3+
class InvalidToolCallSchemaError < StandardError; end
34
Result = Data.define(
45
:evaluation_data,
56
:llm_response,
@@ -33,12 +34,14 @@ def call
3334
}.to_json,
3435
)
3536
parsed_response = JSON.parse(response.body.read)
36-
parsed_structured_output = JSON.parse(
37+
parsed_tool_output = JSON.parse(
3738
parsed_response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"],
3839
)
3940

41+
validate_tool_output_against_schema(parsed_tool_output)
42+
4043
Result.new(
41-
evaluation_data: parsed_structured_output,
44+
evaluation_data: parsed_tool_output,
4245
llm_response: parsed_response,
4346
metrics: build_metrics(start_time, parsed_response),
4447
)
@@ -57,5 +60,12 @@ def build_metrics(start_time, response)
5760
model: response["model"],
5861
}
5962
end
63+
64+
def validate_tool_output_against_schema(tool_output)
65+
schema = tools.dig(0, "function", "parameters")
66+
JSON::Validator.validate!(schema, tool_output)
67+
rescue JSON::Schema::ValidationError => e
68+
raise InvalidToolCallSchemaError, "Tool call response does not match schema: #{e.message}"
69+
end
6070
end
6171
end

lib/auto_evaluation/coherence.rb

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
module AutoEvaluation
2+
class Coherence
3+
THRESHOLD = 0.75
4+
5+
def self.call(...) = new(...).call
6+
7+
def initialize(question_message:, answer_message:)
8+
@question_message = question_message
9+
@answer_message = answer_message
10+
end
11+
12+
def call
13+
result = BedrockOpenAIOssInvoke.call(user_prompt, tools)
14+
score = normalise_rubric_score(result.evaluation_data.fetch("score"))
15+
16+
AutoEvaluation::ScoreResult.new(
17+
score:,
18+
reason: result.evaluation_data.fetch("reason").strip,
19+
success: score >= THRESHOLD,
20+
llm_responses: { coherence: result.llm_response },
21+
metrics: { coherence: result.metrics },
22+
)
23+
end
24+
25+
private
26+
27+
attr_reader :question_message, :answer_message
28+
29+
def llm_prompts
30+
Prompts.config.coherence
31+
end
32+
33+
def user_prompt
34+
sprintf(
35+
llm_prompts.fetch(:user_prompt),
36+
answer: answer_message,
37+
question: question_message,
38+
)
39+
end
40+
41+
def tools
42+
[llm_prompts.fetch(:tool_spec)]
43+
end
44+
45+
def normalise_rubric_score(rubric_score)
46+
min_rubric_score = llm_prompts.fetch(:config).fetch(:min_rubric_score)
47+
max_rubric_score = llm_prompts.fetch(:config).fetch(:max_rubric_score)
48+
49+
(rubric_score.to_d - min_rubric_score) / (max_rubric_score - min_rubric_score)
50+
end
51+
end
52+
end
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
module AutoEvaluation
2+
class EvaluateAnswerFromQuestionMessage
3+
class TaskFailedError < StandardError; end
4+
5+
def self.call(...) = new(...).call
6+
7+
def initialize(evaluation_class:, question_message:)
8+
@evaluation_class = evaluation_class
9+
@question_message = question_message
10+
end
11+
12+
def call
13+
question = Question.new(message: question_message, conversation: Conversation.new)
14+
answer = AnswerComposition::PipelineRunner.call(question:, pipeline: [
15+
AnswerComposition::Pipeline::SearchResultFetcher,
16+
AnswerComposition::Pipeline::Claude::StructuredAnswerComposer,
17+
])
18+
19+
if answer.status =~ /^error/
20+
error_message = "Answer has an error status: #{answer.status} " \
21+
"and error message: #{answer.error_message}"
22+
raise TaskFailedError, error_message
23+
end
24+
25+
evaluation_class.call(
26+
question_message:,
27+
answer_message: answer.message,
28+
)
29+
end
30+
31+
private
32+
33+
attr_reader :evaluation_class, :question_message
34+
end
35+
end

lib/tasks/evaluation.rake

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -177,24 +177,31 @@ namespace :evaluation do
177177
task generate_answer_relevancy_evaluation: :environment do
178178
raise "Requires an INPUT env var" if ENV["INPUT"].blank?
179179

180-
question = Question.new(message: ENV["INPUT"], conversation: Conversation.new)
181-
182-
answer = AnswerComposition::PipelineRunner.call(question:, pipeline: [
183-
AnswerComposition::Pipeline::Claude::QuestionRouter,
184-
AnswerComposition::Pipeline::SearchResultFetcher,
185-
AnswerComposition::Pipeline::Claude::StructuredAnswerComposer,
186-
])
187-
188-
if answer.status =~ /^error/
189-
warn "Warning: answer has an error status: #{answer.status}"
190-
abort(answer.error_message)
180+
begin
181+
result = AutoEvaluation::EvaluateAnswerFromQuestionMessage.call(
182+
evaluation_class: AutoEvaluation::AnswerRelevancy,
183+
question_message: ENV["INPUT"],
184+
)
185+
186+
puts result.to_json
187+
rescue AutoEvaluation::EvaluateAnswerFromQuestionMessage::TaskFailedError => e
188+
abort e.message
191189
end
190+
end
192191

193-
result = AutoEvaluation::AnswerRelevancy.call(
194-
question_message: answer.rephrased_question || question.message,
195-
answer_message: answer.message,
196-
)
192+
desc "Run answer coherence evaluation for a user input"
193+
task generate_coherence_evaluation: :environment do
194+
raise "Requires an INPUT env var" if ENV["INPUT"].blank?
197195

198-
puts(result.to_json)
196+
begin
197+
result = AutoEvaluation::EvaluateAnswerFromQuestionMessage.call(
198+
evaluation_class: AutoEvaluation::Coherence,
199+
question_message: ENV["INPUT"],
200+
)
201+
202+
puts result.to_json
203+
rescue AutoEvaluation::EvaluateAnswerFromQuestionMessage::TaskFailedError => e
204+
abort e.message
205+
end
199206
end
200207
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
FactoryBot.define do
2+
factory :auto_evaluation_score_result, class: "AutoEvaluation::ScoreResult" do
3+
skip_create
4+
5+
score { 0.85.to_d }
6+
reason { "Most statements are relevant." }
7+
success { true }
8+
llm_responses { {} }
9+
metrics { {} }
10+
11+
initialize_with { new(**attributes) }
12+
end
13+
end

spec/lib/auto_evaluation/answer_relevancy/verdicts_generator_spec.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
let(:statements) { ["Statement one.", "Statement two."] }
55
let(:verdicts) do
66
[
7-
{ "verdict" => "Yes" },
8-
{ "verdict" => "No", "reason" => "The statement is irrelevant." },
7+
{ "verdict" => "yes" },
8+
{ "verdict" => "no", "reason" => "The statement is irrelevant." },
99
]
1010
end
1111
let(:verdicts_json) do

spec/lib/auto_evaluation/answer_relevancy_spec.rb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
let(:verdicts) do
2525
[
26-
{ "verdict" => "Yes" },
27-
{ "verdict" => "No", "reason" => "The statement is irrelevant." },
26+
{ "verdict" => "yes" },
27+
{ "verdict" => "no", "reason" => "The statement is irrelevant." },
2828
]
2929
end
3030
let(:verdicts_json) { { verdicts: }.to_json }
@@ -90,7 +90,7 @@
9090
reason: shared_expected_metrics_attributes,
9191
}
9292
expect(result)
93-
.to be_a(described_class::Result)
93+
.to be_a(AutoEvaluation::ScoreResult)
9494
.and have_attributes(
9595
score: 0.5,
9696
reason:,
@@ -104,7 +104,7 @@
104104
let(:verdicts) do
105105
[
106106
{ "verdict" => "idk", "reason" => "Cannot determine relevance." },
107-
{ "verdict" => "No", "reason" => "The statement is irrelevant." },
107+
{ "verdict" => "no", "reason" => "The statement is irrelevant." },
108108
]
109109
end
110110

@@ -130,7 +130,7 @@
130130
)
131131

132132
expect(result)
133-
.to be_a(described_class::Result)
133+
.to be_a(AutoEvaluation::ScoreResult)
134134
.and have_attributes(
135135
score: 1.0,
136136
reason: "No statements were extracted from the answer.",
@@ -154,7 +154,7 @@
154154
)
155155

156156
expect(result)
157-
.to be_a(described_class::Result)
157+
.to be_a(AutoEvaluation::ScoreResult)
158158
.and have_attributes(
159159
score: 1.0,
160160
reason: "No verdicts were generated for the extracted statements.",
@@ -172,7 +172,7 @@
172172
end
173173

174174
context "when verdicts are generated and none have a 'no' verdict" do
175-
let(:verdicts_json) { { verdicts: [{ "verdict" => "Yes" }, { "verdict" => "Yes" }] }.to_json }
175+
let(:verdicts_json) { { verdicts: [{ "verdict" => "yes" }, { "verdict" => "yes" }] }.to_json }
176176

177177
it "returns a result object with the expected attributes" do
178178
allow(Clock).to receive(:monotonic_time).and_return(200.0, 202.0, 204.0, 206.0)
@@ -183,7 +183,7 @@
183183
)
184184

185185
expect(result)
186-
.to be_a(described_class::Result)
186+
.to be_a(AutoEvaluation::ScoreResult)
187187
.and have_attributes(
188188
score: 1.0,
189189
reason: "The response fully addressed the input with no irrelevant statements.",

spec/lib/auto_evaluation/bedrock_openai_oss_invoke_spec.rb

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
let(:tools) do
55
[
66
{
7-
type: "function",
8-
function: {
9-
name: "test_schema",
10-
description: "A test JSON schema",
11-
schema: {
12-
type: "object",
13-
properties: {
14-
response: { type: "string" },
7+
"type" => "function",
8+
"function" => {
9+
"name" => "test_schema",
10+
"description" => "A test JSON schema",
11+
"parameters" => {
12+
"type" => "object",
13+
"properties" => {
14+
"response" => { "type" => "string" },
1515
},
16-
required: %w[response],
16+
"required" => %w[response],
1717
},
18-
strict: true,
18+
"strict" => true,
1919
},
2020
},
2121
]
@@ -55,5 +55,20 @@
5555
},
5656
)
5757
end
58+
59+
it "raises an error if the response does not conform to the schema" do
60+
bedrock_invoke_model_openai_oss_tool_call(
61+
user_message,
62+
tools,
63+
{ "invalid_key" => "This does not conform to the schema." }.to_json,
64+
)
65+
66+
expect {
67+
described_class.call(user_message, tools)
68+
}.to raise_error(
69+
described_class::InvalidToolCallSchemaError,
70+
/The property '#\/' did not contain a required property of 'response'/,
71+
)
72+
end
5873
end
5974
end

0 commit comments

Comments
 (0)