Skip to content

Commit 41ff448

Browse files
committed
add input field render facet with jinja templates
1 parent 1ffa920 commit 41ff448

21 files changed

+442
-31
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/dspy-rs/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ rig-core = { git = "https://github.com/0xPlaygrounds/rig", rev="e7849df" }
4343
enum_dispatch = "0.3.13"
4444
tracing = "0.1.44"
4545
tracing-subscriber = { version = "0.3.22", features = ["env-filter", "fmt"] }
46+
minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["builtins", "serde"] }
4647

4748
[package.metadata.cargo-machete]
4849
ignored = ["rig-core"]

crates/dspy-rs/src/adapter/chat.rs

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use anyhow::Result;
22
use indexmap::IndexMap;
3+
use minijinja::UndefinedBehavior;
34
use regex::Regex;
45
use rig::tool::ToolDyn;
56
use serde_json::{Value, json};
@@ -17,8 +18,8 @@ use crate::serde_utils::get_iter_from_value;
1718
use crate::utils::cache::CacheEntry;
1819
use crate::{
1920
BamlValue, Cache, Chat, ConstraintLevel, ConstraintResult, Example, FieldMeta, Flag,
20-
JsonishError, LM, Message, MetaSignature, OutputFormatContent, ParseError, Prediction,
21-
RenderOptions, Signature, TypeIR,
21+
InputRenderSpec, JsonishError, LM, Message, MetaSignature, OutputFormatContent, ParseError,
22+
Prediction, RenderOptions, Signature, TypeIR,
2223
};
2324

2425
#[derive(Default, Clone)]
@@ -526,15 +527,19 @@ impl ChatAdapter {
526527
return String::new();
527528
};
528529
let input_output_format = <S::Input as BamlTypeTrait>::baml_output_format();
530+
let input_json = build_input_context_value(fields, S::input_fields(), input_output_format);
531+
let vars = Value::Object(serde_json::Map::new());
529532

530533
let mut result = String::new();
531534
for field_spec in S::input_fields() {
532535
if let Some(value) = fields.get(field_spec.rust_name) {
533536
result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.name));
534-
result.push_str(&format_baml_value_for_prompt_typed(
537+
result.push_str(&render_input_field(
538+
field_spec,
535539
value,
540+
&input_json,
536541
input_output_format,
537-
field_spec.format,
542+
&vars,
538543
));
539544
result.push_str("\n\n");
540545
}
@@ -880,23 +885,92 @@ fn format_baml_value_for_prompt(value: &BamlValue) -> String {
880885
}
881886
}
882887

883-
fn format_baml_value_for_prompt_typed(
888+
fn render_input_field(
889+
field_spec: &crate::FieldSpec,
884890
value: &BamlValue,
891+
input: &Value,
885892
output_format: &OutputFormatContent,
886-
format: Option<&str>,
893+
vars: &Value,
887894
) -> String {
888-
let format = match format {
889-
Some(format) => format,
890-
None => {
891-
if let BamlValue::String(s) = value {
892-
return s.clone();
893-
}
894-
"json"
895+
match field_spec.input_render {
896+
InputRenderSpec::Default => match value {
897+
BamlValue::String(s) => s.clone(),
898+
_ => crate::bamltype::internal_baml_jinja::format_baml_value(
899+
value,
900+
output_format,
901+
"json",
902+
)
903+
.unwrap_or_else(|_| "<error>".to_string()),
904+
},
905+
InputRenderSpec::Format(format) => {
906+
crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, format)
907+
.unwrap_or_else(|_| "<error>".to_string())
908+
}
909+
InputRenderSpec::Jinja(template) => {
910+
render_input_field_jinja(template, field_spec, value, input, output_format, vars)
911+
.unwrap_or_else(|_| "<error>".to_string())
895912
}
913+
}
914+
}
915+
916+
fn render_input_field_jinja(
917+
template: &str,
918+
field_spec: &crate::FieldSpec,
919+
value: &BamlValue,
920+
input: &Value,
921+
output_format: &OutputFormatContent,
922+
vars: &Value,
923+
) -> Result<String, minijinja::Error> {
924+
let mut env = minijinja::Environment::new();
925+
env.set_undefined_behavior(UndefinedBehavior::Strict);
926+
env.add_template("__input_field__", template)?;
927+
let template = env.get_template("__input_field__")?;
928+
929+
let this = baml_value_to_render_json(value, output_format);
930+
let field = json!({
931+
"name": field_spec.name,
932+
"rust_name": field_spec.rust_name,
933+
"type": (field_spec.type_ir)().diagnostic_repr().to_string(),
934+
});
935+
let context = json!({
936+
"this": this,
937+
"input": input,
938+
"field": field,
939+
"vars": vars,
940+
});
941+
942+
template.render(minijinja::Value::from_serialize(context))
943+
}
944+
945+
fn build_input_context_value(
946+
fields: &crate::bamltype::baml_types::BamlMap<String, BamlValue>,
947+
field_specs: &[crate::FieldSpec],
948+
output_format: &OutputFormatContent,
949+
) -> Value {
950+
let mut map = serde_json::Map::new();
951+
952+
for field_spec in field_specs {
953+
let Some(value) = fields.get(field_spec.rust_name) else {
954+
continue;
955+
};
956+
let value_json = baml_value_to_render_json(value, output_format);
957+
map.insert(field_spec.rust_name.to_string(), value_json.clone());
958+
if field_spec.name != field_spec.rust_name {
959+
map.entry(field_spec.name.to_string()).or_insert(value_json);
960+
}
961+
}
962+
963+
Value::Object(map)
964+
}
965+
966+
fn baml_value_to_render_json(value: &BamlValue, output_format: &OutputFormatContent) -> Value {
967+
let Ok(rendered_json) =
968+
crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, "json")
969+
else {
970+
return serde_json::to_value(value).unwrap_or(Value::Null);
896971
};
897972

898-
crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, format)
899-
.unwrap_or_else(|_| "<error>".to_string())
973+
serde_json::from_str(&rendered_json).unwrap_or(Value::Null)
900974
}
901975

902976
fn collect_flags_recursive(value: &BamlValueWithFlags, flags: &mut Vec<Flag>) {

crates/dspy-rs/src/core/signature.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@ use crate::{Example, OutputFormatContent, TypeIR};
22
use anyhow::Result;
33
use serde_json::Value;
44

5+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6+
pub enum InputRenderSpec {
7+
Default,
8+
Format(&'static str),
9+
Jinja(&'static str),
10+
}
11+
512
#[derive(Debug, Clone, Copy)]
613
pub struct FieldSpec {
714
pub name: &'static str,
815
pub rust_name: &'static str,
916
pub description: &'static str,
1017
pub type_ir: fn() -> TypeIR,
1118
pub constraints: &'static [ConstraintSpec],
12-
pub format: Option<&'static str>,
19+
pub input_render: InputRenderSpec,
1320
}
1421

1522
#[derive(Debug, Clone, Copy)]

crates/dspy-rs/src/predictors/predict.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use tracing::{debug, trace};
1010
use crate::adapter::Adapter;
1111
use crate::bamltype::baml_types::BamlMap;
1212
use crate::bamltype::compat::{BamlValueConvert, ToBamlValue};
13-
use crate::core::{FieldSpec, MetaSignature, Module, Optimizable, Signature};
13+
use crate::core::{FieldSpec, InputRenderSpec, MetaSignature, Module, Optimizable, Signature};
1414
use crate::{
1515
BamlValue, CallResult, Chat, ChatAdapter, Example, GLOBAL_SETTINGS, LM, LmError, LmUsage,
1616
PredictError, Prediction,
@@ -258,8 +258,14 @@ fn field_specs_to_value(fields: &[FieldSpec], field_type: &'static str) -> Value
258258
meta.insert("desc".to_string(), json!(field.description));
259259
meta.insert("schema".to_string(), json!(""));
260260
meta.insert("__dsrs_field_type".to_string(), json!(field_type));
261-
if let Some(format) = field.format {
262-
meta.insert("format".to_string(), json!(format));
261+
match field.input_render {
262+
InputRenderSpec::Default => {}
263+
InputRenderSpec::Format(format) => {
264+
meta.insert("format".to_string(), json!(format));
265+
}
266+
InputRenderSpec::Jinja(template) => {
267+
meta.insert("render".to_string(), json!({ "jinja": template }));
268+
}
263269
}
264270
result.insert(field.rust_name.to_string(), Value::Object(meta));
265271
}

crates/dspy-rs/tests/test_input_format.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,34 @@ struct DefaultFormatSig {
6262
answer: String,
6363
}
6464

65+
#[derive(Signature, Clone, Debug)]
66+
/// Render a context field using Jinja.
67+
struct RenderJinjaSig {
68+
#[input]
69+
question: String,
70+
71+
#[input]
72+
#[alias("ctx")]
73+
#[render(
74+
jinja = "{{ this.text }} | {{ input.question }} | {{ input.ctx.text }} | {{ input.context.text }} | {{ field.name }} | {{ field.rust_name }}"
75+
)]
76+
context: Document,
77+
78+
#[output]
79+
answer: String,
80+
}
81+
82+
#[derive(Signature, Clone, Debug)]
83+
/// Render with strict undefined vars.
84+
struct RenderJinjaStrictSig {
85+
#[input]
86+
#[render(jinja = "{{ missing_var }}")]
87+
question: String,
88+
89+
#[output]
90+
answer: String,
91+
}
92+
6593
fn extract_field(message: &str, field_name: &str) -> String {
6694
let start_marker = format!("[[ ## {field_name} ## ]]");
6795
let start_pos = message
@@ -184,3 +212,35 @@ fn typed_input_default_non_string_is_json() {
184212
.expect("expected array with object");
185213
assert_eq!(first.get("text").and_then(|v| v.as_str()), Some("Hello"));
186214
}
215+
216+
#[test]
217+
fn typed_input_render_jinja_uses_context_values() {
218+
let adapter = ChatAdapter;
219+
let input = RenderJinjaSigInput {
220+
question: "Question".to_string(),
221+
context: Document {
222+
text: "Hello".to_string(),
223+
},
224+
};
225+
226+
let message = adapter.format_user_message_typed::<RenderJinjaSig>(&input);
227+
let context_value = extract_field(&message, "ctx");
228+
229+
assert_eq!(
230+
context_value,
231+
"Hello | Question | Hello | Hello | ctx | context"
232+
);
233+
}
234+
235+
#[test]
236+
fn typed_input_render_jinja_strict_undefined_returns_error_sentinel() {
237+
let adapter = ChatAdapter;
238+
let input = RenderJinjaStrictSigInput {
239+
question: "Question".to_string(),
240+
};
241+
242+
let message = adapter.format_user_message_typed::<RenderJinjaStrictSig>(&input);
243+
let question_value = extract_field(&message, "question");
244+
245+
assert_eq!(question_value, "<error>");
246+
}

crates/dsrs-macros/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ quote = "1"
1919
proc-macro2 = "1"
2020
proc-macro-crate = "3.2"
2121
serde_json = { version = "1.0.143", features = ["preserve_order"] }
22+
minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["serde"] }
2223

2324
[dev-dependencies]
2425
dspy-rs = { path = "../dspy-rs" }

0 commit comments

Comments
 (0)