@@ -4,25 +4,25 @@ import type { z } from "zod";
44import { zodToJsonSchema } from 'zod-to-json-schema' ;
55import { parseUntilJson } from "./parseUntilJson" ;
66
7- export class ChatPromptTemplate {
7+ export class ChatPromptTemplate < T > {
88 protected template = "" ;
99 protected llm : OpenAI | null = null ;
10- protected variables : string [ ] | null = null ;
10+ protected variables : ( keyof T ) [ ] | null = null ;
1111 protected invokeFn : ( ...params : any ) => string = ( ) => { return "" }
1212
1313 constructor ( params : {
1414 template : string ;
15- inputVariables : string [ ] ;
15+ inputVariables : ( keyof T ) [ ] ;
1616 templateFormat ?: "mustache" ;
1717 } ) {
1818 this . template = params . template ;
1919 this . variables = params . inputVariables ;
2020 }
2121
22- public static fromTemplate ( template : string ) : ChatPromptTemplate {
22+ public static fromTemplate < T extends z . infer < z . ZodObject > > ( template : string ) : ChatPromptTemplate < T > {
2323 // Extract mustache-style variables from the template
2424 const variableRegex = / { { \s * ( [ a - z A - Z 0 - 9 _ ] + ) \s * } } / g;
25- const variables : string [ ] = [ ] ;
25+ const variables : ( keyof T ) [ ] = [ ] ;
2626 let match : RegExpExecArray | null = variableRegex . exec ( template ) ;
2727 while ( match !== null ) {
2828 variables . push ( match [ 1 ] ?? "" ) ;
@@ -51,20 +51,20 @@ export class ChatPromptTemplate {
5151 }
5252 }
5353
54- public format ( params : Record < string , any > ) {
54+ public format < T extends z . infer < z . ZodObject > > ( params : T ) {
5555 const paramsInVariables = Object . fromEntries (
5656 Object . keys ( params )
57- . filter ( ( key ) => ( this . variables ?? [ ] ) . includes ( key ) )
57+ . filter ( ( key ) => ( this . variables ?? [ ] ) . includes ( key as never ) )
5858 . map ( ( key ) => [ key , params [ key ] ] ) ,
5959 ) ;
6060 const finalTemplate = Mustache . render ( this . template , paramsInVariables ) ;
6161 return finalTemplate ;
6262 }
6363
64- public async invoke ( model : string , params : Record < string , any > ) {
64+ public async invoke < T extends z . infer < z . ZodObject > > ( model : string , params : T ) {
6565 let finalTemplate = this . format ( params ) ;
6666
67- if ( ! ! this . invokeFn ) {
67+ if ( this . invokeFn ) {
6868 finalTemplate = this . invokeFn ( finalTemplate ) ;
6969 }
7070
@@ -108,4 +108,4 @@ export function getValidatedOutput<T extends z.ZodTypeAny>(schema: T, data: stri
108108 const parsedData = parseUntilJson ( data ) ;
109109
110110 return schema . safeParse ( parsedData ) ;
111- }
111+ }
0 commit comments