11use anyhow:: Result ;
22use indexmap:: IndexMap ;
3+ use minijinja:: UndefinedBehavior ;
34use regex:: Regex ;
45use rig:: tool:: ToolDyn ;
56use serde_json:: { Value , json} ;
@@ -17,8 +18,8 @@ use crate::serde_utils::get_iter_from_value;
1718use crate :: utils:: cache:: CacheEntry ;
1819use crate :: {
1920 BamlValue , Cache , Chat , ConstraintLevel , ConstraintResult , Example , FieldMeta , Flag ,
20- JsonishError , LM , Message , MetaSignature , OutputFormatContent , ParseError , Prediction ,
21- RenderOptions , Signature , TypeIR ,
21+ InputRenderSpec , JsonishError , LM , Message , MetaSignature , OutputFormatContent , ParseError ,
22+ Prediction , RenderOptions , Signature , TypeIR ,
2223} ;
2324
2425#[ derive( Default , Clone ) ]
@@ -526,15 +527,19 @@ impl ChatAdapter {
526527 return String :: new ( ) ;
527528 } ;
528529 let input_output_format = <S :: Input as BamlTypeTrait >:: baml_output_format ( ) ;
530+ let input_json = build_input_context_value ( fields, S :: input_fields ( ) , input_output_format) ;
531+ let vars = Value :: Object ( serde_json:: Map :: new ( ) ) ;
529532
530533 let mut result = String :: new ( ) ;
531534 for field_spec in S :: input_fields ( ) {
532535 if let Some ( value) = fields. get ( field_spec. rust_name ) {
533536 result. push_str ( & format ! ( "[[ ## {} ## ]]\n " , field_spec. name) ) ;
534- result. push_str ( & format_baml_value_for_prompt_typed (
537+ result. push_str ( & render_input_field (
538+ field_spec,
535539 value,
540+ & input_json,
536541 input_output_format,
537- field_spec . format ,
542+ & vars ,
538543 ) ) ;
539544 result. push_str ( "\n \n " ) ;
540545 }
@@ -880,23 +885,92 @@ fn format_baml_value_for_prompt(value: &BamlValue) -> String {
880885 }
881886}
882887
883- fn format_baml_value_for_prompt_typed (
888+ fn render_input_field (
889+ field_spec : & crate :: FieldSpec ,
884890 value : & BamlValue ,
891+ input : & Value ,
885892 output_format : & OutputFormatContent ,
886- format : Option < & str > ,
893+ vars : & Value ,
887894) -> String {
888- let format = match format {
889- Some ( format) => format,
890- None => {
891- if let BamlValue :: String ( s) = value {
892- return s. clone ( ) ;
893- }
894- "json"
895+ match field_spec. input_render {
896+ InputRenderSpec :: Default => match value {
897+ BamlValue :: String ( s) => s. clone ( ) ,
898+ _ => crate :: bamltype:: internal_baml_jinja:: format_baml_value (
899+ value,
900+ output_format,
901+ "json" ,
902+ )
903+ . unwrap_or_else ( |_| "<error>" . to_string ( ) ) ,
904+ } ,
905+ InputRenderSpec :: Format ( format) => {
906+ crate :: bamltype:: internal_baml_jinja:: format_baml_value ( value, output_format, format)
907+ . unwrap_or_else ( |_| "<error>" . to_string ( ) )
908+ }
909+ InputRenderSpec :: Jinja ( template) => {
910+ render_input_field_jinja ( template, field_spec, value, input, output_format, vars)
911+ . unwrap_or_else ( |_| "<error>" . to_string ( ) )
895912 }
913+ }
914+ }
915+
916+ fn render_input_field_jinja (
917+ template : & str ,
918+ field_spec : & crate :: FieldSpec ,
919+ value : & BamlValue ,
920+ input : & Value ,
921+ output_format : & OutputFormatContent ,
922+ vars : & Value ,
923+ ) -> Result < String , minijinja:: Error > {
924+ let mut env = minijinja:: Environment :: new ( ) ;
925+ env. set_undefined_behavior ( UndefinedBehavior :: Strict ) ;
926+ env. add_template ( "__input_field__" , template) ?;
927+ let template = env. get_template ( "__input_field__" ) ?;
928+
929+ let this = baml_value_to_render_json ( value, output_format) ;
930+ let field = json ! ( {
931+ "name" : field_spec. name,
932+ "rust_name" : field_spec. rust_name,
933+ "type" : ( field_spec. type_ir) ( ) . diagnostic_repr( ) . to_string( ) ,
934+ } ) ;
935+ let context = json ! ( {
936+ "this" : this,
937+ "input" : input,
938+ "field" : field,
939+ "vars" : vars,
940+ } ) ;
941+
942+ template. render ( minijinja:: Value :: from_serialize ( context) )
943+ }
944+
945+ fn build_input_context_value (
946+ fields : & crate :: bamltype:: baml_types:: BamlMap < String , BamlValue > ,
947+ field_specs : & [ crate :: FieldSpec ] ,
948+ output_format : & OutputFormatContent ,
949+ ) -> Value {
950+ let mut map = serde_json:: Map :: new ( ) ;
951+
952+ for field_spec in field_specs {
953+ let Some ( value) = fields. get ( field_spec. rust_name ) else {
954+ continue ;
955+ } ;
956+ let value_json = baml_value_to_render_json ( value, output_format) ;
957+ map. insert ( field_spec. rust_name . to_string ( ) , value_json. clone ( ) ) ;
958+ if field_spec. name != field_spec. rust_name {
959+ map. entry ( field_spec. name . to_string ( ) ) . or_insert ( value_json) ;
960+ }
961+ }
962+
963+ Value :: Object ( map)
964+ }
965+
966+ fn baml_value_to_render_json ( value : & BamlValue , output_format : & OutputFormatContent ) -> Value {
967+ let Ok ( rendered_json) =
968+ crate :: bamltype:: internal_baml_jinja:: format_baml_value ( value, output_format, "json" )
969+ else {
970+ return serde_json:: to_value ( value) . unwrap_or ( Value :: Null ) ;
896971 } ;
897972
898- crate :: bamltype:: internal_baml_jinja:: format_baml_value ( value, output_format, format)
899- . unwrap_or_else ( |_| "<error>" . to_string ( ) )
973+ serde_json:: from_str ( & rendered_json) . unwrap_or ( Value :: Null )
900974}
901975
902976fn collect_flags_recursive ( value : & BamlValueWithFlags , flags : & mut Vec < Flag > ) {
0 commit comments