Skip to content

Commit f0b2b99

Browse files
committed
add input field render facet with jinja templates
1 parent 1ffa920 commit f0b2b99

27 files changed

+539
-34
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/bamltype/tests/ui/non_string_literal_attr.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ error: expected string literal; hint: wrap the value in quotes
22
--> tests/ui/non_string_literal_attr.rs:4:8
33
|
44
4 | #[baml(name = 123)]
5-
| ^^^^
5+
| ^^^^^^^^^^

crates/dspy-rs/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ rig-core = { git = "https://github.com/0xPlaygrounds/rig", rev="e7849df" }
4343
enum_dispatch = "0.3.13"
4444
tracing = "0.1.44"
4545
tracing-subscriber = { version = "0.3.22", features = ["env-filter", "fmt"] }
46+
minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["builtins", "serde"] }
4647

4748
[package.metadata.cargo-machete]
4849
ignored = ["rig-core"]

crates/dspy-rs/src/adapter/chat.rs

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use anyhow::Result;
22
use indexmap::IndexMap;
3+
use minijinja::UndefinedBehavior;
34
use regex::Regex;
45
use rig::tool::ToolDyn;
56
use serde_json::{Value, json};
@@ -17,8 +18,8 @@ use crate::serde_utils::get_iter_from_value;
1718
use crate::utils::cache::CacheEntry;
1819
use crate::{
1920
BamlValue, Cache, Chat, ConstraintLevel, ConstraintResult, Example, FieldMeta, Flag,
20-
JsonishError, LM, Message, MetaSignature, OutputFormatContent, ParseError, Prediction,
21-
RenderOptions, Signature, TypeIR,
21+
InputRenderSpec, JsonishError, LM, Message, MetaSignature, OutputFormatContent, ParseError,
22+
Prediction, RenderOptions, Signature, TypeIR,
2223
};
2324

2425
#[derive(Default, Clone)]
@@ -526,15 +527,19 @@ impl ChatAdapter {
526527
return String::new();
527528
};
528529
let input_output_format = <S::Input as BamlTypeTrait>::baml_output_format();
530+
let input_json = build_input_context_value(fields, S::input_fields(), input_output_format);
531+
let vars = Value::Object(serde_json::Map::new());
529532

530533
let mut result = String::new();
531534
for field_spec in S::input_fields() {
532535
if let Some(value) = fields.get(field_spec.rust_name) {
533536
result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.name));
534-
result.push_str(&format_baml_value_for_prompt_typed(
537+
result.push_str(&render_input_field(
538+
field_spec,
535539
value,
540+
&input_json,
536541
input_output_format,
537-
field_spec.format,
542+
&vars,
538543
));
539544
result.push_str("\n\n");
540545
}
@@ -880,23 +885,92 @@ fn format_baml_value_for_prompt(value: &BamlValue) -> String {
880885
}
881886
}
882887

883-
fn format_baml_value_for_prompt_typed(
888+
fn render_input_field(
889+
field_spec: &crate::FieldSpec,
884890
value: &BamlValue,
891+
input: &Value,
885892
output_format: &OutputFormatContent,
886-
format: Option<&str>,
893+
vars: &Value,
887894
) -> String {
888-
let format = match format {
889-
Some(format) => format,
890-
None => {
891-
if let BamlValue::String(s) = value {
892-
return s.clone();
893-
}
894-
"json"
895+
match field_spec.input_render {
896+
InputRenderSpec::Default => match value {
897+
BamlValue::String(s) => s.clone(),
898+
_ => crate::bamltype::internal_baml_jinja::format_baml_value(
899+
value,
900+
output_format,
901+
"json",
902+
)
903+
.unwrap_or_else(|_| "<error>".to_string()),
904+
},
905+
InputRenderSpec::Format(format) => {
906+
crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, format)
907+
.unwrap_or_else(|_| "<error>".to_string())
908+
}
909+
InputRenderSpec::Jinja(template) => {
910+
render_input_field_jinja(template, field_spec, value, input, output_format, vars)
911+
.unwrap_or_else(|_| "<error>".to_string())
895912
}
913+
}
914+
}
915+
916+
fn render_input_field_jinja(
917+
template: &str,
918+
field_spec: &crate::FieldSpec,
919+
value: &BamlValue,
920+
input: &Value,
921+
output_format: &OutputFormatContent,
922+
vars: &Value,
923+
) -> Result<String, minijinja::Error> {
924+
let mut env = minijinja::Environment::new();
925+
env.set_undefined_behavior(UndefinedBehavior::Strict);
926+
env.add_template("__input_field__", template)?;
927+
let template = env.get_template("__input_field__")?;
928+
929+
let this = baml_value_to_render_json(value, output_format);
930+
let field = json!({
931+
"name": field_spec.name,
932+
"rust_name": field_spec.rust_name,
933+
"type": (field_spec.type_ir)().diagnostic_repr().to_string(),
934+
});
935+
let context = json!({
936+
"this": this,
937+
"input": input,
938+
"field": field,
939+
"vars": vars,
940+
});
941+
942+
template.render(minijinja::Value::from_serialize(context))
943+
}
944+
945+
fn build_input_context_value(
946+
fields: &crate::bamltype::baml_types::BamlMap<String, BamlValue>,
947+
field_specs: &[crate::FieldSpec],
948+
output_format: &OutputFormatContent,
949+
) -> Value {
950+
let mut map = serde_json::Map::new();
951+
952+
for field_spec in field_specs {
953+
let Some(value) = fields.get(field_spec.rust_name) else {
954+
continue;
955+
};
956+
let value_json = baml_value_to_render_json(value, output_format);
957+
map.insert(field_spec.rust_name.to_string(), value_json.clone());
958+
if field_spec.name != field_spec.rust_name {
959+
map.entry(field_spec.name.to_string()).or_insert(value_json);
960+
}
961+
}
962+
963+
Value::Object(map)
964+
}
965+
966+
fn baml_value_to_render_json(value: &BamlValue, output_format: &OutputFormatContent) -> Value {
967+
let Ok(rendered_json) =
968+
crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, "json")
969+
else {
970+
return serde_json::to_value(value).unwrap_or(Value::Null);
896971
};
897972

898-
crate::bamltype::internal_baml_jinja::format_baml_value(value, output_format, format)
899-
.unwrap_or_else(|_| "<error>".to_string())
973+
serde_json::from_str(&rendered_json).unwrap_or(Value::Null)
900974
}
901975

902976
fn collect_flags_recursive(value: &BamlValueWithFlags, flags: &mut Vec<Flag>) {

crates/dspy-rs/src/core/signature.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@ use crate::{Example, OutputFormatContent, TypeIR};
22
use anyhow::Result;
33
use serde_json::Value;
44

5+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6+
pub enum InputRenderSpec {
7+
Default,
8+
Format(&'static str),
9+
Jinja(&'static str),
10+
}
11+
512
#[derive(Debug, Clone, Copy)]
613
pub struct FieldSpec {
714
pub name: &'static str,
815
pub rust_name: &'static str,
916
pub description: &'static str,
1017
pub type_ir: fn() -> TypeIR,
1118
pub constraints: &'static [ConstraintSpec],
12-
pub format: Option<&'static str>,
19+
pub input_render: InputRenderSpec,
1320
}
1421

1522
#[derive(Debug, Clone, Copy)]

crates/dspy-rs/src/predictors/predict.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use tracing::{debug, trace};
1010
use crate::adapter::Adapter;
1111
use crate::bamltype::baml_types::BamlMap;
1212
use crate::bamltype::compat::{BamlValueConvert, ToBamlValue};
13-
use crate::core::{FieldSpec, MetaSignature, Module, Optimizable, Signature};
13+
use crate::core::{FieldSpec, InputRenderSpec, MetaSignature, Module, Optimizable, Signature};
1414
use crate::{
1515
BamlValue, CallResult, Chat, ChatAdapter, Example, GLOBAL_SETTINGS, LM, LmError, LmUsage,
1616
PredictError, Prediction,
@@ -258,8 +258,14 @@ fn field_specs_to_value(fields: &[FieldSpec], field_type: &'static str) -> Value
258258
meta.insert("desc".to_string(), json!(field.description));
259259
meta.insert("schema".to_string(), json!(""));
260260
meta.insert("__dsrs_field_type".to_string(), json!(field_type));
261-
if let Some(format) = field.format {
262-
meta.insert("format".to_string(), json!(format));
261+
match field.input_render {
262+
InputRenderSpec::Default => {}
263+
InputRenderSpec::Format(format) => {
264+
meta.insert("format".to_string(), json!(format));
265+
}
266+
InputRenderSpec::Jinja(template) => {
267+
meta.insert("render".to_string(), json!({ "jinja": template }));
268+
}
263269
}
264270
result.insert(field.rust_name.to_string(), Value::Object(meta));
265271
}

crates/dspy-rs/tests/test_input_format.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,48 @@ struct DefaultFormatSig {
6262
answer: String,
6363
}
6464

65+
#[derive(Signature, Clone, Debug)]
66+
/// Render a context field using Jinja.
67+
struct RenderJinjaSig {
68+
#[input]
69+
question: String,
70+
71+
#[input]
72+
#[alias("ctx")]
73+
#[render(
74+
jinja = "{{ this.text }} | {{ input.question }} | {{ input.ctx.text }} | {{ input.context.text }} | {{ field.name }} | {{ field.rust_name }}"
75+
)]
76+
context: Document,
77+
78+
#[output]
79+
answer: String,
80+
}
81+
82+
#[derive(Signature, Clone, Debug)]
83+
/// Render with strict undefined vars.
84+
struct RenderJinjaStrictSig {
85+
#[input]
86+
#[render(jinja = "{{ missing_var }}")]
87+
question: String,
88+
89+
#[output]
90+
answer: String,
91+
}
92+
93+
#[derive(Signature, Clone, Debug)]
94+
/// Render using field metadata and vars context.
95+
struct RenderJinjaFieldMetaSig {
96+
#[input]
97+
#[alias("ctx")]
98+
#[render(
99+
jinja = "{{ field.name }}|{{ field.rust_name }}|{{ field.type }}|{{ vars is defined }}"
100+
)]
101+
context: Document,
102+
103+
#[output]
104+
answer: String,
105+
}
106+
65107
fn extract_field(message: &str, field_name: &str) -> String {
66108
let start_marker = format!("[[ ## {field_name} ## ]]");
67109
let start_pos = message
@@ -184,3 +226,55 @@ fn typed_input_default_non_string_is_json() {
184226
.expect("expected array with object");
185227
assert_eq!(first.get("text").and_then(|v| v.as_str()), Some("Hello"));
186228
}
229+
230+
#[test]
231+
fn typed_input_render_jinja_uses_context_values() {
232+
let adapter = ChatAdapter;
233+
let input = RenderJinjaSigInput {
234+
question: "Question".to_string(),
235+
context: Document {
236+
text: "Hello".to_string(),
237+
},
238+
};
239+
240+
let message = adapter.format_user_message_typed::<RenderJinjaSig>(&input);
241+
let context_value = extract_field(&message, "ctx");
242+
243+
assert_eq!(
244+
context_value,
245+
"Hello | Question | Hello | Hello | ctx | context"
246+
);
247+
}
248+
249+
#[test]
250+
fn typed_input_render_jinja_strict_undefined_returns_error_sentinel() {
251+
let adapter = ChatAdapter;
252+
let input = RenderJinjaStrictSigInput {
253+
question: "Question".to_string(),
254+
};
255+
256+
let message = adapter.format_user_message_typed::<RenderJinjaStrictSig>(&input);
257+
let question_value = extract_field(&message, "question");
258+
259+
assert_eq!(question_value, "<error>");
260+
}
261+
262+
#[test]
263+
fn typed_input_render_jinja_exposes_field_metadata_and_vars() {
264+
let adapter = ChatAdapter;
265+
let input = RenderJinjaFieldMetaSigInput {
266+
context: Document {
267+
text: "Hello".to_string(),
268+
},
269+
};
270+
271+
let message = adapter.format_user_message_typed::<RenderJinjaFieldMetaSig>(&input);
272+
let context_value = extract_field(&message, "ctx");
273+
let parts: Vec<&str> = context_value.split('|').collect();
274+
275+
assert_eq!(parts.len(), 4);
276+
assert_eq!(parts[0], "ctx");
277+
assert_eq!(parts[1], "context");
278+
assert!(parts[2].contains("Document"));
279+
assert_eq!(parts[3].to_ascii_lowercase(), "true");
280+
}

crates/dsrs-macros/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ quote = "1"
1919
proc-macro2 = "1"
2020
proc-macro-crate = "3.2"
2121
serde_json = { version = "1.0.143", features = ["preserve_order"] }
22+
minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["serde"] }
2223

2324
[dev-dependencies]
2425
dspy-rs = { path = "../dspy-rs" }

0 commit comments

Comments
 (0)