Skip to content

Commit b09fee4

Browse files
committed
[FIX] Fixed types
1 parent 4f32040 commit b09fee4

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

src/utils/chat.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,25 @@ import type { z } from "zod";
44
import { zodToJsonSchema } from 'zod-to-json-schema';
55
import { parseUntilJson } from "./parseUntilJson";
66

7-
export class ChatPromptTemplate {
7+
export class ChatPromptTemplate<T> {
88
protected template = "";
99
protected llm: OpenAI | null = null;
10-
protected variables: string[] | null = null;
10+
protected variables: (keyof T)[] | null = null;
1111
protected invokeFn: (...params: any) => string = () => { return "" }
1212

1313
constructor(params: {
1414
template: string;
15-
inputVariables: string[];
15+
inputVariables: (keyof T)[];
1616
templateFormat?: "mustache";
1717
}) {
1818
this.template = params.template;
1919
this.variables = params.inputVariables;
2020
}
2121

22-
public static fromTemplate(template: string): ChatPromptTemplate {
22+
public static fromTemplate<T extends z.infer<z.ZodObject>>(template: string): ChatPromptTemplate<T> {
2323
// Extract mustache-style variables from the template
2424
const variableRegex = /{{\s*([a-zA-Z0-9_]+)\s*}}/g;
25-
const variables: string[] = [];
25+
const variables: (keyof T)[] = [];
2626
let match: RegExpExecArray | null = variableRegex.exec(template);
2727
while (match !== null) {
2828
variables.push(match[1] ?? "");
@@ -51,20 +51,20 @@ export class ChatPromptTemplate {
5151
}
5252
}
5353

54-
public format(params: Record<string, any>) {
54+
public format<T extends z.infer<z.ZodObject>>(params: T) {
5555
const paramsInVariables = Object.fromEntries(
5656
Object.keys(params)
57-
.filter((key) => (this.variables ?? []).includes(key))
57+
.filter((key) => (this.variables ?? []).includes(key as never))
5858
.map((key) => [key, params[key]]),
5959
);
6060
const finalTemplate = Mustache.render(this.template, paramsInVariables);
6161
return finalTemplate;
6262
}
6363

64-
public async invoke(model: string, params: Record<string, any>) {
64+
public async invoke<T extends z.infer<z.ZodObject>>(model: string, params: T) {
6565
let finalTemplate = this.format(params);
6666

67-
if (!!this.invokeFn) {
67+
if (this.invokeFn) {
6868
finalTemplate = this.invokeFn(finalTemplate);
6969
}
7070

@@ -108,4 +108,4 @@ export function getValidatedOutput<T extends z.ZodTypeAny>(schema: T, data: stri
108108
const parsedData = parseUntilJson(data);
109109

110110
return schema.safeParse(parsedData);
111-
}
111+
}

0 commit comments

Comments
 (0)