Skip to content

Commit 6a3bdd6

Browse files
authored
Add alignment strategy (#11)
1 parent 55f4c20 commit 6a3bdd6

File tree

20 files changed

+891
-482
lines changed

20 files changed

+891
-482
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async fn main() {
112112
let executor = StructuredExecutorBuilder::new()
113113
.with_lm(&lm)
114114
.with_preamble("You are a geography expert who helps users understand the capital city of countries around the world.")
115-
.with_options(&variants!(ResponseVariants))
115+
.with_options(Box::new(variants!(ResponseVariants)))
116116
.try_build()
117117
.unwrap();
118118
let response = executor

core/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ orch_response = { path = "../response", version = "0.0.12" }
1313
orch_response_derive = { path = "../response_derive", version = "0.0.12" }
1414
async-gen = "0.2.3"
1515
dotenv = "0.15.0"
16-
dyn-clone = "1.0.17"
1716
openai-api-rs = "5.0.2"
1817
reqwest = { version = "0.12.5", features = ["blocking"] }
1918
serde = { version = "1.0.164", features = ["derive"] }
@@ -22,3 +21,5 @@ thiserror = "1.0.63"
2221
tokio = { version = "1.28.2", features = ["rt", "macros"] }
2322
tokio-stream = "0.1.15"
2423
async-trait = "0.1.81"
24+
dyn-clone = "1.0.17"
25+
async-recursion = "1.1.1"

core/examples/alignment.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#![allow(dead_code)]
2+
3+
use orch::alignment::AlignmentStrategyBuilder;
4+
use orch::execution::*;
5+
use orch::lm::*;
6+
use orch::response::*;
7+
8+
#[derive(Variants, serde::Deserialize)]
9+
pub enum ResponseVariants {
10+
Answer(AnswerResponseVariant),
11+
Fail(FailResponseVariant),
12+
}
13+
14+
#[derive(Variant, serde::Deserialize)]
15+
#[variant(
16+
variant = "Answer",
17+
scenario = "You know the answer",
18+
description = "Result of the calculation"
19+
)]
20+
pub struct AnswerResponseVariant {
21+
#[schema(description = "Result of the calculation", example = "42")]
22+
pub result: String,
23+
}
24+
25+
#[derive(Variant, serde::Deserialize)]
26+
#[variant(
27+
variant = "Fail",
28+
scenario = "You don't know the answer",
29+
description = "Reason why the answer is not known"
30+
)]
31+
pub struct FailResponseVariant {
32+
#[schema(
33+
description = "Reason why the answer is not known",
34+
example = "The phrase is not a mathematical related expression"
35+
)]
36+
pub reason: String,
37+
}
38+
39+
#[tokio::main]
40+
async fn main() {
41+
// We use a large foundational model for the main task.
42+
let ollama_large = OllamaBuilder::new()
43+
.with_model(ollama_model::LLAMA3_8B.to_string())
44+
.try_build()
45+
.unwrap();
46+
47+
// We use a smaller model for the correction.
48+
let ollama_corrector = OllamaBuilder::new()
49+
.with_model(ollama_model::LLAMA3_8B.to_string())
50+
.try_build()
51+
.unwrap();
52+
53+
// We define an alignment strategy that uses the correction model.
54+
let alignment_strategy = AlignmentStrategyBuilder::new()
55+
.with_retries(2)
56+
.with_lm(Box::new(ollama_corrector))
57+
.try_build()
58+
.unwrap();
59+
60+
let executor = StructuredExecutorBuilder::new()
61+
.with_lm(&ollama_large)
62+
.with_preamble("
63+
You are a mathematician who helps users understand the result of mathematical expressions.
64+
You will receive a mathematical expression, and you will need to provide the result of that expression.
65+
")
66+
.with_options(Box::new(variants!(ResponseVariants)))
67+
.with_alignment(alignment_strategy)
68+
.try_build()
69+
.unwrap();
70+
let response = executor.execute("2 + 2").await.expect("Execution failed");
71+
72+
match response.content {
73+
ResponseVariants::Answer(answer) => {
74+
println!("Result: {}", answer.result);
75+
}
76+
ResponseVariants::Fail(fail) => {
77+
println!("Model failed to generate a response: {}", fail.reason);
78+
}
79+
}
80+
}

core/examples/structured_data_generation_blog.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async fn main() {
9292
9393
Be very specific and refer to specific sentences, paragraph and sections of the blog post.
9494
")
95-
.with_options(&variants!(ResponseVariants))
95+
.with_options(Box::new(variants!(ResponseVariants)))
9696
.try_build()
9797
.unwrap();
9898
let response = executor.execute(&prompt).await.expect("Execution failed");

core/examples/structured_data_generation_capital.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async fn main() {
7777
You are a geography expert who helps users understand the capital city of countries around the world.
7878
You will receive a country name, and you will need to provide the capital city of that country.
7979
")
80-
.with_options(&variants!(ResponseVariants))
80+
.with_options(Box::new(variants!(ResponseVariants)))
8181
.try_build()
8282
.unwrap();
8383
let response = executor.execute(prompt).await.expect("Execution failed");

core/src/alignment/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//! A module containing all logic related to alignment.
2+
//! Alignment, in this context, means "aligning" the model's output with the desired output.
3+
//! This takes the form of a so-called `[AlignmentStrategy]`, which is a trait that defines how to align the model's output.
4+
//!
5+
//! This concept has similarities to to traditional "resilience" techniques and libraries, such as .NET's [Polly](https://github.com/App-vNext/Polly),
6+
//! which I personally like a lot.
7+
8+
mod strategy;
9+
mod strategy_builder;
10+
11+
pub use strategy::*;
12+
pub use strategy_builder::*;

0 commit comments

Comments
 (0)